grouped-spiking-transformer-gemmst

star 1

分组脉冲Transformer (Ge²mS-T) - 多维分组策略实现超高能效。将脉冲神经网络应用于Transformer架构,通过分组自注意力降低计算复杂度。适用于边缘设备部署和神经形态计算。激活: spiking transformer, grouped attention, energy efficient, neuromorphic computing

hiyenwong By hiyenwong schedule Updated 6/3/2026

name: grouped-spiking-transformer-gemmst description: "分组脉冲Transformer (Ge²mS-T) - 多维分组策略实现超高能效。将脉冲神经网络应用于Transformer架构,通过分组自注意力降低计算复杂度。适用于边缘设备部署和神经形态计算。激活: spiking transformer, grouped attention, energy efficient, neuromorphic computing" arxiv: "2604.08894" date: "2026-04-10" category: neuromorphic-computing tags: ["spiking-transformer", "grouped-attention", "energy-efficiency", "neuromorphic", "vision-transformer"]

Ge²mS-T: 分组脉冲Transformer

论文信息

  • 标题: Ge²mS-T: Multi-Dimensional Grouping for Ultra-High Energy Efficiency in Spiking Transformer
  • 作者: Zecheng Hao, Shenghao Xie, Kang Chen
  • arXiv ID: 2604.08894
  • 发布日期: 2026-04-10
  • 类别: cs.NE, cs.AI

核心创新

1. 分组脉冲自注意力 (Grouped Spiking Self-Attention)

将标准自注意力的O(n²)复杂度降低到O(n²/g),其中g为组大小。

2. 跨组信息交换

通过精心设计的信息交换机制保持表示能力。

3. 脉冲梯度估计

专门为分组脉冲操作设计的梯度估计方法。

技术规格

指标 数值
能效提升 4.2×
准确率损失 <1%
神经形态硬件性能 15.3 TOPS/W
复杂度 O(n²/g)

架构设计

分组注意力机制

class GroupedSpikingAttention(nn.Module):
    """分组脉冲自注意力"""
    
    def __init__(self, dim, num_heads=8, num_groups=4):
        super().__init__()
        self.num_groups = num_groups
        self.group_dim = dim // num_groups
        
        # 每组独立的注意力
        self.group_attns = nn.ModuleList([
            SpikingSelfAttention(self.group_dim, num_heads)
            for _ in range(num_groups)
        ])
        
        # 跨组信息交换
        self.cross_group_exchange = CrossGroupExchange(dim, num_groups)
    
    def forward(self, x):
        B, N, D = x.shape
        
        # 分组处理
        x_groups = x.view(B, N, self.num_groups, self.group_dim)
        group_outputs = []
        
        for g in range(self.num_groups):
            group_out = self.group_attns[g](x_groups[:, :, g, :])
            group_outputs.append(group_out)
        
        # 合并并交换信息
        output = torch.stack(group_outputs, dim=2).view(B, N, D)
        output = self.cross_group_exchange(output)
        
        return output

脉冲Transformer块

class Ge2mSTransformerBlock(nn.Module):
    """Ge²mS-T Transformer块"""
    
    def __init__(self, dim, num_heads, num_groups=4):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = GroupedSpikingAttention(dim, num_heads, num_groups)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = SpikingMLP(dim, hidden_dim=dim*4)
    
    def forward(self, x, time_step):
        # 脉冲注意力
        x = x + self.attn(self.norm1(x), time_step)
        # 脉冲MLP
        x = x + self.mlp(self.norm2(x), time_step)
        return x

训练策略

替代梯度学习

class SurrogateGradient(torch.autograd.Function):
    """分组脉冲函数的替代梯度"""
    
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return (input > 0).float()  # Heaviside
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        # 使用矩形窗口近似
        grad_input = grad_output.clone()
        grad_input[(input < -1) | (input > 1)] = 0
        return grad_input

时序学习

def temporal_training(model, data_loader, num_time_steps):
    """时序训练循环"""
    for batch in data_loader:
        # 初始化膜电位
        model.reset_membrane_potentials()
        
        # 前向传播多个时间步
        outputs = []
        for t in range(num_time_steps):
            output = model(batch['input'], t)
            outputs.append(output)
        
        # 基于脉冲计数计算损失
        spike_counts = torch.stack(outputs).sum(dim=0)
        loss = criterion(spike_counts, batch['target'])
        loss.backward()

性能评估

ImageNet结果

模型 参数量 Top-1 Acc 能效(TOPS/W)
ViT-S 22M 81.8% 3.2
Spiking-ViT 22M 80.2% 8.5
Ge²mS-T (g=4) 22M 80.5% 15.3

CIFAR-10结果

  • 准确率: 96.8%
  • 能耗: 0.12 mJ/inference
  • 延迟: 4.2 ms

应用场景

1. 边缘视觉识别

  • 智能摄像头
  • 无人机视觉
  • 移动设备

2. 神经形态芯片部署

  • Intel Loihi
  • IBM TrueNorth
  • 定制ASIC

3. 持续学习系统

  • 低功耗在线学习
  • 终身学习代理

优化技巧

1. 分组大小选择

group_size_guide = {
    'classification': 4,    # 标准分类
    'detection': 2,         # 目标检测
    'segmentation': 8,      # 语义分割
}

2. 脉冲阈值调优

  • 初始阈值: 1.0
  • 学习率: 0.1 × 标准学习率
  • 温度参数: 0.5-1.0

激活关键词

  • spiking transformer
  • grouped attention
  • energy efficient ViT
  • neuromorphic transformer
  • Ge²mS-T
  • ultra-low power vision

引用

Hao, Z., Xie, S., & Chen, K. (2026). Ge²mS-T: Multi-Dimensional Grouping for Ultra-High Energy Efficiency in Spiking Transformer. arXiv:2604.08894.

Install via CLI
npx skills add https://github.com/hiyenwong/ai_collection --skill grouped-spiking-transformer-gemmst
Repository Details
star Stars 1
call_split Forks 0
navigation Branch main
article Path SKILL.md
More from Creator