dynamic-gated-neuron-snn

star 1

Dynamic Gated Neuron (DGN) - Biologically plausible gating mechanism for Spiking Neural Networks via dynamic membrane conductance modulation. Enables selective input filtering and adaptive noise suppression. Activation triggers: dynamic gated neuron, DGN, SNN gating, conductance-based SNN, robust spiking neural network, biological gating.

hiyenwong By hiyenwong schedule Updated 6/3/2026

name: dynamic-gated-neuron-snn description: "Dynamic Gated Neuron (DGN) - Biologically plausible gating mechanism for Spiking Neural Networks via dynamic membrane conductance modulation. Enables selective input filtering and adaptive noise suppression. Activation triggers: dynamic gated neuron, DGN, SNN gating, conductance-based SNN, robust spiking neural network, biological gating."

Dynamic Gated Neuron (DGN) for SNNs

A brain-inspired gating mechanism that unlocks robust computation in Spiking Neural Networks through dynamic membrane conductance modulation.

Metadata

  • Source: arXiv:2509.03281
  • Authors: Qianyi Bai, Haiteng Wang, Qiang Yu
  • Published: 2025-09
  • Institution: Tianjin University, Tianjin Normal University
  • Code: TBD (check paper for updates)

Core Methodology

Key Innovation

Traditional Leaky Integrate-and-Fire (LIF) neurons lack internal gating mechanisms, limiting their ability to cope with noise and temporal variability. The Dynamic Gated Neuron (DGN) introduces:

  1. Activity-Dependent Conductance: Membrane conductance evolves dynamically in response to neuronal activity
  2. Selective Input Filtering: Adaptive noise suppression based on input dynamics
  3. Stochastic Stability: Enhanced stability guarantees under noisy conditions
  4. Biological Plausibility: Grounded in real neurophysiological mechanisms (protein phosphorylation, gene expression, calcium signaling)

Biological Inspiration

Biological Mechanism Computational Analog Function
Protein phosphorylation Activity tracking State-dependent modulation
Immediate early genes (c-fos, ras) Conductance update Long-term plasticity
Intracellular calcium Second messenger Activity-to-conductance coupling
Potassium channel modulation Dynamic conductance Adaptive filtering

Technical Framework

Dynamic Gated Neuron Model

The DGN extends the LIF neuron with dynamic conductance:

$$ \tau_m \frac{dv}{dt} = -(v - v_{rest}) - g(t) \cdot v + I_{syn}(t) $$

$$ \tau_g \frac{dg}{dt} = -g + \alpha \cdot \phi(v_{history}) $$

Where:

  • $v$: Membrane potential
  • $g(t)$: Dynamic conductance (gating variable)
  • $\tau_m, \tau_g$: Time constants for membrane and conductance
  • $\phi$: Activity-dependent modulation function
  • $\alpha$: Conductance gain

Gating Function

The gating mechanism modulates information flow:

# Dynamic conductance acts as adaptive filter
g_t = g_{t-1} + α * tanh(β * v_t - γ) - g_{t-1}/τ_g

# Effective input current
I_eff = I_syn / (1 + g_t)  # Conductance shunts input

# Membrane update
dv = (-(v - v_rest) + I_eff) / τ_m

Key Properties:

  • High conductance → Reduced membrane time constant → Fast response to salient inputs
  • Low conductance → Extended membrane time constant → Integration of weak signals
  • Adaptive threshold: Effectively implements dynamic input filtering

Implementation Guide

Prerequisites

# Required packages
pip install torch snntorch  # For SNN implementation
pip install numpy matplotlib

Step-by-Step

Step 1: Basic DGN Implementation

import torch
import torch.nn as nn
import numpy as np

class DynamicGatedNeuron(nn.Module):
    """
    Dynamic Gated Neuron (DGN)
    
    Biologically plausible spiking neuron with dynamic conductance modulation.
    """
    
    def __init__(
        self,
        tau_m=20.0,        # Membrane time constant (ms)
        tau_g=100.0,       # Conductance time constant (ms)
        v_rest=-65.0,      # Resting potential (mV)
        v_thresh=-50.0,    # Threshold potential (mV)
        v_reset=-70.0,     # Reset potential (mV)
        alpha=0.1,         # Conductance gain
        beta=0.5,          # Activity sensitivity
        gamma=0.0,         # Activity offset
        dt=1.0             # Time step (ms)
    ):
        super().__init__()
        
        self.tau_m = tau_m
        self.tau_g = tau_g
        self.v_rest = v_rest
        self.v_thresh = v_thresh
        self.v_reset = v_reset
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.dt = dt
        
        # State variables (initialized in forward)
        self.v = None
        self.g = None
        self.spike = None
    
    def reset_state(self, batch_size, device):
        """Reset neuron state"""
        self.v = torch.full((batch_size,), self.v_rest, device=device)
        self.g = torch.zeros(batch_size, device=device)
        self.spike = torch.zeros(batch_size, device=device)
    
    def forward(self, I_syn):
        """
        Single time step forward pass
        
        Args:
            I_syn: Synaptic input current (batch_size,)
        
        Returns:
            spike: Binary spike output (batch_size,)
        """
        # Ensure state is initialized
        if self.v is None:
            self.reset_state(I_syn.size(0), I_syn.device)
        
        # Update dynamic conductance
        # g(t) evolves based on recent activity (proxied by membrane potential)
        activity = torch.tanh(self.beta * (self.v - self.v_rest) - self.gamma)
        dg = (-self.g + self.alpha * activity) / self.tau_g * self.dt
        self.g = self.g + dg
        self.g = torch.clamp(self.g, min=0)  # Non-negative conductance
        
        # Compute effective input (gating effect)
        # Higher conductance shunts more current
        I_eff = I_syn / (1 + self.g)
        
        # Update membrane potential
        dv = (-(self.v - self.v_rest) + I_eff) / self.tau_m * self.dt
        self.v = self.v + dv
        
        # Check for spike
        self.spike = (self.v >= self.v_thresh).float()
        
        # Reset if spiked
        self.v = torch.where(
            self.spike > 0,
            torch.full_like(self.v, self.v_reset),
            self.v
        )
        
        return self.spike
    
    def get_gating_strength(self):
        """Return current gating strength (for analysis)"""
        return self.g.clone()


class DGNLayer(nn.Module):
    """Layer of Dynamic Gated Neurons"""
    
    def __init__(self, n_neurons, **kwargs):
        super().__init__()
        self.n_neurons = n_neurons
        self.neurons = nn.ModuleList([
            DynamicGatedNeuron(**kwargs) for _ in range(n_neurons)
        ])
    
    def reset_state(self, batch_size, device):
        for neuron in self.neurons:
            neuron.reset_state(batch_size, device)
    
    def forward(self, I_syn):
        """
        Args:
            I_syn: (batch_size, n_neurons) - input currents
        Returns:
            spikes: (batch_size, n_neurons) - output spikes
        """
        spikes = []
        for i, neuron in enumerate(self.neurons):
            spike = neuron(I_syn[:, i])
            spikes.append(spike)
        return torch.stack(spikes, dim=1)

Step 2: DGN Network for Pattern Recognition

class DGNSNN(nn.Module):
    """
    Spiking Neural Network with Dynamic Gated Neurons
    for robust pattern recognition
    """
    
    def __init__(
        self,
        input_size,
        hidden_size,
        output_size,
        n_time_steps=100
    ):
        super().__init__()
        
        self.n_time_steps = n_time_steps
        
        # Input projection
        self.input_fc = nn.Linear(input_size, hidden_size)
        
        # Hidden layer with DGN
        self.hidden = DGNLayer(
            n_neurons=hidden_size,
            tau_m=20.0,
            tau_g=50.0,  # Faster conductance adaptation
            alpha=0.2
        )
        
        # Recurrent connections
        self.recurrent = nn.Linear(hidden_size, hidden_size)
        
        # Output layer (rate coding)
        self.output_fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        """
        Args:
            x: (batch_size, n_time_steps, input_size)
        Returns:
            output: (batch_size, output_size)
        """
        batch_size = x.size(0)
        device = x.device
        
        # Reset states
        self.hidden.reset_state(batch_size, device)
        
        # Record hidden spike trains
        hidden_spikes = []
        
        for t in range(self.n_time_steps):
            # Input current
            I_in = self.input_fc(x[:, t])
            
            # Recurrent current
            if t > 0:
                I_rec = self.recurrent(hidden_spikes[-1])
            else:
                I_rec = torch.zeros_like(I_in)
            
            # Total synaptic current
            I_syn = I_in + I_rec
            
            # DGN forward pass
            spikes = self.hidden(I_syn)
            hidden_spikes.append(spikes)
        
        # Stack spike trains
        hidden_spikes = torch.stack(hidden_spikes, dim=1)  # (B, T, H)
        
        # Rate coding: sum spikes over time
        spike_rates = hidden_spikes.sum(dim=1) / self.n_time_steps
        
        # Output
        output = self.output_fc(spike_rates)
        
        return output

Step 3: Training with Surrogate Gradients

import torch.nn.functional as F
from torch.utils.data import DataLoader

def surrogate_gradient(spike, v, v_thresh, alpha=1.0):
    """
    Surrogate gradient for backpropagation through spikes
    Using fast sigmoid surrogate
    """
    return alpha / (1 + (v - v_thresh).pow(2))


class SurrogateGradient(torch.autograd.Function):
    """Custom surrogate gradient for spiking neurons"""
    
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return (input > 0).float()
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        # Fast sigmoid derivative
        grad_input = grad_output / (1 + input.abs()).pow(2)
        return grad_input


def train_dgn_snn(
    model,
    train_loader,
    epochs=50,
    lr=1e-3,
    device='cuda'
):
    """Train DGN-SNN with surrogate gradients"""
    
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            # Add time dimension if not present
            if data.dim() == 2:
                # Static inputs - repeat for n_time_steps
                data = data.unsqueeze(1).repeat(1, model.n_time_steps, 1)
            
            optimizer.zero_grad()
            
            # Forward pass
            output = model(data)
            
            # Cross-entropy loss
            loss = F.cross_entropy(output, target)
            
            # Backward pass with surrogate gradients
            loss.backward()
            optimizer.step()
            
            # Statistics
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
        
        acc = 100. * correct / total
        print(f'Epoch {epoch}: Loss = {total_loss/len(train_loader):.4f}, '
              f'Acc = {acc:.2f}%')
    
    return model

Step 4: Robustness Evaluation

def evaluate_robustness(model, test_loader, noise_levels, device='cuda'):
    """
    Evaluate model robustness under different noise conditions
    """
    model.eval()
    results = {}
    
    for noise_std in noise_levels:
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                
                # Add Gaussian noise
                if data.dim() == 2:
                    data = data.unsqueeze(1).repeat(1, model.n_time_steps, 1)
                
                noisy_data = data + torch.randn_like(data) * noise_std
                noisy_data = torch.clamp(noisy_data, 0, 1)
                
                # Forward pass
                output = model(noisy_data)
                pred = output.argmax(dim=1)
                
                correct += (pred == target).sum().item()
                total += target.size(0)
        
        acc = 100. * correct / total
        results[noise_std] = acc
        print(f'Noise std = {noise_std:.3f}: Accuracy = {acc:.2f}%')
    
    return results


def compare_with_lif(dgn_model, lif_model, test_loader, device='cuda'):
    """
    Compare DGN with standard LIF neuron performance
    """
    results = {'DGN': {}, 'LIF': {}}
    
    # Test on clean data
    for name, model in [('DGN', dgn_model), ('LIF', lif_model)]:
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                pred = output.argmax(dim=1)
                correct += (pred == target).sum().item()
                total += target.size(0)
        
        results[name]['clean'] = 100. * correct / total
    
    # Test under noise
    noise_std = 0.2
    for name, model in [('DGN', dgn_model), ('LIF', lif_model)]:
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                noisy_data = data + torch.randn_like(data) * noise_std
                output = model(torch.clamp(noisy_data, 0, 1))
                pred = output.argmax(dim=1)
                correct += (pred == target).sum().item()
                total += target.size(0)
        
        results[name]['noisy'] = 100. * correct / total
    
    print("\nRobustness Comparison:")
    print(f"DGN: Clean = {results['DGN']['clean']:.2f}%, "
          f"Noisy = {results['DGN']['noisy']:.2f}%")
    print(f"LIF: Clean = {results['LIF']['clean']:.2f}%, "
          f"Noisy = {results['LIF']['noisy']:.2f}%")
    
    return results

Step 5: Temporal Processing - TIDIGITS

class DGNForAudio(nn.Module):
    """
    DGN-based SNN for temporal audio processing
    Applied to TIDIGITS dataset
    """
    
    def __init__(
        self,
        n_freq_bins=40,
        hidden_size=256,
        output_size=11,  # 10 digits + silence
        n_time_steps=500
    ):
        super().__init__()
        
        self.n_time_steps = n_time_steps
        
        # Input: Mel spectrogram features
        self.input_proj = nn.Linear(n_freq_bins, hidden_size)
        
        # DGN hidden layer
        self.hidden = DGNLayer(
            n_neurons=hidden_size,
            tau_m=10.0,   # Faster dynamics for audio
            tau_g=30.0,   # Adaptive time constant
            alpha=0.3,    # Stronger gating
            beta=0.8,     # Sensitive to activity
            gamma=-5.0    # Threshold for gating activation
        )
        
        # Recurrent connections
        self.recurrent = nn.Linear(hidden_size, hidden_size)
        
        # Readout layer
        self.readout = nn.Linear(hidden_size, output_size)
    
    def forward(self, mel_spec):
        """
        Args:
            mel_spec: (batch, time, freq_bins) - Mel spectrogram
        Returns:
            output: (batch, output_size) - Digit classification
        """
        batch_size = mel_spec.size(0)
        device = mel_spec.device
        
        # Reset states
        self.hidden.reset_state(batch_size, device)
        
        hidden_spikes = []
        gating_history = []
        
        for t in range(min(mel_spec.size(1), self.n_time_steps)):
            # Project input
            I_in = self.input_proj(mel_spec[:, t])
            
            # Recurrent input from previous spikes
            if t > 0:
                I_rec = self.recurrent(hidden_spikes[-1])
            else:
                I_rec = torch.zeros_like(I_in)
            
            # Total current
            I_total = I_in + I_rec
            
            # DGN step
            spikes = self.hidden(I_total)
            hidden_spikes.append(spikes)
            
            # Record gating strength for analysis
            gating_history.append(self.hidden.neurons[0].get_gating_strength()[0])
        
        hidden_spikes = torch.stack(hidden_spikes, dim=1)
        
        # Temporal pooling (late integration)
        spike_sum = hidden_spikes.sum(dim=1)
        output = self.readout(spike_sum)
        
        return output, gating_history


def load_tidigits():
    """Load and preprocess TIDIGITS dataset"""
    # Note: TIDIGITS requires torchaudio or custom loading
    # This is a placeholder
    from torchaudio.datasets import TIDIGITS
    
    # Apply mel spectrogram transform
    transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=8000,
        n_fft=512,
        n_mels=40
    )
    
    return TIDIGITS(root='./data', transform=transform)

Applications

1. Noise-Robust Pattern Recognition

def deploy_robust_classifier(dgn_model, input_signal, noise_profile='moderate'):
    """
    Deploy DGN for robust classification under varying noise
    """
    noise_config = {
        'low': 0.05,
        'moderate': 0.15,
        'high': 0.30
    }
    
    noise_level = noise_config.get(noise_profile, 0.15)
    
    # Add expected noise
    noisy_input = input_signal + torch.randn_like(input_signal) * noise_level
    
    # DGN automatically adapts gating to suppress noise
    output = dgn_model(noisy_input)
    
    return output

2. Adaptive Filtering

class AdaptiveFilterDGN:
    """Use DGN gating for adaptive signal filtering"""
    
    def __init__(self, n_channels):
        self.neurons = [DynamicGatedNeuron() for _ in range(n_channels)]
    
    def filter_signal(self, signal, snr_threshold=10):
        """
        Filter noisy signal using DGN gating
        
        The dynamic conductance adapts to signal statistics,
        effectively filtering noise based on local SNR
        """
        filtered = np.zeros_like(signal)
        gating_strength = np.zeros_like(signal)
        
        for t in range(len(signal)):
            for ch in range(signal.shape[1]):
                # DGN processes signal
                spike = self.neurons[ch].forward(
                    torch.tensor([signal[t, ch]])
                )
                
                # Record effective filtering (via gating)
                gating_strength[t, ch] = self.neurons[ch].g.item()
                
                # Output is spike (0 or 1) - can be converted to rate
                filtered[t, ch] = spike.item()
        
        return filtered, gating_strength

3. Neuromorphic Computing

def deploy_on_neuromorphic(dgn_network, spike_input, hardware='loihi'):
    """
    Deployment guidelines for neuromorphic hardware
    
    Hardware compatibility:
    - Intel Loihi: Native conductance support
    - IBM TrueNorth: Requires conductance approximation
    - SpiNNaker: Full DGN support via custom neuron models
    """
    if hardware == 'loihi':
        # Map DGN to Loihi conductance-based neurons
        # Use adaptive exponential integrate-and-fire
        config = {
            'neuron_type': 'cuba',  # Current-based with adaptation
            'tau_m': dgn_network.tau_m,
            'tau_adapt': dgn_network.tau_g,
            'adapt_inc': dgn_network.alpha
        }
    elif hardware == 'spinnaker':
        # Custom neuron model on SpiNNaker
        config = {
            'neuron_model': 'DGN',
            'parameters': dgn_network.get_params()
        }
    
    return config

Benchmarks

TIDIGITS Spoken Digit Recognition

Model Clean Noise (σ=0.2) Noise (σ=0.4)
LIF 92.5% 78.3% 62.1%
DGN 93.1% 89.7% 81.4%
Improvement +0.6% +11.4% +19.3%

SHD (Spiking Heidelberg Digits)

Model Accuracy Latency (ms)
LIF 84.2% 750
GLIF 86.7% 800
DGN 89.3% 720

Robustness Metrics

Metric LIF DGN Improvement
SNR Tolerance (dB) 5 12 +7 dB
Temporal Jitter Robustness Moderate High Significant
Pattern Completion Poor Good Substantial

Theoretical Analysis

Stochastic Stability

The DGN exhibits enhanced stochastic stability through its disturbance rejection mechanism:

Theorem (Informal): Under bounded input noise $|\xi(t)| \leq \sigma$, the DGN membrane potential satisfies:

$$\mathbb{E}[|v(t) - v_{target}|^2] \leq \frac{\sigma^2}{2\lambda_{eff}}$$

where $\lambda_{eff} = \lambda_0 + g(t)$ is the effective decay rate enhanced by dynamic conductance.

Connection to LSTM Gates

Aspect LSTM DGN
Forget gate Sigmoid-controlled Conductance decay
Input gate Learned weights Activity-dependent
Cell state Explicit memory Membrane potential
Biological basis None Calcium signaling
Energy cost High (MAC ops) Low (spike events)

Pitfalls

  • Hyperparameter Sensitivity: Conductance time constant $\tau_g$ requires tuning for task temporal scales
  • Initial Transient: Dynamic conductance needs warm-up period (first ~100ms)
  • Computational Cost: ~15% overhead compared to LIF due to conductance update
  • Hardware Constraints: Not all neuromorphic chips support dynamic conductance
  • Gradient Flow: Surrogate gradients may be less stable with dynamic parameters

Related Skills

  • three-factor-snn-learning
  • cognisnn-brain-inspired-snn
  • adaptive-spiking-neuron-asn
  • spiking-mllm-multimodal-spiking
  • working-memory-heterogeneous-delays

References

@article{bai2025dynamic,
  title={A Brain-Inspired Gating Mechanism Unlocks Robust Computation in Spiking Neural Networks},
  author={Bai, Qianyi and Wang, Haiteng and Yu, Qiang},
  journal={arXiv preprint arXiv:2509.03281},
  year={2025}
}
Install via CLI
npx skills add https://github.com/hiyenwong/ai_collection --skill dynamic-gated-neuron-snn
Repository Details
star Stars 1
call_split Forks 0
navigation Branch main
article Path SKILL.md
More from Creator