h2learn-snn-accelerator

star 2

High-efficiency hardware accelerator for BPTT-based Spiking Neural Network training. Design LUT-based processing elements, dual-sparsity-aware backward engine, and pipeline optimization. Achieve 7.38x area saving, 10.20x speedup vs GPU.

hiyenwong By hiyenwong schedule Updated 6/3/2026

name: h2learn-snn-accelerator description: 'High-efficiency hardware accelerator for BPTT-based Spiking Neural Network training. Design LUT-based processing elements, dual-sparsity-aware backward engine, and pipeline optimization. Achieve 7.38x area saving, 10.20x speedup vs GPU.'

H2Learn: SNN Training Accelerator

Description

A novel hardware architecture achieving high efficiency for BPTT-based Spiking Neural Network (SNN) learning while ensuring high accuracy. Exploits binary spike computation and gradient sparsity to design specialized processing engines, achieving 7.38x area saving, 5.74-10.20x speedup, and 5.25-7.12x energy saving compared to NVIDIA V100 GPU.

Source: arXiv:2107.11746v1 Utility: 0.91

Activation Keywords

  • SNN accelerator
  • BPTT spiking neural network
  • neuromorphic hardware training
  • LUT-based SNN processing
  • dual-sparsity backward engine
  • spike-based gradient computation
  • SNN hardware optimization
  • efficient SNN training

Core Concepts

1. BPTT for SNNs Challenge

Problem:

  • Local synaptic plasticity rules → low accuracy
  • BPTT-based SNN learning → high accuracy but computationally expensive
  • General-purpose processors → low efficiency (ANN-tailored)
  • Neuromorphic chips → cannot support BPTT (local rules only)

Solution: H2Learn bridges the gap with specialized hardware for BPTT-based SNN training.

2. SNN BPTT Behavior Analysis

Key Observations:

Phase Computation Characteristics
Forward pass Binary spike-based Sparse, binary operations
Backward pass Gradient computation Rich sparsity in gradients
Weight update Spike-based accumulation Binary operations

Exploited Properties:

  1. Binary spike computation → LUT-based processing
  2. Gradient sparsity → Dual-sparsity-aware design
  3. Weight update → Implicit accumulation

3. Three-Engine Architecture

Engine 1: Forward Engine

Input spikes (binary) → LUT-based PE → Membrane potential → Output spikes

Key features:
- LUT-based processing elements (implicit accumulation)
- Fused computation for multiple input points
- Binary spike operations

Engine 2: Backward Engine

Gradient input → Dual-sparsity-aware processing → Gradient output

Key features:
- Exploit both input and output sparsity
- Skip zero-gradient computations
- Efficient sparse matrix operations

Engine 3: Weight Update Engine

Gradient + Spike trace → LUT-based PE → Weight delta

Key features:
- LUT-based processing (similar to Forward Engine)
- Implicit accumulation
- Binary spike-based update

4. Pipeline Optimization

End-to-End Pipeline:

Forward Engine → Backward Engine → Weight Update Engine
     ↓              ↓                 ↓
  Pipeline stage 1  Pipeline stage 2  Pipeline stage 3

Optimization:
- Minimize pipeline bubbles
- Overlap computation across layers
- Efficient memory access pattern

5. Performance Results

Metric vs NVIDIA V100 GPU
Area saving 7.38x
Speedup 5.74-10.20x
Energy saving 5.25-7.12x

Benchmark datasets:

  • MNIST
  • CIFAR-10
  • DVS-Gesture
  • N-MNIST

Step-by-Step Instructions

1. LUT-Based Processing Element Design

import numpy as np
from typing import Tuple, Optional

class LUTProcessingElement:
    """
    LUT-based processing element for spike computation.
    
    Benefits:
    - Implicit accumulation via LUT lookup
    - Fused computation for multiple input points
    - Binary spike operations
    
    Args:
        lut_size: Size of lookup table
        threshold: Spike threshold
    """
    def __init__(self, lut_size: int = 256, threshold: float = 1.0):
        self.lut_size = lut_size
        self.threshold = threshold
        
        # Pre-compute LUT for membrane potential
        self.lut = self._build_lut()
        
    def _build_lut(self) -> np.ndarray:
        """
        Build lookup table for membrane potential accumulation.
        
        Returns:
            lut: Lookup table array
        """
        # LUT stores accumulated membrane potential
        lut = np.linspace(-self.threshold, self.threshold, self.lut_size)
        return lut
    
    def forward_pass(
        self,
        input_spikes: np.ndarray,
        weights: np.ndarray,
        membrane_potential: float
    ) -> Tuple[float, np.ndarray]:
        """
        LUT-based forward pass computation.
        
        Args:
            input_spikes: Binary input spikes (0 or 1)
            weights: Synaptic weights
            membrane_potential: Current membrane potential
        
        Returns:
            new_potential: Updated membrane potential
            output_spikes: Binary output spikes
        """
        # Implicit accumulation via LUT
        weight_sum = np.sum(input_spikes * weights)
        
        # LUT lookup for membrane potential update
        lut_index = int((membrane_potential + weight_sum) / self.threshold * (self.lut_size // 2))
        lut_index = np.clip(lut_index, 0, self.lut_size - 1)
        
        new_potential = self.lut[lut_index]
        
        # Spike generation
        output_spikes = (new_potential >= self.threshold).astype(float)
        
        return new_potential, output_spikes
    
    def fused_computation(
        self,
        input_batch: np.ndarray,
        weights: np.ndarray
    ) -> np.ndarray:
        """
        Fused computation for multiple input points.
        
        Args:
            input_batch: Batch of binary input spikes
            weights: Synaptic weights
        
        Returns:
            membrane_potentials: Batch of membrane potentials
        """
        # Batch LUT lookup
        weight_sums = np.dot(input_batch, weights)
        lut_indices = (weight_sums / self.threshold * (self.lut_size // 2)).astype(int)
        lut_indices = np.clip(lut_indices, 0, self.lut_size - 1)
        
        membrane_potentials = self.lut[lut_indices]
        
        return membrane_potentials

2. Dual-Sparsity-Aware Backward Engine

class DualSparsityBackwardEngine:
    """
    Dual-sparsity-aware backward engine for gradient computation.
    
    Exploits:
    - Input sparsity: Many gradients are zero
    - Output sparsity: Many neurons don't fire
    
    Args:
        input_sparsity_threshold: Threshold for input sparsity
        output_sparsity_threshold: Threshold for output sparsity
    """
    def __init__(
        self,
        input_sparsity_threshold: float = 0.01,
        output_sparsity_threshold: float = 0.01
    ):
        self.input_threshold = input_sparsity_threshold
        self.output_threshold = output_sparsity_threshold
        
    def backward_pass(
        self,
        gradient_input: np.ndarray,
        weights: np.ndarray,
        spike_traces: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Dual-sparsity-aware backward pass.
        
        Args:
            gradient_input: Input gradients
            weights: Synaptic weights
            spike_traces: Spike traces from forward pass
        
        Returns:
            gradient_output: Output gradients
            weight_gradients: Weight gradients
        """
        # Input sparsity: skip zero gradients
        input_mask = np.abs(gradient_input) > self.input_threshold
        
        # Output sparsity: skip neurons that didn't fire
        output_mask = spike_traces > self.output_threshold
        
        # Sparse gradient computation
        sparse_gradient_input = gradient_input * input_mask
        
        # Only compute where both input and output are active
        dual_mask = input_mask & output_mask
        
        gradient_output = np.dot(sparse_gradient_input, weights.T) * output_mask
        weight_gradients = np.outer(sparse_gradient_input, spike_traces) * dual_mask
        
        return gradient_output, weight_gradients
    
    def compute_sparsity_ratio(self, gradient_input: np.ndarray) -> float:
        """
        Compute sparsity ratio of gradients.
        
        Args:
            gradient_input: Input gradients
        
        Returns:
            sparsity_ratio: Ratio of zero gradients
        """
        zero_count = np.sum(np.abs(gradient_input) <= self.input_threshold)
        total_count = gradient_input.size
        
        sparsity_ratio = zero_count / total_count
        
        return sparsity_ratio

3. Weight Update Engine

class WeightUpdateEngine:
    """
    LUT-based weight update engine.
    
    Similar to Forward Engine:
    - LUT-based processing elements
    - Implicit accumulation
    - Binary spike-based update
    
    Args:
        learning_rate: Learning rate for weight update
        lut_size: Size of lookup table
    """
    def __init__(self, learning_rate: float = 0.01, lut_size: int = 256):
        self.learning_rate = learning_rate
        self.lut_size = lut_size
        
        # LUT for weight delta computation
        self.delta_lut = self._build_delta_lut()
        
    def _build_delta_lut(self) -> np.ndarray:
        """
        Build LUT for weight delta computation.
        
        Returns:
            delta_lut: Lookup table for weight deltas
        """
        delta_lut = np.linspace(-1.0, 1.0, self.lut_size)
        return delta_lut
    
    def update_weights(
        self,
        weights: np.ndarray,
        weight_gradients: np.ndarray,
        spike_traces: np.ndarray
    ) -> np.ndarray:
        """
        LUT-based weight update.
        
        Args:
            weights: Current weights
            weight_gradients: Weight gradients
            spike_traces: Spike traces
        
        Returns:
            new_weights: Updated weights
        """
        # LUT lookup for weight delta
        delta_indices = (weight_gradients * self.learning_rate).astype(int)
        delta_indices = np.clip(delta_indices, 0, self.lut_size - 1)
        
        weight_deltas = self.delta_lut[delta_indices]
        
        # Implicit accumulation (via LUT)
        new_weights = weights + weight_deltas
        
        return new_weights

4. Pipeline Controller

class H2LearnPipeline:
    """
    End-to-end pipeline controller for H2Learn.
    
    Coordinates:
    - Forward Engine
    - Backward Engine
    - Weight Update Engine
    
    Args:
        n_layers: Number of layers
        batch_size: Batch size
    """
    def __init__(self, n_layers: int = 3, batch_size: int = 32):
        self.n_layers = n_layers
        self.batch_size = batch_size
        
        # Initialize engines for each layer
        self.forward_engines = [
            LUTProcessingElement() for _ in range(n_layers)
        ]
        self.backward_engines = [
            DualSparsityBackwardEngine() for _ in range(n_layers)
        ]
        self.weight_update_engines = [
            WeightUpdateEngine() for _ in range(n_layers)
        ]
        
        # Pipeline state
        self.pipeline_state = {
            'forward_complete': False,
            'backward_complete': False,
            'update_complete': False
        }
        
    def forward_pipeline(
        self,
        input_batch: np.ndarray,
        weights: list
    ) -> Tuple[list, list]:
        """
        Execute forward pass pipeline.
        
        Args:
            input_batch: Input spike batch
            weights: List of weight matrices
        
        Returns:
            membrane_potentials: List of membrane potentials per layer
            spike_traces: List of spike traces per layer
        """
        membrane_potentials = []
        spike_traces = []
        
        current_input = input_batch
        
        for i, engine in enumerate(self.forward_engines):
            # Forward pass for layer i
            potentials = engine.fused_computation(current_input, weights[i])
            spikes = (potentials >= engine.threshold).astype(float)
            
            membrane_potentials.append(potentials)
            spike_traces.append(spikes)
            
            current_input = spikes
        
        self.pipeline_state['forward_complete'] = True
        
        return membrane_potentials, spike_traces
    
    def backward_pipeline(
        self,
        gradient_input: np.ndarray,
        weights: list,
        spike_traces: list
    ) -> Tuple[list, list]:
        """
        Execute backward pass pipeline.
        
        Args:
            gradient_input: Gradient from loss
            weights: List of weight matrices
            spike_traces: Spike traces from forward pass
        
        Returns:
            layer_gradients: List of layer gradients
            weight_gradients: List of weight gradients
        """
        layer_gradients = []
        weight_gradients = []
        
        current_gradient = gradient_input
        
        for i in reversed(range(self.n_layers)):
            # Backward pass for layer i
            grad_out, w_grad = self.backward_engines[i].backward_pass(
                current_gradient, weights[i], spike_traces[i]
            )
            
            layer_gradients.append(grad_out)
            weight_gradients.append(w_grad)
            
            current_gradient = grad_out
        
        self.pipeline_state['backward_complete'] = True
        
        return layer_gradients, weight_gradients
    
    def update_pipeline(
        self,
        weights: list,
        weight_gradients: list,
        spike_traces: list
    ) -> list:
        """
        Execute weight update pipeline.
        
        Args:
            weights: Current weights
            weight_gradients: Weight gradients
            spike_traces: Spike traces
        
        Returns:
            new_weights: Updated weights
        """
        new_weights = []
        
        for i, engine in enumerate(self.weight_update_engines):
            updated = engine.update_weights(
                weights[i], weight_gradients[i], spike_traces[i]
            )
            new_weights.append(updated)
        
        self.pipeline_state['update_complete'] = True
        
        return new_weights
    
    def full_pipeline(
        self,
        input_batch: np.ndarray,
        weights: list,
        gradient_input: np.ndarray
    ) -> list:
        """
        Execute full training pipeline.
        
        Args:
            input_batch: Input spike batch
            weights: Current weights
            gradient_input: Gradient from loss
        
        Returns:
            new_weights: Updated weights
        """
        # Forward pass
        potentials, traces = self.forward_pipeline(input_batch, weights)
        
        # Backward pass
        layer_grads, weight_grads = self.backward_pipeline(
            gradient_input, weights, traces
        )
        
        # Weight update
        new_weights = self.update_pipeline(weights, weight_grads, traces)
        
        return new_weights

5. Complete Training Loop

def h2learn_training_loop(
    train_data: np.ndarray,
    train_labels: np.ndarray,
    n_layers: int = 3,
    epochs: int = 10,
    learning_rate: float = 0.01
) -> dict:
    """
    Complete H2Learn training loop.
    
    Args:
        train_data: Training data (binary spikes)
        train_labels: Training labels
        n_layers: Number of layers
        epochs: Training epochs
        learning_rate: Learning rate
    
    Returns:
        results: Training results
    """
    batch_size = 32
    pipeline = H2LearnPipeline(n_layers, batch_size)
    
    # Initialize weights
    weights = [
        np.random.randn(100, 100) * 0.1 for _ in range(n_layers)
    ]
    
    losses = []
    
    for epoch in range(epochs):
        epoch_loss = 0.0
        
        for batch_idx in range(0, len(train_data), batch_size):
            batch = train_data[batch_idx:batch_idx + batch_size]
            labels = train_labels[batch_idx:batch_idx + batch_size]
            
            # Forward pass
            potentials, traces = pipeline.forward_pipeline(batch, weights)
            
            # Compute loss (simple MSE)
            output = traces[-1]
            loss = np.mean((output - labels)**2)
            epoch_loss += loss
            
            # Gradient
            gradient = 2 * (output - labels) / labels.size
            
            # Backward + update
            weights = pipeline.full_pipeline(batch, weights, gradient)
        
        losses.append(epoch_loss / (len(train_data) // batch_size))
        print(f"Epoch {epoch}: Loss = {losses[-1]}")
    
    results = {
        'final_loss': losses[-1],
        'loss_history': losses,
        'final_weights': weights
    }
    
    return results

Tools Used

  • numpy - Numerical computations
  • typing - Type annotations
  • exec - Run simulation scripts

Example Use Cases

1. Basic Forward Pass

# Create LUT-based PE
pe = LUTProcessingElement(lut_size=256, threshold=1.0)

# Binary input spikes
input_spikes = np.array([1, 0, 1, 1, 0])
weights = np.random.randn(5) * 0.1

# Forward pass
membrane, output = pe.forward_pass(input_spikes, weights, 0.0)
print(f"Output spike: {output}")

2. Sparsity Analysis

# Analyze gradient sparsity
backward_engine = DualSparsityBackwardEngine()
gradient = np.random.randn(100) * 0.1
gradient[gradient < 0.05] = 0  # Make sparse

sparsity = backward_engine.compute_sparsity_ratio(gradient)
print(f"Sparsity ratio: {sparsity:.2%}")

3. Full Pipeline Training

# Generate synthetic data
train_data = np.random.randint(0, 2, (1000, 100)).astype(float)
train_labels = np.random.randint(0, 2, (1000, 10)).astype(float)

# Train with H2Learn pipeline
results = h2learn_training_loop(train_data, train_labels, epochs=10)
print(f"Final loss: {results['final_loss']}")

Hardware Implementation Notes

Area Optimization:

  • LUT-based PE → 7.38x area saving vs GPU
  • Implicit accumulation → Reduced hardware complexity
  • Fused computation → Fewer memory accesses

Speedup Optimization:

  • Pipeline overlap → 5.74-10.20x speedup
  • Dual-sparsity skip → Reduced computation
  • LUT lookup → Faster than arithmetic

Energy Optimization:

  • Sparse computation → 5.25-7.12x energy saving
  • Binary operations → Reduced power consumption
  • Pipeline efficiency → Minimized idle cycles

Instructions for Agents

Follow these steps when applying this skill:

Step 1: LUT-Based Processing Element Design

Examples

Example 1: Basic Application

User: I need to apply H2Learn: SNN Training Accelerator to my analysis.

Agent: I'll help you apply h2learn-snn-accelerator. First, let me understand your specific use case...

Context: Apply the methodology

Example 2: Advanced Scenario

User: Complex analysis scenario

Agent: Based on the methodology, I'll guide you through the advanced application...

Example 2: Advanced Application

User: What are the key considerations for h2learn-snn-accelerator?

Agent: Let me search for the latest research and best practices...

Related Skills

  • spikingjelly-framework - Spiking Jelly framework
  • multi-plasticity-snn-training - Multi-plasticity SNN training
  • decolle-snn-learning - DECOLLE SNN learning

References

  • Liang, L. et al. (2021). "H2Learn: High-Efficiency Learning Accelerator for High-Accuracy Spiking Neural Networks" arXiv:2107.11746v1 [cs.NE]

Created: 2026-03-30 00:05 Author: Aerial (from arXiv:2107.11746v1)

Install via CLI
npx skills add https://github.com/hiyenwong/ai_collection --skill h2learn-snn-accelerator
Repository Details
star Stars 2
call_split Forks 0
navigation Branch main
article Path SKILL.md
More from Creator