This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.
We welcome and encourage contributions! You can help by:
- Improving this example
- Creating new examples
- Reporting issues or bugs
- Suggesting enhancements
Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪
ODE Parameter Estimation
In this notebook we will explore how we can solve and learn the parameters of an ODE simultaneously using RxInfer. To illustrate how we can can utilize RxInfer, we will take Lotka-Volterra differential equation as an example. We will explore three different alternatives to parameter estimation. The first alternative will demonstrate how we can use free energy to obtain point estimates. The second alternative will demonstrate how we can use a prior distribution on the parameters to obtain a posterior estimate for the unknown parameters of the ODE. The second alternative will do parameter learning in two stages. The first stage will obtain the initialization for the prior hyper-parameters and then use these initial values of the prior to obtain the posterior by message passing. The third alternative will use purely message passing.
using RxInfer, Optim, LinearAlgebra, Plots, SeeToDee, StaticArrays, StableRNGs
Introduction to Lotka-Volterra Equations
The Lotka-Volterra equations, are a pair of first-order nonlinear differential equations frequently used to describe the dynamics of biological systems in which two species interact: one as a predator and the other as prey. The equations are defined as follows:
Prey Population Dynamics: $\frac{dx}{dt} = \alpha x - \beta xy$
Predator Population Dynamics: $\frac{dy}{dt} = -\gamma y + \delta xy$
In this ODE, $x$ is the population of the prey (e.g., rabbits), $y$ is the population of the predator (e.g., foxes), $\alpha$ represents the maximum growth rate of the prey, $\beta$ is the rate of predation, $\gamma$ is the predator's per capita death rate and $\delta$ is the growth rate of the predator population based on the availability of prey.
function lotka_volterra(u, z, p, t)
α, β, δ, γ = p[SA[1,2,3,4]]
x, y = u[SA[1, 2]]
du1 = α * x - β * x * y
du2 = -δ * y + γ * x * y
return [du1, du2]
end;
The Runge-Kutta 4th Order (RK4) Method
The Runge-Kutta 4th order method is one of the most widely used numerical techniques for solving ordinary differential equations (ODEs). For a system of the form:
\[\frac{dx}{dt} = f(x, t)\]
where $x$ can be a scalar or vector-valued function, RK4 provides a numerical approximation with local truncation error of order $O(h^5)$ and global error of order $O(h^4)$.
Algorithm
Given the current state $x_n$ at time $t_n$, RK4 computes the state at $t_{n+1} = t_n + dt$ using four intermediate evaluations:
\[\begin{aligned} k_1 &= f(x_n, t_n) \\ k_2 &= f(x_n + \frac{dt}{2}k_1, t_n + \frac{dt}{2}) \\ k_3 &= f(x_n + \frac{dt}{2}k_2, t_n + \frac{dt}{2}) \\ k_4 &= f(x_n + dt\,k_3, t_n + dt) \end{aligned}\]
The solution is then advanced using a weighted average of these evaluations:
\[x_{n+1} = x_n + \frac{dt}{6}(k_1 + 2k_2 + 2k_3 + k_4)\]
For this implementation, we will use the SeeToDee
package to define the RK4 method. This is necessary to create a non-linear deterministic node for the RxInfer model. SeeToDee
package requires the dynamics function to be defined as f(x, u, θ, t)
, where u
is the control input. Since we don't have any control input, we will set u = 0
.
NOTE: There are many improved versions of solvers that can be more appropriate. For this simple problem though RK4 will be enough to convey the message. However, in practice for real world problems adaptive or implicit variants of solvers are preferred.
dt = 0.1 # sample_interval
function lotka_volterra_rk4(x, θ, t, dt)
lotka_volterra_dynamics = SeeToDee.Rk4(lotka_volterra, dt)
return lotka_volterra_dynamics(x, 0, θ, t)
end
lotka_volterra_rk4 (generic function with 1 method)
Data Generation
Lotka Volterra data is generated using the RK4 method. The data is then corrupted with noise to simulate real-world observations.
DISCLAIMER: Since Lotka-Volterra equations model prey and predator dynamics, adding a Gaussian noise is not realistic. Although adding other noise forms are possible it will complicate the inference process. Therefore, we will use Gaussian noise for instructive purposes.
function generate_data(θ; x = ones(2), t =0.0, dt = 0.001, n = 1000, v = 1, seed = 123)
rng = StableRNG(seed)
data = Vector{Vector{Float64}}(undef, n)
ts = Vector{Float64}(undef, n)
for i in 1:n
data[i] = lotka_volterra_rk4(x, θ, t, dt)
x = data[i]
t += dt
ts[i] = t
end
noisy_data = map(data) do d
noise = sqrt(v) * [randn(rng), randn(rng)]
d + noise
end
return data, noisy_data, ts
end
noisev = 0.35
n = 10000
true_params = [1.0, 1.5, 3.0, 1.0]
data_long, noisy_data_long, ts_long = generate_data(true_params,dt = dt, n = n, v = noisev);
## We create a smaller dataset for the global parameter optimization. Utilizing the entire dataset for the global optimization will take too much time.
n_train = 100
data = data_long[1:n_train]
noisy_data = noisy_data_long[1:n_train]
ts = ts_long[1:n_train];
Data Visualization
p = plot(layout=(2,1))
plot!(subplot=1, ts, [d[1] for d in data], label="True x₁", color=:blue)
plot!(subplot=1, ts, [d[1] for d in noisy_data], seriestype=:scatter, label="Noisy x₁", color=:blue, alpha=0.3, markersize=1.3)
plot!(subplot=2, ts, [d[2] for d in data], label="True x₂", color=:red)
plot!(subplot=2, ts, [d[2] for d in noisy_data], seriestype=:scatter, label="Noisy x₂", color=:red, alpha=0.3, markersize=1.3)
xlabel!("Time")
ylabel!(subplot=1, "Prey Population")
ylabel!(subplot=2, "Predator Population")
First Alternative: Global Parameter Optimization
In the first alternative we will construct one time-segment of Lotka-Volterra equation. We will use lotka_volterra_rk4
function to create non-linear node. This function was defined earlier to numerically solve the Lotka-Volterra equations using the 4th order Runge-Kutta method.
@model function lotka_volterra_model_without_prior(obs, mprev, Vprev, dt, t, θ)
xprev ~ MvNormalMeanCovariance(mprev, Vprev)
x := lotka_volterra_rk4(xprev, θ, t, dt)
obs ~ MvNormalMeanCovariance(x, noisev * diageye(length(mprev)))
end
Non-linear deterministic nodes require meta specification that will determine the type of message approximations to be used. In this case, we can use the Linearization
method that will trigger an Extended Kalman Filter (EKF) type of approximation or the Unscented
method that will trigger an Unscented Kalman Filter (UKF) type of approximation. Moreover, because we are using RxInfer in an online setting we need to specify how the mean and covariance of the Gaussian distribution will be updated. We do this by using the @autoupdates
macro and initialize using the @initialization
macro.
delta_meta = @meta begin
lotka_volterra_rk4() -> Linearization()
end
autoupdates_without_prior = @autoupdates begin
mprev, Vprev= mean_cov(q(x))
end
@initialization function initialize_without_prior(mx, Vx)
q(x) = MvNormalMeanCovariance(mx, Vx)
end;
Free Energy Computation
We will now define the free energy function that will be minimized to infer the parameters of the model. Since the parameters of the model are not constrained to be positive, we will use the exp
function to transform the parameters to the positive domain. We will set the free energy to true to keep track of the free energy values.
function compute_free_energy_without_prior(θ ; mx = ones(2), Vx = 1e-6 * diageye(2))
θ = exp.(θ)
result = infer(
model = lotka_volterra_model_without_prior(dt = dt, θ = θ),
data = (obs = noisy_data, t= ts),
initialization = initialize_without_prior(mx, Vx),
meta = delta_meta,
autoupdates = autoupdates_without_prior,
keephistory = length(noisy_data),
free_energy = true
)
return sum(result.free_energy_final_only_history)
end;
Now we are ready to perform the parameter inference by minimizing the free energy function. We will use the optimize
function from the Optim
package to perform the optimization. We will use the NelderMead
method as the optimizer as it doesn't require gradient information and is faster.
res_without_prior = optimize(compute_free_energy_without_prior, zeros(4), NelderMead(), Optim.Options(show_trace = true, show_every = 300));
Iter Function value √(Σ(yᵢ-ȳ)²)/n
------ -------------- --------------
0 1.274436e+03 9.650180e+00
* time: 5.602836608886719e-5
θ_minimizer_without_prior = exp.(res_without_prior.minimizer)
println("\nEstimated point mass valued parameters:")
for (i, (name, val)) in enumerate(zip(["α", "β", "γ", "δ"], θ_minimizer_without_prior))
println(" * $name: $(round(val, digits=3))")
end
println("\nActual parameters used to generate data:")
for (i, (name, val)) in enumerate(zip(["α", "β", "γ", "δ"], true_params))
println(" * $name: $(round(val, digits=3))")
end
Estimated point mass valued parameters:
* α: 0.994
* β: 1.491
* γ: 3.054
* δ: 0.997
Actual parameters used to generate data:
* α: 1.0
* β: 1.5
* γ: 3.0
* δ: 1.0
Second Alternative: RxInfer Model with Prior on the Parameters
We will now define the corresponding RxInfer model with the prior distribution on the parameters. For this, we will use the @model
macro to create a time segment for the ODE using the deterministic ODE solver lotka_volterra_rk4
as a non-linear node in the RxInfer model. For the prior distribution of the parameters, we will use a multivariate Gaussian distribution with mean mθ
and covariance Vθ
that will be initialized using the initialize
macro.
@model function lotka_volterra_model(obs, mprev, Vprev, dt, t, mθ, Vθ)
θ ~ MvNormalMeanCovariance(mθ, Vθ)
xprev ~ MvNormalMeanCovariance(mprev, Vprev)
x := lotka_volterra_rk4(xprev, θ, t, dt)
obs ~ MvNormalMeanCovariance(x, noisev * diageye(length(mprev)))
end
autoupdates = @autoupdates begin
mprev, Vprev= mean_cov(q(x))
mθ, Vθ = mean_cov(q(θ))
end
@initialization function initialize(mx, Vx, mθ, Vθ)
q(x) = MvNormalMeanCovariance(mx, Vx)
q(θ) = MvNormalMeanCovariance(mθ, Vθ)
end;
Prior Initialization by means of Free Energy Minimization
We will now define the free energy function that will be minimized to infer the initial hyper-parameters of the prior distribution. Since we have 4 parameters, we will initialize the mean of the prior distribution with 4 elements and the diagonal elements of the covariance matrix. Again, we will use the exp
function to transform the parameters to the positive domain.
function compute_free_energy(θ ; mx = ones(2), Vx = 1e-6 * diageye(2))
θ = exp.(θ)
mθ = θ[1:4]
Vθ = Diagonal(θ[5:end])
result = infer(
model = lotka_volterra_model(dt = dt,),
data = (obs = noisy_data, t = ts),
initialization = initialize(mx, Vx, mθ, Vθ),
meta = delta_meta,
autoupdates = autoupdates,
keephistory = length(noisy_data),
free_energy = true
)
return sum(result.free_energy_final_only_history)
end;
Parameter Inference
We will now perform the parameter inference by minimizing the free energy function. We will use the optimize
function from the Optim
package to perform the optimization. We will use the NelderMead
method as the optimizer as it doesn't require gradient information and is faster.
res = optimize(compute_free_energy, [zeros(4); 0.1ones(4)], NelderMead(), Optim.Options(show_trace = true, show_every = 300));
Iter Function value √(Σ(yᵢ-ȳ)²)/n
------ -------------- --------------
0 2.814831e+02 2.192538e-01
* time: 6.413459777832031e-5
300 2.714613e+02 1.554900e-03
* time: 4.297810077667236
600 2.712411e+02 6.805549e-07
* time: 8.695860147476196
θ_minimizer = exp.(res.minimizer)
mθ_init = θ_minimizer[1:4]
Vθ_init = Diagonal(θ_minimizer[5:end])
println("\nEstimated initialization parameters for the prior distribution:")
for (i, (name, val, var)) in enumerate(zip(["α", "β", "γ", "δ"], mθ_init, θ_minimizer[5:8]))
println(" * $name: $(round(val, digits=3)) ± $(round(sqrt(var), digits=3))")
end
println("\nActual parameters used to generate data:")
for (i, (name, val)) in enumerate(zip(["α", "β", "γ", "δ"], true_params))
println(" * $name: $(round(val, digits=3))")
end
Estimated initialization parameters for the prior distribution:
* α: 2.279 ± 1.275
* β: 1.91 ± 2.119
* γ: 1.636 ± 2.228
* δ: 0.962 ± 0.693
Actual parameters used to generate data:
* α: 1.0
* β: 1.5
* γ: 3.0
* δ: 1.0
Having estimated the initial hyper-parameters of the prior distribution, we can now perform the parameter inference by online message passing. We will use the infer
function to perform the inference.
result = infer(
model = lotka_volterra_model(dt = dt,),
data = (obs = noisy_data_long, t= ts_long),
initialization = initialize(ones(2), 1e-6 * diageye(2), mθ_init, Vθ_init),
meta = delta_meta,
autoupdates = autoupdates,
keephistory = length(noisy_data_long),
free_energy = true
);
mθ_posterior = mean.(result.history[:θ])
Vθ_posterior = var.(result.history[:θ])
p = plot(layout=(4,1), size=(800,1000), legend=:right)
param_names = ["α", "β", "γ", "δ"]
for i in 1:4
means = [m[i] for m in mθ_posterior]
stds = [2sqrt(v[i]) for v in Vθ_posterior]
plot!(p[i], means, ribbon=stds, label="Posterior", subplot=i)
hline!(p[i], [true_params[i]], label="True value", linestyle=:dash, color=:red, subplot=i)
title!(p[i], param_names[i], subplot=i)
if i == 4
xlabel!(p[i], "Time step", subplot=i)
end
end
# Place legend at top right for all subplots
plot!(p, legend=:topright)
display(p)
final_means = last(mθ_posterior)
final_vars = last(Vθ_posterior)
final_stds = sqrt.(final_vars)
# Print results
println("\nFinal Parameter Estimates:")
for (param, mean, std) in zip(param_names, final_means, final_stds)
println("$param: $mean ± $(std)")
end
# Get final covariance matrix
final_cov = cov(last(result.history[:θ]))
println("\nFinal Parameter Covariance Matrix:")
display(final_cov)
Final Parameter Estimates:
α: 0.9873125810391418 ± 0.02268319328262911
β: 1.4928568988039093 ± 0.026900704209026304
γ: 3.032995698910327 ± 0.12553964480472204
δ: 1.01924256539008 ± 0.033832412596848514
Final Parameter Covariance Matrix:
4×4 Matrix{Float64}:
0.000514527 0.00038494 8.38569e-5 3.78402e-5
0.00038494 0.000723648 -8.1867e-5 -4.44924e-5
8.38569e-5 -8.1867e-5 0.0157602 0.00373872
3.78402e-5 -4.44924e-5 0.00373872 0.00114463
from = 1
skip = 1
to = 500
# Get state estimates and variances
mx = mean.(result.history[:x])
Vx = var.(result.history[:x])
# Plot state estimates with uncertainty bands
p1 = plot(ts_long[from:skip:to] , getindex.(mx, 1)[from:skip:to], ribbon=2*sqrt.(getindex.(Vx, 1)[from:skip:to]),
label="Prey estimate", legend=:topright)
scatter!(p1, ts_long[from:skip:to], getindex.(noisy_data_long, 1)[from:skip:to], label="Noisy prey observations", alpha=0.5,ms=1)
plot!(p1, ts_long[from:skip:to], getindex.(data_long, 1)[from:skip:to], label="True prey", linestyle=:dash)
title!(p1, "Prey Population")
p2 = plot(ts_long[from:skip:to], getindex.(mx, 2)[from:skip:to], ribbon=2*sqrt.(getindex.(Vx, 2)[from:skip:to]),
label="Predator estimate", legend=:topright)
scatter!(p2, ts_long[from:skip:to], getindex.(noisy_data_long, 2)[from:skip:to], label="Noisy predator observations", alpha=0.5, ms=1)
plot!(p2, ts_long[from:skip:to], getindex.(data_long, 2)[from:skip:to] , label="True predator", linestyle=:dash)
title!(p2, "Predator Population")
plot(p1, p2, layout=(2,1), size=(1000,600))
Third Alternative: RxInfer Model with Exponential Transformation on the Parameters
So far we have used the exp
function to transform the parameters to the positive domain and computed free energy. This transformation was done outside of @model
macro. In this approach, we will use the exp
function to transform the parameters to the positive domain but within the @model
macro. We will then use the Unscented
method to approximate the non-linear deterministic node. This approach is more computationally efficient than the previous one, however it may suffer from accuracy issues as we may not have a good hyper-parameter initialization.
NOTE: We can not use exp.()
inside the @model
macro as the model macro doesn't support broadcasting yet. So we need to define a function that will apply the exp
function to the parameters.
expf(θ) = exp.(θ) ## This function is used to apply the exp function to the parameters within the @model macro
@model function lotka_volterra_model2(obs, mprev, Vprev, dt, t, mθ, Vθ)
θ ~ MvNormalMeanCovariance(mθ, Vθ)
xprev ~ MvNormalMeanCovariance(mprev, Vprev)
θ_exp := expf(θ)
x := lotka_volterra_rk4(xprev, θ_exp, t, dt)
obs ~ MvNormalMeanCovariance(x, noisev * diageye(length(mprev)))
end
delta_meta2 = @meta begin
lotka_volterra_rk4() -> Unscented()
expf() -> Unscented()
end
autoupdates2 = @autoupdates begin
mprev, Vprev= mean_cov(q(x))
mθ, Vθ = mean_cov(q(θ))
end
@initialization function initialize2(mx, Vx, mθ, Vθ)
q(x) = MvNormalMeanCovariance(mx, Vx)
q(θ) = MvNormalMeanCovariance(mθ, Vθ)
end
result2 = infer(
model = lotka_volterra_model2(dt = dt,),
data = (obs = noisy_data_long, t= ts_long),
initialization = initialize2(ones(2), 1e-6diageye(2), zeros(4), 0.1*diageye(4)),
meta = delta_meta2,
autoupdates = autoupdates2,
keephistory = length(noisy_data_long),
free_energy = true
)
RxInferenceEngine:
Posteriors stream | enabled for (θ_exp, θ, xprev, x)
Free Energy stream | enabled
Posteriors history | available for (θ_exp, θ, xprev, x)
Free Energy history | available
Enabled events | [ ]
mθ_exp = mean.(result2.history[:θ_exp])
Vθ_exp = var.(result2.history[:θ_exp])
# Plot the inferred parameters with uncertainty
p1 = plot(ts_long, getindex.(mθ_exp, 1), ribbon=2*sqrt.(getindex.(Vθ_exp, 1)), label="α", legend=:topright)
plot!(p1, ts_long, fill(true_params[1], length(ts_long)), label="True α", linestyle=:dash)
title!(p1, "Parameter α")
p2 = plot(ts_long, getindex.(mθ_exp, 2), ribbon=2*sqrt.(getindex.(Vθ_exp, 2)), label="β", legend=:topright)
plot!(p2, ts_long, fill(true_params[2], length(ts_long)), label="True β", linestyle=:dash)
title!(p2, "Parameter β")
p3 = plot(ts_long, getindex.(mθ_exp, 3), ribbon=2*sqrt.(getindex.(Vθ_exp, 3)), label="γ", legend=:topright)
plot!(p3, ts_long, fill(true_params[3], length(ts_long)), label="True γ", linestyle=:dash)
title!(p3, "Parameter γ")
p4 = plot(ts_long, getindex.(mθ_exp, 4), ribbon=2*sqrt.(getindex.(Vθ_exp, 4)), label="δ", legend=:topright)
plot!(p4, ts_long, fill(true_params[4], length(ts_long)), label="True δ", linestyle=:dash)
title!(p4, "Parameter δ")
plot(p1, p2, p3, p4, layout=(4,1), size=(1000,800))
# Print final parameter estimates and covariance
final_means = last(mθ_exp)
final_vars = last(Vθ_exp)
final_stds = sqrt.(final_vars)
# Print results
println("\nFinal Parameter Estimates:")
for (param, mean, std) in zip(param_names, final_means, final_stds)
println("$param: $mean ± $(std)")
end
println("\nFinal parameter covariance matrix:")
display(cov(last(result2.history[:θ_exp])))
Final Parameter Estimates:
α: 1.1262630068329251 ± 0.025551178790336644
β: 1.559153042713237 ± 0.02790701016695101
γ: 2.493178031020226 ± 0.11291883296479856
δ: 0.8461062937663921 ± 0.030319628433966395
Final parameter covariance matrix:
4×4 Matrix{Float64}:
0.000652863 0.000508431 8.64717e-5 3.83381e-5
0.000508431 0.000778801 -7.76106e-5 -5.9575e-5
8.64717e-5 -7.76106e-5 0.0127507 0.00295032
3.83381e-5 -5.9575e-5 0.00295032 0.00091928
# Get state estimates and variances
mx = mean.(result2.history[:x])
Vx = var.(result2.history[:x])
# Plot state estimates with uncertainty bands
p1 = plot(ts_long[from:skip:to] , getindex.(mx, 1)[from:skip:to], ribbon=2*sqrt.(getindex.(Vx, 1)[from:skip:to]),
label="Prey estimate", legend=:topright)
scatter!(p1, ts_long[from:skip:to], getindex.(noisy_data_long, 1)[from:skip:to], label="Noisy prey observations", alpha=0.5,ms=1)
plot!(p1, ts_long[from:skip:to], getindex.(data_long, 1)[from:skip:to], label="True prey", linestyle=:dash)
title!(p1, "Prey Population")
p2 = plot(ts_long[from:skip:to], getindex.(mx, 2)[from:skip:to], ribbon=2*sqrt.(getindex.(Vx, 2)[from:skip:to]),
label="Predator estimate", legend=:topright)
scatter!(p2, ts_long[from:skip:to], getindex.(noisy_data_long, 2)[from:skip:to], label="Noisy predator observations", alpha=0.5, ms=1)
plot!(p2, ts_long[from:skip:to], getindex.(data_long, 2)[from:skip:to] , label="True predator", linestyle=:dash)
title!(p2, "Predator Population")
plot(p1, p2, layout=(2,1), size=(1000,600))
This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.
We welcome and encourage contributions! You can help by:
- Improving this example
- Creating new examples
- Reporting issues or bugs
- Suggesting enhancements
Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪
This example was executed in a clean, isolated environment. Below are the exact package versions used:
For reproducibility:
- Use the same package versions when running locally
- Report any issues with package compatibility
Status `~/work/RxInferExamples.jl/RxInferExamples.jl/docs/src/categories/problem_specific/ode_parameter_estimation/Project.toml`
[429524aa] Optim v1.11.0
[91a5bcdd] Plots v1.40.9
[86711068] RxInfer v4.0.1
⌃ [1c904df7] SeeToDee v1.2.1
[860ef19b] StableRNGs v1.0.2
[90137ffa] StaticArrays v1.9.12
[37e2e46d] LinearAlgebra v1.11.0
Info Packages marked with ⌃ have new versions available and may be upgradable.