Contributing

This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.

We welcome and encourage contributions! You can help by:

  • Improving this example
  • Creating new examples
  • Reporting issues or bugs
  • Suggesting enhancements

Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪


Drone Dynamics

Note: These examples demonstrate the use of RxInfer for motion planning. The animations show the inferred trajectories from probabilistic inference, rather than simulated executions. For more realistic simulations, especially in the 3D drone example, the model would need to be extended with a reactive environment that responds to the drone's actions during plan execution. If you're interested in collaborating on a more realistic implementation, please open a discussion and let's work on it together!

using RxInfer, LinearAlgebra

Defining structures

"""
    Environment(; gravitational_constant::Float64 = 9.81)

This structure contains the properties of the environment.
"""
Base.@kwdef struct Environment
    gravitational_constant::Float64 = 9.81
end
get_gravity(env::Environment) = env.gravitational_constant
get_gravity (generic function with 1 method)
"""
    Drone(mass, inertia, radius, force_limit)

This structure contains the properties of the drone.
"""
Base.@kwdef struct Drone
    mass::Float64
    inertia::Float64
    radius::Float64
    force_limit::Float64
end
get_mass(drone::Drone) = drone.mass
get_properties(drone::Drone) = (drone.mass, drone.inertia, drone.radius, drone.force_limit)
get_properties (generic function with 1 method)
"""
    State(x, y, vx, vy, 𝜃, 𝜔)

This structure contains the state of the drone. It contains the position, velocity, and orientation of the drone.
"""
struct State
    x::Float64
    y::Float64
    vx::Float64
    vy::Float64
    𝜃::Float64
    𝜔::Float64
end
get_state(state::State) = (state.x, state.y, state.vx, state.vy, state.𝜃, state.𝜔)
get_state (generic function with 1 method)

Model specification

"""
    state_transition(state, actions, drone, environment, dt)

This function computes the next state of the drone given the current state, the actions, the drone properties and the environment properties.
"""
function state_transition(state, actions, drone::Drone, environment::Environment, dt)

    # extract drone properties
    m, I, r, limit  = get_properties(drone)

    # extract environment properties
    g = get_gravity(environment)

    # extract feasible actions
    Fl, Fr   = clamp.(actions, 0, limit)
        
    # extract state properties
    x, y, vx, vy, θ, ω = state

    # compute forces and torques
    Fg = m * g
    Fy = (Fl + Fr) * cos(θ) - Fg
    Fx = (Fl + Fr) * sin(θ)
    𝜏  = (Fl - Fr) * r

    # compute movements
    ax = Fx / m
    ay = Fy / m
    vx_new = vx + ax * dt
#     vy_new = vx + ay * dt # old version
    vy_new = vy + ay * dt   # new version
    x_new  = x + vx * dt + ax * dt^2 / 2
    y_new  = y + vy * dt + ay * dt^2 / 2
        
    # compute rotations
    α = 𝜏 / I
    ω_new = ω + α * dt
    θ_new = θ + ω * dt + α * dt^2 / 2
	
    return [x_new, y_new, vx_new, vy_new, θ_new, ω_new]

end
Main.anonymous.state_transition
@model function drone_model(drone, environment, initial_state, goal, horizon, dt)	

	# extract environment properties
	g = get_gravity(environment)

	# extract drone properties
	m = get_mass(drone)

	# initial state prior
	s[1] ~ MvNormal(mean = initial_state, covariance = 1e-5 * I)

	for i in 1:horizon

		# prior on actions (mean compensates for gravity)
		u[i] ~ MvNormal(μ = [m * g / 2, m * g / 2], Σ = diageye(2))

		# state transition
		s[i + 1] ~ MvNormal(
            μ = state_transition(s[i], u[i], drone, environment, dt), 
			Σ = 1e-10 * I
		)
	end
	
	s[end] ~ MvNormal(mean = goal, covariance = 1e-5 * diageye(6))

end

Probabilistic inference

@meta function drone_meta()

	# approximate the state transition function using the Unscented transform
	state_transition() -> Unscented()

end
drone_meta (generic function with 1 method)
function move_to_target(drone::Drone, env::Environment, start::State, target, horizon, dt)

    results = infer(
        model = drone_model(
            drone = drone, 
            environment = env,
            horizon = horizon,
            dt = dt
        ),
        data  = (
            initial_state = collect(get_state(start)), 
            goal = [target[1], target[2], 0, 0, 0, 0],
        ),
        meta  = drone_meta(),
        returnvars = (s = KeepLast(), u = KeepLast())
    )

    return results

end
move_to_target (generic function with 1 method)
drone = Drone(
    mass = 1,
    inertia = 1,
    radius = 0.2,
    force_limit = 15.0
)

env = Environment()

start = State(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)

target = [-0.8, 0.6]

results = move_to_target(drone, env, start, target, 40, 0.05)
Inference results:
  Posteriors       | available for (s, u)

Plotting

using Plots
function plot_drone!(p, drone::Drone, state::State; color = :black)
    x, y, x_a, y_a, θ, ω = get_state(state)
    _, _, radius, _ = get_properties(drone)
    dx = radius * cos(θ)
    dy = radius * sin(θ)

    drone_position = [ x ], [ y ]
    drone_engines  = [ x - dx, x + dx ], [ y + dy, y - dy ]
    drone_coordinates = [ x - dx, x, x + dx ], [ y + dy, y, y - dy ]

    rotation_matrix = [ cos(-θ) -sin(-θ); sin(-θ) cos(-θ) ]
    engine_shape = [ -1 0 1; 1 -1 1 ]
    drone_shape  = [ -2 -2 2 2 ; -1 1 1 -1 ]
    
    engine_shape = rotation_matrix * engine_shape
    drone_shape  = rotation_matrix * drone_shape
    engine_marker = Shape(engine_shape[1, :], engine_shape[2, :])
    drone_marker  = Shape(drone_shape[1, :], drone_shape[2, :])
    
    scatter!(p, drone_position[1], drone_position[2]; color = color, label = false, marker = drone_marker)
    scatter!(p, drone_engines[1], drone_engines[2]; color = color, label = false, marker = engine_marker, ms = 10)
    plot!(p, drone_coordinates; color = color, label = false)

    return p
end
plot_drone! (generic function with 1 method)
function animate_drone(drone::Drone, target, results::InferenceResult)

    states = hcat(map(p -> mean(p), results.posteriors[:s])...)
    

    animation = @animate for k in 1:size(states,2)

        # plot target
        p = scatter([target[1]], [target[2]], label = "target"; color = :red)

        # plot drone
        plot_drone!(p, drone, State(states[:, k]...))

        xlims!(-1.5, 1.5)
        ylims!(-1.5, 1.5)
    
    end

    gif(animation, "drone.gif", show_msg = false)

    nothing
end
animate_drone (generic function with 1 method)
animate_drone(drone, target, results)

let 
    inferred_angle_mean = map(p -> mean(p)[5], results.posteriors[:s])
    inferred_angle_std  = map(p -> std(p)[5],  results.posteriors[:s])
    plot(inferred_angle_mean; ribbon = inferred_angle_std, fillalpha = 0.2, label = "inferred angle", size=(600,300))
end

let 
    inferred_forces_mean = hcat(map(p -> mean(p), results.posteriors[:u])...)'
    inferred_forces_std  = hcat(map(p -> sqrt.(var(p)),  results.posteriors[:u])...)'
    plot(inferred_forces_mean[:,1]; ribbon = inferred_forces_std[:,1], fillalpha = 0.2, label = "Fl", size=(600,300))
    plot!(inferred_forces_mean[:,2]; ribbon = inferred_forces_std[:,2], fillalpha = 0.2, label = "Fr", size=(600,300))
    
    hline!([get_mass(drone) * get_gravity(env) / 2], label = "Fg/2")
end

3D Drone Extension

This section extends our 2D drone model into three-dimensional space, allowing for full spatial navigation. The model includes:

  • 6 degrees of freedom (position and orientation)
  • Four-motor configuration
  • Basic aerodynamic forces

Note: This implementation is a simplified model intended for educational purposes. While it captures the fundamental dynamics of a quadcopter, it omits advanced aerodynamic effects and motor dynamics for clarity.

# Extended Drone structure for 4 motors
Base.@kwdef struct Drone3D
    mass::Float64
    inertia::Matrix{Float64}  # 3x3 inertia matrix
    radius::Float64
    arm_length::Float64
    force_limit::Float64
end
Main.anonymous.Drone3D
function get_properties(drone::Drone3D)
    return (
        drone.mass,
        drone.inertia,
        drone.radius,
        drone.arm_length,
        drone.force_limit
    )
end
get_properties (generic function with 2 methods)
# Extended State for 3D
struct State3D
    x::Float64   # position
    y::Float64
    z::Float64
    vx::Float64  # velocity
    vy::Float64
    vz::Float64
    ϕ::Float64   # roll
    θ::Float64   # pitch
    ψ::Float64   # yaw
    ωx::Float64  # angular velocity
    ωy::Float64
    ωz::Float64
end
function get_state(state::State3D)
    return (
        state.x, state.y, state.z,
        state.vx, state.vy, state.vz,
        state.ϕ, state.θ, state.ψ,
        state.ωx, state.ωy, state.ωz
    )
end
get_state (generic function with 2 methods)
"""
    rotation_matrix(ψ, θ, ϕ)

Create a 3D rotation matrix from yaw (ψ), pitch (θ), and roll (ϕ) angles.
"""
function rotation_matrix(ψ, θ, ϕ)
    # Rotation matrices for each axis
    Rz = [cos(ψ) -sin(ψ) 0;
          sin(ψ)  cos(ψ) 0;
          0       0      1]
    
    Ry = [cos(θ)  0  sin(θ);
          0       1  0;
         -sin(θ)  0  cos(θ)]
    
    Rx = [1  0       0;
          0  cos(ϕ) -sin(ϕ);
          0  sin(ϕ)  cos(ϕ)]
    
    # Combined rotation matrix (ZYX order)
    return Rz * Ry * Rx
end
Main.anonymous.rotation_matrix
"""
    state_transition_3d(state, actions, drone, environment, dt)

Compute the next state of the 3D drone given current state and four motor forces.
"""
function state_transition_3d(state, actions, drone::Drone3D, environment::Environment, dt)
    # Extract properties
    m, I, r, L, limit = get_properties(drone)
    g = get_gravity(environment)
    
    # Clamp motor forces
    F1, F2, F3, F4 = clamp.(actions, 0, limit)
    
    # Extract state
    x, y, z, vx, vy, vz, ϕ, θ, ψ, ωx, ωy, ωz = state
    
    # Current rotation matrix
    R = rotation_matrix(ψ, θ, ϕ)
    
    # Total thrust force in body frame
    F_total = sum([F1, F2, F3, F4])
    
    # Compute torques
    τx = L * (F2 - F4)  # roll torque
    τy = L * (F1 - F3)   # pitch torque
    τz = (F1 + F3 - F2 - F4) * r  # yaw torque
    
    # Forces in world frame
    F_world = R * [0, 0, F_total]
    
    # Accelerations
    ax = F_world[1] / m
    ay = F_world[2] / m
    az = F_world[3] / m - g
    
    # Angular accelerations
    α = I \ ([τx, τy, τz] - cross([ωx, ωy, ωz], I * [ωx, ωy, ωz]))
    
    # Update velocities
    vx_new = vx + ax * dt
    vy_new = vy + ay * dt
    vz_new = vz + az * dt
    
    # Update positions
    x_new = x + vx * dt + ax * dt^2 / 2
    y_new = y + vy * dt + ay * dt^2 / 2
    z_new = z + vz * dt + az * dt^2 / 2
    
    # Update angular velocities
    ωx_new = ωx + α[1] * dt
    ωy_new = ωy + α[2] * dt
    ωz_new = ωz + α[3] * dt
    
    # Update angles
    ϕ_new = ϕ + ωx * dt + α[1] * dt^2 / 2
    θ_new = θ + ωy * dt + α[2] * dt^2 / 2
    ψ_new = ψ + ωz * dt + α[3] * dt^2 / 2
    
    return [
        x_new, y_new, z_new,
        vx_new, vy_new, vz_new,
        ϕ_new, θ_new, ψ_new,
        ωx_new, ωy_new, ωz_new
    ]
end
Main.anonymous.state_transition_3d
@model function drone_model_3d(drone, environment, initial_state, goal, horizon, dt)
    # Extract properties
    g = get_gravity(environment)
    m = drone.mass
    
    # Initial state prior
    s[1] ~ MvNormal(mean = initial_state, covariance = 1e-5 * I)
    
    for i in 1:horizon
        # Prior on motor actions (mean compensates for gravity)
        hover_force = m * g / 4
        u[i] ~ MvNormal(μ = [hover_force, hover_force, hover_force, hover_force], Σ = diageye(4))
        
        # State transition
        s[i + 1] ~ MvNormal(
            μ = state_transition_3d(s[i], u[i], drone, environment, dt),
            Σ = 1e-10 * I
        )
    end
    
    s[end] ~ MvNormal(mean = goal, covariance = 1e-5 * diageye(12))
end
@meta function drone_meta_3d()
    state_transition_3d() -> Unscented()
end
drone_meta_3d (generic function with 1 method)
function move_to_target_3d(drone::Drone3D, env::Environment, start::State3D, target, horizon, dt)
    results = infer(
        model = drone_model_3d(
            drone = drone,
            environment = env,
            horizon = horizon,
            dt = dt
        ),
        data = (
            initial_state = collect(get_state(start)),
            goal = [target[1], target[2], target[3], 0, 0, 0, 0, 0, 0, 0, 0, 0],
        ),
        meta = drone_meta_3d(),
        returnvars = (s = KeepLast(), u = KeepLast())
    )
    
    return results
end
move_to_target_3d (generic function with 1 method)
function move_through_waypoints(drone::Drone3D, env::Environment, start::State3D, waypoints, steps_per_segment=40, dt=0.05)
    current_state = start
    all_results = []
    all_states = []
    
    # Move through each waypoint
    for (i, target) in enumerate(waypoints)
        println("Moving to waypoint $i: $target")
        
        # Get results for this segment
        results = move_to_target_3d(drone, env, current_state, target, steps_per_segment, dt)
        push!(all_results, results)
        
        # Extract final state for next segment
        final_states = hcat(map(p -> mean(p), results.posteriors[:s])...)
        final_state = State3D(final_states[:, end]...)
        push!(all_states, final_states)
        
        # Update current state
        current_state = final_state
    end
    
    # Combine all states for animation
    combined_states = hcat(all_states...)
    
    return combined_states, waypoints
end
move_through_waypoints (generic function with 3 methods)
# Visualization function for 3D drone
function plot_drone_3d!(p, drone::Drone3D, state::State3D; color=:black)
    x, y, z, _, _, _, ϕ, θ, ψ, _, _, _ = get_state(state)
    _, _, radius, arm_length, _ = get_properties(drone)
    
    # Create rotation matrix
    R = rotation_matrix(ψ, θ, ϕ)
    
    # Define arm endpoints in body frame (relative to center)
    arm_endpoints = [
        [arm_length, 0, 0],   # Right arm (X configuration)
        [0, arm_length, 0],   # Front arm
        [-arm_length, 0, 0],  # Left arm
        [0, -arm_length, 0]   # Back arm
    ]
    
    # Transform arm endpoints to world frame
    world_endpoints = []
    for endpoint in arm_endpoints
        # Convert endpoint to column vector for matrix multiplication
        endpoint_vec = reshape(endpoint, :, 1)
        # Apply rotation and translation
        world_point = R * endpoint_vec + [x, y, z]
        push!(world_endpoints, vec(world_point))
    end
    
    # Plot center
    scatter!(p, [x], [y], [z], color=color, label=false, markersize=5)
    
    # Plot arms and motors
    for endpoint in world_endpoints
        # Draw arm
        plot!(p, [x, endpoint[1]], [y, endpoint[2]], [z, endpoint[3]], 
              color=color, label=false, linewidth=2)
        # Draw motor
        scatter!(p, [endpoint[1]], [endpoint[2]], [endpoint[3]], 
                color=color, label=false, markersize=3)
    end
end
plot_drone_3d! (generic function with 1 method)
function animate_drone_3d_multi(drone::Drone3D, states, targets; fps=30)
    # Note: The rain animation is purely for visualization aesthetics
    # and does not affect the drone's dynamics or trajectory planning
    
    # Create initial rain streaks
    function generate_raindrops(n=50)
        x = 4 * rand(n) .- 2  # range [-2, 2]
        y = 4 * rand(n) .- 2
        z1 = 2 .+ 4 * rand(n)  # start higher up to have some offscreen
        z2 = z1 .- 0.3  # fixed length rain streaks
        return (x, y, z1, z2)
    end
    
    # Initialize raindrops
    raindrops = generate_raindrops()
    
    animation = @animate for k in 1:size(states,2)
        # Update raindrop positions
        fall_speed = 0.1
        z1 = raindrops[3] .- fall_speed
        z2 = raindrops[4] .- fall_speed
        
        # Regenerate raindrops that have fallen below view
        below_view = findall(z2 .<= -2)
        if !isempty(below_view)
            new_drops = generate_raindrops(length(below_view))
            raindrops[1][below_view] = new_drops[1]
            raindrops[2][below_view] = new_drops[2]
            z1[below_view] = new_drops[3]
            z2[below_view] = new_drops[4]
        end
        
        # Update raindrops state
        raindrops = (raindrops[1], raindrops[2], z1, z2)
        
        p = plot3d(
            xlims=(-2, 2), ylims=(-2, 2), zlims=(-2, 2),
            xlabel="X", ylabel="Y", zlabel="Z",
            camera=(45, 30),
            title="Multi-Waypoint Drone Flight",
            background=:white
        )
        
        # Draw rain streaks
        for i in 1:length(raindrops[1])
            if raindrops[4][i] > -2  # only draw if in view
                plot!(p, [raindrops[1][i], raindrops[1][i]], 
                        [raindrops[2][i], raindrops[2][i]], 
                        [raindrops[3][i], raindrops[4][i]],
                     color=:grey, 
                     linestyle=:dash,
                     alpha=0.6,
                     legend=false,
                     linewidth=1)
            end
        end
        
        # Plot all targets
        for (i, target) in enumerate(targets)
            scatter!(p, [target[1]], [target[2]], [target[3]], 
                    label=i == 1 ? "waypoints" : false, 
                    color=:red,
                    markersize=i == 1 ? 5 : 3)
            
            # Connect waypoints with lines
            if i > 1
                prev_target = targets[i-1]
                plot!(p, [prev_target[1], target[1]], 
                         [prev_target[2], target[2]], 
                         [prev_target[3], target[3]],
                     color=:red, linestyle=:dash, label=false, linewidth=1)
            end
        end
        
        # Plot drone
        plot_drone_3d!(p, drone_3d, State3D(states[:, k]...))
        
        # Add trajectory trace (last 100 points)
        trace_start = max(1, k-100)
        if k > 1
            plot!(p, states[1,trace_start:k], states[2,trace_start:k], states[3,trace_start:k],
                  color=:blue, label=false, linewidth=1, linealpha=0.5)
        end
    end

    gif(animation, "drone_3d_multi.gif", fps=fps, show_msg = false)

    nothing
end
animate_drone_3d_multi (generic function with 1 method)
# Create drone instance
drone_3d = Drone3D(
    mass = 1.0,
    inertia = diagm([0.1, 0.1, 0.15]),  # 3×3 diagonal inertia matrix
    radius = 0.1,
    arm_length = 0.2,
    force_limit = 15.0
)

env = Environment()

# Initial state
start = State3D(
    0.0, 0.0, 0.0,  # position (x, y, z)
    0.0, 0.0, 0.0,  # velocity (vx, vy, vz)
    0.0, 0.0, 0.0,  # orientation (ϕ, θ, ψ)
    0.0, 0.0, 0.0   # angular velocity (ωx, ωy, ωz)
)

# Define a sequence of waypoints for a square pattern
waypoints = [
    [1.0, 1.0, 1.0],    # Front-right corner
    [-1.0, 1.0, 1.0],   # Front-left corner
    [-1.0, -1.0, 1.0],  # Back-left corner
    [1.0, -1.0, 1.0],   # Back-right corner
    [0.0, 0.0, 0.0]     # Land at center
]

# Run simulation through all waypoints
combined_states, targets = move_through_waypoints(drone_3d, env, start, waypoints);
Moving to waypoint 1: [1.0, 1.0, 1.0]
Moving to waypoint 2: [-1.0, 1.0, 1.0]
Moving to waypoint 3: [-1.0, -1.0, 1.0]
Moving to waypoint 4: [1.0, -1.0, 1.0]
Moving to waypoint 5: [0.0, 0.0, 0.0]
animate_drone_3d_multi(drone_3d, combined_states, targets)


Contributing

This example was automatically generated from a Jupyter notebook in the RxInferExamples.jl repository.

We welcome and encourage contributions! You can help by:

  • Improving this example
  • Creating new examples
  • Reporting issues or bugs
  • Suggesting enhancements

Visit our GitHub repository to get started. Together we can make RxInfer.jl even better! 💪


Environment

This example was executed in a clean, isolated environment. Below are the exact package versions used:

For reproducibility:

  • Use the same package versions when running locally
  • Report any issues with package compatibility
Status `~/work/RxInferExamples.jl/RxInferExamples.jl/docs/src/categories/advanced_examples/drone_dynamics/Project.toml`
  [91a5bcdd] Plots v1.40.9
  [86711068] RxInfer v4.0.1
  [37e2e46d] LinearAlgebra v1.11.0