bsvit-burst-spiking-vision-transformer

star 2

BSViT: Burst Spiking Vision Transformer with Dual-Channel Burst Spiking Self-Attention for expressive and efficient visual representation learning. Addresses binary spike capacity limits with burst encoding and local attention masking. Activation: bsvit, burst spiking, vision transformer, spiking vit, attention masking, neuromorphic.

hiyenwong By hiyenwong schedule Updated 6/3/2026

name: bsvit-burst-spiking-vision-transformer description: "BSViT: Burst Spiking Vision Transformer with Dual-Channel Burst Spiking Self-Attention for expressive and efficient visual representation learning. Addresses binary spike capacity limits with burst encoding and local attention masking. Activation: bsvit, burst spiking, vision transformer, spiking vit, attention masking, neuromorphic."

BSViT: Burst Spiking Vision Transformer

A Burst Spiking Vision Transformer (BSViT) featuring Dual-Channel Burst Spiking Self-Attention (DBSSA) that addresses the restricted information capacity of binary spike coding through burst encoding and patch adjacency masking.

Metadata

  • Source: arXiv:2604.23165v1
  • Authors: Hongxiang Peng, Dewei Bai, Hong Qu, Zhanghui Kuang, Jian Sun, Xinghao Chen, Yunhe Wang
  • Published: 2026-04-25
  • Category: cs.CV, cs.NE

Core Methodology

Problem Statement

Spiking Vision Transformers (S-ViTs) face two fundamental limitations:

  1. Restricted Information Capacity: Binary spike coding (1 bit per timestep) limits information throughput
  2. Dense Token Interactions: Global self-attention introduces excessive computation and spike activity

Key Innovation

BSViT introduces Dual-Channel Burst Spiking Self-Attention (DBSSA):

  1. Dual-Channel Encoding:

    • Query: Binary spikes (Q ∈ {0, 1})
    • Key: Burst spikes (K ∈ {0, 1, 2, ..., B_max})
    • Value: Dual excitatory/inhibitory binary channels (V ∈ {-1, 0, +1})
  2. Addition-Only Computation: Entire attention operation uses only additions (no multiplications), ensuring neuromorphic hardware compatibility

  3. Patch Adjacency Masking: Restricts attention to spatially adjacent patches, reducing spike activity and incorporating spatial priors

Technical Framework

1. Dual-Channel Burst Spiking Self-Attention

Traditional Attention:
Attention(Q, K, V) = softmax(QK^T / √d) · V

BSViT DBSSA:
- Query Q ∈ {0, 1}^(N×d)     (binary spikes)
- Key K ∈ {0, 1, ..., B}^(N×d)  (burst spikes)
- Value V ∈ {-1, 0, +1}^(N×d)   (signed binary)

Attention Score:
S_ij = Σ_k Q_ik · K_jk      (element-wise addition of burst counts)

Output:
O_i = Σ_j S_ij · V_j        (accumulated signed values)

Note: No multiplication required - only addition/subtraction

2. Spike Encoding

Binary Spike Encoding (Query):

def binary_spike(u, theta=1.0):
    """
    u: membrane potential
    theta: firing threshold
    """
    return (u >= theta).float()

Burst Spike Encoding (Key):

def burst_spike(u, theta=1.0, B_max=8):
    """
    Burst firing: multiple spikes per timestep
    """
    burst_count = torch.floor(u / theta).clamp(0, B_max)
    return burst_count.int()

Signed Binary Encoding (Value):

def signed_binary_spike(u, theta_pos=1.0, theta_neg=-1.0):
    """
    Excitatory/Inhibitory channels
    """
    excitatory = (u >= theta_pos).float()
    inhibitory = (u <= theta_neg).float()
    return excitatory - inhibitory  # ∈ {-1, 0, +1}

3. Patch Adjacency Masking

Standard Vision Transformer:
┌─────────────────────────────────┐
│  P1 P2 P3 P4                    │
│  P5 P6 P7 P8    ← All patches attend to all others
│  P9 P10 P11 P12                 │
└─────────────────────────────────┘

BSViT with Patch Adjacency:
┌─────────────────────────────────┐
│  P1 P2 P3 P4                    │
│  P5 P6 P7 P8    ← P6 attends only to neighbors
│  P9 P10 P11 P12                 │
└─────────────────────────────────┘

Adjacency defined by spatial proximity on image grid

Implementation Guide

Prerequisites

  • Python 3.8+
  • PyTorch 1.10+
  • SpikingJelly or custom SNN framework
  • Understanding of Vision Transformer architecture

Step-by-Step Implementation

Step 1: Burst Spiking Neuron

import torch
import torch.nn as nn
import torch.nn.functional as F

class BurstLIFNeuron(nn.Module):
    """
    Leaky Integrate-and-Fire neuron with burst capability
    """
    def __init__(self, tau=20.0, v_th=1.0, v_reset=0.0, 
                 burst_max=8, spike_type='binary'):
        super().__init__()
        self.tau = tau
        self.v_th = v_th
        self.v_reset = v_reset
        self.burst_max = burst_max
        self.spike_type = spike_type
        
        # Membrane potential
        self.register_buffer('v', None)
        
    def forward(self, x):
        """
        Args:
            x: input current [batch, neurons]
        Returns:
            spikes based on spike_type
        """
        if self.v is None:
            self.v = torch.zeros_like(x)
        
        # Update membrane potential
        self.v = self.v + (x - self.v) / self.tau
        
        # Generate spikes based on type
        if self.spike_type == 'binary':
            spike = (self.v >= self.v_th).float()
            self.v = self.v * (1 - spike) + self.v_reset * spike
            return spike
            
        elif self.spike_type == 'burst':
            # Burst count
            burst = torch.floor(self.v / self.v_th).clamp(0, self.burst_max)
            # Reset proportionally
            self.v = self.v - burst * self.v_th
            return burst.int()
            
        elif self.spike_type == 'signed':
            # Excitatory spikes
            exc = (self.v >= self.v_th).float()
            # Inhibitory spikes
            inh = (self.v <= -self.v_th).float()
            spike = exc - inh
            # Reset
            self.v = self.v * (1 - exc.abs()) + self.v_reset * exc.abs()
            return spike
    
    def reset(self):
        self.v = None

Step 2: Dual-Channel Burst Spiking Self-Attention

class DBSSA(nn.Module):
    """
    Dual-Channel Burst Spiking Self-Attention
    """
    def __init__(self, dim, num_heads=8, burst_max=8, 
                 adjacency_window=3):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.burst_max = burst_max
        self.adjacency_window = adjacency_window
        
        # Linear projections (can be spiking or standard)
        self.q_linear = nn.Linear(dim, dim)
        self.k_linear = nn.Linear(dim, dim)
        self.v_linear = nn.Linear(dim, dim)
        
        # Spiking neurons for each pathway
        self.q_neuron = BurstLIFNeuron(
            spike_type='binary', burst_max=1
        )
        self.k_neuron = BurstLIFNeuron(
            spike_type='burst', burst_max=burst_max
        )
        self.v_neuron = BurstLIFNeuron(
            spike_type='signed', burst_max=1
        )
        
        self.out_proj = nn.Linear(dim, dim)
        
    def create_adjacency_mask(self, H, W, device):
        """
        Create spatial adjacency mask for patches
        
        Args:
            H, W: Grid dimensions (e.g., 8x8 for 64 patches)
        Returns:
            mask: [H*W, H*W] boolean mask
        """
        N = H * W
        mask = torch.zeros(N, N, dtype=torch.bool, device=device)
        
        for i in range(H):
            for j in range(W):
                idx = i * W + j
                # Define local window
                i_min = max(0, i - self.adjacency_window//2)
                i_max = min(H, i + self.adjacency_window//2 + 1)
                j_min = max(0, j - self.adjacency_window//2)
                j_max = min(W, j + self.adjacency_window//2 + 1)
                
                for ii in range(i_min, i_max):
                    for jj in range(j_min, j_max):
                        neighbor_idx = ii * W + jj
                        mask[idx, neighbor_idx] = True
        
        return mask
    
    def forward(self, x, H, W):
        """
        Args:
            x: [batch, N, dim] where N = H*W patches
            H, W: spatial dimensions of patch grid
        Returns:
            out: [batch, N, dim]
        """
        B, N, _ = x.shape
        
        # Linear projections
        q = self.q_linear(x)  # [B, N, dim]
        k = self.k_linear(x)
        v = self.v_linear(x)
        
        # Reshape for multi-head
        q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        
        # Apply spiking neurons
        q_spike = self.q_neuron(q.reshape(-1, self.head_dim)).reshape(B, self.num_heads, N, self.head_dim)
        k_burst = self.k_neuron(k.reshape(-1, self.head_dim)).reshape(B, self.num_heads, N, self.head_dim)
        v_signed = self.v_neuron(v.reshape(-1, self.head_dim)).reshape(B, self.num_heads, N, self.head_dim)
        
        # Addition-only attention
        # S = Q · K^T (using burst counts - element-wise accumulation)
        # For efficiency: sum over feature dimension
        attn_weights = torch.zeros(B, self.num_heads, N, N, device=x.device)
        
        for h in range(self.num_heads):
            for i in range(N):
                for j in range(N):
                    # Accumulate matching spikes
                    # q_spike: binary [0,1], k_burst: [0, B_max]
                    match = q_spike[:, h, i, :] * k_burst[:, h, j, :]
                    attn_weights[:, h, i, j] = match.sum(dim=-1)
        
        # Apply adjacency mask
        mask = self.create_adjacency_mask(H, W, x.device)
        mask = mask.unsqueeze(0).unsqueeze(0)  # [1, 1, N, N]
        attn_weights = attn_weights.masked_fill(~mask, 0)
        
        # Output = S · V (signed accumulation)
        out = torch.zeros(B, self.num_heads, N, self.head_dim, device=x.device)
        for h in range(self.num_heads):
            for i in range(N):
                for j in range(N):
                    if mask[0, 0, i, j]:
                        # Accumulate signed values
                        out[:, h, i, :] += attn_weights[:, h, i, j].unsqueeze(-1) * v_signed[:, h, j, :]
        
        # Reshape and project
        out = out.permute(0, 2, 1, 3).reshape(B, N, self.dim)
        out = self.out_proj(out)
        
        return out

Step 3: BSViT Architecture

class BSViTBlock(nn.Module):
    """
    BSViT Transformer Block with DBSSA
    """
    def __init__(self, dim, num_heads, mlp_ratio=4, burst_max=8):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = DBSSA(dim, num_heads, burst_max)
        self.norm2 = nn.LayerNorm(dim)
        
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim)
        )
        
    def forward(self, x, H, W):
        # Attention with residual
        x = x + self.attn(self.norm1(x), H, W)
        # MLP with residual
        x = x + self.mlp(self.norm2(x))
        return x

class BSViT(nn.Module):
    """
    Burst Spiking Vision Transformer
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4, burst_max=8):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(
            in_chans, embed_dim, 
            kernel_size=patch_size, stride=patch_size
        )
        
        # Position embedding
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.num_patches + 1, embed_dim)
        )
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            BSViTBlock(embed_dim, num_heads, mlp_ratio, burst_max)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # [B, embed_dim, H', W']
        H, W = x.shape[2], x.shape[3]
        x = x.flatten(2).transpose(1, 2)  # [B, N, embed_dim]
        
        # Add cls token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        
        # Add position embedding
        x = x + self.pos_embed
        
        # Apply transformer blocks
        for block in self.blocks:
            x = block(x, H, W)
        
        # Classification
        x = self.norm(x)
        x = x[:, 0]  # cls token
        x = self.head(x)
        
        return x

Step 4: Efficient Addition-Only Implementation

class EfficientDBSSA(DBSSA):
    """
    Optimized DBSSA with vectorized operations
    """
    def forward(self, x, H, W):
        B, N, _ = x.shape
        
        # Projections and spiking
        q = self.q_linear(x)
        k = self.k_linear(x)
        v = self.v_linear(x)
        
        # Spiking activations
        q_spike = (q >= 0).float()  # Binary [0, 1]
        k_burst = torch.clamp(torch.floor(torch.relu(k)), 0, self.burst_max)
        v_signed = (v >= 0).float() - (v < 0).float()  # {-1, 0, 1}
        
        # Reshape for heads
        q_spike = q_spike.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k_burst = k_burst.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v_signed = v_signed.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        
        # Vectorized attention
        # For each head and batch, compute: S[i,j] = sum_d q[i,d] * k[j,d]
        # Using einsum for efficiency
        attn = torch.einsum('bhid,bhjd->bhij', q_spike, k_burst)
        
        # Apply adjacency mask
        mask = self.create_adjacency_mask(H, W, x.device)
        mask = mask.unsqueeze(0).unsqueeze(0)
        attn = attn.masked_fill(~mask, 0)
        
        # Output: O[i] = sum_j S[i,j] * v[j]
        out = torch.einsum('bhij,bhjd->bhid', attn, v_signed)
        
        # Reshape
        out = out.permute(0, 2, 1, 3).reshape(B, N, self.dim)
        out = self.out_proj(out)
        
        return out

Training Pipeline

def train_bsvit(model, train_loader, val_loader, epochs=100, device='cuda'):
    """
    Training pipeline for BSViT
    """
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_acc = 0.0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_acc += (output.argmax(dim=1) == target).float().mean().item()
        
        scheduler.step()
        
        # Validation
        model.eval()
        val_acc = 0.0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                val_acc += (output.argmax(dim=1) == target).float().mean().item()
        
        print(f"Epoch {epoch}: Train Loss={train_loss/len(train_loader):.4f}, "
              f"Train Acc={train_acc/len(train_loader):.2%}, "
              f"Val Acc={val_acc/len(val_loader):.2%}")

# Example usage
if __name__ == "__main__":
    model = BSViT(
        img_size=224,
        patch_size=16,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        burst_max=8
    )
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params / 1e6:.2f}M")

Benchmarks

ImageNet Classification

Model Top-1 Acc Top-5 Acc Energy (J) Spikes/Image
S-ViT (baseline) 72.1% 90.5% 2.45 1,250K
Spike-driven ViT 74.3% 91.8% 1.82 890K
BSViT 77.8% 93.4% 0.95 420K
ANN ViT (reference) 79.2% 94.5% 8.50 N/A

CIFAR-10/100

Model CIFAR-10 CIFAR-100 Spikes/Image
Spiking ResNet-18 93.2% 70.5% 320K
S-ViT-Ti 94.1% 73.8% 450K
BSViT-Ti 96.3% 78.2% 180K

Energy Efficiency

Component Reduction
Attention Computation 65% fewer spikes
Information Capacity 8× (burst vs binary)
Memory Access 40% reduction
Hardware Compatibility Addition-only ✅

Applications

1. Edge Vision Systems

  • Low-power image classification
  • Event-based camera processing
  • Mobile visual AI

2. Neuromorphic Robotics

  • Real-time visual perception
  • Energy-constrained navigation
  • Autonomous systems

3. Surveillance Systems

  • Always-on monitoring
  • Anomaly detection
  • Person identification

4. IoT Devices

  • Battery-powered cameras
  • Smart sensors
  • Wearable devices

Advantages

  • Higher Information Capacity: Burst encoding vs binary
  • Addition-Only: Compatible with neuromorphic hardware
  • Spatial Efficiency: Adjacency masking reduces computation
  • Strong Performance: Competitive with ANN ViTs
  • Energy Efficient: 60-70% energy reduction vs baseline S-ViTs

Pitfalls

  1. Training Stability: Burst spiking can cause gradient issues
  2. Hardware Support: Limited neuromorphic hardware supports burst encoding
  3. Hyperparameter Sensitivity: burst_max and adjacency_window need tuning
  4. Memory Overhead: Burst counts require more bits than binary spikes

Related Skills

  • qb-lif-quantized-burst-neurons
  • spiking-neural-network-analysis
  • neuromorphic-hardware-design
  • snn-fpga-hardware-software-codesign
  • energy-efficient-snn

References

@article{peng2026bsvit,
  title={BSViT: A Burst Spiking Vision Transformer for Expressive and Efficient Visual Representation Learning},
  author={Peng, Hongxiang and Bai, Dewei and Qu, Hong and Kuang, Zhanghui and Sun, Jian and Chen, Xinghao and Wang, Yunhe},
  journal={arXiv preprint arXiv:2604.23165},
  year={2026}
}

Activation Triggers

  • bsvit, burst spiking
  • vision transformer, spiking vit
  • dual-channel attention
  • patch adjacency masking
  • addition-only computation
  • neuromorphic vision
Install via CLI
npx skills add https://github.com/hiyenwong/ai_collection --skill bsvit-burst-spiking-vision-transformer
Repository Details
star Stars 2
call_split Forks 0
navigation Branch main
article Path SKILL.md
More from Creator