topology-neural-collapse-monitor

star 1

Monitoring Neural Training with Topology — Footprint-Predictable Collapse Index using Modular Morse Homology Maintenance (MMHM). Detects representational collapse in neural network embeddings before performance metrics degrade. Activation: neural collapse, representational collapse, topological monitoring, training diagnostics, MMHM, embedding degradation, Morse homology.

hiyenwong By hiyenwong schedule Updated 6/3/2026

name: topology-neural-collapse-monitor description: "Monitoring Neural Training with Topology — Footprint-Predictable Collapse Index using Modular Morse Homology Maintenance (MMHM). Detects representational collapse in neural network embeddings before performance metrics degrade. Activation: neural collapse, representational collapse, topological monitoring, training diagnostics, MMHM, embedding degradation, Morse homology." category: ai_collection source: paper: "Monitoring Neural Training with Topology: A Footprint-Predictable Collapse Index" authors: - "Alexander Kalinowski" arxiv: "2604.26984" date: "2026-04-28" fields: - cs.LG - cs.AI activation_keywords: en: - neural collapse - representational collapse - topological monitoring - training diagnostics - MMHM - embedding degradation - Morse homology - anisotropic embeddings - representation monitoring - training collapse detection zh: - 神经坍缩 - 表征坍缩 - 拓扑监控 - 训练诊断 - 嵌入退化 - 莫尔斯同调 - 各向异性嵌入 - 表征监控

version: "1.0.0"

Topology-Aware Neural Training Monitor: Collapse Index

Reference: Kalinowski, A. Monitoring Neural Training with Topology: A Footprint-Predictable Collapse Index. arXiv:2604.26984 [cs.LG] (2026).

Overview

Representational collapse — where neural network embeddings become anisotropic and lose multi-scale structure — can erode downstream performance long before standard performance metrics (loss, accuracy) react. This skill provides a topology-aware monitoring framework that detects early signs of representational degradation.

Core Innovation: MMHM + Collapse Index

Instead of rebuilding topological complexes from scratch each epoch (expensive), the method uses:

  1. Modular Morse Homology Maintenance (MMHM): Applies sparse edits at a fixed scale and maintains a discrete Morse matching incrementally
  2. Composite Collapse Index (CI): Combines topological signatures into a single scalar diagnostic

Key Advantage

  • Fast: Incremental updates instead of full recomputation
  • Early detection: Signals collapse before loss/accuracy degradation
  • Footprint-predictable: Computational cost scales predictably with data size

Methodology

Step 1: Build Point Cloud from Embeddings

import numpy as np
from scipy.spatial.distance import pdist, squareform

def get_embeddings_snapshot(model, dataloader, device):
    """Extract embeddings from a model's penultimate layer."""
    model.eval()
    embeddings = []
    labels = []
    
    with torch.no_grad():
        for batch_X, batch_y in dataloader:
            batch_X = batch_X.to(device)
            batch_y = batch_y.cpu().numpy()
            
            # Get embeddings (penultimate layer output)
            emb = model.get_features(batch_X)
            embeddings.append(emb.cpu().numpy())
            labels.append(batch_y)
    
    embeddings = np.concatenate(embeddings, axis=0)
    labels = np.concatenate(labels, axis=0)
    return embeddings, labels

Step 2: Compute Persistence Diagram

import ripser
from persim import plot_diagrams

def compute_persistence(embeddings, max_dim=1, threshold=None):
    """
    Compute persistent homology of embedding point cloud.
    Returns persistence diagram for dimensions 0 (connected components)
    and 1 (loops/cycles).
    """
    # Compute pairwise distances
    dist_matrix = squareform(pdist(embeddings, metric='euclidean'))
    
    # Apply threshold if specified (for sparse computation)
    if threshold is not None:
        dist_matrix = np.minimum(dist_matrix, threshold)
    
    # Compute persistence diagram
    diagrams = ripser.ripser(
        dist_matrix,
        maxdim=max_dim,
        distance_matrix=True
    )['diagrams']
    
    return diagrams

def persistence_lifetime(diagram, dim=1):
    """Compute persistence lifetimes (death - birth) for a dimension."""
    if dim >= len(diagram) or len(diagram[dim]) == 0:
        return np.array([])
    
    births = diagram[dim][:, 0]
    deaths = diagram[dim][:, 1]
    # Filter out infinite deaths
    finite = np.isfinite(deaths)
    return deaths[finite] - births[finite]

Step 3: Incremental Morse Homology (MMHM)

class MorseHomologyMonitor:
    """
    Modular Morse Homology Maintenance for incremental topological monitoring.
    Instead of rebuilding the complex each epoch, maintains a discrete Morse
    matching and applies sparse edits.
    """
    
    def __init__(self, scale=0.5, max_dim=1):
        self.scale = scale
        self.max_dim = max_dim
        self.current_matching = None
        self.birth_death_pairs = {}
        self.critical_cells = set()
        
    def initialize(self, embeddings):
        """Build initial Morse matching from first epoch."""
        self.current_embeddings = embeddings.copy()
        self._build_initial_complex(embeddings)
        
    def _build_initial_complex(self, embeddings):
        """Build initial simplicial complex and Morse matching."""
        # Build Vietoris-Rips complex at fixed scale
        dist_matrix = squareform(pdist(embeddings, metric='euclidean'))
        edges = np.argwhere(dist_matrix < self.scale)
        
        # Initialize 0-simplices (vertices)
        self.critical_cells = set(range(len(embeddings)))
        
        # Build gradient matching
        self.current_matching = {}
        for i, j in edges:
            if i < j:
                # Simple greedy matching
                if i not in self.current_matching and j not in self.current_matching.values():
                    self.current_matching[i] = j
                    
    def update(self, new_embeddings, old_embeddings):
        """
        Incremental update: only process changed embeddings.
        Sparse edit maintains Morse matching efficiency.
        """
        # Find changed points (or subsample for efficiency)
        changed_mask = self._find_changes(new_embeddings, old_embeddings)
        
        # Update matching for changed region
        for idx in np.where(changed_mask)[0]:
            self._local_update(idx, new_embeddings)
            
        self.current_embeddings = new_embeddings.copy()
        
    def _find_changes(self, new, old, threshold=0.01):
        """Find embeddings that changed significantly."""
        return np.linalg.norm(new - old, axis=1) > threshold
    
    def _local_update(self, idx, embeddings):
        """Update Morse matching locally around changed point."""
        # Recompute local neighborhood
        dists = np.linalg.norm(embeddings - embeddings[idx], axis=1)
        neighbors = np.where(dists < self.scale)[0]
        
        # Update matching for local region
        for n in neighbors:
            if n != idx:
                pair = (min(idx, n), max(idx, n))
                if pair not in [(k, v) for k, v in self.current_matching.items()]:
                    if idx not in self.current_matching and n not in self.current_matching.values():
                        self.current_matching[idx] = n

Step 4: Compute Collapse Index

def compute_collapse_index(embeddings, labels=None):
    """
    Composite Collapse Index measuring representational degradation.
    Combines multiple topological and geometric signals.
    """
    ci_components = {}
    
    # 1. Anisotropy: ratio of largest to smallest eigenvalue of covariance
    cov = np.cov(embeddings.T)
    eigenvalues = np.linalg.eigvalsh(cov)
    eigenvalues = eigenvalues[eigenvalues > 0]  # filter numerical zeros
    anisotropy = eigenvalues[-1] / (eigenvalues[0] + 1e-10)
    ci_components['anisotropy'] = np.log1p(anisotropy)
    
    # 2. Effective dimensionality: participation ratio
    eigenvalues_norm = eigenvalues / eigenvalues.sum()
    effective_dim = 1.0 / (eigenvalues_norm ** 2).sum()
    max_dim = len(eigenvalues)
    ci_components['dim_ratio'] = 1.0 - (effective_dim / max_dim)
    
    # 3. Persistence entropy (topological complexity)
    diagrams = compute_persistence(embeddings[:min(500, len(embeddings))])
    if len(diagrams) > 1 and len(diagrams[1]) > 0:
        lifetimes = persistence_lifetime(diagrams, dim=1)
        if len(lifetimes) > 0:
            total = lifetimes.sum()
            if total > 0:
                probs = lifetimes / total
                persistence_entropy = -(probs * np.log(probs + 1e-10)).sum()
                ci_components['persistence_entropy'] = persistence_entropy
    
    # 4. Class separation (if labels available)
    if labels is not None:
        unique_labels = np.unique(labels)
        if len(unique_labels) > 1:
            # Compute within-class vs between-class scatter
            overall_mean = embeddings.mean(axis=0)
            sw = 0  # within-class scatter
            sb = 0  # between-class scatter
            for l in unique_labels:
                mask = labels == l
                class_mean = embeddings[mask].mean(axis=0)
                sw += np.sum((embeddings[mask] - class_mean) ** 2)
                sb += mask.sum() * np.sum((class_mean - overall_mean) ** 2)
            
            if sw > 0:
                fisher_ratio = sb / sw
                ci_components['fisher_ratio'] = 1.0 / (1.0 + fisher_ratio)
    
    # Composite index (weighted sum)
    weights = {
        'anisotropy': 0.3,
        'dim_ratio': 0.3,
        'persistence_entropy': 0.2,
        'fisher_ratio': 0.2,
    }
    
    ci = 0.0
    for key, weight in weights.items():
        if key in ci_components:
            ci += weight * normalize_component(ci_components[key])
    
    return ci, ci_components

def normalize_component(value, ref_min=0.0, ref_max=10.0):
    """Normalize component to [0, 1] range."""
    return np.clip((value - ref_min) / (ref_max - ref_min + 1e-10), 0, 1)

Training Integration

class CollapseMonitor:
    """Monitor for representational collapse during training."""
    
    def __init__(self, check_every=5, threshold=0.7):
        self.check_every = check_every
        self.threshold = threshold
        self.history = []
        self.morse_monitor = MorseHomologyMonitor(scale=0.5)
        self.is_collapsing = False
        self.collapse_epoch = None
        
    def check(self, model, dataloader, epoch, device):
        """Run collapse check at current epoch."""
        if epoch % self.check_every != 0:
            return
        
        embeddings, labels = get_embeddings_snapshot(model, dataloader, device)
        ci, components = compute_collapse_index(embeddings, labels)
        
        self.history.append({
            'epoch': epoch,
            'collapse_index': ci,
            'components': components,
        })
        
        # Check for collapse trend
        if len(self.history) >= 3:
            recent = [h['collapse_index'] for h in self.history[-3:]]
            if recent[-1] > self.threshold and recent[-1] > recent[-2]:
                self.is_collapsing = True
                self.collapse_epoch = epoch
                print(f"⚠️ COLLAPSE DETECTED at epoch {epoch}! CI={ci:.4f}")
        
        return ci
    
    def get_report(self):
        """Generate monitoring report."""
        if not self.history:
            return "No monitoring data yet."
        
        report = "## Collapse Monitor Report\n\n"
        report += "| Epoch | Collapse Index | Status |\n"
        report += "|-------|---------------|--------|\n"
        
        for h in self.history:
            status = "⚠️ COLLAPSE" if h['collapse_index'] > self.threshold else "✅ OK"
            report += f"| {h['epoch']} | {h['collapse_index']:.4f} | {status} |\n"
        
        return report

Usage in Training Loop

# Initialize monitor
monitor = CollapseMonitor(check_every=5, threshold=0.7)

for epoch in range(num_epochs):
    # ... standard training steps ...
    train_step(model, optimizer, dataloader)
    
    # Check for collapse
    ci = monitor.check(model, val_dataloader, epoch, device)
    
    # Early intervention if collapse detected
    if monitor.is_collapsing:
        print("Intervening: reducing learning rate, adding regularization")
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.5
        # Could also add contrastive loss, weight decay, etc.

Diagnostic Signals

Signs of Representational Collapse

Signal What It Means Action
High anisotropy Embeddings concentrate along few dimensions Add contrastive loss, increase regularization
Low effective dim Dimensionality collapse Use larger hidden layers, dropout
Low persistence entropy Loss of topological structure Add topological regularization term
Poor class separation Classes merging in embedding space Increase class margin, use focal loss

Intervention Strategies

  1. Learning rate warm restart — reset and anneal again
  2. Add contrastive regularization — push apart different classes
  3. Increase weight decay — prevent overfitting to training distribution
  4. Spectral normalization — constrain weight matrix singular values
  5. Batch normalization tuning — adjust momentum to preserve diversity
  6. Early stopping — halt training before collapse completes

Computational Efficiency

Method Complexity When to Use
Full persistence (all data) O(n³) Small datasets (< 1000 samples)
Subsampled persistence O(k³) where k << n Medium datasets (1K-100K)
MMHM incremental O(Δn · k²) Large datasets, frequent checks
Eigenvalue-only CI O(d² · n) Quick checks every epoch

Best Practices

  1. Check periodically, not every epoch — topological computation is expensive
  2. Use subsampling for large datasets — 500-1000 random samples is usually sufficient
  3. Track trend, not absolute value — increasing CI is more important than absolute threshold
  4. Combine with standard metrics — CI complements loss/accuracy, doesn't replace
  5. Calibrate threshold per task — different tasks have different collapse baselines
  6. Use MMHM for production — incremental updates are much faster than recomputation
  7. Monitor both train and val — collapse can happen differently in each
  8. Log component breakdown — understanding which component drives CI helps debugging

References

  • Kalinowski, A. (2026). Monitoring Neural Training with Topology: A Footprint-Predictable Collapse Index. arXiv:2604.26984 [cs.LG].
  • Related: neural-code-dynamics-analysis, neural-population-dynamics
Install via CLI
npx skills add https://github.com/hiyenwong/ai_collection --skill topology-neural-collapse-monitor
Repository Details
star Stars 1
call_split Forks 0
navigation Branch main
article Path SKILL.md
Occupations
More from Creator