gflowstate-training-visualization

star 2

GFlowState visual analytics system for illuminating Generative Flow Network training. Provides interactive visualizations of training states, flow distributions, and mode coverage for diagnosing mode collapse and reward hacking. Activation: GFlowState, GFlowNet visualization, generative flow networks, training visualization.

hiyenwong By hiyenwong schedule Updated 6/3/2026

name: gflowstate-training-visualization description: "GFlowState visual analytics system for illuminating Generative Flow Network training. Provides interactive visualizations of training states, flow distributions, and mode coverage for diagnosing mode collapse and reward hacking. Activation: GFlowState, GFlowNet visualization, generative flow networks, training visualization."

GFlowState: GFlowNet Training Visualization

Visual analytics system for illuminating the training process of Generative Flow Networks (GFlowNets), enabling diagnosis of training issues like mode collapse, insufficient exploration, and reward hacking.

Metadata

  • Source: arXiv:2604.21830
  • Authors: Thomas Weber, Anna Kowalski, James Liu
  • Published: 2026-04-23
  • Category: cs.LG

Core Methodology

GFlowNets Background

Generative Flow Networks (GFlowNets) are generative models that:

  • Learn to sample objects proportionally to a given reward function
  • Use flow matching objective for training
  • Applications in molecule generation, causal discovery, combinatorial optimization
  • Complex interplay between flow matching and exploratory behavior

Training Challenges

  1. Mode collapse: Model focuses on high-reward modes, missing diversity
  2. Insufficient exploration: Policy gets stuck in local optima
  3. Reward hacking: Exploits reward function rather than learning true distribution
  4. Training instability: Difficult to diagnose when training goes wrong

GFlowState Visualization Components

1. Training State Visualization

  • Real-time monitoring of training metrics
  • Flow consistency over time
  • Loss decomposition (flow loss, policy loss, value loss)

2. Flow Distribution Tracking

  • Distribution of flow values across states
  • Comparison of learned vs target flows
  • Flow mismatch visualization

3. Mode Coverage Analysis

  • Coverage of reward modes over training
  • Diversity metrics for generated samples
  • Mode discovery timeline

Implementation Guide

Prerequisites

  • GFlowNet implementation (PyTorch/JAX)
  • Visualization library (Plotly, D3.js, or similar)
  • Real-time data streaming capability

Visualization Components

Component 1: Training Metrics Dashboard

import plotly.graph_objects as go
from plotly.subplots import make_subplots

class GFlowTrainingDashboard:
    """Real-time training metrics visualization."""
    
    def __init__(self):
        self.metrics_history = {
            'flow_loss': [],
            'policy_loss': [],
            'value_loss': [],
            'reward_mean': [],
            'reward_std': [],
            'step': []
        }
        
    def update(self, step, metrics):
        """Update dashboard with new training metrics."""
        self.metrics_history['step'].append(step)
        self.metrics_history['flow_loss'].append(metrics['flow_loss'])
        self.metrics_history['policy_loss'].append(metrics['policy_loss'])
        self.metrics_history['value_loss'].append(metrics['value_loss'])
        self.metrics_history['reward_mean'].append(metrics['reward_mean'])
        self.metrics_history['reward_std'].append(metrics['reward_std'])
        
    def plot(self):
        """Generate comprehensive training plot."""
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('Flow Loss', 'Policy Loss', 
                          'Value Loss', 'Reward Distribution')
        )
        
        steps = self.metrics_history['step']
        
        # Flow Loss
        fig.add_trace(
            go.Scatter(x=steps, y=self.metrics_history['flow_loss'],
                      mode='lines', name='Flow Loss'),
            row=1, col=1
        )
        
        # Policy Loss
        fig.add_trace(
            go.Scatter(x=steps, y=self.metrics_history['policy_loss'],
                      mode='lines', name='Policy Loss'),
            row=1, col=2
        )
        
        # Value Loss
        fig.add_trace(
            go.Scatter(x=steps, y=self.metrics_history['value_loss'],
                      mode='lines', name='Value Loss'),
            row=2, col=1
        )
        
        # Reward distribution over time (heatmap)
        # Implementation depends on data structure
        
        fig.update_layout(height=800, title_text="GFlowNet Training Metrics")
        return fig

Component 2: Flow Distribution Visualization

class FlowDistributionVisualizer:
    """Visualize learned vs target flow distributions."""
    
    def __init__(self):
        self.state_flows = {}
        self.target_flows = {}
        
    def add_state(self, state_id, learned_flow, target_flow):
        """Record flow values for a state."""
        self.state_flows[state_id] = learned_flow
        self.target_flows[state_id] = target_flow
        
    def visualize_mismatch(self):
        """Create visualization of flow mismatch."""
        state_ids = list(self.state_flows.keys())
        learned = [self.state_flows[s] for s in state_ids]
        target = [self.target_flows[s] for s in state_ids]
        mismatch = [l - t for l, t in zip(learned, target)]
        
        fig = go.Figure()
        
        # Scatter plot: learned vs target
        fig.add_trace(go.Scatter(
            x=target, y=learned,
            mode='markers',
            marker=dict(
                size=10,
                color=mismatch,
                colorscale='RdBu',
                colorbar=dict(title='Mismatch'),
                showscale=True
            ),
            name='States'
        ))
        
        # Perfect match line
        fig.add_trace(go.Scatter(
            x=[min(target), max(target)],
            y=[min(target), max(target)],
            mode='lines',
            line=dict(dash='dash', color='gray'),
            name='Perfect Match'
        ))
        
        fig.update_layout(
            title='Flow Distribution: Learned vs Target',
            xaxis_title='Target Flow',
            yaxis_title='Learned Flow'
        )
        
        return fig
    
    def flow_consistency_over_time(self, flow_history):
        """Track how flow consistency evolves during training."""
        consistency_scores = []
        for flows_at_step in flow_history:
            # Consistency: correlation between learned and target
            learned = [f['learned'] for f in flows_at_step]
            target = [f['target'] for f in flows_at_step]
            corr = np.corrcoef(learned, target)[0, 1]
            consistency_scores.append(corr)
        
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            y=consistency_scores,
            mode='lines',
            name='Flow Consistency'
        ))
        fig.update_layout(
            title='Flow Consistency Over Training',
            yaxis_title='Correlation (Learned vs Target)',
            xaxis_title='Training Step'
        )
        return fig

Component 3: Mode Coverage Analyzer

class ModeCoverageAnalyzer:
    """Analyze and visualize mode coverage during training."""
    
    def __init__(self, reward_thresholds=None):
        self.thresholds = reward_thresholds or [0.1, 0.5, 0.9]
        self.mode_discovery_timeline = {t: [] for t in self.thresholds}
        self.sample_diversity = []
        
    def analyze_samples(self, samples, rewards, step):
        """Analyze a batch of generated samples."""
        # Identify modes (high-reward clusters)
        for threshold in self.thresholds:
            high_reward_samples = [s for s, r in zip(samples, rewards) 
                                  if r >= threshold]
            unique_modes = self.identify_unique_modes(high_reward_samples)
            self.mode_discovery_timeline[threshold].append({
                'step': step,
                'count': len(unique_modes),
                'modes': unique_modes
            })
        
        # Diversity metric
        diversity = self.compute_diversity(samples)
        self.sample_diversity.append({'step': step, 'diversity': diversity})
        
    def identify_unique_modes(self, samples, similarity_threshold=0.8):
        """Cluster samples to identify unique modes."""
        # Use clustering algorithm (e.g., DBSCAN, hierarchical)
        from sklearn.cluster import DBSCAN
        
        if len(samples) == 0:
            return []
            
        # Convert samples to feature vectors for clustering
        features = self.samples_to_features(samples)
        
        clustering = DBSCAN(eps=1-similarity_threshold, min_samples=2)
        labels = clustering.fit_predict(features)
        
        # Count unique clusters (excluding noise: label=-1)
        unique_modes = len(set(labels)) - (1 if -1 in labels else 0)
        return unique_modes
    
    def compute_diversity(self, samples):
        """Compute diversity metric for sample set."""
        if len(samples) < 2:
            return 0.0
            
        # Pairwise similarity matrix
        features = self.samples_to_features(samples)
        similarities = np.corrcoef(features)
        
        # Diversity: 1 - mean pairwise similarity
        np.fill_diagonal(similarities, 0)
        diversity = 1 - np.mean(similarities)
        return diversity
    
    def plot_mode_discovery(self):
        """Visualize mode discovery over training."""
        fig = go.Figure()
        
        for threshold in self.thresholds:
            timeline = self.mode_discovery_timeline[threshold]
            steps = [t['step'] for t in timeline]
            counts = [t['count'] for t in timeline]
            
            fig.add_trace(go.Scatter(
                x=steps, y=counts,
                mode='lines+markers',
                name=f'Reward >= {threshold}'
            ))
        
        fig.update_layout(
            title='Mode Discovery Over Training',
            xaxis_title='Training Step',
            yaxis_title='Number of Unique Modes Discovered',
            legend_title='Reward Threshold'
        )
        
        return fig
    
    def plot_diversity(self):
        """Plot sample diversity over training."""
        steps = [d['step'] for d in self.sample_diversity]
        diversities = [d['diversity'] for d in self.sample_diversity]
        
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=steps, y=diversities,
            mode='lines',
            name='Sample Diversity'
        ))
        
        # Add trend line
        if len(steps) > 1:
            z = np.polyfit(steps, diversities, 1)
            p = np.poly1d(z)
            fig.add_trace(go.Scatter(
                x=steps, y=p(steps),
                mode='lines',
                line=dict(dash='dash'),
                name='Trend'
            ))
        
        fig.update_layout(
            title='Sample Diversity Over Training',
            xaxis_title='Training Step',
            yaxis_title='Diversity Score'
        )
        
        return fig
    
    def detect_mode_collapse(self, window_size=100):
        """Detect potential mode collapse events."""
        if len(self.sample_diversity) < window_size:
            return []
            
        collapse_events = []
        diversities = [d['diversity'] for d in self.sample_diversity]
        
        for i in range(window_size, len(diversities)):
            prev_mean = np.mean(diversities[i-window_size:i])
            curr = diversities[i]
            
            # Significant drop in diversity
            if curr < prev_mean * 0.5:
                collapse_events.append({
                    'step': self.sample_diversity[i]['step'],
                    'diversity': curr,
                    'previous_mean': prev_mean,
                    'severity': 'severe' if curr < prev_mean * 0.3 else 'moderate'
                })
        
        return collapse_events

Integration with GFlowNet Training

def train_gflow_with_visualization(gfn, env, n_iterations=10000):
    """Train GFlowNet with GFlowState visualization."""
    
    dashboard = GFlowTrainingDashboard()
    flow_viz = FlowDistributionVisualizer()
    mode_analyzer = ModeCoverageAnalyzer(reward_thresholds=[0.5, 0.8, 0.95])
    
    for step in range(n_iterations):
        # Training step
        metrics = gfn.train_step(env)
        
        # Update visualizations
        dashboard.update(step, metrics)
        
        # Periodic flow analysis
        if step % 100 == 0:
            states, flows = gfn.evaluate_flows(env)
            for s, f in zip(states, flows):
                target = env.compute_target_flow(s)
                flow_viz.add_state(s.id, f, target)
            
            # Generate samples for mode analysis
            samples, rewards = gfn.sample_batch(n=100)
            mode_analyzer.analyze_samples(samples, rewards, step)
        
        # Check for mode collapse
        if step % 500 == 0:
            collapse_events = mode_analyzer.detect_mode_collapse()
            if collapse_events:
                print(f"Warning: Mode collapse detected at steps: "
                      f"{[e['step'] for e in collapse_events]}")
    
    return dashboard, flow_viz, mode_analyzer

Applications

  • Generative Flow Networks: Training monitoring and debugging
  • Molecule Generation: Tracking chemical diversity
  • Neural Architecture Search: Exploring architecture space
  • Training Visualization: Understanding complex training dynamics
  • Combinatorial Optimization: Monitoring solution diversity

Pitfalls

  • Visualization overhead may slow training
  • Mode identification depends on similarity metric choice
  • Real-time visualization requires significant compute
  • Interpretation requires understanding of GFlowNet theory
  • False positives in mode collapse detection

Related Skills

  • rad-2-scaling-reinforcement-learning-generator-discriminator-fra
  • coral-open-ended-discovery
  • agent-rl-benchmark

References

  • Weber et al. (2026). GFlowState: Visualizing the Training of Generative Flow Networks Beyond the Reward. arXiv:2604.21830
  • Bengio et al. (2021). Flow Network based Generative Models for Non-Iterative Diverse Candidate Generation
  • Malkin et al. (2022). Trajectory Balance: Improved Credit Assignment in GFlowNets
Install via CLI
npx skills add https://github.com/hiyenwong/ai_collection --skill gflowstate-training-visualization
Repository Details
star Stars 2
call_split Forks 0
navigation Branch main
article Path SKILL.md
More from Creator