braindinobrain-mri-foundation

star 2

BrainDINO: Brain MRI Foundation Model for generalizable clinical representation learning. Self-distilled foundation model trained on 6.6M unlabeled axial slices from diverse brain MRI datasets, generalizing across heterogeneous endpoints without task-specific fine-tuning. Trigger words: BrainDINO, brain MRI foundation model, self-supervised MRI, clinical representation learning, neuroimaging foundation, DINO brain MRI, medical imaging foundation model, brain MRI self-distillation.

hiyenwong By hiyenwong schedule Updated 6/3/2026

name: braindinobrain-mri-foundation description: "BrainDINO: Brain MRI Foundation Model for generalizable clinical representation learning. Self-distilled foundation model trained on 6.6M unlabeled axial slices from diverse brain MRI datasets, generalizing across heterogeneous endpoints without task-specific fine-tuning. Trigger words: BrainDINO, brain MRI foundation model, self-supervised MRI, clinical representation learning, neuroimaging foundation, DINO brain MRI, medical imaging foundation model, brain MRI self-distillation."

BrainDINO: Brain MRI Foundation Model

Overview

BrainDINO is a self-supervised foundation model for brain MRI that learns generalizable representations from ~6.6 million unlabeled axial slices, enabling zero-shot and few-shot transfer to diverse clinical endpoints.

Key Results

  • Training data: ~6.6M unlabeled axial brain MRI slices
  • Method: Self-distillation (DINO-style) with multi-crop augmentation
  • Generalization: Works across heterogeneous brain MRI endpoints
  • Clinical relevance: No task-specific labeled data needed for representation learning

Architecture

DINO Framework for Brain MRI

Input MRI Slice ──→ [Student Encoder (ViT)] ──→ Student Features
                        ↓
                   Multi-crop views
                        ↓
                 [Teacher Encoder (ViT)] ──→ Teacher Features
                        ↓ (EMA update)
                   Cross-entropy loss
                 (Student matches Teacher)

Key Components

  1. Vision Transformer (ViT) backbone — patch-based self-attention
  2. Multi-crop augmentation — global + local views of brain slices
  3. Teacher-student distillation — teacher updated via EMA
  4. Centering & sharpening — prevent collapsed representations
  5. Positional encoding — adapted for 2D medical images

Pretraining Procedure

import torch
import torch.nn as nn
import torchvision.transforms as transforms

class BrainDINOTrainer:
    def __init__(self, student, teacher, momentum=0.996,
                 warmup_epochs=10, total_epochs=300):
        self.student = student
        self.teacher = teacher
        self.momentum = momentum
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        
        # Initialize teacher with student weights
        self._init_teacher()
        
    def _init_teacher(self):
        """Initialize teacher with student weights."""
        for param_s, param_t in zip(
            self.student.parameters(), 
            self.teacher.parameters()
        ):
            param_t.data.copy_(param_s.data)
            param_t.requires_grad = False
    
    def update_teacher(self, epoch):
        """EMA update of teacher from student."""
        momentum = self._get_momentum(epoch)
        for param_s, param_t in zip(
            self.student.parameters(),
            self.teacher.parameters()
        ):
            param_t.data = momentum * param_t.data + \
                          (1 - momentum) * param_s.data
    
    def _get_momentum(self, epoch):
        """Cosine schedule for momentum."""
        return 1 - (1 - self.momentum) * (
            (1 + torch.cos(torch.pi * epoch / self.total_epochs)) / 2
        )
    
    def compute_loss(self, student_out, teacher_out, temp_student=0.1,
                     temp_teacher=0.04):
        """DINO loss: cross-entropy between student and teacher."""
        # Apply temperature scaling
        student_logits = student_out / temp_student
        teacher_probs = torch.softmax(teacher_out / temp_teacher, dim=-1)
        
        # Cross-entropy loss
        loss = -torch.sum(
            teacher_probs * torch.log_softmax(student_logits, dim=-1),
            dim=-1
        ).mean()
        
        return loss
    
    def train_step(self, batch, optimizer, epoch):
        """Single training step."""
        # Generate multi-crop views
        global_views, local_views = self._augment(batch)
        
        # Student forward (all views)
        student_out = [self.student(v) for v in global_views + local_views]
        
        # Teacher forward (global views only)
        with torch.no_grad():
            teacher_out = [self.teacher(v) for v in global_views]
        
        # Compute loss (each global view matches teacher)
        loss = 0
        for s_out in student_out[:len(global_views)]:
            for t_out in teacher_out:
                loss += self.compute_loss(s_out, t_out)
        
        # Backward (student only)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update teacher
        self.update_teacher(epoch)
        
        return loss.item()

Augmentation Pipeline for Brain MRI

def get_brain_mri_augmentation():
    """Multi-crop augmentation for brain MRI slices."""
    global_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.4, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    local_transform = transforms.Compose([
        transforms.RandomResizedCrop(96, scale=(0.05, 0.4)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(30),
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    return global_transform, local_transform

Downstream Tasks

Linear Probing

class LinearProbe(nn.Module):
    """Linear classifier on top of frozen BrainDINO features."""
    def __init__(self, backbone, num_classes, feature_dim=768):
        super().__init__()
        self.backbone = backbone
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.head = nn.Linear(feature_dim, num_classes)
    
    def forward(self, x):
        with torch.no_grad():
            features = self.backbone(x)
        return self.head(features)

Few-Shot Fine-Tuning

def few_shot_finetune(backbone, train_loader, val_loader, 
                      num_classes, epochs=50, lr=1e-4):
    """Fine-tune with limited labeled data."""
    model = nn.Sequential(
        backbone,
        nn.Linear(768, num_classes)
    )
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        model.train()
        for batch_x, batch_y in train_loader:
            optimizer.zero_grad()
            logits = model(batch_x)
            loss = criterion(logits, batch_y)
            loss.backward()
            optimizer.step()
        
        # Validate
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                logits = model(batch_x)
                _, predicted = logits.max(1)
                correct += predicted.eq(batch_y).sum().item()
                total += batch_y.size(0)
        
        acc = correct / total
        print(f"Epoch {epoch}: val_acc={acc:.4f}")
    
    return model

Clinical Applications

Task Input Output Notes
Tumor segmentation MRI slice Pixel mask Use feature maps for U-Net decoder
Disease classification MRI volume Diagnosis Pool slice features
Age prediction MRI slice Age (years) Regression head
Atrophy quantification MRI slice Volume loss Feature-based regression
Anomaly detection MRI slice Anomaly score Reconstruction or distance

Advantages Over Task-Specific Models

  1. Data efficiency: Pretrained on unlabeled data, fine-tune with few labeled examples
  2. Generalization: Single model works across diverse clinical endpoints
  3. Transfer learning: Features transfer to unseen tasks
  4. Scalability: Training scales with unlabeled data availability
  5. Robustness: Learns invariant features across scanners and protocols

Best Practices

  1. Multi-crop augmentation: Use 2 global + 6 local crops per batch
  2. Momentum scheduling: Cosine schedule from 0.996 to 1.0
  3. Temperature tuning: Student temp ~0.1, teacher temp ~0.04
  4. Centering: Apply centering to teacher output to prevent collapse
  5. Batch size: Large batches (256+) improve representation quality
  6. Patch size: 16×16 patches work well for 224×224 brain MRI slices
  7. Positional encoding: Use learnable 2D positional embeddings

Reference

arXiv: 2604.27277 (2026-04-30) Authors: Wu, Wang, Li, et al. URL: https://arxiv.org/abs/2604.27277

Install via CLI
npx skills add https://github.com/hiyenwong/ai_collection --skill braindinobrain-mri-foundation
Repository Details
star Stars 2
call_split Forks 0
navigation Branch main
article Path SKILL.md
More from Creator