Google JAX是一个用于高性能数值计算的机器学习Python库,拥有和NumPy一样易用的接口,但却支持GPU加速。
Google JAX是谷歌内部开发的一个项目,它不是一个机器学习或者深度学习库,而是一个高性能的数值计算库。它的API接口是基于NumPy的,因此JAX本身简单、灵活且易于使用。
但是JAX本身支持基于GPU或者TPU的加速。因此,相比较NumPy,其速度要快很多。
除了可以完成类似NumPy的数值计算,JAX还包括一个可扩展的可组合函数转换系统,有助于机器学习的研究,包括:
虽然JAX本身不是一个深度学习框架,但它肯定为深度学习的目的提供了一个更充分的基础。有许多建立在JAX之上的库,旨在建立深度学习能力,包括Flax、Haiku和Elegy。JAX对Hessians的高效计算也与深度学习有关,因为它们使高阶优化技术更加可行。
尽管JAX在逐渐成长,但它目前被运用于一系列项目中,比如除了深度学习之外,还有贝叶斯方法和机器人。DeepMind宣布了四个新的库,将加入他们的生态系统。Mctx提供AlphaZero和MuZero蒙特卡洛树搜索,KFAC-JAX是一个用于神经网络二阶优化和计算可扩展曲率近似值的库,DM_AUX是用于JAX的音频信号处理,提供频谱提取和SpecAug增强的工具,TF2JAX是一个用于将TensorFlow函数和图形转换为JAX函数的库。
总之,关于JAX的相关生态正在快速发展。甚至在某些报道中,它被当做是TensorFlow的接任者。
是否开源: 是
许可协议: Apache-2.0 license
官方地址: https://jax.readthedocs.io/en/latest/index.html
GitHub地址: https://github.com/google/jax
初始贡献者: Google内部人员
官方指南:https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
TensorFlow - 深度学习
MindSpore - 深度学习