TPU 训练的真实效率往往取决于两个核心要素:Shape 的稳定性与算子的融合度。
很多时候,JAX 任务之所以出现严重的性能瓶颈,并非算法本身设计有问题,而是忽视了 XLA 编译器与底层硬件对“确定性”的极度偏好。基于大量实战调优经验,本文总结了八条能让 JAX 训练任务从“甚至跑不通”蜕变为“跑满 TPU 算力”的工程经验。

TPU 喜欢静态 Shape,JAX 也是,所以动态 Shape 是性能杀手,它会触发重新编译(Recompile)。一旦发生重编译,Step time 和内存占用都会直接炸裂。所以解决方法也很简单,选定几个规范的尺寸,剩下的全填(Pad)满。
全局 Batch Size 要能被 TPU 核心数整除,然后就是对于变长序列,别指望它原本多长就多长,把它 Pad 到几个固定的“桶(Bucket)”里,比如 128、256 或 512,这步工作最好在输入(Input Pipeline)里就做完。
Python层面的条件判断尽量别依赖 Shape,真要分支逻辑,就老老实实让 lax.cond 或 lax.switch 来接管。
# Example: bucketing & padding (conceptual) def pad_to_length(arr, L): pad = L - arr.shape[0] return jnp.pad(arr, ((0, pad), (0, 0)), mode='constant') bucket_sizes = [128, 256, 512] def bucket_len(n): return next(b for b in bucket_sizes if n <= b) def preprocess_batch(batch): L = bucket_len(batch["tokens"].shape[1]) batch["tokens"] = pad_to_length(batch["tokens"], L) batch["mask"] = pad_to_length(batch["mask"], L) return batch
每个 Step 喂给 TPU 的 Shape 只要是固定的,XLA 编译器就不会找麻烦。
2、激活值默认用 bfloat16,主权重要 FP32在 TPU 上bfloat16 (bf16) 是个好东西,兼顾了速度、内存和数值稳定性。
工程上的常规操作是:激活(Activations)和梯度(Gradients)存成 bf16。但是,优化器状态里的权重必须保留一份 FP32 的“主副本”,不然跑久了数值就会漂移。所欲需要在模型边界做类型转换(Cast)的时候小心点。
class MLP(nn.Module): features: int @nn.compact def __call__(self, x): x = x.astype(jnp.bfloat16) # fast path on TPUs x = nn.Dense(self.features, dtype=jnp.bfloat16)(x) x = nn.gelu(x) x = nn.Dense(self.features, dtype=jnp.bfloat16)(x) return x # Optimizer state stays in FP32 (conceptual) params_fp32 = params.astype(jnp.float32) grads_bf16 = compute_grads_bf16(...) updates_fp32 = opt.update(grads_bf16.astype(jnp.float32), opt_state, params_fp32)
3、pjit和命名网格:切分要明确,别靠猜JAX 在 TPU 上最强的一点就是通过 pjit 实现了 GSPMD。你通过 PartitionSpecs 告诉它想要什么切分方式,XLA 负责搞定如何在设备间搬运数据。
在 TPU 核心上建个命名网格(Mesh)。做数据并行(Data Parallelism)时,用 PartitionSpec('data', None) 这种模式。如果模型太大需要张量并行(Tensor Model Parallelism),就加个 'model' 轴。
import numpy as np import jax import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec as P from jax.experimental import pjit devices = np.array(jax.devices()).reshape(1, -1) # 1 x N mesh mesh = Mesh(devices, ('data',)) def loss_fn(params, batch): logits = model_apply(params, batch['x']) return cross_entropy(logits, batch['y']) @pjit.pjit( in_shardings=(P(None), P('data')), # params replicated, batch sharded on 'data' out_shardings=P(None), # scalar loss replicated ) def step(params, batch): grads = jax.grad(loss_fn)(params, batch) # aggregate grads across cores grads = jax.tree.map(lambda g: jax.lax.pmean(g, axis_name='data'), grads) return grads with mesh: grads = step(params, sharded_batch)
切分(Sharding)这事必须显式。如果偷懒依赖自动推导,等到后期 debug 那些悄无声息的跨设备数据传输时,绝对会很痛苦。
4、jit, vmap, scan 三件套TPU 喜欢大块头的 Kernel,讨厌成千上万个细碎的小算子。训练 Step 和任何中大型计算逻辑,必须用 jit 包起来。遇到 Python 循环,如果是时间步逻辑就换成 lax.scan,如果是批次并行就用 vmap。
把 Loss 计算、梯度计算和参数更新塞进同一个 jitted 函数里,这样编译器才有机会把它们融合成一个大算子。
import optax import jax optimizer = optax.adamw(3e-4) def loss_and_grads(params, batch): def _loss(p): logits = model_apply(p, batch['x']) return cross_entropy(logits, batch['y']) loss, grads = jax.value_and_grad(_loss)(params) return loss, grads @jax.jit def train_step(state, batch): loss, grads = loss_and_grads(state.params, batch) grads = jax.lax.pmean(grads, axis_name='data') updates, new_opt_state = optimizer.update(grads, state.opt_state, state.params) new_params = optax.apply_updates(state.params, updates) return state.replace(params=new_params, opt_state=new_opt_state), loss
5、别让输入管道拖后腿Host 到 Device 的数据传输一旦停顿,吞吐量就掉下来了,所以永远别让计算单元等数据。
用 tf.data 或者高效的 NumPy loader 配合 prefetch。数据预取到设备(Stage to device) 最好做双重缓冲。全局 Batch 尽量大(当然要能被核心数整除),数据增强这种脏活累活在 Host 上一次性做完。
# tf.data pipeline (conceptual) ds = (tf.data.TFRecordDataset(files) .map(parse_example, num_parallel_calls=tf.data.AUTOTUNE) .batch(global_batch_size, drop_remainder=True) .prefetch(tf.data.AUTOTUNE)) # Convert to NumPy and prefetch onto devices from flax.jax_utils import prefetch_to_device it = prefetch_to_device(map(npify, ds.as_numpy_iterator()), size=2) with mesh: for step_i in range(num_steps): batch = next(it) # already sharded/prefetched state, loss = train_step(state, batch)
6、PRNG要Fold 进 Step 和 Device IDJAX 的 PRNG 是无状态的,这意味如果不小心,很容易在不同 Step 或者不同设备上用了一样的随机数 Key。
每个 Step 都要 Split 一次绝对别复用。所以说为了保证独立性必须把 Global Step 和 Device Index 都 Fold 进去。数据增强/Dropout 的 Key 和参数初始化的 Key 得分开管理。
def make_step_rng(rng, step): step_key = jax.random.fold_in(rng, step) dev_key = jax.random.fold_in(step_key, jax.lax.axis_index('data')) return jax.random.split(dev_key, 1)[0] @jax.jit def train_step(state, batch, base_rng): rng = make_step_rng(base_rng, state.step) logits = model_apply(state.params, batch['x'], rngs={'dropout': rng}) ...
7、Remat,智能 Checkpoint,梯度累积TPU 内存看着大,模型一跑起来就不够用。深层网络可以直接用 Activation Checkpointing(jax.checkpoint 或 nn.remat),用计算换显存。想跑大 Batch 但显存不够,就用梯度累积(Gradient Accumulation) 把它切成小的 micro-step。
存盘的时候,推荐用 Orbax 做异步、分片(Sharded)的 Checkpoint,稳。
from flax import linen as nn class DeepBlock(nn.Module): @nn.compact def __call__(self, x): # recompute on backward to trim activation memory f = nn.remat(lambda y: nn.gelu(nn.Dense(x.shape[-1])(y))) return f(x) # Gradient accumulation (conceptual) @jax.jit def accum_step(state, batch_slices): def body(carry, micro): state, grad_sum = carry _, grads = loss_and_grads(state.params, micro) return (state, jax.tree_util.tree_map(jnp.add, grad_sum, grads)), None init_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params) (state, grad_sum), _ = jax.lax.scan(body, (state, init_grads), batch_slices) grads = jax.tree_map(lambda g: g / len(batch_slices), grad_sum) ...
8、一定要跑 Profiler把关键代码段用 Profiler Annotations 包起来,看 Step Timeline。重点找 Host Waits、Recompiles 和那些没融合好的细碎算子(Small op soup)。
稳态运行的时候,盯着 Tokens/sec 或者Images/sec,还有硬件利用率。
from jax.experimental import host_callback as hcb from jax import profiler def tagged(name, fn, *a, **k): profiler.annotate_function(name=name) return fn(*a, **k) @jax.jit def train_step(state, batch): profiler.annotate_function(name="train_step") # do work... return state, loss
一定要在锁定 Shape 并且 JIT 完热点路径之后再做 Profile,不然全是噪音,根本看不到真正的瓶颈。
极简 TPU 训练示例这基本包含了上面所有的内容
# Pseudo-skeleton (Flax + JAX + TPU) mesh = Mesh(np.array(jax.devices()).reshape(1, -1), ('data',)) @pjit.pjit(in_shardings=(P(None), P('data'), P(None)), out_shardings=(P(None), P(None))) def train_step(state, batch, base_rng): rng = jax.random.fold_in(base_rng, state.step) rng = jax.random.fold_in(rng, jax.lax.axis_index('data')) def loss_fn(p): logits = model_apply(p, batch['x'].astype(jnp.bfloat16), rngs={'dropout': rng}) return cross_entropy(logits, batch['y']) loss, grads = jax.value_and_grad(loss_fn)(state.params) grads = jax.tree_map(lambda g: jax.lax.pmean(g, 'data'), grads) updates, opt_state = optimizer.update(grads, state.opt_state, state.params) params = optax.apply_updates(state.params, updates) return state.replace(params=params, opt_state=opt_state, step=state.step+1), loss with mesh: for step_i, batch in enumerate(prefetched_iterator): state, loss = train_step(state, batch, base_rng) if step_i % log_every == 0: # Pull back just tiny scalars; keep big tensors on device host_loss = jax.device_get(loss) print(f"[{step_i}] loss={host_loss:.4f}")
总结TPU 需要的是 一致性:稳定的 Shape,融合的 Kernel,目的明确的切分,不掉链子的数据管道,把上面的这八件事做好,写 JAX 训练循环就非常顺畅了。
https://avoid.overfit.cn/post/16b582a493ba4eca8333314859665dd2
作者:Modexa