JAX 是 Google 和 NVIDIA 联合开发的高性能数值计算库,这两年 JAX 生态快速发展,周边工具链也日益完善了。如果你用过 NumPy 或 PyTorch,但还没接触过 JAX,这篇文章能帮助你快速上手。
围绕 JAX 已经涌现出一批好用的库:Flax 用来搭神经网络,Optax 处理梯度和优化,Equinox 提供类似 PyTorch 的接口,Haiku 则是简洁的函数式 API,Jraph 用于图神经网络,RLax 是强化学习工具库,Chex 提供测试和调试工具,Orbax 负责模型检查点和持久化。

JAX 对函数有个基本要求:必须是纯函数。这意味着函数不能有副作用,对同样的输入必须总是返回同样的输出。
这个约束来自函数式编程范式。JAX 内部做各种变换(编译、自动微分等)依赖纯函数的特性,用不纯的函数可能导致错误或静默失败,结果完全不对。
# 纯函数,没问题def pure_addition(a, b): return a + b # 不纯的函数,JAX 不接受counter = 0 def impure_addition(a, b): global counter counter += 1 return a + b
JAX NumPy 与原生 NumPyJAX 提供了类 NumPy 的接口,核心优势在于能自动高效地在 CPU、GPU 甚至 TPU 上运行,支持本地或分布式执行。这套能力来自 XLA(Accelerated Linear Algebra) 编译器,它把 JAX 代码翻译成针对不同硬件优化的机器码。
NumPy 默认只在 CPU 上跑,JAX NumPy 则不同。用法上两者很相似,这也是 JAX 容易上手的原因。
# JAX 也差不多import jax.numpy as jnp print(jnp.sqrt(4))# NumPy 的写法import numpy as np print(np.sqrt(4))# JAX 也差不多import jax.numpy as jnp print(jnp.sqrt(4))
常见的操作两者看起来基本一样:
import numpy as np import jax.numpy as jnp # 创建数组np_a = np.array([1.0, 2.0, 3.0]) jnp_a = jnp.array([1.0, 2.0, 3.0]) # 元素级操作print(np_a + 2) print(jnp_a + 2) # 广播np_b = np.array([[1, 2, 3]]) jnp_b = jnp.array([[1, 2, 3]]) print(np_b + np.arange(3)) print(jnp_b + jnp.arange(3)) # 求和print(np.sum(np_a)) print(jnp.sum(jnp_a)) # 平均值print(np.mean(np_a)) print(jnp.mean(jnp_a)) # 点积print(np.dot(np_a, np_a)) print(jnp.dot(jnp_a, jnp_a))
但有个重要差异需要注意:
JAX 数组是不可变的,对数组的修改操作会返回新数组而不是改变原数组。
NumPy 数组则可以直接修改:
import numpy as np x = np.array([1, 2, 3]) x[0] = 10 # 直接修改,没问题
JAX 这边就不行了:
import jax.numpy as jnp x = jnp.array([1, 2, 3]) x[0] = 10 # 报错
但是JAX 提供了专门的 API 来处理这种情况,通过返回一个新数组的方式实现"修改":
z = x.at[idx].set(y)
完整的例子:
x = jnp.array([1, 2, 3]) y = x.at[0].set(10) print(y) # [10, 2, 3] print(x) # [1, 2, 3](没变)
JIT 编译加速即时编译(JIT)是 JAX 里一个核心特性,通过 XLA 把 Python/JAX 代码编译成优化后的机器码。
直接用 Python 解释器跑函数会很慢。加上 @jit 装饰器后,函数会被编译成快速的原生代码:
from jax import jit # 不编译的版本def square(x): return x * x # 编译过的版本@jit def jit_square(x): return x * x
jit_square 快好几个数量级。函数首次调用时,JIT 引擎会:
追踪函数逻辑,构建计算图
把图编译成优化的 XLA 代码
缓存编译结果
后续调用直接用缓存的版本
自动微分JAX 的 grad 模块能自动计算函数的导数。
import jax.numpy as jnp from jax import grad # 定义函数:f(x) = x² + 2x + 2def f(x): return x**2 + 2 * x + 2 # 计算导数df_dx = grad(f) # 在 x = 2.0 处求值print(df_dx(2.0)) # 6.0
随机数处理NumPy 用全局随机状态生成随机数。每次调用 np.random.random() 时,NumPy 会更新隐藏的内部状态:
import numpy as np np.random.random() # 0.9539264374520571
JAX 的做法完全不同。作为纯函数库,它不能维护全局状态,所以要求显式传入一个伪随机数生成器(PRNG)密钥。每次生成随机数前要先分割密钥:
from jax import random # 初始化密钥key = random.PRNGKey(0) # 每次生成前分割key, subkey = random.split(key) # 从正态分布采样x = random.normal(subkey, ()) print(x) # -2.4424558 # 从均匀分布采样key, subkey = random.split(key) u = random.uniform(subkey, (), minval=0.0, maxval=1.0) print(u) # 0.104290366
一个常见的坑:同一个密钥生成的随机数始终相同。
# 用同一个 subkey,结果重复x = random.normal(subkey, ()) print(x) # -2.4424558 x = random.normal(subkey, ()) print(x) # -2.4424558(还是这个值)
所以要记住总是用新密钥。
向量化:vmapvmap 自动把函数转换成能处理批量数据的版本。逻辑上就像循环遍历每个样本,但执行效率远高于 Python 循环。
import jax.numpy as jnp from jax import vmap def f(x): return x * x + 1 arr = jnp.array([1., 2., 3., 4.]) # Python 循环(慢)outputs_loop = jnp.array([f(x) for x in arr]) # vmap 版本(快)f_vectorized = vmap(f) outputs_vmap = f_vectorized(arr)
并行化:pmappmap 不同于 vmap。vmap 在单个设备上做批处理,pmap 把计算分散到多个设备(GPU/TPU 核心),每个设备处理输入的一部分。
VMAP:单设备批处理向量化
PMAP:跨多设备并行执行
import jax.numpy as jnp from jax import pmap # 查看可用设备print(jax.devices()) # [TpuDevice(id=0), TpuDevice(id=1), ..., TpuDevice(id=7)] def f(x): return x * x + 1 arr = jnp.array([1., 2., 3., 4.]) # pmap 会把数组分配到不同设备ys = pmap(f)(arr)
PyTreesPyTree 在 JAX 里是个常见的概念:任何嵌套的 Python 容器(列表、字典、元组等)加上基本类型的组合。JAX 里用它来组织模型参数、优化器状态、梯度等。
import jax.numpy as jnp from jax import tree_util as tu# 构建 PyTreepytree = { "a": jnp.array([1, 2]), "b": [jnp.array([3, 4]), 5] } # 获取所有叶子节点leaves = tu.tree_leaves(pytree) # 对每个叶子应用函数doubled = tu.tree_map(lambda x: x * 2, pytree)
Optax:梯度处理和优化Optax 是 JAX 生态里的优化库。它包含损失函数、优化器、梯度变换、学习率调度等一套工具。
损失函数:
logits = jnp.array([[2.0, -1.0]]) labels_onehot = jnp.array([[1.0, 0.0]]) labels_int = jnp.array([0]) # Softmax 交叉熵(独热编码)loss_softmax_onehot = optax.softmax_cross_entropy(logits, labels_onehot).mean() # Softmax 交叉熵(整数标签)loss_softmax_int = optax.softmax_cross_entropy_with_integer_labels(logits, labels_int).mean() # 二元交叉熵loss_bce = optax.sigmoid_binary_cross_entropy(logits, labels_onehot).mean() # L2 损失loss_l2 = optax.l2_loss(jnp.array([1., 2.]), jnp.array([0., 1.])).mean() # Huber 损失loss_huber = optax.huber_loss(jnp.array([1.,2.]), jnp.array([0.,1.])).mean()
优化器:
# SGDopt_sgd = optax.sgd(learning_rate=1e-2) # SGD with momentumopt_momentum = optax.sgd(learning_rate=1e-2, momentum=0.9) # RMSPropopt_rmsprop = optax.rmsprop(1e-3) # Adafactoropt_adafactor = optax.adafactor(learning_rate=1e-3) # Adamopt_adam = optax.adam(1e-3) # AdamWopt_adamw = optax.adamw(1e-3, weight_decay=1e-4)
梯度变换:
# 梯度裁剪tx_clip = optax.clip(1.0) # 全局梯度范数裁剪tx_clip_global = optax.clip_by_global_norm(1.0) # 权重衰减(L2)tx_weight_decay = optax.add_decayed_weights(1e-4) # 添加梯度噪声tx_noise = optax.add_noise(0.01)
学习率调度:
# 指数衰减lr_exp = optax.exponential_decay(init_value=1e-3, transition_steps=1000, decay_rate=0.99) # 余弦衰减lr_cos = optax.cosine_decay_schedule(init_value=1e-3, decay_steps=10_000) # 线性预热lr_linear = optax.linear_schedule(init_value=0.0, end_value=1e-3, transition_steps=500)
更新步骤:
# 计算梯度loss, grads = jax.value_and_grad(loss_fn)(params) # 生成优化器更新updates, opt_state = optimizer.update(grads, opt_state) # 应用更新params = optax.apply_updates(params, updates)
链式组合:
# 把多个操作链起来optimizer = optax.chain( optax.clip_by_global_norm(1.0), # 梯度裁剪 optax.add_decayed_weights(1e-4), # 权重衰减 optax.adam(1e-3) # Adam 优化)
Flax 与神经网络JAX 本身只是数值计算库,Flax 在其基础上提供了神经网络定义和训练的高级 API。Flax 代码风格接近 PyTorch,如果你用过 PyTorch 会很快上手。
Flax 提供了丰富的层和操作。基础层 包括全连接层 Dense、卷积 Conv、嵌入 Embed、多头注意力 MultiHeadDotProductAttention 等:
flax.linen.Dense(features=128) flax.linen.Conv(features=64, kernel_size=(3, 3)) flax.linen.Embed(num_embeddings=10000, features=256) flax.linen.MultiHeadDotProductAttention(num_heads=8) flax.linen.SelfAttention(num_heads=8)
归一化 支持多种方式:
flax.linen.BatchNorm() flax.linen.LayerNorm() flax.linen.GroupNorm(num_groups=32) flax.linen.RMSNorm()
激活和 Dropout:
flax.linen.relu(x) flax.linen.gelu(x) flax.linen.sigmoid(x) flax.linen.tanh(x) flax.linen.Dropout(rate=0.1)
池化:
flax.linen.avg_pool(x, window_shape=(2,2), strides=(2,2)) flax.linen.max_pool(x, window_shape=(2,2), strides=(2,2))
循环层:
flax.linen.LSTMCell() flax.linen.GRUCell() flax.linen.OptimizedLSTMCell()
下面是一个简单的多层感知机(MLP)例子:
import jax import jax.numpy as jnp from flax import linen as nn class MLP(nn.Module): features: list @nn.compact def __call__(self, x): for f in self.features[:-1]: x = nn.Dense(f)(x) x = nn.relu(x) x = nn.Dense(self.features[-1])(x) return x model = MLP([32, 16, 10]) key = jax.random.PRNGKey(0) # 输入:batch_size=1, 特征数=4x = jnp.ones((1, 4)) # 初始化参数params = model.init(key, x) # 前向传播y = model.apply(params, x) print("Input:", x) # Input: [[1. 1. 1. 1.]] print("Input shape:", x.shape) # Input shape: (1, 4) print("Output:", y) # Output: [[ 0.51415515 0.36979797 0.6212194 -0.74496573 -0.8318489 0.6590691 0.89224255 0.00737424 0.33062232 0.34577468]] print("Output shape:", y.shape) # Output shape: (1, 10)
Flax 用 @nn.compact 装饰器,让你在 __call__ 方法里直接定义层。参数是独立于模型对象存储的,需要通过 init 方法显式初始化,然后在 apply 方法中使用。
总结JAX 的出现解决了一个长期存在的问题:如何让 Python 科学计算既保持灵活性,又能获得接近 C/CUDA 的性能。
不过 JAX 的学习曲线确实比 PyTorch 陡。纯函数的约束、不可变数组的特性、显式密钥管理等细节起初会有些别扭。但一旦习惯会发现它带来的优雅和灵活性。
https://avoid.overfit.cn/post/a16194fdc3ea450f858515d7cb3d49c4
作者:Ashish Bamania