name: kaggle-tpu-accelerate-guide description: "Kaggle TPU v5e-8 + Accelerate 分布式训练实战指南。涵盖环境配置、核心 API、常见陷阱及解决方案。Triggers: TPU训练, Kaggle TPU, Accelerate, 分布式训练, TPU补丁, TPU环境配置。"
Kaggle TPU v5e-8 + Accelerate 分布式训练实战指南
本文总结了在 Kaggle TPU v5e-8 上使用 PyTorch + Accelerate 进行分布式深度学习训练的完整经验。涵盖环境配置、核心 API、常见陷阱及解决方案。适用于任意模型架构(Transformer、CNN、MLP 等)。
1. 环境准备
1.1 安装依赖
Kaggle TPU VM 预装了 PyTorch 和 torch_xla,但需要额外安装 accelerate:
!pip install accelerate tiktoken -q
关键:必须卸载 tensorflow 并换为 tensorflow-cpu。Kaggle TPU 环境默认安装的 tensorflow 会占用 TPU 资源,与 PyTorch XLA 冲突:
!pip uninstall -y -q tensorflow
!pip install -q tensorflow-cpu
1.2 TPU 环境补丁(必须最先执行)
这段代码必须在任何 import torch 之前执行,否则会报错或行为异常:
import os
# 告诉 XLA 使用 TPU 作为设备
os.environ["PJRT_DEVICE"] = "TPU"
# 限制 CPU 线程数,避免多进程竞争
os.environ["OMP_NUM_THREADS"] = "1"
# 删除这两个环境变量,它们会干扰 accelerate 的多进程启动
if "TPU_PROCESS_ADDRESSES" in os.environ:
del os.environ["TPU_PROCESS_ADDRESSES"]
if "CLOUD_TPU_TASK_ID" in os.environ:
del os.environ["CLOUD_TPU_TASK_ID"]
# TPU 内存优化:增大 tensor 分配器上限
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'
为什么需要这些?
| 环境变量 | 作用 | 不设会怎样 |
|---|---|---|
PJRT_DEVICE=TPU |
告诉 PyTorch/XLA 使用 PJRT 运行时连接 TPU | 报错找不到设备 |
OMP_NUM_THREADS=1 |
限制 OpenMP 线程数 | 8 个进程各抢 CPU 线程,性能下降甚至死锁 |
删除 TPU_PROCESS_ADDRESSES |
避免 XLA 尝试连接不存在的多节点 TPU | 启动报错 |
删除 CLOUD_TPU_TASK_ID |
清除遗留的多节点标识 | 进程标识混乱 |
XLA_TENSOR_ALLOCATOR_MAXSIZE |
增大 XLA 张量分配器缓存上限 | 大张量分配失败或频繁重新分配 |
1.3 Kaggle Notebook 启动方式
from accelerate import notebook_launcher
def train_worker():
# 所有训练逻辑写在这里
# 每个 TPU core 运行一个此函数的实例
...
# 启动 8 个进程,对应 TPU v5e-8 的 8 个 core
notebook_launcher(train_worker, num_processes=8)
notebook_launcher 会 fork 出 8 个进程,每个进程绑定一个 TPU core。所有进程执行同一个 train_worker 函数,通过 accelerator.process_index 区分身份。
2. 核心概念速览
TPU v5e-8 的硬件模型
┌─────────────────────────────────────────────┐
│ Kaggle TPU v5e-8 │
│ │
│ Core 0 Core 1 Core 2 Core 3 │
│ Core 4 Core 5 Core 6 Core 7 │
│ │
│ 每个 core 有独立的 HBM(约 16GB) │
│ core 间通过 ICI 高速互联(all-reduce 等) │
└─────────────────────────────────────────────┘
- 8 个 core = 8 个进程,每个进程独立运行模型的完整副本
- 每个进程处理不同的数据 batch(数据并行)
- 梯度通过 all-reduce 在 8 个进程间同步
- Accelerate 帮你封装了这些分布式细节
Accelerate 的角色
Accelerate 是 HuggingFace 出品的分布式训练封装库,它的核心价值:
- 统一 API:同一套代码适配单 GPU、多 GPU、TPU、多节点
- 自动处理分布式细节:模型包装、梯度同步、混合精度、checkpoint 保存/加载
notebook_launcher:唯一靠谱的在 Kaggle Notebook 里启动 TPU 多进程的方式
3. Accelerate 核心 API 详解
3.1 创建 Accelerator
from accelerate import Accelerator
accelerator = Accelerator(
mixed_precision="bf16", # TPU 支持 bf16,不支持 fp16
gradient_accumulation_steps=4, # 梯度累积步数,必须在构造时传入
)
重要:gradient_accumulation_steps 必须在构造时指定,不能在训练过程中修改。accumulate() 上下文管理器依赖此值决定何时做梯度同步。
3.2 准备模型和优化器
model, optimizer = accelerator.prepare(model, optimizer)
prepare() 会:
- 将模型包装为分布式模型(DDP 或 XLA 分布式)
- 确保优化器状态在多进程间正确同步
- 在 TPU 上,模型参数会被移到 XLA 设备
顺序很重要:先创建模型和优化器,再 prepare()。
3.3 前向 / 反向 / 优化器步进
with accelerator.autocast(): # 自动混合精度上下文
logits, loss = model(x, y)
accelerator.backward(loss) # 替代 loss.backward()
accelerator.clip_grad_norm_(model.parameters(), 1.0) # 梯度裁剪
optimizer.step()
3.4 数据移到设备
x = x.to(accelerator.device) # 自动移到当前进程的 TPU core
y = y.to(accelerator.device)
3.5 多进程同步
accelerator.wait_for_everyone() # 阻塞直到所有进程都到达此处
用于在保存 checkpoint、采样等操作前确保所有进程同步。
3.6 收集各进程的数据
# 收集所有进程的 loss 值,求全局平均
local_loss = loss.detach()
global_loss = accelerator.gather(local_loss).mean().item()
3.7 获取原始模型
raw_model = accelerator.unwrap_model(model)
# 用于访问原始模型的属性(如 config),或保存 state_dict
3.8 保存 Checkpoint
accelerator.save(ckpt, save_path) # 安全的多进程保存
不要用 torch.save(),在 TPU 上可能引发多进程互锁。accelerator.save() 内部会处理 XLA 张量到 CPU 的转换。
3.9 设置随机种子
from accelerate.utils import set_seed
# 每个进程用不同的种子,保证数据 shuffle 不同但可复现
process_seed = 1337 + accelerator.process_index
set_seed(process_seed)
4. TPU 上的 bf16 混合精度
4.1 为什么用 bf16 而不是 fp16?
TPU 原生支持 bf16(Brain Float 16),不支持 fp16 的快速运算。bf16 的特点是:
- 与 fp32 相同的指数位(8 位),动态范围大,不容易溢出
- 尾数位较少(7 位 vs fp32 的 23 位),精度略低
4.2 必须强制 fp32 的关键路径
bf16 精度不够的地方会产生 NaN。以下路径必须手动转 fp32:
LayerNorm
# ❌ 错误:bf16 下方差计算会 catastrophic cancellation → NaN
x = self.ln_1(x)
# ✅ 正确:先转 fp32,归一化后再转回
x = self.ln_1(x.float()).to(x.dtype)
原因:LayerNorm 内部计算方差时需要 E[x²] - E[x]²,两个相近的大数相减,bf16 精度不够会导致结果为 0 或负数,开方后 NaN。
Cross-Entropy Loss
# ❌ 错误:bf16 下 log_softmax 会溢出为 -inf → NaN
loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
# ✅ 正确:logits 转 fp32 再算
loss = F.cross_entropy(logits.float().view(-1, vocab_size), targets.view(-1))
4.3 QK-Norm:防止 Attention Logit 爆炸
除了手动 fp32,还可以加 QK-Norm 从根本上防止 attention logit 数值过大:
self.q_norm = nn.LayerNorm(head_dim)
self.k_norm = nn.LayerNorm(head_dim)
# 在计算 attention score 之前归一化 Q 和 K
q = self.q_norm(q.float()).to(q.dtype)
k = self.k_norm(k.float()).to(k.dtype)
5. XLA 计算图与同步机制
5.1 XLA 的惰性执行模型
这是 TPU 编程与 GPU 最大的区别:
- GPU:CUDA 操作基本是立即执行的
- TPU (XLA):操作被记录到一个计算图中,只有在
sync()时才真正执行
import torch_xla
import torch_xla.core.xla_model as xm
a = torch.tensor([1.0], device="xla:0")
b = a + 1 # 此时 b 还没真正计算!只是一个图节点
c = b * 2 # 同上
torch_xla.sync() # 现在 a→b→c 的计算图被编译并执行
print(c) # 现在才能拿到真实值
5.2 什么时候必须 sync?
训练循环末尾
# 每个 optimizer step 结束后同步
optimizer.step()
optimizer.zero_grad()
if use_xla_sync:
torch_xla.sync()
验证循环中每步之后
for _ in range(val_steps):
_, loss = model(x, y)
# ❌ 危险:不 sync 的话,20 步的计算图会累积
# logits (B, T, vocab_size) ≈ 3.1GB/步 × 20 步 → OOM
losses.append(loss.detach())
# ✅ 正确:每步立即同步,释放中间张量
if use_xla_sync:
torch_xla.sync()
losses.append(loss.detach())
从 CPU 加载权重到 TPU 后
model, optimizer = accelerator.prepare(model, optimizer)
# 强制同步:确保从 CPU 加载的权重真正物化到 TPU 内存
if use_xla_sync:
xm.mark_step()
5.3 sync() vs mark_step()
torch_xla.sync():触发计算图执行,更常用xm.mark_step():标记一个计算图的边界,通常用于训练循环中分割计算图
在实践中两者都可以用于强制同步,sync() 语义更明确。
6. 梯度累积
6.1 为什么要梯度累积?
TPU 每个 core 的 HBM 有限(约 16GB),大 batch 可能放不下。梯度累积允许用小 batch 多次前向-反向,累积梯度后再做一次 optimizer step,等效于大 batch。
等效 batch size = micro_batch_size × seq_len × grad_accum_steps × num_processes
524288 = 32 × 1024 × 2 × 8
6.2 Accelerate 的 accumulate 上下文管理器
grad_accum_steps = TOTAL_BATCH_SIZE // (MICRO_BATCH_SIZE * SEQ_LEN * world_size)
# 构造 Accelerator 时传入
accelerator = Accelerator(
gradient_accumulation_steps=grad_accum_steps,
)
# 训练循环中使用
with accelerator.accumulate(model):
for micro_step in range(grad_accum_steps):
x, y = train_loader.next_batch()
x = x.to(device)
y = y.to(device)
with accelerator.autocast():
_, loss = model(x, y)
accelerator.backward(loss)
# accumulate 上下文会自动处理:
# - 前 N-1 步:跳过梯度同步(all-reduce)
# - 最后一步:自动做梯度同步
# 比手动调用少了 N-1 次 all-reduce 通信
optimizer.step()
optimizer.zero_grad()
6.3 手动梯度累积 vs accumulate 上下文
# ❌ 手动方式:每次 backward 都会触发 all-reduce,通信开销大
for micro_step in range(grad_accum_steps):
loss = model(x, y)
(loss / grad_accum_steps).backward() # 每步都 all-reduce
# ✅ accumulate 上下文:只在最后一步 all-reduce
with accelerator.accumulate(model):
for micro_step in range(grad_accum_steps):
loss = model(x, y)
accelerator.backward(loss) # 前 N-1 步跳过同步
7. 梯度检查点(Gradient Checkpointing)
7.1 原理
正常前向传播会保存所有中间激活值用于反向传播。梯度检查点只保存每层的输入,反向传播时重新计算中间值。用计算时间换内存。
7.2 XLA 版本的 checkpoint
TPU 上必须用 torch_xla.utils.checkpoint,不能直接用 PyTorch 的版本:
try:
from torch_xla.utils.checkpoint import checkpoint as xla_checkpoint
HAS_XLA_CHECKPOINT = True
except ImportError:
HAS_XLA_CHECKPOINT = False
# 在模型 forward 中使用(以 Transformer block 列表为例)
for block in self.layers:
if self.gradient_checkpointing and self.training:
if x.device.type == "xla" and HAS_XLA_CHECKPOINT:
x = xla_checkpoint(block, x) # XLA 版本
else:
x = torch.utils.checkpoint.checkpoint( # PyTorch 版本(GPU 回退)
block, x, use_reentrant=False
)
else:
x = block(x)
为什么不能用 PyTorch 原版? PyTorch 的 checkpoint 实现依赖 CUDA 特定的 autograd hook,在 XLA 设备上会报错或行为异常。XLA 版本做了适配。
7.3 开启方式
model.gradient_checkpointing = True
建议在 accelerator.prepare() 之前设置。
8. 模型设计注意事项
8.1 Weight Tying 被 prepare() 破坏(Transformer 模型通用问题)
这是 TPU + Accelerate 最隐蔽的坑之一。
很多 Transformer 模型(GPT、LLaMA、BERT 等)会把输入 embedding 和输出投影层的权重绑定(weight tying),即 lm_head.weight = wte.weight。但在 TPU 上:
# ❌ 问题:accelerator.prepare() 会破坏 weight tying
model = MyTransformer(config) # lm_head.weight 和 wte.weight 指向同一块内存
model, optimizer = accelerator.prepare(model, optimizer)
# 此时 lm_head.weight 和 wte.weight 可能已经是不同的张量了!
# prepare() 内部会为分布式训练重新组织参数,weight tying 的引用关系丢失
解决方案:不用 lm_head,直接用 F.linear
class MyTransformer(nn.Module):
def __init__(self, config):
super().__init__()
self.embedding = nn.Embedding(config.vocab_size, config.n_embd)
self.layers = nn.ModuleList([MyBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
# 故意不定义 lm_head!
def forward(self, idx, targets=None):
x = self.embedding(idx)
# ... transformer blocks ...
x = self.ln_f(x.float()).to(x.dtype)
# 直接用 F.linear,手动传入 embedding.weight 作为转置权重
# 这样无论 prepare() 怎么处理参数,都能保证是同一份权重
logits = F.linear(x, self.embedding.weight)
return logits, loss
8.2 SDPA(Scaled Dot-Product Attention)在 TPU 上不可用
PyTorch 的 F.scaled_dot_product_attention 是为 CUDA 优化的,在 XLA 设备上要么不支持,要么走很慢的 fallback 路径。
# 需要检测设备类型,选择不同的 attention 实现
if self.use_sdpa and x.device.type != "xla":
# GPU 上走高效的 fused attention
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
else:
# TPU 上走手写 attention
att = (q * self.scale) @ k.transpose(-2, -1)
att = att.masked_fill(~mask_slice, -1e4)
att = F.softmax(att, dim=-1)
y = att @ v
8.3 torch.compile 在 TPU 上不要用
# ❌ TPU 上 torch.compile 几乎没有加速效果,还可能引入 bug
model = torch.compile(model)
# ✅ 直接用原生模型
model = model
XLA 本身就是一个编译器(将 PyTorch 操作编译为 TPU 可执行的 HLO IR),再套一层 torch.compile 没有意义。
8.4 残差投影层零初始化(可选,通用技巧)
对于使用残差连接的深层网络(如 Transformer),可以将残差分支的最后一层初始化为零,使得训练初期残差流为恒等映射,有助于稳定训练:
# 用自定义属性标记需要零初始化的层
class MyBlock(nn.Module):
def __init__(self, config):
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
self.c_proj.ZERO_INIT = 1 # 标记为需要零初始化
def _init_weights(self, module):
if isinstance(module, nn.Linear):
if hasattr(module, "ZERO_INIT"):
torch.nn.init.zeros_(module.weight) # 残差流初期为恒等映射
else:
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
9. Checkpoint 保存与加载
9.1 保存
def save_checkpoint(accelerator, model, optimizer, train_loader, val_loader,
step, best_val_loss, model_config, save_path):
# 1. 等待所有进程同步
accelerator.wait_for_everyone()
# 2. 获取去除分布式包装后的 state_dict
state_dict = accelerator.get_state_dict(model)
# 3. 构建 checkpoint 字典
ckpt = {
"model": state_dict,
"optimizer": optimizer.state_dict(),
"step": step,
"best_val_loss": best_val_loss,
"model_config": asdict(model_config),
"train_loader_state": train_loader.state_dict(),
"val_loader_state": val_loader.state_dict(),
}
# 4. 安全保存(处理 XLA 张量到 CPU 的转换)
accelerator.save(ckpt, save_path)
# 5. 再次等待同步
accelerator.wait_for_everyone()
绝对不要手动 .cpu() 转换 XLA 张量,会引发多进程互锁死锁。用 accelerator.save() 统一处理。
9.2 加载
# 1. 在 CPU 上加载 checkpoint
checkpoint = torch.load(RESUME_PATH, map_location="cpu")
# 2. 创建模型,加载权重(在 prepare() 之前)
model = MyModel(model_config)
state_dict = checkpoint["model"]
# 3. 清理可能的前缀
for prefix in ['_orig_mod.', 'module.']:
for k in list(state_dict.keys()):
if k.startswith(prefix):
state_dict[k[len(prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False)
# 4. 确保参数为 fp32(checkpoint 可能是 bf16 保存的)
for name, param in model.named_parameters():
if param.dtype != torch.float32:
param.data = param.data.float()
# 5. 创建优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
# 6. prepare() 之后再加载 optimizer state
model, optimizer = accelerator.prepare(model, optimizer)
if use_xla_sync:
xm.mark_step() # 确保模型权重物化到 TPU
# 7. optimizer state 必须在 prepare() 之后加载
optimizer.load_state_dict(checkpoint["optimizer"])
if use_xla_sync:
xm.mark_step()
9.3 加载顺序总结
torch.load (CPU) → 创建模型 → load_state_dict → 创建优化器
→ accelerator.prepare(model, optimizer) → xm.mark_step()
→ optimizer.load_state_dict → xm.mark_step()
optimizer state 必须在 prepare() 之后加载,因为 prepare() 会改变优化器的内部结构。
10. 多进程同步陷阱大全
陷阱 1:evaluate 中不同步导致 OOM
# ❌ 错误:多步的计算图累积在内存中,中间张量无法释放 → OOM
for _ in range(20):
_, loss = model(x, y)
losses.append(loss.detach())
# ✅ 正确:每步同步,立即释放中间张量
for _ in range(20):
_, loss = model(x, y)
torch_xla.sync() # 立即执行计算图,释放中间张量
losses.append(loss.detach())
陷阱 2:部分进程执行导致死锁
# ❌ 致命:只有 rank 0 做推理,其他 7 个 core 空闲等待
# XLA 的 all-reduce 需要所有 core 参与,否则永久死锁!
if accelerator.process_index == 0:
run_evaluation(model)
# ✅ 正确:所有 8 个 core 都参与前向计算
run_evaluation(model) # 所有进程都执行,只有 rank 0 打印/保存结果
# 在函数内部:
def run_evaluation(model, accelerator, ...):
# 所有进程都做推理(XLA 要求计算图同步)
outputs = model(eval_data)
...
# 只有 rank 0 打印/保存结果
if accelerator.process_index == 0:
print(results)
陷阱 3:保存 checkpoint 时不同步
# ❌ 错误:只有 rank 0 保存,其他进程不等待
if accelerator.process_index == 0:
torch.save(ckpt, path) # 其他进程可能还在计算 → 竞态
# ✅ 正确:保存前后都等待
accelerator.wait_for_everyone()
accelerator.save(ckpt, path)
accelerator.wait_for_everyone()
陷阱 4:手动 .cpu() 转换 XLA 张量
# ❌ 危险:手动 .cpu() 在多进程环境下会互锁死锁
state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
# ✅ 正确:用 accelerator.save() 自动处理
accelerator.save(ckpt, path)
陷阱 5:optimizer state 加载时机错误
# ❌ 错误:在 prepare() 之前加载 optimizer state
optimizer.load_state_dict(checkpoint["optimizer"]) # 此时 optimizer 还没被包装
model, optimizer = accelerator.prepare(model, optimizer) # prepare 会重置 optimizer 结构!
# ✅ 正确:prepare() 之后再加载
model, optimizer = accelerator.prepare(model, optimizer)
optimizer.load_state_dict(checkpoint["optimizer"])
11. 性能优化技巧
11.1 选择合适的 Micro Batch Size
TPU v5e 每个 core 约 16GB HBM。需要在内存允许的范围内选择最大的 micro batch size:
- 1 亿参数级别的模型:micro_batch_size = 32(seq_len=1024)比较安全
- 更大的模型需要调小,或开启梯度检查点
- 实际取值需要在 TPU 上试跑确认不 OOM
11.2 bf16 vs fp32
# bf16 训练吞吐量约为 fp32 的 1.5-2x
accelerator = Accelerator(mixed_precision="bf16")
11.3 减少 all-reduce 次数
使用 accelerator.accumulate(model) 上下文管理器,只在累积的最后一步做梯度同步。
11.4 DataLoader 的 num_workers
在 TPU 上,数据加载通常不是瓶颈(TPU 计算比 GPU 慢,数据准备跟得上)。如果数据加载成为瓶颈,可以用多 worker 的 DataLoader,但要注意:
# Kaggle TPU 环境下,简单的方式往往最稳定
# 直接在主进程中加载 numpy shard 文件通常够用
arr = np.load(filename).astype(np.int32)
tokens = torch.tensor(arr, dtype=torch.long)
11.5 监控吞吐量
t0 = time.time()
# ... 训练一步 ...
t1 = time.time()
tokens_processed = MICRO_BATCH_SIZE * SEQ_LEN * grad_accum_steps * world_size
tok_per_sec = tokens_processed / (t1 - t0)
print(f"tok/s: {tok_per_sec:.0f}")
在 TPU v5e-8 上,1 亿参数级别的语言模型可以达到约 40 万 tokens/s 的训练吞吐量。
12. 完整最小示例
一个可以在 Kaggle TPU 上运行的最小训练循环:
# Cell 1: 环境准备
import os
os.environ["PJRT_DEVICE"] = "TPU"
os.environ["OMP_NUM_THREADS"] = "1"
if "TPU_PROCESS_ADDRESSES" in os.environ:
del os.environ["TPU_PROCESS_ADDRESSES"]
if "CLOUD_TPU_TASK_ID" in os.environ:
del os.environ["CLOUD_TPU_TASK_ID"]
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'
# Cell 2: 写入训练脚本
%%writefile train_minimal.py
import os
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'
import torch
import torch.nn as nn
from torch.nn import functional as F
from accelerate import Accelerator, notebook_launcher
from accelerate.utils import set_seed
def train_worker():
# 1. 创建 Accelerator
accelerator = Accelerator(
mixed_precision="bf16",
gradient_accumulation_steps=1,
)
use_xla_sync = accelerator.device.type == "xla"
# 2. 设置种子
set_seed(1337 + accelerator.process_index)
device = accelerator.device
# 3. 创建模型(简单示例)
model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
# 4. prepare
model, optimizer = accelerator.prepare(model, optimizer)
if use_xla_sync:
import torch_xla.core.xla_model as xm
xm.mark_step()
# 5. 训练循环
for step in range(100):
# 合成数据
x = torch.randn(32, 100, device=device)
y = torch.randint(0, 10, (32,), device=device)
with accelerator.autocast():
logits = model(x)
loss = F.cross_entropy(logits, y)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
if use_xla_sync:
torch_xla.sync()
if step % 10 == 0 and accelerator.process_index == 0:
global_loss = accelerator.gather(loss.detach()).mean().item()
print(f"step {step} | loss {global_loss:.4f}")
if accelerator.process_index == 0:
print("training done")
def launch():
notebook_launcher(train_worker, num_processes=8)
if __name__ == "__main__":
launch()
# Cell 3: 启动
from train_minimal import launch
launch()
13. 常见报错与解决方案
13.1 RuntimeError: PJRT runtime not found
原因:没有设置 PJRT_DEVICE=TPU。
解决:在所有 import torch 之前加上 os.environ["PJRT_DEVICE"] = "TPU"。
13.2 RuntimeError: Bad StatusOr access 或 TPU 初始化失败
原因:Kaggle 环境中的 TPU_PROCESS_ADDRESSES 或 CLOUD_TPU_TASK_ID 干扰了 XLA 的进程发现。
解决:在环境补丁中删除这两个变量。
13.3 NaN loss
原因:bf16 精度不够,LayerNorm / softmax / cross_entropy 中出现数值溢出。
解决:参考第 4.2 节,在关键路径强制 fp32。
13.4 OOM(Out of Memory)
原因:中间张量太大,或计算图累积未释放。
解决:
- 调小
MICRO_BATCH_SIZE - 开启梯度检查点
- 验证循环中每步
torch_xla.sync()释放中间张量
13.5 训练死锁(所有进程卡住不动)
原因:某个分支只在部分进程执行(如 if rank == 0: 做推理),XLA all-reduce 等待所有进程参与。
解决:确保所有 TPU core 执行相同的计算图路径。只在打印/保存时区分 rank。
13.6 accelerator.prepare() 后权重值变了
原因:weight tying 被破坏,或 optimizer state 被重置。
解决:
- 如果模型有 weight tying,改用
F.linear(x, embedding.weight)代替独立的输出层 - optimizer state 在
prepare()之后加载
13.7 resume 后 loss 突然跳变
原因:optimizer state 没有正确加载,或数据位置不对。
解决:
- 确认 optimizer state 在
prepare()之后加载 - 确认数据加载器的
state_dict被正确保存和恢复 - 加载后打印第一个参数的前几个值验证
13.8 scaled_dot_product_attention 报错或很慢
原因:SDPA 是 CUDA 优化的算子,在 XLA 上不支持或走 fallback。
解决:检测 x.device.type != "xla" 时才用 SDPA,否则用手写 attention。