name: topology-neural-collapse-monitor description: "Monitoring Neural Training with Topology — Footprint-Predictable Collapse Index using Modular Morse Homology Maintenance (MMHM). Detects representational collapse in neural network embeddings before performance metrics degrade. Activation: neural collapse, representational collapse, topological monitoring, training diagnostics, MMHM, embedding degradation, Morse homology." category: ai_collection source: paper: "Monitoring Neural Training with Topology: A Footprint-Predictable Collapse Index" authors: - "Alexander Kalinowski" arxiv: "2604.26984" date: "2026-04-28" fields: - cs.LG - cs.AI activation_keywords: en: - neural collapse - representational collapse - topological monitoring - training diagnostics - MMHM - embedding degradation - Morse homology - anisotropic embeddings - representation monitoring - training collapse detection zh: - 神经坍缩 - 表征坍缩 - 拓扑监控 - 训练诊断 - 嵌入退化 - 莫尔斯同调 - 各向异性嵌入 - 表征监控
version: "1.0.0"
Topology-Aware Neural Training Monitor: Collapse Index
Reference: Kalinowski, A. Monitoring Neural Training with Topology: A Footprint-Predictable Collapse Index. arXiv:2604.26984 [cs.LG] (2026).
Overview
Representational collapse — where neural network embeddings become anisotropic and lose multi-scale structure — can erode downstream performance long before standard performance metrics (loss, accuracy) react. This skill provides a topology-aware monitoring framework that detects early signs of representational degradation.
Core Innovation: MMHM + Collapse Index
Instead of rebuilding topological complexes from scratch each epoch (expensive), the method uses:
- Modular Morse Homology Maintenance (MMHM): Applies sparse edits at a fixed scale and maintains a discrete Morse matching incrementally
- Composite Collapse Index (CI): Combines topological signatures into a single scalar diagnostic
Key Advantage
- Fast: Incremental updates instead of full recomputation
- Early detection: Signals collapse before loss/accuracy degradation
- Footprint-predictable: Computational cost scales predictably with data size
Methodology
Step 1: Build Point Cloud from Embeddings
import numpy as np
from scipy.spatial.distance import pdist, squareform
def get_embeddings_snapshot(model, dataloader, device):
"""Extract embeddings from a model's penultimate layer."""
model.eval()
embeddings = []
labels = []
with torch.no_grad():
for batch_X, batch_y in dataloader:
batch_X = batch_X.to(device)
batch_y = batch_y.cpu().numpy()
# Get embeddings (penultimate layer output)
emb = model.get_features(batch_X)
embeddings.append(emb.cpu().numpy())
labels.append(batch_y)
embeddings = np.concatenate(embeddings, axis=0)
labels = np.concatenate(labels, axis=0)
return embeddings, labels
Step 2: Compute Persistence Diagram
import ripser
from persim import plot_diagrams
def compute_persistence(embeddings, max_dim=1, threshold=None):
"""
Compute persistent homology of embedding point cloud.
Returns persistence diagram for dimensions 0 (connected components)
and 1 (loops/cycles).
"""
# Compute pairwise distances
dist_matrix = squareform(pdist(embeddings, metric='euclidean'))
# Apply threshold if specified (for sparse computation)
if threshold is not None:
dist_matrix = np.minimum(dist_matrix, threshold)
# Compute persistence diagram
diagrams = ripser.ripser(
dist_matrix,
maxdim=max_dim,
distance_matrix=True
)['diagrams']
return diagrams
def persistence_lifetime(diagram, dim=1):
"""Compute persistence lifetimes (death - birth) for a dimension."""
if dim >= len(diagram) or len(diagram[dim]) == 0:
return np.array([])
births = diagram[dim][:, 0]
deaths = diagram[dim][:, 1]
# Filter out infinite deaths
finite = np.isfinite(deaths)
return deaths[finite] - births[finite]
Step 3: Incremental Morse Homology (MMHM)
class MorseHomologyMonitor:
"""
Modular Morse Homology Maintenance for incremental topological monitoring.
Instead of rebuilding the complex each epoch, maintains a discrete Morse
matching and applies sparse edits.
"""
def __init__(self, scale=0.5, max_dim=1):
self.scale = scale
self.max_dim = max_dim
self.current_matching = None
self.birth_death_pairs = {}
self.critical_cells = set()
def initialize(self, embeddings):
"""Build initial Morse matching from first epoch."""
self.current_embeddings = embeddings.copy()
self._build_initial_complex(embeddings)
def _build_initial_complex(self, embeddings):
"""Build initial simplicial complex and Morse matching."""
# Build Vietoris-Rips complex at fixed scale
dist_matrix = squareform(pdist(embeddings, metric='euclidean'))
edges = np.argwhere(dist_matrix < self.scale)
# Initialize 0-simplices (vertices)
self.critical_cells = set(range(len(embeddings)))
# Build gradient matching
self.current_matching = {}
for i, j in edges:
if i < j:
# Simple greedy matching
if i not in self.current_matching and j not in self.current_matching.values():
self.current_matching[i] = j
def update(self, new_embeddings, old_embeddings):
"""
Incremental update: only process changed embeddings.
Sparse edit maintains Morse matching efficiency.
"""
# Find changed points (or subsample for efficiency)
changed_mask = self._find_changes(new_embeddings, old_embeddings)
# Update matching for changed region
for idx in np.where(changed_mask)[0]:
self._local_update(idx, new_embeddings)
self.current_embeddings = new_embeddings.copy()
def _find_changes(self, new, old, threshold=0.01):
"""Find embeddings that changed significantly."""
return np.linalg.norm(new - old, axis=1) > threshold
def _local_update(self, idx, embeddings):
"""Update Morse matching locally around changed point."""
# Recompute local neighborhood
dists = np.linalg.norm(embeddings - embeddings[idx], axis=1)
neighbors = np.where(dists < self.scale)[0]
# Update matching for local region
for n in neighbors:
if n != idx:
pair = (min(idx, n), max(idx, n))
if pair not in [(k, v) for k, v in self.current_matching.items()]:
if idx not in self.current_matching and n not in self.current_matching.values():
self.current_matching[idx] = n
Step 4: Compute Collapse Index
def compute_collapse_index(embeddings, labels=None):
"""
Composite Collapse Index measuring representational degradation.
Combines multiple topological and geometric signals.
"""
ci_components = {}
# 1. Anisotropy: ratio of largest to smallest eigenvalue of covariance
cov = np.cov(embeddings.T)
eigenvalues = np.linalg.eigvalsh(cov)
eigenvalues = eigenvalues[eigenvalues > 0] # filter numerical zeros
anisotropy = eigenvalues[-1] / (eigenvalues[0] + 1e-10)
ci_components['anisotropy'] = np.log1p(anisotropy)
# 2. Effective dimensionality: participation ratio
eigenvalues_norm = eigenvalues / eigenvalues.sum()
effective_dim = 1.0 / (eigenvalues_norm ** 2).sum()
max_dim = len(eigenvalues)
ci_components['dim_ratio'] = 1.0 - (effective_dim / max_dim)
# 3. Persistence entropy (topological complexity)
diagrams = compute_persistence(embeddings[:min(500, len(embeddings))])
if len(diagrams) > 1 and len(diagrams[1]) > 0:
lifetimes = persistence_lifetime(diagrams, dim=1)
if len(lifetimes) > 0:
total = lifetimes.sum()
if total > 0:
probs = lifetimes / total
persistence_entropy = -(probs * np.log(probs + 1e-10)).sum()
ci_components['persistence_entropy'] = persistence_entropy
# 4. Class separation (if labels available)
if labels is not None:
unique_labels = np.unique(labels)
if len(unique_labels) > 1:
# Compute within-class vs between-class scatter
overall_mean = embeddings.mean(axis=0)
sw = 0 # within-class scatter
sb = 0 # between-class scatter
for l in unique_labels:
mask = labels == l
class_mean = embeddings[mask].mean(axis=0)
sw += np.sum((embeddings[mask] - class_mean) ** 2)
sb += mask.sum() * np.sum((class_mean - overall_mean) ** 2)
if sw > 0:
fisher_ratio = sb / sw
ci_components['fisher_ratio'] = 1.0 / (1.0 + fisher_ratio)
# Composite index (weighted sum)
weights = {
'anisotropy': 0.3,
'dim_ratio': 0.3,
'persistence_entropy': 0.2,
'fisher_ratio': 0.2,
}
ci = 0.0
for key, weight in weights.items():
if key in ci_components:
ci += weight * normalize_component(ci_components[key])
return ci, ci_components
def normalize_component(value, ref_min=0.0, ref_max=10.0):
"""Normalize component to [0, 1] range."""
return np.clip((value - ref_min) / (ref_max - ref_min + 1e-10), 0, 1)
Training Integration
class CollapseMonitor:
"""Monitor for representational collapse during training."""
def __init__(self, check_every=5, threshold=0.7):
self.check_every = check_every
self.threshold = threshold
self.history = []
self.morse_monitor = MorseHomologyMonitor(scale=0.5)
self.is_collapsing = False
self.collapse_epoch = None
def check(self, model, dataloader, epoch, device):
"""Run collapse check at current epoch."""
if epoch % self.check_every != 0:
return
embeddings, labels = get_embeddings_snapshot(model, dataloader, device)
ci, components = compute_collapse_index(embeddings, labels)
self.history.append({
'epoch': epoch,
'collapse_index': ci,
'components': components,
})
# Check for collapse trend
if len(self.history) >= 3:
recent = [h['collapse_index'] for h in self.history[-3:]]
if recent[-1] > self.threshold and recent[-1] > recent[-2]:
self.is_collapsing = True
self.collapse_epoch = epoch
print(f"⚠️ COLLAPSE DETECTED at epoch {epoch}! CI={ci:.4f}")
return ci
def get_report(self):
"""Generate monitoring report."""
if not self.history:
return "No monitoring data yet."
report = "## Collapse Monitor Report\n\n"
report += "| Epoch | Collapse Index | Status |\n"
report += "|-------|---------------|--------|\n"
for h in self.history:
status = "⚠️ COLLAPSE" if h['collapse_index'] > self.threshold else "✅ OK"
report += f"| {h['epoch']} | {h['collapse_index']:.4f} | {status} |\n"
return report
Usage in Training Loop
# Initialize monitor
monitor = CollapseMonitor(check_every=5, threshold=0.7)
for epoch in range(num_epochs):
# ... standard training steps ...
train_step(model, optimizer, dataloader)
# Check for collapse
ci = monitor.check(model, val_dataloader, epoch, device)
# Early intervention if collapse detected
if monitor.is_collapsing:
print("Intervening: reducing learning rate, adding regularization")
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.5
# Could also add contrastive loss, weight decay, etc.
Diagnostic Signals
Signs of Representational Collapse
| Signal | What It Means | Action |
|---|---|---|
| High anisotropy | Embeddings concentrate along few dimensions | Add contrastive loss, increase regularization |
| Low effective dim | Dimensionality collapse | Use larger hidden layers, dropout |
| Low persistence entropy | Loss of topological structure | Add topological regularization term |
| Poor class separation | Classes merging in embedding space | Increase class margin, use focal loss |
Intervention Strategies
- Learning rate warm restart — reset and anneal again
- Add contrastive regularization — push apart different classes
- Increase weight decay — prevent overfitting to training distribution
- Spectral normalization — constrain weight matrix singular values
- Batch normalization tuning — adjust momentum to preserve diversity
- Early stopping — halt training before collapse completes
Computational Efficiency
| Method | Complexity | When to Use |
|---|---|---|
| Full persistence (all data) | O(n³) | Small datasets (< 1000 samples) |
| Subsampled persistence | O(k³) where k << n | Medium datasets (1K-100K) |
| MMHM incremental | O(Δn · k²) | Large datasets, frequent checks |
| Eigenvalue-only CI | O(d² · n) | Quick checks every epoch |
Best Practices
- Check periodically, not every epoch — topological computation is expensive
- Use subsampling for large datasets — 500-1000 random samples is usually sufficient
- Track trend, not absolute value — increasing CI is more important than absolute threshold
- Combine with standard metrics — CI complements loss/accuracy, doesn't replace
- Calibrate threshold per task — different tasks have different collapse baselines
- Use MMHM for production — incremental updates are much faster than recomputation
- Monitor both train and val — collapse can happen differently in each
- Log component breakdown — understanding which component drives CI helps debugging
References
- Kalinowski, A. (2026). Monitoring Neural Training with Topology: A Footprint-Predictable Collapse Index. arXiv:2604.26984 [cs.LG].
- Related: neural-code-dynamics-analysis, neural-population-dynamics