This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.
We welcome and encourage contributions! You can help by:
- Improving this example
- Creating new examples
- Reporting issues or bugs
- Suggesting enhancements
Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪
Bayesian Multinomial Regression
This notebook is an introductory tutorial to Bayesian multinomial regression with RxInfer
.
using RxInfer, Plots, StableRNGs, Distributions, ExponentialFamily, StatsPlots
import ExponentialFamily: softmax
Model Description
The key innovation in Linderman et al. (2015) is extending the Pólya-gamma augmentation scheme to the multinomial case. This allows us to transform the non-conjugate multinomial likelihood into a conditionally conjugate form by introducing auxiliary Pólya-gamma random variables.
The multinomial regression model with Pólya-gamma augmentation can be written as: $p(y | \psi, N) = \text{Multinomial}(y |N, \psi)$
where:
\[y\]
is a $K$-dimensional vector of count data with $N$ total counts\[\psi\]
is a $K-1$ -dimensional Gaussian random variable
Implementation
In this notebook, we will implement the Pólya-gamma augmented Bayesian multinomial regression model with RxInfer
by performing inference using message passing to estimate the posterior distribution of the regression coefficients
function generate_multinomial_data(rng=StableRNG(123);N = 3, k=3, nsamples = 1000)
Ψ = randn(rng, k)
p = softmax(Ψ)
X = rand(rng, Multinomial(N, p), nsamples)
X= [X[:,i] for i in 1:size(X,2)];
return X, Ψ,p
end
generate_multinomial_data (generic function with 2 methods)
nsamples = 5000
N = 30
k = 40
X, Ψ, p = generate_multinomial_data(N=N,k=k,nsamples=nsamples);
The MultinomialPolya
factor node is used to model the likelihood of the multinomial distribution.
Due to non-conjugacy of the likelihood and the prior distribution, we need to use a more complex inference algorithm. RxInfer provides an Expectation Propagation (EP) [2] algorithm to infer the posterior distribution. Due to EP's approximation, we need to specify an inbound message for the regression coefficients while using the MultinomialPolya
factor node. This feature is implemented in the dependencies
keyword argument during the creation of the MultinomialPolya
factor node. ReactiveMP.jl
provides a RequireMessageFunctionalDependencies
type that is used to specify the inbound message for the regression coefficients ψ
. Refer to the ReactiveMP.jl documentation for more information.
@model function multinomial_model(obs, N, ξ_ψ, W_ψ)
ψ ~ MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ)
obs .~ MultinomialPolya(N, ψ) where {dependencies = RequireMessageFunctionalDependencies(ψ = MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ))}
end
result = infer(
model = multinomial_model(ξ_ψ=zeros(k-1), W_ψ=rand(Wishart(3, diageye(k-1))), N=N),
data = (obs=X, ),
iterations = 50,
free_energy = true,
showprogress = true,
options = (
limit_stack_depth = 100,
)
)
Inference results:
Posteriors | available for (ψ)
Free Energy: | Real[4.47163e5, 2.93666e5, 2.39374e5, 2.13822e5, 1.998
63e5, 1.91485e5, 1.86116e5, 1.82505e5, 1.79982e5, 1.78166e5 … 1.71746e5,
1.71739e5, 1.71733e5, 1.71727e5, 1.71722e5, 1.71718e5, 1.71714e5, 171711.0,
1.71708e5, 1.71705e5]
plot(result.free_energy[1:end],
title="Free Energy Over Iterations",
xlabel="Iteration",
ylabel="Free Energy",
linewidth=2,
legend=false,
grid=true,
)
predictive = @call_rule MultinomialPolya(:x, Marginalisation) (q_N = PointMass(N), q_ψ = result.posteriors[:ψ][end], meta = MultinomialPolyaMeta(21))
println("Estimated data generation probabilities: $(predictive.p)")
println("True data generation probabilities: $(p)")
Estimated data generation probabilities: [0.012921996337420428, 0.028270152
420452485, 0.0049613670207640514, 0.013283925968017245, 0.01450923458120122
5, 0.0380318172011463, 0.007886590163156739, 0.00686606245189614, 0.0058307
909924652825, 0.004396745569919048, 0.005093967623704321, 0.004168620689537
3765, 0.003720427813084725, 0.03627928128811411, 0.10839838075125663, 0.072
03077129175312, 0.026653569951859853, 0.02384101664160622, 0.01011507832655
1239, 0.007645961594354, 0.03939837139614987, 0.004348679172704573, 0.00883
1453585770564, 0.026889706364971105, 0.007121675017067946, 0.00839694589389
5818, 0.009799567811715372, 0.0074250693171146005, 0.0175058396104508, 0.00
768210939007052, 0.008267731869076266, 0.003728336434200991, 0.011705849535
07465, 0.01067192494750194, 0.09382709670283643, 0.043472178899373125, 0.13
061227013524632, 0.02809864473049759, 0.03079544991246554, 0.06651534059555
546]
True data generation probabilities: [0.012475572764691347, 0.02759115956301
153, 0.004030932560100506, 0.013008651265311708, 0.012888510278451618, 0.03
7656116813111006, 0.007242363105598982, 0.006930069564505769, 0.00538389836
228327, 0.0036198124274772225, 0.005212387391120808, 0.003185556887255863,
0.003820168769118259, 0.036849638787622915, 0.109428569898501, 0.0726075387
5224316, 0.026079268674281158, 0.024477855252934583, 0.010207778995219957,
0.008532295265944583, 0.040242532118754906, 0.005181587450423221, 0.0082073
91370854009, 0.02741148713822125, 0.006623087410725917, 0.00836770271463416
2, 0.009668643362989908, 0.007171783607096945, 0.016985615150215773, 0.0070
80691453323701, 0.008297044496975403, 0.0037359000700039487, 0.011142755810
390478, 0.010256554277897088, 0.09528238587772694, 0.04369806970660494, 0.1
3308101804159636, 0.02665693577960761, 0.030479170124456504, 0.069201498658
71575]
mse = mean((predictive.p - p).^2);
println("MSE between estimated and true data generation probabilities: $mse")
MSE between estimated and true data generation probabilities: 7.75266162255
5736e-7
@model function multinomial_regression(obs, N, X, ϕ, ξβ, Wβ)
β ~ MvNormalWeightedMeanPrecision(ξβ, Wβ)
for i in eachindex(obs)
Ψ[i] := ϕ(X[i])*β
obs[i] ~ MultinomialPolya(N, Ψ[i]) where {dependencies = RequireMessageFunctionalDependencies(ψ = MvNormalWeightedMeanPrecision(zeros(length(obs[i])-1), diageye(length(obs[i])-1)))}
end
end
function generate_regression_data(rng=StableRNG(123);ϕ = identity,N = 3, k=5, nsamples = 1000)
β = randn(rng, k)
X = randn(rng, nsamples, k, k)
X = [X[i,:,:] for i in 1:size(X,1)];
Ψ = ϕ.(X)
p = map(x -> logistic_stick_breaking(x*β), Ψ)
return map(x -> rand(rng, Multinomial(N, x)), p), X, β, p
end
generate_regression_data (generic function with 2 methods)
ϕ = x -> sin(x)
obs_regression, X_regression, β_regression, p_regression = generate_regression_data(;nsamples = 5000, ϕ = ϕ);
reg_results = infer(
model = multinomial_regression(N = 3, ϕ = ϕ, ξβ = zeros(5), Wβ = rand(Wishart(5, diageye(5)))),
data = (obs=obs_regression,X = X_regression ),
iterations = 20,
free_energy = true,
showprogress = true,
returnvars = KeepLast(),
options = (
limit_stack_depth = 100,
)
)
Inference results:
Posteriors | available for (Ψ, β)
Free Energy: | Real[11950.8, 11583.8, 11501.1, 11479.9, 11474.2, 1147
2.6, 11472.2, 11472.1, 11472.0, 11472.0, 11472.0, 11472.0, 11472.0, 11472.0
, 11472.0, 11472.0, 11472.0, 11472.0, 11472.0, 11472.0]
println("estimated β: with mean and covariance: $(mean_cov(reg_results.posteriors[:β]))")
println("true β: $(β_regression)")
estimated β: with mean and covariance: ([-0.11473506140511743, 0.6637670957
25271, -1.2556217371206575, -0.08604219750441625, -0.08016966649970333], [0
.00014805005349370298 -2.1056573971431235e-6 3.6316025472036616e-6 -1.57522
08754889512e-6 3.2349379005947306e-6; -2.1056573971431235e-6 0.000151796188
7508871 -1.9356375635799642e-5 -2.6440506950843055e-7 1.2576682344176783e-6
; 3.6316025472036616e-6 -1.9356375635799642e-5 0.00017992554626697165 4.543
654346891547e-6 5.239637225283828e-7; -1.5752208754889512e-6 -2.64405069508
43055e-7 4.543654346891547e-6 0.00014006406142991353 3.273947125004614e-6;
3.2349379005947306e-6 1.2576682344176783e-6 5.239637225283828e-7 3.27394712
5004614e-6 0.00013945784752252273])
true β: [-0.12683768965424458, 0.6668851724871252, -1.2566124895590247, -0.
08499562516549662, -0.094274004848194]
plot(reg_results.free_energy,
title="Free Energy Over Iterations",
xlabel="Iteration",
ylabel="Free Energy",
linewidth=2,
legend=false,
grid=true,)
mse_β = mean((mean(reg_results.posteriors[:β]) - β_regression).^2)
println("MSE of β estimate: $mse_β")
MSE of β estimate: 7.144105550656234e-5
We can visualize how the logistic stick-breaking transformation of the simplex coordinates of the regression coefficients affects the prior distribution of the regression coefficients and vice versa since the logistic stick-breaking transformation is invertible.
# Previous helper functions remain the same
σ(x) = 1 / (1 + exp(-x))
σ_inv(x) = log(x / (1 - x))
function jacobian_det(π)
K = length(π)
det = 1.0
for k in 1:(K-1)
num = 1 - sum(π[1:(k-1)])
den = π[k] * (1 - sum(π[1:k]))
det *= num / den
end
return det
end
function ψ_to_π(ψ::Vector{Float64})
K = length(ψ) + 1
π = zeros(K)
for k in 1:(K-1)
π[k] = σ(ψ[k]) * (1 - sum(π[1:(k-1)]))
end
π[K] = 1 - sum(π[1:(K-1)])
return π
end
function π_to_ψ(π)
K = length(π)
ψ = zeros(K-1)
ψ[1] = σ_inv(π[1])
for k in 2:(K-1)
ψ[k] = σ_inv(π[k] / (1 - sum(π[1:(k-1)])))
end
return ψ
end
# Function to compute density in simplex coordinates
function compute_simplex_density(x::Float64, y::Float64, Σ::Matrix{Float64})
# Check if point is inside triangle
if y < 0 || y > 1 || x < 0 || x > 1 || (x + y) > 1
return 0.0
end
# Convert from simplex coordinates to π
π1 = x
π2 = y
π3 = 1 - x - y
try
ψ = π_to_ψ([π1, π2, π3])
# Compute Gaussian density
dist = MvNormal(zeros(2), Σ)
return pdf(dist, ψ) * abs(jacobian_det([π1, π2, π3]))
catch
return 0.0
end
end
function plot_transformed_densities()
# Create three different covariance matrices
###For higher variances values needs scaling for proper visualization.
σ² = 1.0
Σ_corr = [σ² 0.9σ²; 0.9σ² σ²]
Σ_anticorr = [σ² -0.9σ²; -0.9σ² σ²]
Σ_uncorr = [σ² 0.0; 0.0 σ²]
# Plot Gaussian densities
ψ1, ψ2 = range(-4sqrt(σ²), 4sqrt(σ²), length=500), range(-4sqrt(σ²), 4sqrt(σ²), length=100)
p1 = contour(ψ1, ψ2, (x,y) -> pdf(MvNormal(zeros(2), Σ_corr), [x,y]),
title="Correlated Prior", xlabel="ψ₁", ylabel="ψ₂")
p2 = contour(ψ1, ψ2, (x,y) -> pdf(MvNormal(zeros(2), Σ_anticorr), [x,y]),
title="Anti-correlated Prior", xlabel="ψ₁", ylabel="ψ₂")
p3 = contour(ψ1, ψ2, (x,y) -> pdf(MvNormal(zeros(2), Σ_uncorr), [x,y]),
title="Uncorrelated Prior", xlabel="ψ₁", ylabel="ψ₂")
# Plot simplex densities
n_points = 500
x = range(0, 1, length=n_points)
y = range(0, 1, length=n_points)
# Plot simplices
p4 = contour(x, y, (x,y) -> compute_simplex_density(x, y, Σ_corr),
title="Correlated Simplex")
# Add simplex boundaries and median lines
plot!(p4, [0,1,0,0], [0,0,1,0], color=:black, label="") # Triangle boundaries
p5 = contour(x, y, (x,y) -> compute_simplex_density(x, y, Σ_anticorr),
title="Anti-correlated Simplex")
plot!(p5, [0,1,0,0], [0,0,1,0], color=:black, label="")
p6 = contour(x, y, (x,y) -> compute_simplex_density(x, y, Σ_uncorr),
title="Uncorrelated Simplex")
plot!(p6, [0,1,0,0], [0,0,1,0], color=:black, label="")
# Combine all plots
plot(p1, p2, p3, p4, p5, p6, layout=(2,3), size=(900,600))
end
# Generate the plots
plot_transformed_densities()
This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.
We welcome and encourage contributions! You can help by:
- Improving this example
- Creating new examples
- Reporting issues or bugs
- Suggesting enhancements
Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪
This example was executed in a clean, isolated environment. Below are the exact package versions used:
For reproducibility:
- Use the same package versions when running locally
- Report any issues with package compatibility
Status `~/work/RxInferExamples.jl/RxInferExamples.jl/docs/src/categories/basic_examples/bayesian_multinomial_regression/Project.toml`
[31c24e10] Distributions v0.25.117
[62312e5e] ExponentialFamily v2.0.1
[91a5bcdd] Plots v1.40.9
[86711068] RxInfer v4.1.0
[860ef19b] StableRNGs v1.0.2
[f3b207a7] StatsPlots v0.15.7