[SysML'18] Compiling Machine Learning Programs via High-Level Tracing (JAX) 论文阅读
本文介绍了 JAX:一个面向领域的 tracing JIT 编译器,能够从纯 Python 和 NumPy 编写的机器学习程序生成高性能的加速器代码。JAX 借助 XLA 编译基础设施为"最适合加速"的子程序生成优化代码,这些优化后的子程序可以被任意 Python 代码调用和编排。由于 JAX 与 Autograd 完全兼容,它支持对 Python 函数进行任意阶的前向和反向自动微分。由于 JAX 支持结构化控制流,它能够为复杂机器学习算法生成高性能代码。将 JAX 与 Autograd 和 NumPy 结合,可以得到一个既易于编程、又高度高性能的 ML 系统,能够同时面向 CPU、GPU 和 TPU,并可扩展到多核 Cloud TPU。