eeg-diffusion-visual-reconstruction

star 2

Structure-Guided Diffusion Model (SGDM) for EEG-based visual cognition reconstruction. Combines structurally supervised VAE, spatiotemporal EEG encoder with contrastive learning, and ControlNet-guided diffusion for high-fidelity visual reconstruction from brain signals. Activation: EEG visual reconstruction, brain-computer interface image generation, SGDM, neural decoding visual.

hiyenwong By hiyenwong schedule Updated 6/3/2026

name: eeg-diffusion-visual-reconstruction description: "Structure-Guided Diffusion Model (SGDM) for EEG-based visual cognition reconstruction. Combines structurally supervised VAE, spatiotemporal EEG encoder with contrastive learning, and ControlNet-guided diffusion for high-fidelity visual reconstruction from brain signals. Activation: EEG visual reconstruction, brain-computer interface image generation, SGDM, neural decoding visual."

EEG Diffusion Visual Reconstruction

Structure-Guided Diffusion Model (SGDM) for decoding visual information from EEG signals using a multi-stage generative framework with explicit structural guidance.

Metadata

  • Source: arXiv:2604.22649v1
  • Authors: Yongxiang Lian, Yueyang Cang, Pingge Hu, et al.
  • Published: 2026-04-24
  • Categories: cs.NE, cs.CV

Core Methodology

Problem Statement

Traditional EEG-based visual decoding methods are limited to:

  • Natural images only
  • Categorical representations
  • Limited structural feature capture
  • No differentiation between objective perception vs. subjective cognition

SGDM Architecture

Stage 1: Structural Supervision via VAE

Input: Abstract visual objects (Kilogram dataset) / Natural images (THINGS dataset)
↓
Structurally Supervised Variational Autoencoder (VAE)
- Learns explicit structural representations
- Captures geometric and semantic structure
- Provides structural conditioning signals

Stage 2: Spatiotemporal EEG Encoding

Input: EEG signals during visual cognition
↓
Spatiotemporal EEG Encoder
- Temporal dynamics modeling (LSTM/Transformer)
- Spatial feature extraction (CNN/GNN)
- Multi-scale temporal receptive fields
↓
Contrastive Learning Alignment
- EEG embeddings ↔ Visual embeddings
- Joint embedding space for cross-modal alignment

Stage 3: ControlNet-Guided Diffusion

Input: Aligned EEG features + Structural guidance
↓
ControlNet (based on Stable Diffusion)
- Structural conditioning from VAE
- EEG feature injection at multiple resolutions
- Guided denoising process
↓
Output: Reconstructed visual images

Key Components

1. Structural Guidance Mechanism

  • ControlNet Integration: Injects structural information into diffusion process
  • Multi-scale Conditioning: Structure guidance at different resolutions
  • Geometry-aware Generation: Preserves spatial relationships and object topology

2. Spatiotemporal EEG Encoding

  • Temporal Modeling: Captures evoked response dynamics
  • Spatial Localization: Electrode-level feature extraction
  • Cross-modal Alignment: Contrastive learning with visual embeddings

3. Two-stage Training

  1. Pre-training: VAE on visual structure + EEG encoder with contrastive loss
  2. Fine-tuning: End-to-end diffusion model with ControlNet

Implementation Guide

Prerequisites

# Core dependencies
torch >= 2.0
torchvision
diffusers  # Hugging Face Diffusers for ControlNet
transformers
einops

Step-by-Step Implementation

Step 1: Data Preparation

import numpy as np
from torch.utils.data import Dataset

class EEGVisualDataset(Dataset):
    """
    Dataset for EEG-visual cognition pairs
    Expected format:
    - EEG: (channels, time_points) or (trials, channels, time_points)
    - Images: (H, W, C) RGB images
    """
    def __init__(self, eeg_data, images, labels=None):
        self.eeg_data = eeg_data
        self.images = images
        self.labels = labels
        
    def __len__(self):
        return len(self.eeg_data)
    
    def __getitem__(self, idx):
        return {
            'eeg': self.eeg_data[idx],
            'image': self.images[idx],
            'label': self.labels[idx] if self.labels is not None else 0
        }

Step 2: Structural VAE

import torch
import torch.nn as nn

class StructuralVAE(nn.Module):
    """
    VAE with structural supervision for visual features
    """
    def __init__(self, latent_dim=512, structure_dim=256):
        super().__init__()
        self.latent_dim = latent_dim
        self.structure_dim = structure_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),  # 128->64
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),  # 64->32
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1),  # 32->16
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(256 * 16 * 16, 1024),
            nn.ReLU()
        )
        
        # Latent space
        self.fc_mu = nn.Linear(1024, latent_dim)
        self.fc_logvar = nn.Linear(1024, latent_dim)
        
        # Structure head
        self.structure_head = nn.Linear(latent_dim, structure_dim)
        
        # Decoder (simplified)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256 * 16 * 16),
            nn.Unflatten(1, (256, 16, 16)),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Sigmoid()
        )
    
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        return self.decoder(z)
    
    def get_structure(self, z):
        return self.structure_head(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        structure = self.get_structure(z)
        return recon, mu, logvar, structure

Step 3: Spatiotemporal EEG Encoder

class SpatiotemporalEEGEncoder(nn.Module):
    """
    EEG encoder with temporal and spatial modeling
    """
    def __init__(self, n_channels=64, n_timepoints=500, latent_dim=512):
        super().__init__()
        self.n_channels = n_channels
        self.n_timepoints = n_timepoints
        
        # Temporal encoder (LSTM)
        self.temporal_encoder = nn.LSTM(
            input_size=n_channels,
            hidden_size=256,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )
        
        # Spatial encoder (1D conv over channels)
        self.spatial_encoder = nn.Sequential(
            nn.Conv1d(n_timepoints, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )
        
        # Fusion and projection
        self.fusion = nn.Sequential(
            nn.Linear(512 + 256, 1024),  # 512 from bidirectional LSTM
            nn.ReLU(),
            nn.Linear(1024, latent_dim)
        )
    
    def forward(self, eeg):
        # eeg: (batch, channels, timepoints)
        
        # Temporal: (batch, timepoints, channels)
        eeg_t = eeg.transpose(1, 2)
        temporal_out, _ = self.temporal_encoder(eeg_t)
        temporal_feat = temporal_out[:, -1, :]  # Last hidden state
        
        # Spatial
        spatial_feat = self.spatial_encoder(eeg).squeeze(-1)
        
        # Fusion
        combined = torch.cat([temporal_feat, spatial_feat], dim=-1)
        embedding = self.fusion(combined)
        
        return embedding

Step 4: Contrastive Learning Alignment

import torch.nn.functional as F

class ContrastiveAlignment(nn.Module):
    """
    Contrastive learning for EEG-visual alignment
    """
    def __init__(self, temp=0.07):
        super().__init__()
        self.temp = temp
    
    def forward(self, eeg_embed, visual_embed):
        """
        NT-Xent loss (Normalized Temperature-scaled Cross Entropy)
        """
        # Normalize embeddings
        eeg_embed = F.normalize(eeg_embed, dim=-1)
        visual_embed = F.normalize(visual_embed, dim=-1)
        
        # Cosine similarity
        logits = torch.mm(eeg_embed, visual_embed.t()) / self.temp
        
        # Labels: diagonal is positive pairs
        labels = torch.arange(logits.shape[0]).to(logits.device)
        
        # Symmetric loss
        loss_i = F.cross_entropy(logits, labels)
        loss_t = F.cross_entropy(logits.t(), labels)
        loss = (loss_i + loss_t) / 2
        
        return loss

Step 5: SGDM Training Loop

from diffusers import ControlNetModel, StableDiffusionControlNetPipeline

def train_sgdm(eeg_encoder, vae, diffusion_pipeline, dataloader, epochs=100):
    """
    Training loop for SGDM
    """
    optimizer = torch.optim.AdamW(
        list(eeg_encoder.parameters()) + 
        list(vae.parameters()),
        lr=1e-4
    )
    
    contrastive_loss_fn = ContrastiveAlignment()
    
    for epoch in range(epochs):
        for batch in dataloader:
            eeg = batch['eeg']
            images = batch['image']
            
            # VAE forward
            recon, mu, logvar, structure = vae(images)
            
            # EEG encoding
            eeg_embed = eeg_encoder(eeg)
            
            # Contrastive alignment
            visual_embed = vae.encode(images)[0]  # Use mu as visual embedding
            align_loss = contrastive_loss_fn(eeg_embed, visual_embed)
            
            # VAE losses
            recon_loss = F.mse_loss(recon, images)
            kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            
            # Total loss
            total_loss = recon_loss + 0.001 * kl_loss + align_loss
            
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

Applications

Brain-Computer Interfaces

  • Visual Prosthetics: Reconstruct perceived visual stimuli
  • Communication Aid: Generate images from imagined visual content
  • Neural Feedback: Real-time visual feedback from EEG

Neuroscience Research

  • Visual Cortex Mapping: Understand EEG-visual representation relationships
  • Perception Studies: Differentiate objective vs. subjective visual experience
  • Cognitive State Monitoring: Track visual attention and cognition

Clinical Applications

  • Visual Pathway Assessment: Diagnose visual processing disorders
  • Locked-in Syndrome Communication: Image-based communication for paralysis patients
  • Rehabilitation Monitoring: Track visual recovery after brain injury

Pitfalls and Limitations

Data Requirements

  • Large paired datasets: Need EEG-image pairs for training
  • High-quality EEG: Requires clean, artifact-free signals
  • Individual variability: Models may not generalize across subjects

Technical Challenges

  • Temporal resolution: EEG has limited spatial resolution (~cm scale)
  • Signal noise: EEG is susceptible to artifacts (eye movements, muscle activity)
  • Individual calibration: Per-subject fine-tuning often required

Ethical Considerations

  • Privacy: Reconstructing mental imagery raises privacy concerns
  • Consent: Clear informed consent for visual thought decoding
  • Security: Protecting neural data from unauthorized access

Related Skills

  • eeg-hopfield-emotion-energy
  • brain-inspired-snn-pattern-analysis
  • meta-learning-in-context-brain-decoding
  • neural-population-decoding

References

@article{lian2026sgdm,
  title={Structure-Guided Diffusion Model for EEG-Based Visual Cognition Reconstruction},
  author={Lian, Yongxiang and Cang, Yueyang and Hu, Pingge and others},
  journal={arXiv preprint arXiv:2604.22649},
  year={2026}
}
Install via CLI
npx skills add https://github.com/hiyenwong/ai_collection --skill eeg-diffusion-visual-reconstruction
Repository Details
star Stars 2
call_split Forks 0
navigation Branch main
article Path SKILL.md
More from Creator