neural-manifold-dynamics-learning

star 2

Neural Manifold Learning Dynamics methodology for analyzing population activity in high-dimensional neural state spaces. Extracts low-dimensional structure from neural recordings to understand computation and behavior. Combines dimensionality reduction with dynamical systems analysis for neural population decoding. Activation: neural manifold, latent dynamics, population activity, dimensionality reduction, neural state space, behavior decoding, jPCA, dPCA, GPFA.

hiyenwong By hiyenwong schedule Updated 6/3/2026

name: neural-manifold-dynamics-learning description: "Neural Manifold Learning Dynamics methodology for analyzing population activity in high-dimensional neural state spaces. Extracts low-dimensional structure from neural recordings to understand computation and behavior. Combines dimensionality reduction with dynamical systems analysis for neural population decoding. Activation: neural manifold, latent dynamics, population activity, dimensionality reduction, neural state space, behavior decoding, jPCA, dPCA, GPFA." tags: ["neural-manifold", "population-dynamics", "dimensionality-reduction", "jPCA", "dPCA", "GPFA", "latent-dynamics", "neural-decoding"]

Neural Manifold Learning Dynamics

Overview

Neural manifold learning dynamics is a framework for understanding how populations of neurons coordinate their activity to generate behavior. The key insight is that despite having thousands or millions of neurons, neural activity often occupies a low-dimensional manifold within the high-dimensional neural state space.

Core Concepts

Neural Manifold Hypothesis

High-Dimensional Neural Space (N neurons)
    ↓
Low-Dimensional Manifold (d << N dimensions)
    ↓
Behavioral Output

Key Properties:

  • Neural activity is constrained to a low-dimensional surface
  • Manifold structure reflects computational strategies
  • Dynamics on the manifold generate behavior
  • Stable across trials and conditions

Dimensionality Reduction Methods

import numpy as np
from sklearn.decomposition import PCA

def extract_neural_manifold(spike_trains, method='PCA', n_components=10):
    """
    Extract low-dimensional neural manifold from spike data.
    
    Args:
        spike_trains: (n_trials, n_neurons, n_timepoints) array
        method: 'PCA', 'jPCA', 'dPCA', 'GPFA'
        n_components: Dimensionality of manifold
    
    Returns:
        manifold_projection: Low-dimensional trajectory
        components: Manifold basis vectors
        explained_variance: Variance explained
    """
    # Reshape: (n_trials * n_timepoints, n_neurons)
    n_trials, n_neurons, n_timepoints = spike_trains.shape
    data = spike_trains.reshape(-1, n_neurons)
    
    if method == 'PCA':
        model = PCA(n_components=n_components)
        projection = model.fit_transform(data)
        components = model.components_
        variance = model.explained_variance_ratio_
        
    elif method == 'jPCA':
        # jPCA for rotational dynamics
        projection, components, variance = jpca_analysis(spike_trains, n_components)
        
    elif method == 'dPCA':
        # Demixed PCA for parameter separation
        projection, components, variance = dpca_analysis(spike_trains, labels, n_components)
        
    elif method == 'GPFA':
        # Gaussian Process Factor Analysis for smooth trajectories
        projection, components, variance = gpfa_analysis(spike_trains, n_components)
    
    # Reshape back to trials
    projection = projection.reshape(n_trials, n_timepoints, n_components)
    
    return projection, components, variance

jPCA: Rotational Dynamics

JPCA (jPCA) discovers rotational dynamics in neural population activity, particularly relevant for motor cortex.

import numpy as np
from scipy.linalg import schur

def jpca_analysis(spike_trains, n_components=6):
    """
    jPCA for extracting rotational dynamics.
    
    Based on: Churchland et al. (2012) Neural population dynamics...
    
    Args:
        spike_trains: (n_conditions, n_neurons, n_timepoints)
        n_components: Must be even (pairs of components)
    
    Returns:
        jPC_trajectories: Projected neural trajectories
        jPCs: jPCA components
        eigenvalues: Complex eigenvalues (rotation frequencies)
    """
    n_conditions, n_neurons, n_timepoints = spike_trains.shape
    
    # Step 1: Preprocessing with standard PCA
    # Average across conditions and reshape
    mean_activity = spike_trains.mean(axis=0)  # (n_neurons, n_timepoints)
    
    # Center and perform initial PCA
    centered = mean_activity - mean_activity.mean(axis=1, keepdims=True)
    pca = PCA(n_components=n_components)
    pca_proj = pca.fit_transform(centered.T).T  # (n_components, n_timepoints)
    
    # Step 2: Compute derivative (change over time)
    dX = np.diff(pca_proj, axis=1)  # (n_components, n_timepoints-1)
    X = pca_proj[:, :-1]  # (n_components, n_timepoints-1)
    
    # Step 3: Fit skew-symmetric matrix
    # dX/dt ≈ M * X where M is skew-symmetric (M = -M^T)
    # This enforces rotational dynamics
    
    # Solve for M using least squares with skew-symmetric constraint
    M = fit_skew_symmetric(X, dX)
    
    # Step 4: Extract rotational structure via Schur decomposition
    T, Z = schur(M)
    
    # Extract 2x2 blocks corresponding to rotations
    jPCs = pca.components_.T @ Z  # Transform back to neuron space
    
    # Project data onto jPCs
    jPC_trajectories = []
    for i in range(n_conditions):
        traj = spike_trains[i].T @ jPCs  # (n_timepoints, n_components)
        jPC_trajectories.append(traj)
    
    # Eigenvalues indicate rotation frequencies
    eigenvalues = np.linalg.eigvals(M)
    
    return np.array(jPC_trajectories), jPCs, eigenvalues


def fit_skew_symmetric(X, dX):
    """
    Fit skew-symmetric matrix M such that dX/dt ≈ M * X.
    
    Skew-symmetric: M = -M^T
    This ensures purely rotational dynamics (no scaling).
    """
    # Vectorized approach for skew-symmetric constraint
    n = X.shape[0]
    
    # For each time point: dX[:, t] = M @ X[:, t]
    # We want to find M that minimizes ||dX - M @ X||^2
    # Subject to: M = -M^T
    
    # Use least squares with constraint enforcement
    from scipy.optimize import minimize
    
    def objective(m_flat):
        M = m_flat.reshape(n, n)
        # Make skew-symmetric
        M_skew = (M - M.T) / 2
        pred = M_skew @ X
        return np.sum((dX - pred) ** 2)
    
    result = minimize(objective, np.zeros(n * n), method='L-BFGS-B')
    M_optimal = result.x.reshape(n, n)
    M_skew = (M_optimal - M_optimal.T) / 2
    
    return M_skew

dPCA: Demixed Principal Component Analysis

Separates neural activity into components dependent on different task parameters.

def dpca_analysis(spike_trains, labels, n_components_per_marginal=3):
    """
    Demixed PCA for separating task parameter contributions.
    
    Args:
        spike_trains: (n_stimuli, n_decisions, n_timepoints, n_neurons)
        labels: Dictionary mapping parameter names to axis indices
        n_components_per_marginal: Components per parameter
    
    Returns:
        dPCs: Demixed principal components
        marginal_variances: Variance per parameter
    """
    from dPCA import dPCA  # dPCA package
    
    # Initialize dPCA model
    dpca = dPCA(labels.keys(), n_components_per_marginal, regularizer='auto')
    
    # Fit and transform
    dpca.fit(spike_trains)
    transformed = dpca.transform(spike_trains)
    
    # Results organized by marginal (parameter)
    results = {}
    for marginal in labels.keys():
        results[marginal] = {
            'components': dpca.P[marginal],
            'variance': dpca.explained_variance_ratio_[marginal],
            'trajectory': transformed[marginal]
        }
    
    return results

GPFA: Gaussian Process Factor Analysis

For extracting smooth latent trajectories from noisy spike data.

class GPFA:
    """
    Gaussian Process Factor Analysis for neural population data.
    
    Combines factor analysis (linear dimensionality reduction)
    with Gaussian process priors (temporal smoothing).
    """
    
    def __init__(self, n_latent_dimensions=3, em_max_iters=100):
        self.n_latent = n_latent_dimensions
        self.em_max_iters = em_max_iters
        self.C = None  # Loading matrix
        self.R = None  # Noise covariance
        self.tau = None  # GP timescales
        
    def fit(self, spike_trains, bin_width=20):
        """
        Fit GPFA model using EM algorithm.
        
        Args:
            spike_trains: List of (n_neurons, n_timepoints) arrays
            bin_width: Bin width in ms
        """
        # Concatenate all trials
        Y = np.concatenate([y.T for y in spike_trains], axis=0)  # (total_time, n_neurons)
        
        n_time_total, n_neurons = Y.shape
        
        # Initialize with PCA
        pca = PCA(n_components=self.n_latent)
        X_init = pca.fit_transform(Y)
        self.C = pca.components_.T
        self.R = np.eye(n_neurons) * 0.1
        self.tau = np.ones(self.n_latent) * 100  # 100ms timescale
        
        # EM algorithm
        for iteration in range(self.em_max_iters):
            # E-step: Infer latent trajectories
            X, log_likelihood = self._e_step(spike_trains)
            
            # M-step: Update parameters
            self._m_step(X, spike_trains)
            
            print(f"Iteration {iteration}, Log-likelihood: {log_likelihood:.2f}")
        
        return self
    
    def _e_step(self, spike_trains):
        """Infer latent trajectories given parameters."""
        X_all = []
        total_ll = 0
        
        for Y in spike_trains:
            n_time = Y.shape[1]
            
            # Build GP covariance matrix for this trial
            K = self._build_gp_covariance(n_time)
            
            # Kalman smoothing to infer X
            X_trial, ll = self._kalman_smoother(Y, K)
            X_all.append(X_trial)
            total_ll += ll
        
        return X_all, total_ll
    
    def _build_gp_covariance(self, n_time):
        """Build GP covariance matrix."""
        times = np.arange(n_time)
        K = np.zeros((n_time, n_time))
        
        for i in range(n_time):
            for j in range(n_time):
                dt = abs(times[i] - times[j])
                K[i, j] = np.exp(-dt / self.tau[0])  # RBF kernel
        
        return K
    
    def _kalman_smoother(self, Y, K_gp):
        """Kalman smoothing for latent trajectory inference."""
        n_time, n_latent = len(Y), self.n_latent
        
        # Simplified implementation
        # In practice, use more sophisticated inference
        X_smooth = np.random.randn(n_time, n_latent)  # Placeholder
        log_likelihood = 0  # Placeholder
        
        return X_smooth, log_likelihood
    
    def _m_step(self, X_all, spike_trains):
        """Update parameters given latent trajectories."""
        # Update loading matrix C
        # Update noise covariance R
        # Update GP timescales tau
        pass
    
    def transform(self, spike_trains):
        """Transform spike data to latent trajectories."""
        X_all, _ = self._e_step(spike_trains)
        return X_all

Manifold Dynamics Analysis

Trajectory Analysis

def analyze_manifold_trajectory(projection, time_points):
    """
    Analyze dynamics on neural manifold.
    
    Args:
        projection: (n_timepoints, n_dimensions) trajectory
        time_points: Time vector
    
    Returns:
        dynamics_metrics: Dictionary of dynamical properties
    """
    # Compute velocity
    velocity = np.gradient(projection, axis=0) / np.gradient(time_points)[:, None]
    speed = np.linalg.norm(velocity, axis=1)
    
    # Compute acceleration
    acceleration = np.gradient(velocity, axis=0) / np.gradient(time_points)[:, None]
    
    # Curvature
    curvature = np.linalg.norm(np.cross(velocity[:-1], acceleration[:-1]), axis=1) / \
                (speed[:-1] ** 3 + 1e-10)
    
    # Tangling (trajectory uniqueness)
    # Low tangling = unique neural states for each time/condition
    tangling = compute_tangling(projection, velocity)
    
    return {
        'velocity': velocity,
        'speed': speed,
        'acceleration': acceleration,
        'curvature': curvature,
        'tangling': tangling
    }


def compute_tangling(trajectory, velocity):
    """
    Compute trajectory tangling metric.
    
    Tangling measures how often different neural states produce
    the same future trajectory. Low tangling is good for decoding.
    
    Based on: Russo et al. (2018) Motor cortex embeds muscle-like commands
    """
    n_time = len(trajectory)
    tangling_values = []
    
    for t in range(n_time):
        # Find nearby points in state space
        distances = np.linalg.norm(trajectory - trajectory[t], axis=1)
        nearby = np.where(distances < np.percentile(distances, 10))[0]
        
        # Compare velocity directions
        if len(nearby) > 1:
            v_current = velocity[t]
            v_nearby = velocity[nearby]
            
            # Cosine similarity of velocities
            similarities = np.dot(v_nearby, v_current) / \
                          (np.linalg.norm(v_nearby, axis=1) * np.linalg.norm(v_current) + 1e-10)
            
            # Tangling = variance in future directions
            tangling = 1 - np.mean(similarities)
            tangling_values.append(tangling)
    
    return np.mean(tangling_values)

Applications

1. Motor Cortex Decoding

def decode_movement_from_manifold(spike_trains, behavior, decoder_type='velocity'):
    """
    Decode movement kinematics from neural manifold.
    
    Args:
        spike_trains: Neural activity (n_trials, n_neurons, n_time)
        behavior: Movement data (n_trials, n_dims, n_time)
        decoder_type: 'velocity', 'position', or 'force'
    
    Returns:
        decoder: Trained decoder model
        predictions: Decoded behavior
        accuracy: Decoding performance
    """
    # Extract manifold
    manifold, components, _ = extract_neural_manifold(
        spike_trains, method='jPCA', n_components=6
    )
    
    # Prepare data for decoding
    X = manifold.reshape(-1, manifold.shape[-1])  # Flatten trials
    y = behavior.reshape(-1, behavior.shape[-1])
    
    # Train decoder (e.g., Kalman filter or Wiener filter)
    from sklearn.linear_model import Ridge
    decoder = Ridge(alpha=1.0)
    decoder.fit(X, y)
    
    # Evaluate
    predictions = decoder.predict(X)
    accuracy = compute_decoding_accuracy(predictions, y, decoder_type)
    
    return decoder, predictions, accuracy

2. Preparatory Activity Analysis

def analyze_preparatory_manifold(spike_trains, cue_onset, movement_onset):
    """
    Analyze preparatory activity in neural manifold.
    
    Key question: How is movement prepared before execution?
    """
    # Extract preparatory period activity
    prep_spikes = []
    for trial in spike_trains:
        cue_idx = cue_onset
        move_idx = movement_onset
        prep_activity = trial[:, cue_idx:move_idx]
        prep_spikes.append(prep_activity)
    
    # Analyze manifold during preparation
    prep_manifold, _, _ = extract_neural_manifold(
        np.array(prep_spikes), method='dPCA', n_components=4
    )
    
    # Analyze how preparatory states relate to upcoming movement
    prep_movement_correlation = compute_prep_movement_correlation(
        prep_manifold, behavior
    )
    
    return prep_manifold, prep_movement_correlation

3. Learning-Related Manifold Changes

def track_manifold_learning(spike_trains_early, spike_trains_late, behavior):
    """
    Track how neural manifold changes with learning.
    """
    # Extract manifolds
    man_early, comp_early, _ = extract_neural_manifold(spike_trains_early)
    man_late, comp_late, _ = extract_neural_manifold(spike_trains_late)
    
    # Compare manifold structure
    # 1. Dimensionality
    dim_early = estimate_manifold_dimensionality(spike_trains_early)
    dim_late = estimate_manifold_dimensionality(spike_trains_late)
    
    # 2. Alignment with behavior
    align_early = compute_behavior_alignment(comp_early, behavior)
    align_late = compute_behavior_alignment(comp_late, behavior)
    
    # 3. Trajectory consistency
    consistency_early = compute_trajectory_consistency(man_early)
    consistency_late = compute_trajectory_consistency(man_late)
    
    return {
        'dimensionality_change': dim_late - dim_early,
        'behavior_alignment_change': align_late - align_early,
        'consistency_change': consistency_late - consistency_early
    }

Visualization

def plot_neural_manifold(projection, conditions=None, colors=None, 
                         method='3d_trajectory'):
    """
    Visualize neural manifold.
    
    Args:
        projection: (n_timepoints, n_dimensions) or list of trajectories
        conditions: Labels for different conditions
        colors: Color for each condition
        method: '3d_trajectory', 'heat_map', 'state_space'
    """
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    
    fig = plt.figure(figsize=(12, 4))
    
    if method == '3d_trajectory':
        ax = fig.add_subplot(111, projection='3d')
        
        if isinstance(projection, list):
            for i, traj in enumerate(projection):
                color = colors[i] if colors else None
                label = conditions[i] if conditions else None
                ax.plot(traj[:, 0], traj[:, 1], traj[:, 2], 
                       color=color, label=label, alpha=0.6)
        else:
            ax.plot(projection[:, 0], projection[:, 1], projection[:, 2])
        
        ax.set_xlabel('jPC1')
        ax.set_ylabel('jPC2')
        ax.set_zlabel('jPC3')
        
    elif method == 'heat_map':
        # Heat map of activity on manifold
        ax = fig.add_subplot(111)
        scatter = ax.scatter(projection[:, 0], projection[:, 1], 
                           c=conditions, cmap='viridis', alpha=0.5)
        plt.colorbar(scatter, label='Condition')
    
    plt.legend()
    plt.tight_layout()
    return fig

References

  1. Churchland, M. M., et al. (2012). Neural population dynamics during reaching. Nature, 487(7405), 51-56.
  2. Cunningham, J. P., & Yu, B. M. (2014). Dimensionality reduction for large-scale neural recordings. Nature Neuroscience, 17(11), 1500-1509.
  3. Gallego, J. A., et al. (2017). Neural manifolds for the control of movement. Neuron, 94(5), 978-984.
  4. Kaufman, M. T., et al. (2014). The role of premotor cortex in reaching movements. Neuron, 84(2), 468-482.
  5. Russo, A. A., et al. (2018). Motor cortex embeds muscle-like commands in an untangled population response. Neuron, 97(4), 953-966.
  6. Yu, B. M., et al. (2009). Gaussian-process factor analysis for low-dimensional single-trial analysis of neural population activity. Journal of Neurophysiology, 102(1), 614-635.

Activation Keywords

  • neural manifold
  • population dynamics
  • latent dynamics
  • jPCA rotation
  • dPCA demixing
  • GPFA trajectory
  • neural state space
  • dimensionality reduction neuroscience
  • neural decoding manifold
  • behavior encoding
Install via CLI
npx skills add https://github.com/hiyenwong/ai_collection --skill neural-manifold-dynamics-learning
Repository Details
star Stars 2
call_split Forks 0
navigation Branch main
article Path SKILL.md
More from Creator