learning-neuron-dynamics-deep-snn

star 2

Learning rich neuron dynamics within deep Spiking Neural Networks (SNNs). Moves beyond simple LIF neuron models to capture complex temporal dynamics in deep SNN architectures. Enables more biologically realistic and computationally powerful spiking networks. Applicable to deep SNN training, advanced neuron modeling, temporal feature learning. Trigger: deep SNN neuron dynamics, LIF limitations, complex neuron models SNN, temporal dynamics spiking networks, neuron model learning

hiyenwong By hiyenwong schedule Updated 6/3/2026

name: learning-neuron-dynamics-deep-snn version: v1.0.0 last_updated: 2026-04-19 description: Learning rich neuron dynamics within deep Spiking Neural Networks (SNNs). Moves beyond simple LIF neuron models to capture complex temporal dynamics in deep SNN architectures. Enables more biologically realistic and computationally powerful spiking networks. Applicable to deep SNN training, advanced neuron modeling, temporal feature learning. Trigger: deep SNN neuron dynamics, LIF limitations, complex neuron models SNN, temporal dynamics spiking networks, neuron model learning

Learning Neuron Dynamics within Deep Spiking Neural Networks

Description

A methodology for learning rich, complex neuron dynamics within deep Spiking Neural Networks (SNNs), moving beyond the limitations of simple Leaky Integrate-and-Fire (LIF) neuron models. Deep SNNs constrained to simple neuron models cannot capture the rich temporal dynamics observed in biological neurons, limiting their computational power.

Based on: "Learning Neuron Dynamics within Deep Spiking Neural Networks" (arXiv:2510.07341, October 2025)

Problem

  • Standard deep SNNs use simple LIF neuron models
  • LIF cannot capture rich temporal dynamics of biological neurons
  • Limited expressivity restricts SNN performance on complex tasks
  • Need trainable neuron dynamics that adapt to task requirements

Approach

Learn neuron dynamics as part of the network training process, rather than fixing them to predefined models. Key idea: parameterize neuron dynamics flexibly and optimize them end-to-end alongside synaptic weights.

Trainable Neuron Dynamics

class TrainableNeuron(nn.Module):
    """Neuron with learnable dynamics parameters."""
    
    def __init__(self):
        super().__init__()
        # Learnable time constants
        self.tau_mem = nn.Parameter(torch.tensor(10.0))
        self.tau_syn = nn.Parameter(torch.tensor(5.0))
        
        # Learnable threshold dynamics
        self.threshold = nn.Parameter(torch.tensor(1.0))
        self.threshold_decay = nn.Parameter(torch.tensor(0.5))
        
        # Learnable reset mechanism
        self.reset_value = nn.Parameter(torch.tensor(0.0))
    
    def forward(self, input_current, prev_state):
        membrane, threshold, syn_current = prev_state
        
        # Synaptic dynamics (learnable)
        syn_current = syn_current * (1 - 1/self.tau_syn) + input_current
        
        # Membrane dynamics (learnable)
        membrane = membrane * (1 - 1/self.tau_mem) + syn_current
        
        # Spike generation
        spike = (membrane > threshold).float()
        
        # Adaptive threshold
        threshold = threshold * (1 - self.threshold_decay) + spike
        
        # Reset
        membrane = membrane * (1 - spike) + self.reset_value * spike
        
        return spike, (membrane, threshold, syn_current)

Key Innovations

  1. Learnable time constants: Membrane and synaptic time constants optimized per layer or per neuron
  2. Adaptive thresholds: Dynamic thresholds that adjust based on activity
  3. Flexible reset mechanisms: Beyond simple reset-to-zero
  4. Multi-timescale dynamics: Capture both fast and slow neural processes

Training Strategy

Surrogate Gradient Approach

class MultiTimescaleSurrogate(torch.autograd.Function):
    """Surrogate gradient for multi-timescale neuron dynamics."""
    
    @staticmethod
    def forward(ctx, input_val, threshold):
        ctx.save_for_backward(input_val, threshold)
        return (input_val > threshold).float()
    
    @staticmethod
    def backward(ctx, grad_output):
        input_val, threshold = ctx.saved_tensors
        # Smooth gradient around threshold
        diff = input_val - threshold
        sigma = 0.5
        grad_input = grad_output * torch.exp(-diff**2 / (2 * sigma**2))
        return grad_input, None

End-to-End Training

def train_deep_snn_with_learnable_neurons(model, data, epochs=100):
    """
    Train deep SNN with learnable neuron dynamics.
    
    All neuron parameters are optimized alongside synaptic weights.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        
        # Forward through SNN with learnable neurons
        spikes = model(data)
        
        # Compute loss
        loss = criterion(spikes, targets)
        
        # Backprop through surrogate gradients
        loss.backward()
        optimizer.step()

Benefits Over Fixed LIF

Aspect Fixed LIF Learnable Neuron Dynamics
Time constants Fixed Optimized per layer/neuron
Threshold Static Adaptive during inference
Reset Hard reset Smooth/learnable
Expressivity Limited Rich temporal dynamics
Task adaptation None Automatic

Applications

  • Temporal sequence processing: Better capture of temporal patterns
  • Speech/audio recognition: Exploit multi-timescale dynamics
  • Neuromorphic deployment: Hardware-aware neuron optimization
  • Biological plausibility: More realistic neuron behavior

Design Guidelines

  1. Start with LIF parameters as initialization
  2. Allow dynamics to specialize per layer (early layers faster, later layers slower)
  3. Use surrogate gradients compatible with learnable dynamics
  4. Regularize extreme parameter values to maintain stability
Install via CLI
npx skills add https://github.com/hiyenwong/ai_collection --skill learning-neuron-dynamics-deep-snn
Repository Details
star Stars 2
call_split Forks 0
navigation Branch main
article Path SKILL.md
More from Creator