This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.
We welcome and encourage contributions! You can help by:
- Improving this example
- Creating new examples
- Reporting issues or bugs
- Suggesting enhancements
Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! ๐ช
Feature Functions in Bayesian Regression
This notebook demonstrates how we can use probabilistic methods to learn and predict continuous functions from noisy data.
This example is inspired by Chapter 4.1 Regression from the excellent book Probabilistic Numerics by Phillip Hennig, Michael A. Osborn, and Hans P. Kersting. We'll take their theoretical foundations and bring them to life with practical code examples.
The code and narrative in this notebook is written by Dmitry Bagaev (GitHub, LinkedIn). While some explanations draw from the book's content, we'll focus on building intuition through interactive examples and visualizations.
By the end of this notebook, you'll understand:
- The power of linear regression with basis functions
- How to handle uncertainty in your predictions
- Practical implementation using Julia and RxInfer.jl
We start by importing all required packages for this example, the primary of which is of course RxInfer!
using RxInfer, StableRNGs, LinearAlgebra, Plots, DataFrames
Gaussian distributions (multivariate) assign probability density to vectors of real numbers - think of them as sophisticated probability maps for multiple variables at once. In numerical applications, we often encounter real-valued functions $f : \mathbb{X} \rightarrow R$ over some input domain $\mathbb{X}$ (imagine predicting house prices based on features like size and location).
A interesting way to use the Gaussian inference framework is to assume that $f$ can be written as a weighted sum over a finite number $F$ of feature functions $[\phi_i : \mathbb{X} \rightarrow \mathbb{R}]_{i=1,..,F}$ (much like how a house price might be a weighted combination of its features, e.g. size, number of floors, number of rooms, etc..):
\[\begin{align} f(x) = \sum_{i=1}^{F} \phi_i(x)\omega_i =: \Phi^T_x \omega \,\,\, \mathrm{where} \,\, \omega \in \mathbb{R}^F \end{align}\]
As discussed in the Probabilistic Numerics book, uncertainty is a fundamental aspect of numerical computations. When we perform regression, we are essentially solving an inverse problem - trying to infer the underlying function from noisy observations. This inherently involves uncertainty for several reasons:
- Our observations usually contain noise and measurement errors
- We have a finite number of samples, leaving gaps in our knowledge
- The true function may be more complex than our model can capture
By modeling uncertainty explicitly through a Gaussian distribution over the weights $\omega$, we can:
- Quantify our confidence in predictions
- Make more robust decisions by accounting for uncertainty
- Detect when we're extrapolating beyond our data
- Propagate uncertainty on the next step in our Machine Learning pipeline
Mathematically, we express this uncertainty as:
\[\begin{align} p(\omega) = \mathcal{N}(\omega \vert \mu, \Sigma) \end{align}\]
Where $\mu$ represents our best estimate of the weights and $\Sigma$ captures our uncertainty about them.
Dataset: Noisy Observations in the Real World
In real-world scenarios, we rarely have access to perfect measurements. Instead, we collect observations $Y := [y_1, \cdots , y_N ] \in \mathbb{R}$ that are corrupted by Gaussian noise - a common and mathematically convenient way to model measurement uncertainty. These noisy samples of our target function $f$ are taken at specific input locations $X$, with the noise characterized by a covariance matrix $\Lambda โ \mathbb{R}^{NรN}$. This setup mirrors many practical applications, from sensor measurements to experimental data collection.
Let's assume we have collected noisy measurenets $Y$ at locations $X$:
N = 40
ฮ = I
X = range(-8, 8, length=N)
rng = StableRNG(42)
# Arbitrary non-linear function, which is hidden
f(x) = -((-x / 3)^3 - (-x / 2)^2 + x + 10)
Y = rand(rng, MvNormalMeanCovariance(f.(X), ฮ))
# Can be loaded from a file or a database
df = DataFrame(X = X, Y = Y)
40ร2 DataFrame
Row โ X Y
โ Float64 Float64
โโโโโโผโโโโโโโโโโโโโโโโโโโโ
1 โ -8.0 -5.63321
2 โ -7.58974 -3.75472
3 โ -7.17949 -2.26681
4 โ -6.76923 -1.95387
5 โ -6.35897 -2.92934
6 โ -5.94872 -2.31714
7 โ -5.53846 -4.10432
8 โ -5.12821 -4.08565
โฎ โ โฎ โฎ
34 โ 5.53846 -1.5234
35 โ 5.94872 0.98319
36 โ 6.35897 3.86051
37 โ 6.76923 4.48039
38 โ 7.17949 8.71697
39 โ 7.58974 12.703
40 โ 8.0 19.0635
25 rows omitted
Train & Test Dataset Configurations
To thoroughly evaluate our model's performance and robustness, we'll create three distinct train-test splits of our data. This approach helps us understand how well our model generalizes to different regions of the input space and whether it can effectively capture the underlying patterns regardless of which portions of the data it learns from.
We'll explore the following configurations:
- Forward Split: Uses the first half for training and second half for testing, evaluating the model's ability to extrapolate to higher x-values
- Reverse Split: Uses the first half for testing and second half for training, testing extrapolation to lower x-values
- Interleaved Split: Uses first and last quarters for training and middle portion for testing, assessing interpolation capabilities
These diverse splits will help reveal any biases in our model and ensure it performs consistently across different regions of the input space. They also allow us to evaluate both interpolation (predicting within the training range) and extrapolation (predicting outside the training range) capabilities.
# Split data into train/test sets
# Forward split - first half train, second half test
dataset_1 = let mid = N รท 2
(
y_train = Y[1:mid], x_train = X[1:mid],
y_test = Y[mid+1:end], x_test = X[mid+1:end]
)
end
# Reverse split - first half test, second half train
dataset_2 = let mid = N รท 2
(
y_test = Y[1:mid], x_test = X[1:mid],
y_train = Y[mid+1:end], x_train = X[mid+1:end]
)
end
# Interleaved split - first/last quarters train, middle half test
dataset_3 = let q1 = N รท 4, q3 = 3N รท 4
(
y_train = [Y[1:q1]..., Y[q3+1:end]...],
x_train = [X[1:q1]..., X[q3+1:end]...],
y_test = Y[q1+1:q3],
x_test = X[q1+1:q3]
)
end
datasets = [dataset_1, dataset_2, dataset_3]
# Create visualization for each dataset split
ps = map(enumerate(datasets)) do (i, dataset)
p = plot(
xlim = (-10, 10),
ylim = (-30, 30),
title = "Dataset $i",
xlabel = "x",
ylabel = "y"
)
scatter!(p,
dataset[:x_train], dataset[:y_train],
yerror = ฮ,
label = "Train dataset",
color = :blue,
markersize = 4
)
scatter!(p,
dataset[:x_test], dataset[:y_test],
yerror = ฮ,
label = "Test dataset",
color = :red,
markersize = 4
)
return p
end
plot(ps..., size = (1200, 400), layout = @layout([a b c]))
The datasets above provide nonlinear data with independent and identically distributed (i.i.d.) Gaussian observation noise, where we set the noise covariance $ฮ = I$ (identity matrix).
Bayesian Inference with RxInfer
Our exciting challenge is to uncover the probability distribution over the parameter vector $\omega$, given our basis functions $\phi$ and observed data points $(X,Y)$. To tackle this, we'll harness the power of probabilistic programming by constructing an elegant generative model using RxInfer's @model
macro. The beauty of this approach lies in its simplicity - we can express our entire model in just a few lines of code:
@model function parametric_regression(ฯs, x, y, ฮผ, ฮฃ, ฮ)
# Prior distribution over parameters ฯ
ฯ ~ MvNormal(mean = ฮผ, covariance = ฮฃ)
# Design matrix ฮฆโ where each element is ฯแตข(xโฑผ)
ฮฆโ = [ฯ(xแตข) for xแตข in x, ฯ in ฯs]
# Likelihood of observations y given parameters ฯ
y ~ MvNormal(mean = ฮฆโ * ฯ, covariance = ฮ)
end
Let's break down the key components of our probabilistic model:
\[\phi\mathrm{s}\]
contains our basis functions $\phi_i$ - these are the building blocks of our model\[x\]
holds the input locations $X$ where we've made observations. Think of these as the points along the x-axis where we've collected data, like timestamps or spatial coordinates.\[y\]
contains our noisy measurements at each location in $X$.\[\mu\]
defines our prior beliefs about the average values of the parameters $\omega$. Setting $\mu = 0$ indicates we believe the parameters are centered around zero before seeing any data.\[\Sigma\]
encodes our uncertainty about $\omega$ before seeing data. A larger $\Sigma$ means we're more uncertain, while smaller values indicate stronger prior beliefs.\[\Lambda\]
represents the noise in our observations. For example, $\Lambda = 0.1I$ suggests our measurements have small, independent Gaussian noise, while larger values indicate noisier data.
To put this model to work, we'll use RxInfer's powerful infer
function. Here's how:
function infer_ฯ(; ฯs, x, y)
# Create probabilistic model,
# RxInfer will construct the graph of this model auutomatically
model = parametric_regression(
ฯs = ฯs,
ฮผ = zeros(length(ฯs)),
ฮฃ = I,
ฮ = I,
x = x
)
# Let RxInfer do all the math for you
result = infer(
model = model,
data = (y = y,)
)
# Return posterior over ฯ
return result.posteriors[:ฯ]
end
infer_ฯ (generic function with 1 method)
How to choose basis functions?
Just like how choosing between pizza toppings can make or break your dinner, the choice of basis functions $\phi_i$ can dramatically impact our results! Think of basis functions as the building blocks of our mathematical LEGO set - pick the wrong pieces and your model might end up looking more like abstract art than a useful predictor.
Why does this matter? Because these functions are the "vocabulary" our model uses to describe the patterns in our data. Choose a too-simple vocabulary and your model will sound like a caveman ("data go up, data go down"). Choose one that's too complex and it might start speaking mathematical gibberish!
Let's embark on a thrilling journey through different datasets with various basis function choices. We'll create a handy function that will:
- Take our basis functions for a test drive ๐
- Run inference on multiple datasets defined above
- Create beautiful plots that would make any statistician swoon
(RxInfer makes it so easy to perform Bayesian inference so I have more time to make beautiful plots!)
function plot_inference_results_for(; ฯs, datasets, title = "", rng = StableRNG(42))
# Create main plot showing basis functions
p1 = plot(
title = "Basis functions: $(title)",
xlabel = "x",
ylabel = "y",
xlim = (-5, 5),
ylim = (-10, 10),
legend = :outertopleft,
grid = true,
fontfamily = "Computer Modern"
)
# Plot basis functions in gray
plot_ฯ!(p1, ฯs, color = :gray, alpha = 0.5,
labels = ["ฯ$i" for _ in 1:1, i in 1:length(ฯs)])
# Add examples with random ฯ values
plot_ฯ!(p1, ฯs, randn(rng, length(ฯs), 3),
linewidth = 2)
# Create subplot for each dataset
ps = map(enumerate(datasets)) do (i, dataset)
p2 = plot(
title = "Dataset #$(i): $(title)",
xlabel = "x",
ylabel = "y",
xlim = (-10, 10),
ylim = (-25, 25),
grid = true,
fontfamily = "Computer Modern"
)
# Infer posterior over ฯ
ฯs = infer_ฯ(
ฯs = ฯs,
x = dataset[:x_train],
y = dataset[:y_train]
)
# Plot posterior mean
plot_ฯ!(p2, ฯs, mean(ฯs),
linewidth = 3,
color = :green,
labels = "Posterior mean")
# Plot posterior samples
plot_ฯ!(p2, ฯs, rand(ฯs, 15),
linewidth = 1,
color = :gray,
alpha = 0.4,
labels = nothing)
# Add data points
scatter!(p2, dataset[:x_train], dataset[:y_train],
yerror = ฮ,
label = "Training data",
color = :royalblue,
markersize = 4)
scatter!(p2, dataset[:x_test], dataset[:y_test],
yerror = ฮ,
label = "Test data",
color = :crimson,
markersize = 4)
return p2
end
# Combine all plots
plot(p1, ps...,
size = (1000, 800),
margin = 5Plots.mm,
layout = (2,2))
end
# Helper function to plot basis functions
function plot_ฯ!(p, ฯs; rl = -10, rr = 10, kwargs...)
xs = range(rl, rr, length = 200)
ys = [ฯ(x) for x in xs, ฯ in ฯs]
plot!(p, xs, ys; kwargs...)
end
# Helper function to plot function with given weights
function plot_ฯ!(p, ฯs, ฯs; rl = -10, rr = 10, kwargs...)
xs = range(rl, rr, length = 200)
ys = [ฯ(x) for x in xs, ฯ in ฯs]
yr = ys * ฯs
labels = ["Sample $i" for _ in 1:1, i in 1:size(ฯs,2)]
plot!(p, xs, yr, labels = labels; kwargs...)
end
plot_ฯ! (generic function with 2 methods)
Phew! We've finally escaped the plotting purgatory - you know it's bad when the visualization code is longer than the actual inference code! But fear not, dear reader, for we're about to dive into the juicy stuff. Grab your statistical popcorn, because the real fun is about to begin!
Polynomials: The Building Blocks of Function Approximation
Let's start our exploration with one of the most fundamental and elegant choices for basis functions: polynomials. These simple yet powerful functions form the backbone of many approximation techniques in mathematics and machine learning.
For our polynomial basis functions $\phi_i$, we'll use the classic form:
\[\begin{align} \phi_i(x) = x^i \end{align}\]
where $i$ represents the degree of each polynomial term. This gives us a sequence of increasingly complex functions: constant ($x^0 = 1$), linear ($x^1$), quadratic ($x^2$), cubic ($x^3$), and so on. When combined with appropriate weights $\omega$, these basis functions can approximate a wide variety of smooth functions - a result famously known as the Weierstrass approximation theorem.
Let's witness the magic of RxInfer as it efficiently infers the posterior distribution over the weights $\omega$ using these polynomial basis functions. The beauty of this approach lies in how it automatically determines the contribution of each polynomial term to fit our data.
plot_inference_results_for(
title = "polynomials",
datasets = datasets,
ฯs = [ (x) -> x ^ i for i in 0:5 ],
)
Let's break down what we're seeing in these fascinating plots! The first plot (in gray) reveals our polynomial basis functions in their raw form - from constant to quintic terms. Overlaid on these are some example functions generated by combining these basis functions with random weights ฯ, giving us a glimpse of the expressive power of polynomial approximation.
The subsequent plots demonstrate how our model performs inference on different datasets. Notice how the posterior distribution (shown by the shaded region) adapts to capture the uncertainty in different regions of the input space. It's particularly interesting to observe how the model's predictions change when faced with different training and test sets - a beautiful illustration of how the learning process is influenced by the data we feed it.
While polynomials have served us well here, they're just one tool in our mathematical toolbox. Ready to explore some alternative basis functions that might capture different aspects of our target function? Let's dive into some exciting alternatives!
Trigonometric Functions: Catch Some Waves
While polynomials are great (and we love them dearly), sometimes life isn't just about going up and down in straight-ish lines. Sometimes, we need to embrace our inner surfer and catch some waves! Enter trigonometric functions - the mathematical world's answer to the question "What if everything just went round and round?"
Trigonometric functions, particularly sin
and cos
, have been the backbone of mathematical analysis since ancient times. From describing planetary motions to analyzing sound waves, these periodic functions have a special place in the mathematician's heart. Their ability to represent cyclic patterns makes them particularly powerful for approximating periodic phenomena - something our polynomial friends from earlier might struggle with (imagine a polynomial trying to do the wave at a sports event - awkward!).
For our basis functions, we'll use scaled versions of sine and cosine:
\[\begin{align} \phi_i(x) &= \mathrm{sin}(\frac{x}{i}) \end{align}\]
\[\begin{align} \phi_i(x) &= \mathrm{cos}(\frac{x}{i}) \end{align}\]
where $i$ acts as a frequency scaling factor. As $i$ increases, our waves become more stretched out, giving us different frequencies to work with. Think of it as having an orchestra where each instrument plays the same tune but at different tempos!
Let's start by riding the sine wave alone (no cosine jealousy please!) for $i = 1:5$. Will these wavy functions give our polynomial predecessors a run for their money? Let's find out!
plot_inference_results_for(
title = "trigonometric sin",
datasets = datasets,
ฯs = [ (x) -> sin(x / i) for i in 1:8 ],
)
Now let's examine the results using cosine basis functions
plot_inference_results_for(
title = "trigonometric cos",
datasets = datasets,
ฯs = [ (x) -> cos(x / i) for i in 1:8 ],
)
And for our grand finale, let's combine both sin and cos - because two waves are better than one! (Just don't tell that to particle-wave duality...)
plot_inference_results_for(
title = "trigonometric sin & cos",
datasets = datasets,
ฯs = [
[ (x) -> sin(x / i) for i in 1:4 ]...,
[ (x) -> cos(x / i) for i in 1:4 ]...,
],
)
Incredible! RxInfer proved to be quite the adaptable fellow - it handled these different basis functions without missing a beat. The results speak for themselves: our sine and cosine tag team performed remarkably well for this example. I guess you could say they really found their wavelength!
Comparing Model Performance via Log-Evidence
Now that we've explored different basis functions, let's quantitatively evaluate their performance using Free Energy, also known as negative log-evidence or negative Evidence Lower BOund (ELBO). RxInfer can compute Free Energy values when requested, which serve as a principled way to compare different models.
Free Energy has several important properties:
- It acts as a proxy for negative log model evidence P(y|model)
- Lower values indicate better model fit, balancing complexity and data fit
- It automatically implements Occam's Razor by penalizing overly complex models
For example, if we have:
- Model A: Free Energy = 100
- Model B: Free Energy = 50
Then Model B provides a better explanation of the data, as exp(-50) > exp(-100).
Let's analyze the Free Energy values for our polynomial and trigonometric basis functions to determine which model class provides the best explanation of our data. We'll check:
- Pure sine basis functions
- Pure cosine basis functions
- Combined sine and cosine basis functions
This will help us quantitatively validate our earlier visual assessments.
# Combine the function definition with the usage
function infer_ฯ_but_return_free_energy(; ฯs, x, y)
result = infer(
model = parametric_regression(
ฯs = ฯs,
ฮผ = zeros(length(ฯs)),
ฮฃ = I,
ฮ = I,
x = x
),
data = (y = y,),
free_energy = true
)
return first(result.free_energy)
end
dfs = map(enumerate(datasets)) do (i, dataset)
# Generate basis functions
sin_bases = [(x) -> sin(x / i) for i in 1:8]
cos_bases = [(x) -> cos(x / i) for i in 1:8]
combined_bases = [
[(x) -> sin(x / i) for i in 1:4]...,
[(x) -> cos(x / i) for i in 1:4]...
]
# Calculate free energy for each basis
energies = [
infer_ฯ_but_return_free_energy(ฯs=sin_bases, x=dataset[:x_train], y=dataset[:y_train]),
infer_ฯ_but_return_free_energy(ฯs=cos_bases, x=dataset[:x_train], y=dataset[:y_train]),
infer_ฯ_but_return_free_energy(ฯs=combined_bases, x=dataset[:x_train], y=dataset[:y_train])
]
# Create DataFrame row
DataFrame(
dataset = fill(i, 3),
fns = [:sin, :cos, :sin_cos],
free_energy = energies
)
end
vcat(dfs...)
9ร3 DataFrame
Row โ dataset fns free_energy
โ Int64 Symbol Float64
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 1 sin 74.4889
2 โ 1 cos 47.0473
3 โ 1 sin_cos 95.6676
4 โ 2 sin 366.386
5 โ 2 cos 349.413
6 โ 2 sin_cos 5.96904e10
7 โ 3 sin 104.042
8 โ 3 cos 163.047
9 โ 3 sin_cos 84.0258
The results demonstrate that the choice of basis functions plays a significant role, as evidenced by the varying values of the Free Energy function. For dataset 1, cosine-based basis functions perform better than both sine-based and combined sine-cosine basis functions. Meanwhile, for dataset 3, the combination of sine and cosine basis functions yields superior results.
However, why limit ourselves to polynomials and trigonometric functions? Let's explore other possibilities!
Switch Functions: A Binary Approach to Basis Functions
Let's explore an intriguing and perhaps unconventional choice for basis functions: the switch functions. These functions, despite their simplicity, can be remarkably effective in certain scenarios.
A switch function essentially divides the input space into two regions, outputting either +1 or -1 based on which side of a threshold the input falls. Mathematically, we define it as:
\[\begin{align} \phi_i(x) = \mathrm{sign}(x - i) \end{align}\]
where $i$ serves as the threshold point. The function returns +1 when $x > i$ and -1 when $x < i$. This creates a sharp "switch" at $x = i$, hence the name.
What makes these functions particularly interesting is their ability to capture discontinuities and sharp transitions in the data. By combining multiple switch functions with different threshold points, we can approximate complex patterns through a series of binary decisions.
Let's see how these switch functions perform on our datasets!
plot_inference_results_for(
title = "switches",
datasets = datasets,
ฯs = [ (x) -> sign(x - i) for i in -8:8 ],
)
Step Functions: A Binary Leap Forward
Let's explore another fascinating class of basis functions: step functions. Also known as Heaviside functions, these elegant mathematical constructs make a dramatic jump from 0 to 1 at a specific threshold point.
Mathematically, we define our step basis functions as:
\[\begin{align} \phi_i(x) = \mathbb{I}(x - i > 0) \end{align}\]
where $\mathbb{I}$ is the indicator function that equals 1 when its argument is true and 0 otherwise. Unlike the switch functions we saw earlier, step functions provide a unidirectional transition, making them particularly useful for modeling data with distinct regimes or threshold effects.
plot_inference_results_for(
title = "steps",
datasets = datasets,
ฯs = [ (x) -> ifelse(x - i > 0, 1.0, 0.0) for i in -8:8 ],
)
Linear Basis Functions: A Classic Twist
Here's an intriguing proposition: what if we used linear functions as our basis functions in linear regression? While it might sound redundant at first, this approach offers a fascinating perspective. By centering linear functions at different points, we create a rich set of features that can capture both local and global trends in our data.
The basis functions take the form:
\[\begin{align} \phi_i(x) = \vert x - i \vert \end{align}\]
where each function measures the absolute distance from a reference point $i$. This creates a V-shaped function centered at each point, allowing us to model both increasing and decreasing trends with remarkable flexibility.
plot_inference_results_for(
title = "linears",
datasets = datasets,
ฯs = [ (x) -> abs(x - i) for i in -8:8 ],
)
Absolute Exponential Functions: Elegantly Decaying Distance
Let's venture into the realm of absolute exponential functions, a fascinating class of basis functions that elegantly capture the notion of distance-based influence. These functions, also known as Laplace kernels in some contexts, decay exponentially with the absolute distance from their center points.
The mathematical formulation reveals their elegant simplicity:
\[\begin{align} \phi_i(x) = e^{-\vert x - i \vert} \end{align}\]
This expression creates a peaked function that reaches its maximum of 1 at x = i and smoothly decays in both directions, providing a natural way to model localized influences that diminish with distance.
plot_inference_results_for(
title = "abs exps",
datasets = datasets,
ฯs = [ (x) -> exp(-abs(x - i)) for i in -8:8 ],
)
Squared Exponential Functions: The Gaussian Bell Curves
Let's explore another one of the most elegant and widely-used basis functions in machine learning - the squared exponential, also known as the Gaussian or radial basis function. These functions create perfect bell curves that smoothly decay in all directions from their centers.
The mathematical form reveals their graceful symmetry:
\[\begin{align} \phi_i(x) = e^{-(x - i)^2} \end{align}\]
These functions have remarkable properties - they're infinitely differentiable and create ultra-smooth interpolations between points. Their rapid decay also provides natural localization, making them excellent choices for capturing both local and global patterns in data.
plot_inference_results_for(
title = "sqrt exps",
datasets = datasets,
ฯs = [ (x) -> exp(-(x - i) ^ 2) for i in -8:8 ],
)
Sigmoid Functions: The Neural Network's Activation
The sigmoid function, a cornerstone of neural network architectures, offers another fascinating basis for our exploration. This S-shaped curve elegantly transitions between two asymptotic values, creating a smooth, differentiable "step" that's invaluable in modeling transitions and decision boundaries.
The mathematical elegance of the sigmoid reveals itself in its formula:
\[\begin{align} \phi_i(x) = \frac{1}{1 + e^{-3(x - 1)}} \end{align}\]
This function's graceful transition from 0 to 1 makes it particularly well-suited for capturing threshold phenomena and modeling probability-like quantities. Its bounded nature also provides natural regularization, preventing the explosive growth that can plague polynomial bases.
plot_inference_results_for(
title = "sigmoids",
datasets = datasets,
ฯs = [ (x) -> 1 / (1 + exp(-3 * (x - i))) for i in -8:8 ],
)
The Power of Combination: Using Different Classes of Basis Functions Together
What if we could harness the unique strengths of different basis functions we've explored? By combining polynomials, trigonometric functions, squared exponentials, and sigmoids, we can create an incredibly flexible and expressive basis that captures both global trends and local patterns. The polynomials can handle overall growth patterns, trigonometric functions can capture periodic behavior, squared exponentials can provide smooth local interpolation, and sigmoids can model sharp transitions. This combined approach leverages the best of each basis function family, potentially leading to more robust and accurate predictions. And here how easy it is to do so!
# Combine all basis functions we've explored into one powerful basis
combined_basis = vcat(
# Polynomials (from first example)
[ (x) -> x ^ i for i in 0:5 ],
# Trigonometric functions (from second example)
[ (x) -> sin(i*x) for i in 1:3 ],
[ (x) -> cos(i*x) for i in 1:3 ],
# Squared exponentials (from seventh example)
[ (x) -> exp(-(x - i)^2) for i in -8:8 ],
# Sigmoids (from eighth example)
[ (x) -> 1 / (1 + exp(-3 * (x - i))) for i in -8:8 ]
)
plot_inference_results_for(
title = "combined",
datasets = datasets,
ฯs = combined_basis,
)
Now that we've combined these different basis functions, it's interesting to explore how this powerful ensemble performs on our complete dataset. By visualizing the posterior distribution over functions induced by this combined basis, we can see how it leverages the unique characteristics of each basis type - the global trends captured by polynomials, the periodic patterns from trigonometric functions, the local smoothness from squared exponentials, and the sharp transitions enabled by sigmoids. Let's plot the results to see this rich expressiveness in action.
combined_basis_ฯs_all_data = infer_ฯ(ฯs = combined_basis, x = X, y = Y)
# Left plot - local region
p1 = plot(
title = "Local region",
xlabel = "x",
ylabel = "y",
xlim = (-10, 10),
ylim = (-20, 20),
grid = true
)
# Plot posterior mean
plot_ฯ!(p1, combined_basis, mean(combined_basis_ฯs_all_data),
rl = -10,
rr = 10,
linewidth = 3,
color = :green,
labels = "Posterior mean"
)
# Plot posterior samples (in gray)
plot_ฯ!(p1, combined_basis, rand(combined_basis_ฯs_all_data, 50),
rl = -10,
rr = 10,
linewidth = 1,
color = :gray,
alpha = 0.4,
labels = nothing
)
# Plot data points
scatter!(p1, X, Y,
yerror = ฮ,
label = "Data",
color = :royalblue,
markersize = 4
)
# Right plot - bigger region
p2 = plot(
title = "Extended region",
xlabel = "x",
ylabel = "y",
xlim = (-30, 30),
ylim = (-75, 75),
grid = true
)
# Plot posterior mean
plot_ฯ!(p2, combined_basis, mean(combined_basis_ฯs_all_data),
rl = -30,
rr = 30,
linewidth = 3,
color = :green,
labels = "Posterior mean"
)
# Plot posterior samples (in gray)
plot_ฯ!(p2, combined_basis, rand(combined_basis_ฯs_all_data, 50),
rl = -30,
rr = 30,
linewidth = 1,
color = :gray,
alpha = 0.4,
labels = nothing
)
# Plot data points
scatter!(p2, X, Y,
label = "Data",
color = :royalblue,
markersize = 2
)
p = plot(p1, p2, layout=(1,2), size=(1000,400), fontfamily = "Computer Modern")
The plot above beautifully demonstrates the expressive power of combining multiple basis functions. The posterior mean (shown in green) captures both the global trend and local variations in the data with remarkable accuracy. The gray lines, representing samples from the posterior distribution, illustrate the model's uncertainty - tighter in regions with more data points and wider in sparse regions. This combined basis approach leverages the strengths of each basis type: polynomials handle the overall trend, trigonometric functions capture periodic components, and localized basis functions manage fine details. The result is a flexible and robust model that adapts well to the complex patterns in our dataset.
Performance: The Need for Speed! ๐๏ธ
Alright, we've been having a blast playing with different basis functions, and RxInfer has been crunching those posterior calculations faster than you can say "Bayesian inference". But just how zippy is it really? Let's put our mathematical hot rod through its paces with our trusty polynomial basis functions and see what kind of speed records we can break! ๐
using BenchmarkTools
In Julia, benchmarking is made easy with the BenchmarkTools package. The @benchmark macro runs the given expression multiple times to get statistically meaningful results. It provides detailed statistics about execution time, memory allocations, and garbage collection overhead. The output shows minimum, maximum, median and mean execution times, along with a nice histogram visualization of the timing distribution.
@benchmark infer_ฯ(ฯs = $([ (x) -> x ^ i for i in 0:5 ]), x = $(datasets[1][:x_train]), y = $(datasets[1][:y_train]))
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
Range (min โฆ max): 167.224 ฮผs โฆ 78.003 ms โ GC (min โฆ max): 0.00% โฆ 99.
35%
Time (median): 193.007 ฮผs โ GC (median): 0.00%
Time (mean ยฑ ฯ): 204.711 ฮผs ยฑ 778.548 ฮผs โ GC (mean ยฑ ฯ): 3.79% ยฑ 0.
99%
โโ
โ โโโโโโโโโ โโโ
โโโโโโโ
โโโโโโโโโโโโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
167 ฮผs Histogram: frequency by time 256 ฮผs <
Memory estimate: 86.23 KiB, allocs estimate: 1499.
Let's benchmark inference on a larger dataset with 10,000 datapoints to test scalability.
N_benchmark = 10_000
X_benchmark = range(-8, 8, length=N_benchmark)
Y_benchmark = rand(rng, MvNormalMeanCovariance(f.(X_benchmark), ฮ));
@benchmark infer_ฯ(ฯs = $([ (x) -> x ^ i for i in 0:5 ]), x = $(X_benchmark), y = $(Y_benchmark))
BenchmarkTools.Trial: 9 samples with 1 evaluation per sample.
Range (min โฆ max): 573.370 ms โฆ 596.298 ms โ GC (min โฆ max): 0.00% โฆ 0.0
0%
Time (median): 580.360 ms โ GC (median): 0.00%
Time (mean ยฑ ฯ): 583.695 ms ยฑ 8.592 ms โ GC (mean ยฑ ฯ): 0.00% ยฑ 0.0
0%
โ โ โ โ โ โ โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
573 ms Histogram: frequency by time 596 ms <
Memory estimate: 1.15 MiB, allocs estimate: 1503.
And that's a wrap! From exploring different basis functions (polynomials, trigonometric functions, and even those fancy sigmoids) to performing lightning-fast Bayesian inference, we've seen how RxInfer handles parametric Gaussian regression with style. The benchmarks don't lie - processing 10,000 datapoints in a blast while keeping memory usage lean? That's not just fast, that's "blink and you'll miss it" fast!
Throughout this notebook, we've gone from basic data generation to sophisticated model inference, all while keeping things both mathematically rigorous and computationally efficient. Whether you're a Bayesian enthusiast or just someone who appreciates elegant mathematical machinery, this journey through parametric Gaussian regression shows that probabilistic programming doesn't have to be slow or memory-hungry.
Thanks again to the authors of "Probabilistic Numerics: Computation as Machine Learning" for providing the theoretical foundations and inspiration for this notebook!
This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.
We welcome and encourage contributions! You can help by:
- Improving this example
- Creating new examples
- Reporting issues or bugs
- Suggesting enhancements
Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! ๐ช
This example was executed in a clean, isolated environment. Below are the exact package versions used:
For reproducibility:
- Use the same package versions when running locally
- Report any issues with package compatibility
Status `~/work/RxInferExamples.jl/RxInferExamples.jl/docs/src/categories/basic_examples/feature_functions_in_bayesian_regression/Project.toml`
[6e4b80f9] BenchmarkTools v1.6.0
[a93c6f00] DataFrames v1.7.0
[91a5bcdd] Plots v1.40.9
[86711068] RxInfer v4.0.1
[860ef19b] StableRNGs v1.0.2
[37e2e46d] LinearAlgebra v1.11.0
[9a3f8284] Random v1.11.0