gemst-multidimensional-grouping-snn

star 1

Ge²mS-T: Multi-Dimensional Grouping for Ultra-High Energy Efficiency in Spiking Transformers. Temporal, spatial, and channel grouping for efficient S-ViT training and inference. Triggers: spiking transformer, S-ViT, energy efficiency, multi-dimensional grouping, SNN.

hiyenwong By hiyenwong schedule Updated 6/3/2026

name: gemst-multidimensional-grouping-snn description: "Ge²mS-T: Multi-Dimensional Grouping for Ultra-High Energy Efficiency in Spiking Transformers. Temporal, spatial, and channel grouping for efficient S-ViT training and inference. Triggers: spiking transformer, S-ViT, energy efficiency, multi-dimensional grouping, SNN."

Ge²mS-T: Multi-Dimensional Grouping for Ultra-High Energy Efficiency in Spiking Transformers

A multi-dimensional grouping framework for Spiking Vision Transformers (S-ViTs) that achieves ultra-high energy efficiency through temporal, spatial, and channel-wise grouping strategies.

Metadata

  • Source: arXiv:2604.08894v1
  • Authors: Qinyu Chen, Zhenxi Song, Ziyang Zhang, et al.
  • Published: 2026-04-10
  • Institution: Tsinghua University, Chinese Academy of Sciences

Core Methodology

Key Innovation

Spiking Vision Transformers (S-ViTs) combine the energy efficiency of Spiking Neural Networks (SNNs) with the powerful visual representation of Vision Transformers. However, they face challenges in both training (surrogate gradient mismatch) and inference (high computational cost from dense self-attention). Ge²mS-T (Grouped Spiking Transformer) introduces multi-dimensional grouping across time, space, and channels to achieve significant efficiency gains while maintaining accuracy.

Multi-Dimensional Grouping Strategy

1. Temporal Grouping

Standard SNNs process static images by presenting them for multiple timesteps, leading to high latency. Temporal grouping divides timesteps into groups processed in parallel:

Standard: [T1] → [T2] → [T3] → [T4] → ... (sequential)
Grouped:  [T1,T2] ∥ [T3,T4] ∥ ... (parallel groups)

Implementation:

class TemporalGrouping(nn.Module):
    def __init__(self, group_size=2):
        super().__init__()
        self.group_size = group_size
    
    def forward(self, x, time_steps):
        """
        Args:
            x: input spikes (batch, time, channels, height, width)
        Returns:
            grouped: (batch, time//group_size, channels*group_size, h, w)
        """
        batch, T, C, H, W = x.shape
        assert T % self.group_size == 0
        
        # Reshape to group timesteps as channels
        x = x.view(batch, T // self.group_size, self.group_size, C, H, W)
        x = x.permute(0, 1, 3, 2, 4, 5)  # (B, T//G, C, G, H, W)
        x = x.reshape(batch, T // self.group_size, C * self.group_size, H, W)
        
        return x

2. Spatial Grouping

Self-attention in ViTs has O(n²) complexity with patch count. Spatial grouping limits attention to local windows:

Global attention: Each patch attends to all N patches (N² cost)
Window attention: Each patch attends to W patches in window (N×W cost)

Window Partitioning:

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size: int
    Returns:
        windows: (B*n_windows, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    assert H % window_size == 0 and W % window_size == 0
    
    # Reshape into windows
    x = x.view(B, H // window_size, window_size, 
               W // window_size, window_size, C)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    windows = x.view(-1, window_size, window_size, C)
    
    return windows

def window_reverse(windows, window_size, H, W):
    """Reverse window partition"""
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size,
                     window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

3. Channel Grouping

Channel grouping divides feature channels into groups with separate transformations, reducing computation:

class GroupedChannelAttention(nn.Module):
    def __init__(self, dim, num_heads=8, num_groups=4):
        super().__init__()
        self.num_groups = num_groups
        self.group_dim = dim // num_groups
        self.num_heads = num_heads // num_groups
        
        # Separate attention per group
        self.group_attns = nn.ModuleList([
            SpikingSelfAttention(self.group_dim, self.num_heads)
            for _ in range(num_groups)
        ])
    
    def forward(self, x):
        """
        Args:
            x: (B, T, C, H, W) spike tensor
        """
        B, T, C, H, W = x.shape
        
        # Split into groups
        x_groups = torch.chunk(x, self.num_groups, dim=2)
        
        # Process each group independently
        outputs = []
        for i, (x_g, attn) in enumerate(zip(x_groups, self.group_attns)):
            out_g = attn(x_g)
            outputs.append(out_g)
        
        # Concatenate groups
        output = torch.cat(outputs, dim=2)
        return output

Spiking Self-Attention

Leaky Integrate-and-Fire Attention

class SpikingSelfAttention(nn.Module):
    def __init__(self, dim, num_heads, tau=2.0, v_threshold=1.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        # LIF neuron parameters
        self.tau = tau
        self.v_threshold = v_threshold
        
        # Q, K, V projections
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
    
    def lif_forward(self, x, v_mem):
        """Leaky Integrate-and-Fire neuron"""
        v_mem = v_mem + (x - v_mem) / self.tau
        spike = (v_mem >= self.v_threshold).float()
        v_mem = v_mem * (1 - spike)  # Reset
        return spike, v_mem
    
    def forward(self, x):
        """
        Args:
            x: (B, T, N, C) where N = H*W (number of patches)
        """
        B, T, N, C = x.shape
        
        # Initialize membrane potentials
        v_q = torch.zeros(B, self.num_heads, N, self.head_dim, device=x.device)
        v_k = torch.zeros(B, self.num_heads, N, self.head_dim, device=x.device)
        v_v = torch.zeros(B, self.num_heads, N, self.head_dim, device=x.device)
        v_out = torch.zeros(B, N, C, device=x.device)
        
        outputs = []
        
        for t in range(T):
            # Project to Q, K, V
            q = self.q_proj(x[:, t])  # (B, N, C)
            k = self.k_proj(x[:, t])
            v = self.v_proj(x[:, t])
            
            # Reshape for multi-head
            q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
            k = k.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
            v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
            
            # LIF dynamics
            spike_q, v_q = self.lif_forward(q, v_q)
            spike_k, v_k = self.lif_forward(k, v_k)
            spike_v, v_v = self.lif_forward(v, v_v)
            
            # Attention with spikes (event-driven)
            if spike_q.sum() > 0 or spike_k.sum() > 0:
                attn = (spike_q @ spike_k.transpose(-2, -1)) * self.scale
                attn = attn.softmax(dim=-1)
                out = attn @ spike_v
            else:
                out = torch.zeros_like(spike_v)
            
            # Reshape and project
            out = out.transpose(1, 2).reshape(B, N, C)
            out = self.out_proj(out)
            
            # Output LIF
            spike_out, v_out = self.lif_forward(out, v_out)
            outputs.append(spike_out)
        
        return torch.stack(outputs, dim=1)  # (B, T, N, C)

Implementation Guide

Prerequisites

  • Python 3.8+
  • PyTorch 1.12+ with CUDA support
  • SpikingJelly or custom SNN framework
  • timm for Vision Transformer utilities

Step-by-Step: Building Ge²mS-T

  1. Complete Model Architecture
import torch
import torch.nn as nn
from functools import partial

class Ge2mSTBlock(nn.Module):
    """Ge²mS-T Transformer Block with multi-dimensional grouping"""
    
    def __init__(self, dim, num_heads, window_size=7, 
                 temporal_group=2, channel_groups=4):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.temporal_group = temporal_group
        
        # Temporal grouping
        self.temporal_grouping = TemporalGrouping(temporal_group)
        
        # Spatial (window) attention
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowedSpikingAttention(
            dim, num_heads, window_size, channel_groups
        )
        
        # MLP with channel grouping
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden = dim * 4
        self.mlp = GroupedSpikingMLP(dim, mlp_hidden, channel_groups)
    
    def forward(self, x):
        """
        Args:
            x: (B, T, H, W, C) spike input
        """
        B, T, H, W, C = x.shape
        
        # Temporal grouping
        x = x.permute(0, 1, 4, 2, 3)  # (B, T, C, H, W)
        x = self.temporal_grouping(x, T)  # (B, T//G, C*G, H, W)
        TG = x.shape[1]
        x = x.permute(0, 1, 3, 4, 2)  # (B, T//G, H, W, C*G)
        
        # Window attention with residual
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x)  # Windowed spiking attention
        x = shortcut + x
        
        # MLP with residual
        shortcut = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = shortcut + x
        
        # Reverse temporal grouping
        x = x.permute(0, 1, 4, 2, 3)  # (B, T//G, C*G, H, W)
        x = x.view(B, T, C, H, W)
        x = x.permute(0, 1, 3, 4, 2)  # (B, T, H, W, C)
        
        return x

class Ge2mSViT(nn.Module):
    """Complete Ge²mS-T Vision Transformer"""
    
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, window_size=7, temporal_group=2,
                 channel_groups=4, time_steps=4):
        super().__init__()
        self.time_steps = time_steps
        self.patch_embed = nn.Conv2d(in_chans, embed_dim, 
                                     kernel_size=patch_size,
                                     stride=patch_size)
        
        # Ge²mS-T blocks
        self.blocks = nn.ModuleList([
            Ge2mSTBlock(embed_dim, num_heads, window_size,
                       temporal_group, channel_groups)
            for _ in range(depth)
        ])
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
    
    def forward(self, x):
        """
        Args:
            x: (B, C, H, W) static image
        Returns:
            logits: (B, num_classes)
        """
        B = x.shape[0]
        
        # Convert to spikes (rate coding)
        x = x.unsqueeze(1).repeat(1, self.time_steps, 1, 1, 1)
        # Add noise for stochastic spike generation
        x = (x + torch.randn_like(x) * 0.1).clamp(0, 1)
        
        # Patch embedding per timestep
        spikes = []
        for t in range(self.time_steps):
            patch = self.patch_embed(x[:, t])  # (B, embed_dim, H//P, W//P)
            patch = patch.flatten(2).transpose(1, 2)  # (B, N, embed_dim)
            spikes.append(patch)
        
        x = torch.stack(spikes, dim=1)  # (B, T, N, embed_dim)
        H, W = int(x.shape[2] ** 0.5), int(x.shape[2] ** 0.5)
        x = x.view(B, self.time_steps, H, W, -1)
        
        # Apply Ge²mS-T blocks
        for block in self.blocks:
            x = block(x)
        
        # Global average pooling
        x = x.mean(dim=(1, 2, 3))  # (B, embed_dim)
        x = self.norm(x)
        
        # Classification
        logits = self.head(x)
        return logits
  1. Training with Surrogate Gradients
from spikingjelly.clock_driven import surrogate

def train_ge2mst(model, train_loader, epochs=300):
    """Train Ge²mS-T with surrogate gradients"""
    
    # Use surrogate gradient for backprop through spikes
    surrogate_fn = surrogate.ATan(alpha=2.0)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            
            # Loss
            loss = criterion(outputs, labels)
            
            # Backward with surrogate gradients
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            correct += (outputs.argmax(dim=1) == labels).sum().item()
        
        scheduler.step()
        
        acc = 100 * correct / len(train_loader.dataset)
        print(f"Epoch {epoch}: Loss={total_loss/len(train_loader):.4f}, Acc={acc:.2f}%")
    
    return model

Performance Benchmarks

Model Dataset Accuracy Energy (mJ) Speedup
ViT-B/16 ImageNet 81.8% 1250
S-ViT (baseline) ImageNet 74.2% 42 29.8×
Ge²mS-T ImageNet 78.5% 18 69.4×
Ge²mS-T (small) CIFAR-100 86.3% 2.1 595×

Applications

1. Edge Vision Systems

  • Smart cameras: Real-time object detection with milliwatt power
  • Drone navigation: Visual SLAM on battery-constrained UAVs
  • Mobile AR: Efficient scene understanding on smartphones

2. Neuromorphic Sensors

  • Event cameras: Direct processing of DVS output
  • Always-on vision: Continuous monitoring with ultra-low power
  • Industrial inspection: High-speed defect detection

3. Sustainable AI

  • Data center reduction: 70× energy reduction for inference
  • Carbon footprint: Lower emissions for large-scale vision tasks
  • Renewable-powered AI: Viable on solar/battery systems

Pitfalls

Accuracy-Efficiency Tradeoff

  • Problem: Grouping reduces model capacity
  • Solution: Progressive grouping (less in early layers); knowledge distillation

Temporal Grouping Artifacts

  • Problem: Grouping timesteps can lose temporal precision
  • Solution: Use small group sizes (2-4); attention across groups

Window Boundary Effects

  • Problem: Objects spanning window boundaries handled poorly
  • Solution: Shifted window attention (Swin-style); cross-window connections

Training Instability

  • Problem: Spiking transformers can be hard to train
  • Solution: Layer normalization before attention; warm-up schedule; gradient clipping

Related Skills

  • bsvit-burst-spiking-vision-transformer: Burst spiking ViT
  • winner-take-all-spiking: WTA spiking transformer
  • quantized-snn-hardware-optimization: Quantization for SNN hardware

References

@article{chen2026gemst,
  title={Ge$^\\text{2}$mS-T: Multi-Dimensional Grouping for Ultra-High Energy Efficiency in Spiking Transformer},
  author={Chen, Qinyu and Song, Zhenxi and Zhang, Ziyang and others},
  journal={arXiv preprint arXiv:2604.08894},
  year={2026}
}
Install via CLI
npx skills add https://github.com/hiyenwong/ai_collection --skill gemst-multidimensional-grouping-snn
Repository Details
star Stars 1
call_split Forks 0
navigation Branch main
article Path SKILL.md
More from Creator