GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美

机器之心报道
编辑:杜伟、陈萍

近年来 , 谷歌于 2018 年推出的 JAX 迎来了迅猛发展 , 很多研究者对其寄予厚望 , 希望它可以取代 TensorFlow 等众多深度学习框架 。 但 JAX 是否真的适合所有人使用呢?这篇文章对 JAX 的方方面面展开了深入探讨 , 希望可以给研究者选择深度学习框架时提供有益的参考 。
自 2018 年底推出以来 , JAX 的受欢迎程度一直在稳步提升 。 2020 年 , DeepMind 宣布使用 JAX 来加速其研究 。 越来越多来自谷歌大脑(Google Brain)和其他机构的项目也都在使用 JAX 。
目前 , 在 JAX 的 GitHub 项目主页 , Star 量已经达到了 16.3k 。
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

项目地址:https://github.com/google/jax
JAX 是一个非常有前途的项目 , 并且用户一直在稳步增长 。 JAX 已经在深度学习、机器人 / 控制系统、贝叶斯方法和科学模拟等诸多领域得到了广泛应用 。
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

如此 , 是否意味着 JAX 也将成为下一个大型深度学习框架?近日 , 发表在 AssemblyAI 博客上的文章《Why You Should (or Shouldn't) Be Using JAX in 2022》中 , 作者 Ryan O'Connor 为我们深入解读了 JAX 的概念、使用 JAX 的理由以及是否应该使用 JAX 等 。
JAX 简介
JAX 不是一个深度学习框架或库 , 其设计初衷也不是成为一个深度学习框架或库 。 简而言之 , JAX 是一个包含可组合函数转换的数值计算库 。 正如我们所看到的 , 深度学习只是 JAX 功能的一小部分:
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

JAX 的定位科学计算(Scientific Computing)和函数转换(Function Transformations)的交叉融合 , 具有除训练深度学习模型以外的一系列能力 , 包括如下:
  • 即时编译(Just-in-Time Compilation)
  • 自动并行化(Automatic Parallelization)
  • 自动向量化(Automatic Vectorization)
  • 自动微分(Automatic Differentiation)
使用 JAX 的原因有哪些?
简而言之 , 是速度 。 这是 JAX 与任何用例相关的一种通用能力 。 让我们使用 NumPy 和 JAX 对矩阵的前三个幂求和(按元素) 。
首先是 NumPy 实现 。 我们发现 , 该计算大约需要 851 毫秒 。
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

然后使用 JAX 实现该计算:
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

JAX 仅在 5.54 毫秒内执行完成该计算 , 速度是 NumPy 的 150 倍以上 。
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

JAX 的速度比 NumPy 快了 N 个数量级 。 需要注意 , JAX 使用的是 TPU , NumPy 使用了 CPU , 以此强调 JAX 的速度上限远高于 NumPy 。
作者列出了以下六条可能想要使用 JAX 的理由:
  • NumPy 加速器 。 NumPy 是使用 Python 进行科学计算的基础包之一 , 但它仅与 CPU 兼容 。 JAX 提供了 NumPy 的实现(具有几乎相同的 API) , 可以非常轻松地在 GPU 和 TPU 上运行 。 对于许多用户而言 , 仅此一项功能就足以证明使用 JAX 的合理性;
  • XLA 。 XLA(Accelerated Linear Algebra)是专为线性代数设计的全程序优化编译器 。 JAX 建立在 XLA 之上 , 显著提高了计算速度上限;
  • JIT 。 JAX 允许用户使用 XLA 将自己的函数转换为即时编译(JIT)版本 。 这意味着可以通过在计算函数中添加一个简单的函数装饰器(decorator)来将计算速度提高几个数量级;
  • Auto-differentiation 。 JAX 将 Autograd(自动区分原生 Python 代码和 NumPy 代码)和 XLA 结合在一起 , 它的自动微分能力在科学计算的许多领域都至关重要 。 JAX 提供了几个强大的自动微分工具;
  • 深度学习 。 虽然 JAX 本身不是深度学习框架 , 但它的确为深度学习提供了一个很好的基础 。 很多构建在 JAX 之上的库旨在提供深度学习功能 , 包括 Flax、Haiku 和 Elegy 。 甚至在最近的一些 PyTorch 与 TensorFlow 文章中强调了 JAX 作为一个值得关注的「框架」 , 并推荐其用于基于 TPU 的深度学习研究 。 JAX 对 Hessians 的高效计算也与深度学习相关 , 因为它们使高阶优化技术更加可行;
  • 通用可微分编程范式(General Differentiable Programming Paradigm ) 。 虽然我们可以使用 JAX 来构建和训练深度学习模型 , 但它也为通用可微编程提供了一个框架 。 这意味着 JAX 可以通过使用基于模型的机器学习方法来解决问题 , 从而可以利用数十年研究建立起的给定领域的先验知识 。
JAX 转换
到目前为止 , 我们已经讨论了 XLA 以及它如何允许 JAX 在加速器上实现 NumPy;但请记住 , 这只是 JAX 定义的一半 。 JAX 不仅为强大的科学计算提供了工具 , 而且还为可组合的函数转换提供了工具 。
举例来说如果我们对标量值函数 f(x) 使用梯度函数转换 , 那么我们将得到一个向量值函数 f'(x) , 它给出了函数在 f(x) 域中任意点的梯度 。
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

在函数上使用 grad() 可以让我们得到域中任意点的梯度
JAX 包含了一个可扩展系统来实现这样的函数转换 , 有四种典型方式:
  • Grad() 进行自动微分;
  • Vmap() 自动向量化;
  • Pmap() 并行化计算;
  • Jit() 将函数转换为即时编译版本 。
使用 grad() 进行自动微分
训练机器学习模型需要反向传播 。 在 JAX 中 , 就像在 Autograd 中一样 , 用户可以使用 grad() 函数来计算梯度 。
举例来说 , 如下是对函数 f(x) = abs(x^3) 求导 。 我们可以看到 , 当求 x=2 和 x=-3 处的函数及其导数时 , 我们得到了预期的结果 。
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

那么 grad() 能微分到什么程度?JAX 通过重复应用 grad() 使得微分变得很容易 , 如下程序我们可以看到 , 输出函数的三阶导数给出了 f'''(x)=6 的恒定预期输出 。
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

可能有人会问 , grad() 可以用在哪些方面?标量值函数:grad() 采用标量值函数的梯度 , 将标量 / 向量映射到标量函数 。 此外还有向量值函数:对于将向量映射到向量的向量值函数 , 梯度的类似物是雅可比矩阵 。 使用 jacfwd() 和 jacrev() , JAX 返回一个函数 , 该函数在域中的某个点求值时产生雅可比矩阵 。
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

从深度学习角度来看 , JAX 使得计算 Hessians 变得非常简单和高效 。 由于 XLA , JAX 可以比 PyTorch 更快地计算 Hessians , 这使得实现诸如 AdaHessian 这样的高阶优化更加快速 。
下面代码是在 PyTorch 中对一个简单的输入总和进行 Hessian:
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

正如我们所看到的 , 上述计算大约需要 16.3 ms , 在 JAX 中尝试相同的计算:
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

使用 JAX , 计算仅需 1.55 毫秒 , 比 PyTorch 快 10 倍以上:
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

JAX 可以非常快速地计算 Hessians , 使得高阶优化更加可行 。
使用 vmap() 自动向量化
JAX 在其 API 中还有另一种变换:vmap() 自动向量化 。 以下是矢量化向量加法展示:
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

使用 pmap() 实现自动并行化
分布式计算变得越来越重要 , 在深度学习中尤其如此 , 如下图所示 , SOTA 模型已经发展到超大规模 。
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

得益于 XLA , JAX 可以轻松地在加速器上进行计算 , 但 JAX 也可以轻松地使用多个加速器进行计算 , 即使用单个命令 - pmap() 执行 SPMD 程序的分布式训练 。
我们以向量矩阵乘法为例 , 如下为非并行向量矩阵乘法:
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

使用 JAX , 我们可以轻松地将这些计算分布在 4 个 TPU 上 , 只需将操作包装在 pmap() 中即可 。 这允许用户在每个 TPU 上同时执行一个点积 , 显着提高了计算速度(对于大型计算而言) 。
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

使用 jit() 加快功能
JIT 编译是一种执行代码的方法 , 介于解释(interpretation)和 AoT(ahead-of-time)编译之间 。 重要的是 , JIT 编译器在运行时将代码编译成快速的可执行文件 , 但代价是首次运行速度较慢 。
JIT 不是一次将一个操作分配给 GPU 内核 , 而是使用 XLA 将一系列操作编译成一个内核 , 从而为函数提供端到端编译的高效 XLA 实现 。
以下图为例 , 代码定义了一个函数:用三种方式计算 5000 x 5000 矩阵——一次使用 NumPy , 一次使用 JAX , 还有一次在 JIT 编译的函数版本上使用 JAX 。 我们首先在 CPU 上进行实验:
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

JAX 对于逐元素计算明显更快 , 尤其是在使用 jit 时 。
我们看到 JAX 比 NumPy 快 2.3 倍以上 , 当我们 JIT 函数时 , JAX 比 NumPy 快 30 倍 。 这些结果已经令人印象深刻 , 但让我们继续看 , 让 JAX 在 TPU 上进行计算:
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

当 JAX 在 TPU 上执行相同的计算时 , 它的相对性能会进一步提升(NumPy 计算仍在 CPU 上执行 , 因为它不支持 TPU 计算)在这种情况下 , 我们可以看到 JAX 比 NumPy 快了惊人的 13 倍 , 如果我们同时在 TPU 上 JIT 函数和计算 , 我们会发现 JAX 比 NumPy 快 80 倍 。
当然 , 这种速度的大幅提升是有代价的 。 JAX 对 JIT 允许的函数进行了限制 , 尽管通常允许仅涉及上述 NumPy 操作的函数 。 此外 , 通过 Python 控制流进行 JIT 处理存在一些限制 , 因此在编写函数时须牢记这一点 。
2022 年了 , 我该用 JAX 吗?
很遗憾 , 这个问题的答案还是「视情况而定」 。 是否迁移到 JAX 取决于你的情况和目标 。 为具体分析是否应该(或不应该)在 2022 年使用 JAX , 这里将建议汇总到下面的流程图中 , 并针对不同的兴趣领域提供不同的图表 。
科学计算
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

如果你对 JAX 在通用计算感兴趣 , 首先要问的问题就是——是否只尝试在加速器上运行 NumPy?如果答案是肯定的 , 那么你显然应该开始迁移到 JAX 。
如果你不只处理数字而是参与动态计算建模 , 那么是否应该使用 JAX 将取决于具体用例 。 如果大部分工作是在 Python 中使用大量自定义代码完成的 , 那么开始学习 JAX 以增强工作流程是值得的 。
如果大部分工作不在 Python 中 , 但你想构建的是某种基于模型 / 神经网络的混合系统 , 那么使用 JAX 可能是值得的 。
如果大部分工作不使用 Python , 或者你正在使用一些专门的软件进行研究(热力学、半导体等) , 那么 JAX 可能是不合适的工具 , 除非你想从这些程序中导出数据 , 用来做自定义计算 。 如果你感兴趣的领域更接近物理 / 数学并包含计算方法(动力系统、微分几何、统计物理)并且大部分工作都在例如 Mathematica 上 , 那么坚持使用目前的工具才是值得的 , 特别是在已有大型自定义代码库的情形下 。
深度学习
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

虽然我们已经强调过 , JAX 不是专为深度学习构建的通用框架 , 但 JAX 速度很快且具有自动微分功能 , 你肯定想知道使用 JAX 进行深度学习是什么样的 。
若想在 TPU 上进行训练 , 那么你应该开始使用 JAX , 尤其是如果当前正在使用的是 PyTorch 。 虽然有 PyTorch-XLA 存在 , 但使用 JAX 进行 TPU 训练绝对是更好的体验 。 如果你正在研究的是「非标准」架构 / 建模 , 例如 SDE-Nets , 那么也绝对应该尝试一下 JAX 。 此外 , 如果你想利用高阶优化技术 , JAX 也是要尝试的东西 。
如果你不是在构建特殊的架构 , 只是在 GPU 上训练常见的架构 , 那么你现在可能应该坚持使用 PyTorch 或 TensorFlow 。 然而 , 这个建议可能会在未来一两年内快速发生变化 。 虽然 PyTorch 仍然在研究领域占据主导地位 , 但使用 JAX 的论文数量一直在稳步增长 。 随着 DeepMind 和谷歌重量级玩家不断开发用于 JAX 的高级深度学习 API , 在几年内 JAX 可能会出现爆炸性的增长率 。
这意味着你至少应该稍微熟悉一下 JAX , 如果你是研究人员的话更应如此 。
深度学习初学者
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

但如果我只是个初学者呢?情况会有些不一样 。
如果你有兴趣了解深度学习并实现一些想法 , 你应该使用 JAX 或 PyTorch 。 如果你想自上而下学习深度学习 , 或有一些 Python 软件的经验 , 则应该从 PyTorch 入手 。 如果你想自下而上地学习深度学习 , 或具有数学背景 , 你可能会发现 JAX 很直观 。 在这种情况下 , 在进行任何大型项目之前 , 请确保了解如何使用 JAX 。
如果你对深度学习感兴趣 , 又想转行相关的职位 , 那么你需要使用 PyTorch 或 TensorFlow 。 尽管最好是同时熟悉两个框架 , 但你必须知道 TensorFlow 被普遍认为是「行业」框架 , 不同框架的职位发布数量证明了这一点:
GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
文章图片

如果你是一个没有数学或软件背景但想学习深度学习的初学者 , 那么你不会想使用 JAX 。 相反 , Keras 是更好的选择 。
不该使用 JAX 的四条理由
虽然上文已经讨论了很多 JAX 的正面反馈 , 它有潜力极大地提升用户程序的性能 。 但作者同时列举了以下四条不该使用 JAX 的理由:
  • JAX 仍然被官方认为是一个实验性框架 。 JAX 是一个相对「年轻」的项目 。 目前 , JAX 仍被视为一个研究项目 , 而不是成熟的谷歌产品 , 因此如果用户正在考虑迁移到 JAX , 请记住这一点;
  • 使用 JAX 一定要勤勉 。 调试的时间成本 , 或者更严重的是 , 未跟踪副作用(untracked side effects)的风险可能导致那些没有扎实掌握函数式编程的用户不适用 JAX 。 在开始将它用于正式项目之前 , 请确保自己了解使用 JAX 的常见缺陷;
  • JAX 没有针对 CPU 计算进行优化 。 鉴于 JAX 是以「加速器优先」的方式开发的 , 因此每个操作的分派并未针对 JAX 进行完全优化 。 在某些情况下 , NumPy 实际上可能比 JAX 更快 , 尤其是对于小型程序而言 , 这是因为 JAX 引入了开销;
  • JAX 与 Windows 不兼容 。 目前在 Windows 上不支持 JAX 。 如果用户使用 Windows 系统但仍想尝试 JAX , 可以使用 Colab 或将其安装在虚拟机(VM)上 。
【GitHub|2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美】原文链接:https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022/

    推荐阅读