This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.
We welcome and encourage contributions! You can help by:
- Improving this example
- Creating new examples
- Reporting issues or bugs
- Suggesting enhancements
Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪
Probit Model
Estimation of pollutant
Mortality $y_t$ of fishs in a lake is observed over time. Mortality rate $\text{Ber}(\Phi(x_t))$ is linked to the level of pollutant $x_t$ in the lake according to the probit model (see below). The municipality wants to keep track of the pollution. To do so, the level of pollutant in the lake is tracked over time through observations of the fishs.
Objective
Probit model aims to infer a random proces value from noisy binary observations of it. RxInfer
comes with support for expectation propagation (EP). In this demo we illustrate EP in the context of state-estimation in a linear state-space model that combines a Gaussian state-evolution model with a discrete observation model. Here, the probit function links continuous variable $x_t$ with the discrete variable $y_t$. The model is defined as:
\[\begin{aligned} u &= 0.1 \\ x_0 &\sim \mathcal{N}(0, 100) \\ x_t &\sim \mathcal{N}(x_{t-1}+ u, 0.01) \\ y_t &\sim \mathrm{Ber}(\Phi(x_t)) \end{aligned}\]
Import packages
using RxInfer, GraphPPL,StableRNGs, Random, Plots, Distributions
using StatsFuns: normcdf
Data generation
function generate_data(nr_samples::Int64; seed = 123)
rng = StableRNG(seed)
# hyper parameters
u = 0.1
# allocate space for data
data_x = zeros(nr_samples + 1)
data_y = zeros(nr_samples)
# initialize data
data_x[1] = -2
# generate data
for k in eachindex(data_y)
# calculate new x
data_x[k+1] = data_x[k] + u + sqrt(0.01)*randn(rng)
# calculate y
data_y[k] = normcdf(data_x[k+1]) > rand(rng)
end
# return data
return data_x, data_y
end;
n = 40
40
data_x, data_y = generate_data(n);
p = plot(xlabel = "t", ylabel = "x, y")
p = scatter!(p, data_y, label = "y")
p = plot!(p, data_x[2:end], label = "x")
Model specification
@model function probit_model(y, prior_x)
# specify uninformative prior
x_prev ~ prior_x
# create model
for k in eachindex(y)
x[k] ~ Normal(mean = x_prev + 0.1, precision = 100)
y[k] ~ Probit(x[k]) where {
# Probit node by default uses RequireMessage pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge
# To change initial value user may specify it manually, like. Changes to the initial message may improve stability in some situations
dependencies = RequireMessageFunctionalDependencies(in = NormalMeanPrecision(0.0, 0.01))
}
x_prev = x[k]
end
end;
Probit Node
Probit node needs an initialisation of the 'in' message because of this computation methodology. The input message is not directly calculated. First the marginal $q(in)$ is computed and then the output message, this using the margianalisation formula.
\[\overrightarrow{\mu}(x) \overleftarrow{\mu}(x) = q(x)\]
Consequently an initial message $\overleftarrow{\mu}(in)$ is needed to start iterate. It can be speficied as in the above example. Otherwise RxInfer
will initiate it at a default value.
Inference
result = infer(
model = probit_model(prior_x=Normal(0.0, 100.0)),
data = (y = data_y, ),
iterations = 5,
returnvars = (x = KeepLast(),),
free_energy = true
)
Inference results:
Posteriors | available for (x)
Free Energy: | Real[25.6698, 18.0157, 17.9199, 17.9194, 17.9194]
Results
mx = result.posteriors[:x]
p = plot(xlabel = "t", ylabel = "x, y", legend = :bottomright)
p = scatter!(p, data_y, label = "y")
p = plot!(p, data_x[2:end], label = "x", lw = 2)
p = plot!(mean.(mx)[2:end], ribbon = std.(mx)[2:end], fillalpha = 0.2, label="x (inferred mean)")
f = plot(xlabel = "t", ylabel = "BFE")
f = plot!(result.free_energy, label = "Bethe Free Energy")
plot(p, f, size = (800, 400))
This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.
We welcome and encourage contributions! You can help by:
- Improving this example
- Creating new examples
- Reporting issues or bugs
- Suggesting enhancements
Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪
This example was executed in a clean, isolated environment. Below are the exact package versions used:
For reproducibility:
- Use the same package versions when running locally
- Report any issues with package compatibility
Status `~/work/RxInferExamples.jl/RxInferExamples.jl/docs/src/categories/problem_specific/probit_model/Project.toml`
[31c24e10] Distributions v0.25.117
[b3f8163a] GraphPPL v4.6.2
[91a5bcdd] Plots v1.40.9
[86711068] RxInfer v4.0.1
[860ef19b] StableRNGs v1.0.2
[4c63d2b9] StatsFuns v1.3.2
[9a3f8284] Random v1.11.0