neural-dynamics-universal-translator-foundation

star 1

Foundation model for neural spiking data using multi-task masking (MtM) to translate across population, region, and single-neuron levels. Enables zero-shot and few-shot brain decoding across multiple brain areas. Activation triggers: neural translator, foundation model spiking, MtM, multi-task masking, IBL dataset, brain decoding.

hiyenwong By hiyenwong schedule Updated 6/3/2026

name: neural-dynamics-universal-translator-foundation description: "Foundation model for neural spiking data using multi-task masking (MtM) to translate across population, region, and single-neuron levels. Enables zero-shot and few-shot brain decoding across multiple brain areas. Activation triggers: neural translator, foundation model spiking, MtM, multi-task masking, IBL dataset, brain decoding."

Neural Dynamics Universal Translator Foundation

A foundation model for neural spiking data that seamlessly "translates" across all spatial scales of the brain through multi-task masking self-supervised learning.

Metadata

  • Source: arXiv:2407.14668
  • Authors: Yizi Zhang, Yanchen Wang, Donato Jiménez Benetó, Zixuan Wang, Mehdi Azabou, Blake Richards, Olivier Winter, Eva Dyer, Liam Paninski, Cole Hurwitz
  • Published: 2024-07
  • Code: https://ibl-mtm.github.io/

Core Methodology

Key Innovation

The Neural Dynamics Universal Translator introduces a Multi-Task Masking (MtM) approach that enables a single foundation model to:

  • Process neural activity across different time steps, neurons, and brain regions
  • Generalize to unseen animals with unspecified neuron correspondence
  • Perform few-shot learning with minimal supervision
  • Bridge the gap between brain region-specific models and whole-brain analysis

Technical Framework

1. Multi-Task Masking Strategy

The model alternates between masking and reconstructing neural activity across three dimensions:

Masking Type Description Prompt Token
Temporal Masking Masks time steps TEMPORAL
Neuron Masking Masks individual neurons NEURON
Region Masking Masks entire brain regions REGION

Each masking objective is associated with a learnable prompt token that enables "mode switching" during evaluation.

2. Model Architecture

Input: Neural spike trains (time × neurons)
│
├─► Embedding Layer: Convert spike counts to embeddings
│
├─► Transformer Encoder: Process temporal dependencies
│
├─► Prompt Token: Select masking objective (TEMPORAL/NEURON/REGION)
│
└─► Output: Reconstruct masked activity

3. Training Dataset

  • International Brain Laboratory (IBL) Repeated Site Dataset
  • 48 animals across multiple experimental sessions
  • Target regions: Secondary visual areas, hippocampus, thalamus
  • Neuropixels recordings with consistent anatomical targeting

Mathematical Formulation

Self-Supervised Objective

Given neural activity tensor $X \in \mathbb{R}^{T \times N}$ where $T$ is time steps and $N$ is neurons:

  1. Apply masking based on selected mode $m$:

    • Temporal: Mask $X_{t_1:t_2, :}$
    • Neuron: Mask $X_{:, \mathcal{N}_{masked}}$
    • Region: Mask $X_{:, \mathcal{R}_{masked}}$
  2. Add prompt embedding $p_m$ corresponding to mode $m$

  3. Minimize reconstruction loss: $$\mathcal{L} = \mathbb{E}_{X \sim \mathcal{D}} \left[ | X - \hat{X} |^2 \right]$$

where $\hat{X} = f_\theta(X_{unmasked}, p_m)$

Implementation Guide

Prerequisites

# Required packages
pip install torch numpy scipy pandas
pip install ibl-neuropixel  # For IBL dataset access

Step-by-Step

Step 1: Load and Preprocess Neural Data

import numpy as np
from scipy.io import loadmat

def load_neural_data(session_path):
    """Load Neuropixels data from IBL dataset"""
    # Load spike times and clusters
    spikes = loadmat(f"{session_path}/spikes.times.npy")
    clusters = loadmat(f"{session_path}/spikes.clusters.npy")
    
    # Bin spikes into time windows
    bin_size = 0.01  # 10ms bins
    duration = 2.0   # 2 second trials
    
    # Create spike count matrix (time × neurons)
    spike_counts = bin_spikes(spikes, clusters, bin_size, duration)
    return spike_counts

def bin_spikes(spike_times, clusters, bin_size, duration):
    """Convert spike times to count matrix"""
    n_bins = int(duration / bin_size)
    n_neurons = int(clusters.max()) + 1
    
    counts = np.zeros((n_bins, n_neurons))
    for i in range(n_bins):
        t_start = i * bin_size
        t_end = (i + 1) * bin_size
        mask = (spike_times >= t_start) & (spike_times < t_end)
        active_clusters = clusters[mask]
        for c in active_clusters:
            counts[i, c] += 1
    
    return counts

Step 2: Implement Multi-Task Masking

import torch
import torch.nn as nn

class MultiTaskMasking(nn.Module):
    """Multi-task masking for neural activity"""
    
    def __init__(self, n_neurons, embed_dim=128, n_heads=8):
        super().__init__()
        self.embed_dim = embed_dim
        
        # Learnable prompt tokens
        self.prompts = nn.ParameterDict({
            'temporal': nn.Parameter(torch.randn(1, 1, embed_dim)),
            'neuron': nn.Parameter(torch.randn(1, 1, embed_dim)),
            'region': nn.Parameter(torch.randn(1, 1, embed_dim))
        })
        
        # Embedding layer
        self.input_embedding = nn.Linear(n_neurons, embed_dim)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=n_heads,
            dim_feedforward=512,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=6)
        
        # Output projection
        self.output_proj = nn.Linear(embed_dim, n_neurons)
    
    def forward(self, x, mode, mask_indices):
        """
        Args:
            x: Input spike counts (batch, time, neurons)
            mode: 'temporal', 'neuron', or 'region'
            mask_indices: Indices to mask
        Returns:
            reconstructed: Reconstructed activity
        """
        # Create mask
        mask = torch.ones_like(x)
        mask[mask_indices] = 0
        
        # Embed input
        x_embed = self.input_embedding(x * mask)  # Masked input
        
        # Add prompt token
        prompt = self.prompts[mode]
        x_embed = torch.cat([prompt.expand(x.size(0), -1, -1), x_embed], dim=1)
        
        # Transform
        h = self.transformer(x_embed)
        
        # Project output (skip prompt token)
        reconstructed = self.output_proj(h[:, 1:, :])
        
        return reconstructed

Step 3: Training Loop

def train_mtm_model(model, dataloader, epochs=100, lr=1e-4):
    """Train multi-task masking model"""
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    modes = ['temporal', 'neuron', 'region']
    mask_ratio = 0.15
    
    for epoch in range(epochs):
        total_loss = 0
        
        for batch in dataloader:
            x = batch['activity']  # (batch, time, neurons)
            batch_size, T, N = x.shape
            
            # Randomly select masking mode
            mode = np.random.choice(modes)
            
            # Create mask based on mode
            if mode == 'temporal':
                mask_indices = create_temporal_mask(batch_size, T, N, mask_ratio)
            elif mode == 'neuron':
                mask_indices = create_neuron_mask(batch_size, T, N, mask_ratio)
            else:  # region
                mask_indices = create_region_mask(batch_size, T, N, mask_ratio)
            
            # Forward pass
            reconstructed = model(x, mode, mask_indices)
            
            # Compute loss only on masked positions
            loss = criterion(reconstructed[mask_indices], x[mask_indices])
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {total_loss / len(dataloader):.4f}")
    
    return model

def create_temporal_mask(batch_size, T, N, ratio):
    """Create temporal masking indices"""
    n_mask = int(T * ratio)
    mask_t = torch.randperm(T)[:n_mask]
    mask_indices = torch.cartesian_prod(
        torch.arange(batch_size),
        mask_t,
        torch.arange(N)
    )
    return mask_indices[:, 0], mask_indices[:, 1], mask_indices[:, 2]

Step 4: Few-Shot Adaptation

class FewShotAdapter(nn.Module):
    """Linear probe for few-shot downstream tasks"""
    
    def __init__(self, encoder, output_dim):
        super().__init__()
        self.encoder = encoder
        # Freeze encoder
        for param in self.encoder.parameters():
            param.requires_grad = False
        
        self.classifier = nn.Linear(encoder.embed_dim, output_dim)
    
    def forward(self, x):
        # Extract features using frozen encoder
        with torch.no_grad():
            features = self.encoder(x, mode='neuron', mask_indices=None)
        
        # Classify
        return self.classifier(features.mean(dim=1))  # Pool over time

def few_shot_adapt(model, support_set, n_shots=5):
    """
    Adapt model with few labeled examples
    
    Args:
        model: Pretrained MtM model
        support_set: Dict with 'activity' and 'labels'
        n_shots: Number of examples per class
    """
    # Create linear probe
    adapter = FewShotAdapter(model, output_dim=n_classes)
    
    # Train only classifier on support set
    optimizer = torch.optim.Adam(adapter.classifier.parameters(), lr=1e-3)
    
    for epoch in range(50):  # Few epochs for few-shot
        logits = adapter(support_set['activity'])
        loss = F.cross_entropy(logits, support_set['labels'])
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    return adapter

Applications

1. Brain-Computer Interfaces (BCI)

  • Zero-shot decoding on new subjects
  • Reduced calibration time for neural prosthetics
  • Cross-subject motor imagery classification

2. Cross-Region Neural Analysis

  • Study information flow between brain regions
  • Identify region-specific neural codes
  • Map distributed neural computation

3. Behavior Prediction

  • Decode behavioral states from neural activity
  • Predict decision-making processes
  • Analyze cognitive task performance

4. Neurological Disorder Research

  • Compare neural dynamics across patient populations
  • Identify biomarkers for brain disorders
  • Track disease progression

Benchmarks

Task Metric Performance
Single-neuron prediction 0.72
Region-level prediction 0.68
Forward prediction (200ms) MAE 0.15
Behavior decoding (choice) Accuracy 82%
Cross-animal generalization 0.61

Pitfalls

  • Data Quality: Model performance heavily depends on spike sorting quality
  • Temporal Resolution: Requires high temporal resolution recordings (≥1kHz sampling)
  • Recording Stability: Assumes consistent electrode placement across sessions
  • Animal Variability: May require fine-tuning for animals with significant anatomical differences
  • Computational Cost: Large transformer models require significant GPU memory

Related Skills

  • brain-dit-fmri-foundation-model
  • neurostorm-fmri-foundation
  • reve-eeg-foundation
  • spike-mllm-multimodal-spiking
  • meta-learning-in-context-brain-decoding

References

@article{zhang2024universal,
  title={Towards a "universal translator" for neural dynamics at single-cell, single-spike resolution},
  author={Zhang, Yizi and Wang, Yanchen and Jim{\'e}nez Benet{\'o}, Donato and Wang, Zixuan and Azabou, Mehdi and Richards, Blake and Winter, Olivier and others},
  journal={arXiv preprint arXiv:2407.14668},
  year={2024}
}
Install via CLI
npx skills add https://github.com/hiyenwong/ai_collection --skill neural-dynamics-universal-translator-foundation
Repository Details
star Stars 1
call_split Forks 0
navigation Branch main
article Path SKILL.md
More from Creator