This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.
We welcome and encourage contributions! You can help by:
- Improving this example
- Creating new examples
- Reporting issues or bugs
- Suggesting enhancements
Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪
POMDP Control with Reactive Inference
This notebook demonstrates how to perform control in Partially Observable Markov Decision Processes (POMDPs) using reactive message passing and variational inference in RxInfer.jl.
We will cover:
- Setting up a simple POMDP model
- Defining the state transition and observation models
- Implementing the control policy
- Performing inference and control using message passing
- Visualizing the results
using RxInfer
using Distributions
using Plots
using Random
using ProgressMeter
Environment Setup
For this example, we will implement the Windy Gridworld environment using RxEnvironments.jl
. The Windy Gridworld is a simple gridworld environment with deterministic transitions and observations. This code is adapted from the RxEnvironments.jl documentation, and a more elaborate explanation of can be found there.
The environment consists of:
- A grid with wind values for each column
- An agent with a current position
- A goal position to reach
The agent can:
- Move in cardinal directions (one step at a time)
- Observe its current position
- Be affected by wind when moving
The wind effect is applied after each movement, potentially pushing the agent upward by 0-2 positions depending on the column.
First we will define the environment and the agent.
using RxEnvironments
using Plots
struct WindyGridWorld{N}
wind::NTuple{N,Int}
agents::Vector
goal::Tuple{Int,Int}
end
mutable struct WindyGridWorldAgent
position::Tuple{Int,Int}
end
RxEnvironments.update!(env::WindyGridWorld, dt) = nothing # The environment has no "internal" updating process over time
function RxEnvironments.receive!(env::WindyGridWorld{N}, agent::WindyGridWorldAgent, action::Tuple{Int,Int}) where {N}
if action[1] != 0
@assert action[2] == 0 "Only one of the two actions can be non-zero"
elseif action[2] != 0
@assert action[1] == 0 "Only one of the two actions can be non-zero"
end
new_position = (agent.position[1] + action[1], agent.position[2] + action[2] + env.wind[agent.position[1]])
if all(elem -> 0 < elem < N, new_position)
agent.position = new_position
end
end
function RxEnvironments.what_to_send(env::WindyGridWorld, agent::WindyGridWorldAgent)
return agent.position
end
function RxEnvironments.what_to_send(agent::WindyGridWorldAgent, env::WindyGridWorld)
return agent.position
end
function RxEnvironments.add_to_state!(env::WindyGridWorld, agent::WindyGridWorldAgent)
push!(env.agents, agent)
end
function reset_env!(environment::RxEnvironments.RxEntity{<:WindyGridWorld,T,S,A}) where {T,S,A}
env = environment.decorated
for agent in env.agents
agent.position = (1, 1)
end
for subscriber in RxEnvironments.subscribers(environment)
send!(subscriber, environment, (1, 1))
end
end
function plot_environment(environment::RxEnvironments.RxEntity{<:WindyGridWorld,T,S,A}) where {T,S,A}
env = environment.decorated
p1 = scatter([env.goal[1]], [env.goal[2]], color=:blue, label="Goal", xlims=(0, 6), ylims=(0, 6))
for agent in env.agents
p1 = scatter!([agent.position[1]], [agent.position[2]], color=:red, label="Agent")
end
return p1
end
plot_environment (generic function with 1 method)
env = RxEnvironment(WindyGridWorld((0, 1, 1, 1, 0), [], (4, 3)))
agent = add!(env, WindyGridWorldAgent((1, 1)))
plot_environment(env)
Model Setup
First, we'll define our POMDP model structure. We will use the DiscreteTransition
node in RxInfer
to define the state transition model. The DiscreteTransition
node is a special node that accepts any number of Categorical
distributions as input, and outputs a Categorical
distribution. This means that we can use it to define a state transition model that accepts the previous state and the control as Categorical
random variables, but we can also use it to define our observation model! Furthermore, the DiscreteTransition
node can be used both for parameter inference and for inference-as-planning, isn't that neat?
@model function pomdp_model(p_A, p_B, p_goal, p_control, previous_control, p_previous_state, current_y, future_y, T, m_A, m_B)
# Instantiate all model parameters with priors
A ~ p_A
B ~ p_B
previous_state ~ p_previous_state
# Paremeter inference
current_state ~ DiscreteTransition(previous_state, B, previous_control)
current_y ~ DiscreteTransition(current_state, A)
prev_state = current_state
# Inference-as-planning
for t in 1:T
controls[t] ~ p_control
s[t] ~ DiscreteTransition(prev_state, m_B, controls[t])
future_y[t] ~ DiscreteTransition(s[t], m_A)
prev_state = s[t]
end
# Goal prior initialization
s[end] ~ p_goal
end
Now, this model, because we use A
and B
for every timestep, contains loops, so we have to initialize the inference procedure properly. Furthermore, RxInfer
does not support learning a joint probability distribution over the parameters and the states, so we have to supply the model with variational constraints that reflect this:
init = @initialization begin
q(A) = DirichletCollection(diageye(25) .+ 0.1)
q(B) = DirichletCollection(ones(25, 25, 4))
end
constraints = @constraints begin
q(previous_state, previous_control, current_state, B) = q(previous_state, previous_control, current_state)q(B)
q(current_state, current_y, A) = q(current_state, current_y)q(A)
q(current_state, s, controls, B) = q(current_state, s, controls), q(B)
q(s, future_y, A) = q(s, future_y), q(A)
end
Constraints:
q(previous_state, previous_control, current_state, B) = q(previous_state,
previous_control, current_state)q(B)
q(current_state, current_y, A) = q(current_state, current_y)q(A)
q(current_state, s, controls, B) = q(current_state, s, controls)q(B)
q(s, future_y, A) = q(s, future_y)q(A)
Now, in order to use this model, we have to define the priors for the model parameters. The WindyGridworld environment has a 5-by-5 grid, so we need to instantiate a prior 25-by-25 transition matrices for every control! That's quite a lot of parameters, but as we will see, RxInfer
will handle this just fine. We will give our agent a control space of 4 actions, so we need to instantiate 4 transition matrices. Furthermore, we have to transform the output from the environment to a 1-in-25 index, and the controls from a 1-in-4 index to a direction tuple.
The prior on our observation model tells our model that the prior belief is to trust it's observations, but we might be able to deviate from this. However, in this example, the observation model is deterministic and has no noise, meaning that our agent won't have any reason to deviate from the prior.
p_A = DirichletCollection(diageye(25) .+ 0.1)
p_B = DirichletCollection(ones(25, 25, 4))
function grid_location_to_index(pos::Tuple{Int, Int})
return (pos[2] - 1) * 5 + pos[1]
end
function index_to_grid_location(index::Int)
return (index % 5, index ÷ 5 + 1,)
end
function index_to_one_hot(index::Int)
return [i == index ? 1.0 : 0.0 for i in 1:25]
end
goal = Categorical(index_to_one_hot(grid_location_to_index((4, 3))))
Distributions.Categorical{Float64, Vector{Float64}}(
support: Base.OneTo(25)
p: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0
, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
)
RxEnvironments.jl
is a package that allows us to easily communicate between our agent and our environment. We can send actions to the environment, and the environment will automatically respond with the corresponding observations. In order to access these in our model, we can subscribe to the observations and then use the data
function to access the last observation.
Now for our main control loop, we will use a receding horizon control strategy. We will first take an action, observe the environment, and then update our belief. We will then repeat this process for a horizon of 10 steps. In order to learn the parameters of our model, we will conduct this experiment 100 times. We can use the infer
function from RxInfer
to perform inference on our model.
# Number of times to run the experiment
n_experiments = 100
# Number of steps in each experiment
T = 4
observations = keep(Any)
# Subscribe the agent to receive observations
RxEnvironments.subscribe_to_observations!(agent, observations)
successes = []
@showprogress for i in 1:n_experiments
# Reset environment to initial state and initialize state belief to starting position (1,1)
reset_env!(env)
p_s = Categorical(index_to_one_hot(grid_location_to_index((1, 1))))
# Initialize previous action as "down", as this is neutral from the starting position
policy = [Categorical([0.0, 0.0, 1.0, 0.0])]
prev_u = [0.0, 0.0, 1.0, 0.0]
# Run for T-1 steps in each experiment
for t in 1:T
# Convert policy to actual movement in environment
current_action = mode(first(policy))
if current_action == 1
send!(env, agent, (0, 1)) # Move up
prev_u = [1.0, 0.0, 0.0, 0.0]
elseif current_action == 2
send!(env, agent, (1, 0)) # Move right
prev_u = [0.0, 1.0, 0.0, 0.0]
elseif current_action == 3
send!(env, agent, (0, -1)) # Move down
prev_u = [0.0, 0.0, 1.0, 0.0]
elseif current_action == 4
send!(env, agent, (-1, 0)) # Move left
prev_u = [0.0, 0.0, 0.0, 1.0]
end
# Get last observation and convert to one-hot encoding
last_observation = index_to_one_hot(grid_location_to_index(RxEnvironments.data(last(observations))))
# Perform inference using the POMDP model
inference_result = infer(
model = pomdp_model(
p_A = p_A, # prior on observation model parameters
p_B = p_B, # prior on transition model parameters
T = max(T - t, 1), # remaining time steps
p_previous_state = p_s, # posterior belief on previous state
p_goal = goal, # prior on goal state
p_control = vague(Categorical, 4), # prior over controls
m_A = mean(p_A),
m_B = mean(p_B)
),
# Provide data for inference
data = (
previous_control = UnfactorizedData(prev_u),
current_y = UnfactorizedData(last_observation),
future_y = UnfactorizedData(fill(missing, max(T - t, 1)))
),
constraints = constraints,
initialization = init,
iterations = 10
)
# Update beliefs based on inference results
p_s = last(inference_result.posteriors[:current_state]) # Update state belief
policy = last(inference_result.posteriors[:controls]) # Get policy
# Update model parameters globally for the entire notebook
global p_A = last(inference_result.posteriors[:A]) # Update observation model
global p_B = last(inference_result.posteriors[:B]) # Update transition model
if RxEnvironments.data(last(observations)) == (4, 3)
break
end
end
if RxEnvironments.data(last(observations)) == (4, 3)
push!(successes, true)
else
push!(successes, false)
end
end
Now, in this example, we have used a trick: we supplied the mean of p_A
and p_B
to the model to do the predictions for the future in order to learn the controls. The real reason we did this is because we do not want messages from the future to influence the model parameters, instead only learning the model parameters from past data. This is a simple way to do this, but it is not the only way. We could have supplied the full distribution p_A
and p_B
to the model, and used A
and B
in the predictive step as well, but then we would need a separate way to make sure we do not use future messages to influence the model parameters.
mean(successes)
0.85
We see that our agent is able to learn the optimal policy for this environment, and reaches the goal state in 85% of cases!
plot_environment(env)
This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.
We welcome and encourage contributions! You can help by:
- Improving this example
- Creating new examples
- Reporting issues or bugs
- Suggesting enhancements
Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪
This example was executed in a clean, isolated environment. Below are the exact package versions used:
For reproducibility:
- Use the same package versions when running locally
- Report any issues with package compatibility
Status `~/work/RxInferExamples.jl/RxInferExamples.jl/docs/src/categories/basic_examples/pomdp_control/Project.toml`
[31c24e10] Distributions v0.25.117
[91a5bcdd] Plots v1.40.9
[92933f4c] ProgressMeter v1.10.2
[5ea003d0] RxEnvironments v0.2.15
[86711068] RxInfer v4.0.1