neurodegenerative-4d-diffusion

star 2

4D (3D×T) diffusion-based generative framework for modeling neurodegenerative brain anatomy progression. Combines spatial and temporal modeling for longitudinal brain imaging and disease progression prediction. Keywords: neurodegenerative disease, 4D diffusion model, longitudinal brain imaging, brain anatomy modeling, disease progression prediction, generative AI.

hiyenwong By hiyenwong schedule Updated 6/4/2026

name: neurodegenerative-4d-diffusion description: "4D (3D×T) diffusion-based generative framework for modeling neurodegenerative brain anatomy progression. Combines spatial and temporal modeling for longitudinal brain imaging and disease progression prediction. Keywords: neurodegenerative disease, 4D diffusion model, longitudinal brain imaging, brain anatomy modeling, disease progression prediction, generative AI."

Neurodegenerative Brain Anatomy with 4D Diffusion

4D (3D×T) diffusion-based generative framework for modeling neurodegenerative brain anatomy progression, enabling realistic synthesis of longitudinal structural changes in diseases like Alzheimer's and Parkinson's.

Metadata

  • Source: arXiv:2604.22700v1
  • Authors: Nivetha Jayakumar, Swakshar Deb, Bahram Jafrasteh, et al.
  • Published: 2026-04-24
  • Category: eess.IV, cs.CV, cs.LG

Core Methodology

Problem Statement

Understanding neurodegenerative disease progression requires:

  • Longitudinal data (scarce due to patient dropout, death)
  • Multiple timepoints per subject (limited availability)
  • Accurate modeling of spatial and temporal patterns
  • Prediction of future brain states

4D Diffusion Framework

1. 4D Brain Representation

  • Spatial (3D): Brain MRI structure at each timepoint
  • Temporal (T): Disease progression trajectory
  • Joint space: X × Y × Z × T continuous representation

2. Conditional Diffusion Model

Base Image (t=0) → 4D Diffusion → Progressive Anatomical Changes
                          ↑
                   Disease Parameters (age, diagnosis, etc.)

3. Key Innovations

  • Spatio-temporal attention: Model interactions across space and time
  • Conditioning mechanisms: Age, diagnosis, genetic markers
  • Progressive synthesis: Generate realistic trajectories
  • Uncertainty quantification: Model disease progression variability

Architecture

4D U-Net Backbone

import torch
import torch.nn as nn

class FourDUNet(nn.Module):
    def __init__(self, in_channels=1, time_dim=256):
        super().__init__()
        
        # 3D convolutions for spatial features
        self.spatial_encoder = nn.ModuleList([
            self._make_3d_block(in_channels, 64),
            self._make_3d_block(64, 128),
            self._make_3d_block(128, 256),
        ])
        
        # Temporal attention for longitudinal modeling
        self.temporal_attention = SpatioTemporalAttention(
            dim=256, num_heads=8
        )
        
        # Decoder with skip connections
        self.spatial_decoder = nn.ModuleList([
            self._make_3d_block(256, 128, upsample=True),
            self._make_3d_block(128, 64, upsample=True),
            self._make_3d_block(64, 1, upsample=True),
        ])
        
    def _make_3d_block(self, in_ch, out_ch, upsample=False):
        layers = []
        if upsample:
            layers.append(nn.ConvTranspose3d(in_ch, out_ch, 4, 2, 1))
        else:
            layers.append(nn.Conv3d(in_ch, out_ch, 3, 2, 1))
        layers.extend([
            nn.GroupNorm(8, out_ch),
            nn.SiLU(),
            nn.Conv3d(out_ch, out_ch, 3, 1, 1),
            nn.GroupNorm(8, out_ch),
            nn.SiLU(),
        ])
        return nn.Sequential(*layers)
    
    def forward(self, x, timestep, condition):
        # x: [batch, timepoints, channels, D, H, W]
        B, T, C, D, H, W = x.shape
        
        # Process each timepoint
        features = []
        for t in range(T):
            x_t = x[:, t]  # [B, C, D, H, W]
            
            # Spatial encoding
            for encoder in self.spatial_encoder:
                x_t = encoder(x_t)
            features.append(x_t)
        
        # Stack temporal dimension
        x_stacked = torch.stack(features, dim=1)  # [B, T, C, d, h, w]
        
        # Apply temporal attention
        x_temporal = self.temporal_attention(x_stacked, condition)
        
        # Decode with temporal consistency
        outputs = []
        for t in range(T):
            x_dec = x_temporal[:, t]
            for decoder in self.spatial_decoder:
                x_dec = decoder(x_dec)
            outputs.append(x_dec)
        
        return torch.stack(outputs, dim=1)  # [B, T, 1, D, H, W]


class SpatioTemporalAttention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        
    def forward(self, x, condition):
        # x: [B, T, C, D, H, W]
        B, T, C, D, H, W = x.shape
        
        # Flatten spatial dimensions
        x_flat = x.view(B, T, C, -1).permute(0, 1, 3, 2)  # [B, T, S, C]
        
        # Add conditioning
        x_cond = x_flat + condition.unsqueeze(1)
        
        # Self-attention over time and space
        qkv = self.qkv(x_cond).reshape(B, T * D * H * W, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, heads, tokens, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        
        x_attn = (attn @ v).transpose(1, 2).reshape(B, T, D*H*W, C)
        x_attn = self.proj(x_attn)
        
        # Reshape back
        output = x_attn.permute(0, 1, 3, 2).reshape(B, T, C, D, H, W)
        return output + x  # Residual connection

Conditioning Mechanisms

class DiseaseConditioner(nn.Module):
    def __init__(self, cond_dim=256):
        super().__init__()
        
        # Demographic conditioning
        self.demographic_embed = nn.Sequential(
            nn.Linear(4, 64),  # age, sex, education, APOE status
            nn.SiLU(),
            nn.Linear(64, 128),
        )
        
        # Diagnosis embedding
        self.diagnosis_embed = nn.Embedding(5, 128)  # CN, MCI, AD, etc.
        
        # Time embedding
        self.time_embed = nn.Sequential(
            nn.Linear(1, 64),
            nn.SiLU(),
            nn.Linear(64, 128),
        )
        
        # Fusion
        self.fusion = nn.Sequential(
            nn.Linear(128 + 128 + 128, cond_dim),
            nn.SiLU(),
            nn.Linear(cond_dim, cond_dim),
        )
        
    def forward(self, age, sex, education, apoe, diagnosis, time_delta):
        demo = self.demographic_embed(
            torch.stack([age, sex, education, apoe], dim=-1)
        )
        diag = self.diagnosis_embed(diagnosis)
        time = self.time_embed(time_delta.unsqueeze(-1))
        
        combined = torch.cat([demo, diag, time], dim=-1)
        condition = self.fusion(combined)
        return condition

Training Strategy

def train_4d_diffusion(model, dataloader, epochs=500):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        for batch in dataloader:
            # Unpack longitudinal data
            baseline, followups, time_deltas, conditions = batch
            # baseline: [B, 1, D, H, W]
            # followups: [B, T-1, 1, D, H, W]
            
            B, T, C, D, H, W = baseline.shape[0], followups.shape[1] + 1, 1,                               followups.shape[3], followups.shape[4], followups.shape[5]
            
            # Combine baseline and followups
            full_trajectory = torch.cat([baseline.unsqueeze(1), followups], dim=1)
            
            # Sample random timesteps
            t = torch.randint(0, 1000, (B,))
            
            # Add noise to trajectory
            noise = torch.randn_like(full_trajectory)
            alpha_t = get_alpha_schedule(t).view(B, 1, 1, 1, 1, 1)
            noisy_traj = torch.sqrt(alpha_t) * full_trajectory +                         torch.sqrt(1 - alpha_t) * noise
            
            # Predict noise
            predicted_noise = model(noisy_traj, t, conditions)
            
            # Compute loss
            loss = F.mse_loss(predicted_noise, noise)
            
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        if epoch % 50 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Inference and Generation

def generate_longitudinal_trajectory(model, baseline, conditions, 
                                     n_followups=5, num_steps=50):
    """
    Generate longitudinal brain changes from baseline.
    
    Args:
        model: Trained 4D diffusion model
        baseline: Initial MRI scan [B, 1, D, H, W]
        conditions: Disease conditions
        n_followups: Number of future timepoints
        num_steps: Diffusion sampling steps
        
    Returns:
        trajectory: [B, n_followups+1, 1, D, H, W]
    """
    model.eval()
    
    # Initialize with noise
    trajectory = torch.randn(baseline.shape[0], n_followups + 1, 1, 
                           baseline.shape[2], baseline.shape[3], 
                           baseline.shape[4])
    
    # Fix baseline at t=0
    trajectory[:, 0] = baseline
    
    # Iterative denoising
    for i in range(num_steps - 1, -1, -1):
        t = torch.full((trajectory.shape[0],), i)
        
        with torch.no_grad():
            noise_pred = model(trajectory, t, conditions)
            
            # Denoise
            alpha_t = get_alpha_schedule(t)
            alpha_prev = get_alpha_schedule(i - 1) if i > 0 else 1.0
            
            trajectory = (trajectory - torch.sqrt(1 - alpha_t) * noise_pred) /                         torch.sqrt(alpha_t)
            trajectory = torch.sqrt(alpha_prev) * trajectory
            
            # Keep baseline fixed
            trajectory[:, 0] = baseline
            
            if i > 0:
                noise = torch.randn_like(trajectory)
                trajectory = trajectory + torch.sqrt(1 - alpha_prev) * noise
    
    return trajectory

Applications

1. Disease Progression Modeling

  • Alzheimer's disease trajectory prediction
  • Parkinson's structural change modeling
  • Multiple sclerosis lesion evolution
  • Normal aging brain changes

2. Clinical Trial Simulation

  • Virtual patient cohort generation
  • Treatment effect estimation
  • Sample size optimization
  • Biomarker validation

3. Personalized Medicine

  • Individual progression forecasting
  • Risk stratification
  • Treatment planning
  • Monitoring schedule optimization

4. Research Tool

  • Hypothesis generation
  • Pathway identification
  • Multi-modal integration (MRI, PET, CSF)
  • Cross-population analysis

Performance Metrics

Metric ADNI Dataset Synthetic Evaluation
SSIM 0.92 ± 0.03 0.89 ± 0.04
LPIPS 0.08 ± 0.02 0.11 ± 0.03
FID 12.3 ± 2.1 18.5 ± 3.2
Temporal Consistency 0.94 ± 0.02 0.91 ± 0.03

Pitfalls

Data Limitations

  1. Small sample sizes: Longitudinal data is scarce
  2. Irregular sampling: Time intervals vary between subjects
  3. Missing data: Dropout and death create gaps
  4. Registration errors: Alignment artifacts affect quality

Technical Challenges

  • High memory requirements for 4D volumes
  • Long training times (days to weeks)
  • Difficult to validate against ground truth
  • Limited interpretability of learned patterns

Clinical Considerations

  • Synthetic data should not replace clinical judgment
  • Regulatory approval needed for clinical use
  • Ethical considerations for synthetic patient data
  • Generalization across scanners and protocols

Related Skills

  • brain-dit-fmri-foundation-model
  • brain-graph-augmentation-template
  • dgcl-brain-network-construction
  • neurodegenerative-4d-diffusion (existing)

References

  • Jayakumar, N., et al. (2026). Generative Modeling of Neurodegenerative Brain Anatomy with 4D Longitudinal Diffusion Model. arXiv:2604.22700.
  • Pinaya, W.H.L., et al. (2022). Brain imaging generation with latent diffusion models. MICCAI.
  • Smith, S.M., et al. (2012). The effects of transcranial direct current stimulation (tDCS) on brain imaging. NeuroImage.
Install via CLI
npx skills add https://github.com/hiyenwong/ai_collection --skill neurodegenerative-4d-diffusion
Repository Details
star Stars 2
call_split Forks 0
navigation Branch main
article Path SKILL.md
More from Creator