Menu
JAX
JAX is a popular open-source library that makes NumPy run faster by using Accelerated Linear Algebra (XLA) to compile and run NumPy on GPUs. NumPy is a powerful scientific computing library thatย supportsย multi-dimensional arrays and “brings the computational power of languages like C and Fortran to Python.” The only problem, it doesn’t natively support GPUs. That’s where JAX comes in. JAX supports CPUs, GPUs, and TPUs.ย ย ย
Project Background
- Tool:ย JAX
- Author: Google
- Initial Release: September 2017
- Type: NumPy on the CPU, GPU, and TPU
- License: Apache 2.0
- Github:ย jaxย with 14k+ stars
- Contributors: 300+ย
- Hardware: CPUs, GPUs, and TPUs
Differentiates
- Pythonย and NumPy functions
- Loops, branches, recursions, and closures
- Derivatives
- Supports backpropagation using the grad function
Features
- Speedsย up code significantly using a special compiler called Accelerated Linear Algebra (XLA)
- Fuses operations together
- Uses XLA to compile and run NumPy on GPUs and TPUs
- just-in-time compile Python functions into XLA optimized kernels
- pma enables the replication of computations across several GPU cores at once
- Supports composable transformations: grad, jit, pmap, and vmapย