![]() | |
開發者 | Google, Nvidia[1] |
---|---|
首次发布 | 2019年10月31日[2] |
当前版本 |
|
预览版本 | v0.3.13(2022年5月16日 | )
源代码库 | github |
编程语言 | Python, C++ |
操作系统 | Linux, macOS, Windows |
平台 | Python, NumPy |
类型 | 机器学习 |
许可协议 | Apache 2.0 |
网站 | docs |
JAX,是用于变换数值函数的Python机器学习框架,它由Google开发并具有来自Nvidia的一些贡献[4][5][6]。它结合了修改版本的Autograd(自动通过函数的微分获得其梯度函数)[7],和OpenXLA的XLA(加速线性代数)[8]。它被设计为尽可能的遵从NumPy的结构和工作流程,并协同工作于各种现存的框架如TensorFlow和PyTorch[9][10]。
JAX的主要功能是[4]:
下面的代码演示grad
函数的自动微分。
# 导入库
from jax import grad
import jax.numpy as jnp
# 定义logistic函数
def logistic(x):
return jnp.exp(x) / (jnp.exp(x) + 1)
# 获得logistic函数的梯度函数
grad_logistic = grad(logistic)
# 求值logistic函数在x = 1处的梯度
grad_log_out = grad_logistic(1.0)
print(grad_log_out)
最终的输出为:
0.19661194
下面的代码演示jit
函数的优化。
# 导入库
from jax import jit
import jax.numpy as jnp
# 定义cube函数
def cube(x):
return x * x * x
# 生成数据
x = jnp.ones((10000, 10000))
# 创建cube函数的jit版本
jit_cube = jit(cube)
# 应用cube函数和jit_cube函数于相同数据来比较其速度
cube(x)
jit_cube(x)
可见jit_cube
的运行时间显著的短于cube
。
下面的代码展示vmap
函数的通过SIMD的向量化。
# 导入库
from functools import partial
from jax import vmap
import jax.numpy as jnp
# 定义函数
def grads(self, inputs):
in_grad_partial = partial(self._net_grads, self._net_params)
grad_vmap = vmap(in_grad_partial)
rich_grads = grad_vmap(inputs)
flat_grads = np.asarray(self._flatten_batch(rich_grads))
assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
return flat_grads
下面的代码展示pmap
函数的对矩阵乘法的并行化。
# 从JAX导入pmap和random;导入JAX NumPy
from jax import pmap, random
import jax.numpy as jnp
# 生成2个维度为5000 x 6000的随机数矩阵,每设备一个
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)
# 没有数据传输,并行的在每个CPU/GPU上进行局部矩阵乘法
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)
# 没有数据传输,并行的在每个CPU/GPU上分别求取这两个矩阵的均值
means = pmap(jnp.mean)(outputs)
print(means)
最终的输出为:
[1.1566595 1.1805978]
一些Python库使用JAX作为后端,这包括: