Contributing

This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.

We welcome and encourage contributions! You can help by:

  • Improving this example
  • Creating new examples
  • Reporting issues or bugs
  • Suggesting enhancements

Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪


Probit Model

Estimation of pollutant

Mortality $y_t$ of fishs in a lake is observed over time. Mortality rate $\text{Ber}(\Phi(x_t))$ is linked to the level of pollutant $x_t$ in the lake according to the probit model (see below). The municipality wants to keep track of the pollution. To do so, the level of pollutant in the lake is tracked over time through observations of the fishs.

Objective

Probit model aims to infer a random proces value from noisy binary observations of it. RxInfer comes with support for expectation propagation (EP). In this demo we illustrate EP in the context of state-estimation in a linear state-space model that combines a Gaussian state-evolution model with a discrete observation model. Here, the probit function links continuous variable $x_t$ with the discrete variable $y_t$. The model is defined as:

\[\begin{aligned} u &= 0.1 \\ x_0 &\sim \mathcal{N}(0, 100) \\ x_t &\sim \mathcal{N}(x_{t-1}+ u, 0.01) \\ y_t &\sim \mathrm{Ber}(\Phi(x_t)) \end{aligned}\]

Import packages

using RxInfer, GraphPPL,StableRNGs, Random, Plots, Distributions
using StatsFuns: normcdf

Data generation

function generate_data(nr_samples::Int64; seed = 123)
    
    rng = StableRNG(seed)
    
    # hyper parameters
    u = 0.1

    # allocate space for data
    data_x = zeros(nr_samples + 1)
    data_y = zeros(nr_samples)
    
    # initialize data
    data_x[1] = -2
    
    # generate data
    for k in eachindex(data_y)
        
        # calculate new x
        data_x[k+1] = data_x[k] + u + sqrt(0.01)*randn(rng)
        
        # calculate y
        data_y[k] = normcdf(data_x[k+1]) > rand(rng)
        
    end
    
    # return data
    return data_x, data_y
    
end;
n = 40
40
data_x, data_y = generate_data(n);
p = plot(xlabel = "t", ylabel = "x, y")
p = scatter!(p, data_y, label = "y")
p = plot!(p, data_x[2:end], label = "x")

Model specification

@model function probit_model(y, prior_x)
    
    # specify uninformative prior
    x_prev ~ prior_x
    
    # create model 
    for k in eachindex(y)
        x[k] ~ Normal(mean = x_prev + 0.1, precision = 100)
        y[k] ~ Probit(x[k]) where {
            # Probit node by default uses RequireMessage pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge
            # To change initial value user may specify it manually, like. Changes to the initial message may improve stability in some situations
            dependencies = RequireMessageFunctionalDependencies(in = NormalMeanPrecision(0.0, 0.01))
        }
        x_prev = x[k]
    end
    
end;

Probit Node

Probit node needs an initialisation of the 'in' message because of this computation methodology. The input message is not directly calculated. First the marginal $q(in)$ is computed and then the output message, this using the margianalisation formula.

\[\overrightarrow{\mu}(x) \overleftarrow{\mu}(x) = q(x)\]

Consequently an initial message $\overleftarrow{\mu}(in)$ is needed to start iterate. It can be speficied as in the above example. Otherwise RxInfer will initiate it at a default value.

Inference

result = infer(
    model = probit_model(prior_x=Normal(0.0, 100.0)), 
    data  = (y = data_y, ), 
    iterations = 5, 
    returnvars = (x = KeepLast(),),
    free_energy  = true
)
Inference results:
  Posteriors       | available for (x)
  Free Energy:     | Real[25.6698, 18.0157, 17.9199, 17.9194, 17.9194]

Results

mx = result.posteriors[:x]

p = plot(xlabel = "t", ylabel = "x, y", legend = :bottomright)
p = scatter!(p, data_y, label = "y")
p = plot!(p, data_x[2:end], label = "x", lw = 2)
p = plot!(mean.(mx)[2:end], ribbon = std.(mx)[2:end], fillalpha = 0.2, label="x (inferred mean)")

f = plot(xlabel = "t", ylabel = "BFE")
f = plot!(result.free_energy, label = "Bethe Free Energy")

plot(p, f, size = (800, 400))


Contributing

This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.

We welcome and encourage contributions! You can help by:

  • Improving this example
  • Creating new examples
  • Reporting issues or bugs
  • Suggesting enhancements

Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪


Environment

This example was executed in a clean, isolated environment. Below are the exact package versions used:

For reproducibility:

  • Use the same package versions when running locally
  • Report any issues with package compatibility
Status `~/work/RxInferExamples.jl/RxInferExamples.jl/docs/src/categories/problem_specific/probit_model/Project.toml`
  [31c24e10] Distributions v0.25.117
  [b3f8163a] GraphPPL v4.6.2
  [91a5bcdd] Plots v1.40.9
  [86711068] RxInfer v4.0.1
  [860ef19b] StableRNGs v1.0.2
  [4c63d2b9] StatsFuns v1.3.2
  [9a3f8284] Random v1.11.0