prism-probabilistic-intention-switching

star 2

PRISM (Probabilistic Recurrent Intention Switching Model) methodology for multi-intention inverse reinforcement learning. Uses lightweight recurrent networks for intention switching with closed-form EM solution. Activation: 多意图 IRL, intention switching, PRISM, 目标切换, recurrent intention, EM algorithm.

hiyenwong By hiyenwong schedule Updated 6/4/2026

name: prism-probabilistic-intention-switching description: "PRISM (Probabilistic Recurrent Intention Switching Model) methodology for multi-intention inverse reinforcement learning. Uses lightweight recurrent networks for intention switching with closed-form EM solution. Activation: 多意图 IRL, intention switching, PRISM, 目标切换, recurrent intention, EM algorithm."

PRISM: Probabilistic Recurrent Intention Switching Model

Research Source

Title: Probabilistic Recurrent Intention Switching Model

arXiv ID: 2605.26998

Authors: Wenyuan Sheng, Hao Zhu, Joschka Boedecker

Published: May 26, 2026

Categories: Machine Learning (cs.LG); Neurons and Cognition (q-bio.NC)

Link: https://arxiv.org/abs/2605.26998

Core Innovation

PRISM introduces a lightweight recurrent network to model intention switching in inverse reinforcement learning (IRL), replacing Markov chain and fixed history window approaches. Key breakthrough: Exact EM decomposition into independent per-intention subproblems, solved in closed form without variational approximation.

Problem Statement

Traditional IRL limitations:

  • ❌ Single stationary reward assumption
  • ❌ Cannot capture goal switching within episode
  • ❌ Multi-intention IRL methods use Markov chain (memoryless) or fixed history window (limited)

PRISM solution:

  • ✅ Recurrent network maps observation history to intention distribution
  • ✅ Captures non-Markovian intention transitions
  • ✅ Exact EM decomposition with closed-form solutions

Mathematical Framework

Model Architecture

PRISM Architecture:

Input: Observation sequence o₁, o₂, ..., o_T
Recurrent Network: h_t = RNN(o_t, h_{t-1})  
Intention Distribution: π_t = Softmax(W * h_t)
Reward Function: r_i(a|s) for each intention i
Behavior: a_t drawn from policy π(a|s, intention)

Key Innovation: π_t depends on FULL history through RNN, not just current state or fixed window

EM Algorithm Decomposition

Main Result: The EM objective decomposes exactly into independent per-intention subproblems.

E-step:
  Q_i = Expected reward for intention i
  
M-step:
  For each intention i independently:
    r_i* = argmax_r Σ_t π_t(i) * log P(a_t | s_t, r_i)
  
  Closed-form solution for reward estimation!

Proof Sketch

  1. Likelihood Function:

    L = Σ_t log Σ_i π_t(i) * P(a_t | s_t, r_i)
    
  2. E-step (Expectation):

    Q_i(r) = Σ_t π_t(i) * log P(a_t | s_t, r_i)
    
  3. M-step (Maximization):

    Key observation: Q_i depends ONLY on intention i's reward
    → Each intention can be optimized independently
    → Closed-form solution: r_i = estimate_reward(D_i)
    
  4. No Variational Approximation:

    • Standard multi-intention IRL requires variational approximation
    • PRISM achieves exact decomposition through recurrent network structure

Implementation

Algorithm

import torch
import torch.nn as nn

class PRISM(nn.Module):
    """
    Probabilistic Recurrent Intention Switching Model
    
    Args:
        obs_dim: observation dimension
        hidden_dim: RNN hidden dimension  
        num_intentions: number of intentions/goals
        state_dim: state dimension
        action_dim: action dimension
    """
    
    def __init__(self, obs_dim, hidden_dim, num_intentions, state_dim, action_dim):
        super().__init__()
        
        # Recurrent network for intention inference
        self.rnn = nn.GRU(obs_dim, hidden_dim, batch_first=True)
        
        # Intention distribution layer
        self.intention_layer = nn.Linear(hidden_dim, num_intentions)
        
        # Reward functions for each intention (tabular or parametric)
        self.reward_functions = nn.ModuleList([
            RewardFunction(state_dim, action_dim) 
            for _ in range(num_intentions)
        ])
        
    def forward(self, observations):
        """
        Infer intention distribution from observation history
        
        Args:
            observations: (batch, T, obs_dim)
        
        Returns:
            intention_dist: (batch, T, num_intentions)
        """
        # Run RNN over full observation history
        h_seq, _ = self.rnn(observations)
        
        # Compute intention distribution at each time step
        intention_logits = self.intention_layer(h_seq)
        intention_dist = torch.softmax(intention_logits, dim=-1)
        
        return intention_dist
    
    def e_step(self, trajectories, intention_dist):
        """
        Expectation step: compute expected reward for each intention
        
        Args:
            trajectories: list of (s, a) pairs
            intention_dist: (batch, T, num_intentions)
        
        Returns:
            Q: expected reward for each intention
        """
        Q = {}
        for i in range(self.num_intentions):
            Q[i] = self.compute_expected_reward(trajectories, intention_dist, i)
        return Q
    
    def m_step(self, Q):
        """
        Maximization step: update rewards (closed-form)
        
        Each intention optimized independently - no coupling!
        
        Args:
            Q: expected rewards from E-step
        
        Returns:
            updated_rewards: dict of reward parameters
        """
        updated_rewards = {}
        for i in range(self.num_intentions):
            # Closed-form solution for intention i
            updated_rewards[i] = self.solve_reward_closed_form(Q[i])
        
        return updated_rewards
    
    def em_train(self, trajectories, num_iterations=100):
        """
        EM training loop
        
        Args:
            trajectories: list of trajectories
            num_iterations: EM iterations
        """
        for iteration in range(num_iterations):
            # Infer intention distribution
            observations = extract_observations(trajectories)
            intention_dist = self.forward(observations)
            
            # E-step
            Q = self.e_step(trajectories, intention_dist)
            
            # M-step (closed-form, no variational approximation)
            new_rewards = self.m_step(Q)
            
            # Update reward parameters
            self.update_rewards(new_rewards)
            
            # Compute log-likelihood
            ll = self.compute_log_likelihood(trajectories, intention_dist)
            print(f"Iteration {iteration}, Log-likelihood: {ll}")

Closed-Form Reward Estimation

def solve_reward_closed_form(Q_i, regularization=0.1):
    """
    Closed-form solution for reward function
    
    For linear reward: r(s,a) = θ·φ(s,a)
    
    Args:
        Q_i: expected reward statistics for intention i
        regularization: L2 regularization strength
    
    Returns:
        theta: reward parameters (closed-form solution)
    """
    # Collect features and weights
    features = []
    weights = []
    
    for (s, a), prob in Q_i:
        phi = feature_vector(s, a)  # e.g., (s, a) → φ
        features.append(phi)
        weights.append(prob)
    
    F = np.array(features)
    W = np.array(weights)
    
    # Closed-form solution (weighted least squares with regularization)
    # θ = (F^T W F + λI)^{-1} F^T W y
    # where y is reward values (observed behavior probability)
    
    theta = np.linalg.solve(
        F.T @ W @ F + regularization * np.eye(F.shape[1]),
        F.T @ W @ np.ones(len(W))  # normalize to probability
    )
    
    return theta

Experimental Validation

Testbeds

Environment Type Key Challenge Result
Non-Markovian Gridworld Synthetic Hidden intention states Highest log-likelihood
Mouse Labyrinth Biological simulation Temporal goal coherence Recovered coherent intentions
BridgeData V2 Robotic manipulation Large-scale real-world First large-scale multi-intention IRL

Performance Metrics

  1. Held-out Log-Likelihood:

    • PRISM > Markov chain baseline
    • PRISM > Fixed history window baseline
  2. Intention Recovery Quality:

    • Recovered intentions are nameable and temporally coherent
    • Human inspection confirms meaningful goal segments
  3. Application Scale:

    • First large-scale robotic application of multi-intention IRL
    • BridgeData V2: real-world manipulation demonstrations

Research Applications

Use Case 1: Human Behavior Modeling

# Model human decision-making with goal switching

# Load human demonstration data
human_trajectories = load_demonstrations('human_data.csv')

# Train PRISM to recover intentions
prism = PRISM(obs_dim=10, hidden_dim=32, num_intentions=5)
prism.em_train(human_trajectories)

# Result: recovered intentions correspond to human goals
# e.g., "navigate to kitchen", "pick up object", "return to desk"

Use Case 2: Robot Policy Learning

# Learn multi-goal robot policy from demonstrations

# Load robot demonstration data (BridgeData V2)
robot_demos = load_robot_demonstrations('bridge_v2')

# Train PRISM
prism_robot = PRISM(obs_dim=50, hidden_dim=128, num_intentions=10)
prism_robot.em_train(robot_demos)

# Use learned intentions for robot control
def robot_control(current_state, observation_history):
    intention_dist = prism_robot.forward(observation_history)
    dominant_intention = argmax(intention_dist[-1])
    reward = prism_robot.reward_functions[dominant_intention]
    action = select_action(current_state, reward)
    return action

Use Case 3: Animal Behavior Analysis

# Analyze animal behavior with goal switching

# Load mouse labyrinth data
mouse_data = load_mouse_trajectories('labyrinth.npy')

# Train PRISM to infer mouse intentions
prism_mouse = PRISM(obs_dim=20, hidden_dim=64, num_intentions=4)
intention_sequence = prism_mouse.forward(mouse_data.obs)

# Visualize intention switches
plot_intention_timeline(intention_sequence)
# Shows coherent intention segments: "explore", "seek food", "return home"

Key Advantages

1. No Variational Approximation

Standard Multi-Intention IRL:

  • Requires variational lower bound
  • Optimization may not converge to true solution

PRISM:

  • Exact EM decomposition
  • Closed-form reward estimation
  • Guaranteed convergence to local optimum

2. Non-Markovian Capability

Markov Chain Approach:

  • π_t depends only on π_{t-1}
  • Cannot capture long-range dependencies

PRISM RNN:

  • π_t depends on FULL history via h_t
  • Captures complex temporal patterns

3. Scalability

Application Scale:

  • BridgeData V2: ~100,000 demonstrations
  • First multi-intention IRL at this scale
  • Computational efficiency from closed-form solutions

Pitfalls and Solutions

Pitfall 1: Intention Number Selection

Problem: How many intentions exist in data?

Solution:

  • Use log-likelihood cross-validation
  • Penalize over-segmentation (similar to model selection)
  • Start with small number and increment

Pitfall 2: Reward Function Parameterization

Problem: Tabular rewards don't generalize; parametric may underfit

Solution:

  • Use neural network rewards for complex environments
  • Linear rewards for simple environments (closed-form)
  • Hybrid: linear features + neural network transformations

Pitfall 3: RNN Training Stability

Problem: RNN may struggle with long sequences

Solution:

  • Use GRU (gated recurrent unit) for stability
  • Gradient clipping
  • Curriculum learning: start with short sequences

Key Takeaways

  1. Lightweight Recurrent Network: Maps observation history to intention distribution, replacing Markov/fixed-window approaches

  2. Exact EM Decomposition: Closed-form solutions for each intention independently - no variational approximation needed

  3. Non-Markovian Capability: Captures long-range dependencies in intention switching through RNN

  4. Large-Scale Application: First work to apply multi-intention IRL to large-scale robotic dataset (BridgeData V2)

  5. Biological & Artificial Agents: Discovered goal switching in both domains, suggesting common mechanisms

Comparison with Alternatives

Method Intention Memory Optimization Scalability Non-Markovian
Markov Chain Single step Variational Medium ❌ No
Fixed Window Limited history Variational Low Limited
PRISM Full history (RNN) Closed-form EM High ✅ Yes

Related Skills

  • inverse-reinforcement-learning: Traditional IRL methods
  • multi-intention-irl: Multi-goal behavior modeling
  • recurrent-networks: RNN architectures for sequential data
  • robot-learning: Learning robot policies from demonstrations
  • human-behavior-modeling: Modeling human decision processes

Activation Keywords

  • 多意图 IRL (multi-intention IRL)
  • intention switching (目标切换)
  • PRISM
  • recurrent intention model
  • EM algorithm
  • closed-form solution
  • goal switching
  • multi-goal behavior
  • observation history
  • variational-free

Recommended Model

sonnet4.5 for algorithm implementation and behavior modeling

opus4.5 for large-scale robotic applications and theoretical analysis

Tools Used

  • exec: Run EM training and reward estimation
  • read: Load demonstration datasets
  • write: Save learned intentions and policies

Citation

@article{sheng2026prism,
  title={Probabilistic Recurrent Intention Switching Model},
  author={Sheng, Wenyuan and Zhu, Hao and Boedecker, Joschka},
  journal={arXiv preprint arXiv:2605.26998},
  year={2026},
  categories={cs.LG, q-bio.NC}
}

Further Reading

  • Inverse Reinforcement Learning (Ng & Russell, 2000)
  • Multi-Intention IRL ( Babes et al., ICML 2011)
  • Bayesian Nonparametric IRL (Choi & Kim, NIPS 2011)
  • BridgeData V2 (Ebert et al., 2022)
Install via CLI
npx skills add https://github.com/hiyenwong/ai_collection --skill prism-probabilistic-intention-switching
Repository Details
star Stars 2
call_split Forks 0
navigation Branch main
article Path SKILL.md
More from Creator