JAX is a popular open-source library that makes NumPy run faster by using Accelerated Linear Algebra (XLA) to compile and run NumPy on GPUs. NumPy is a powerful scientific computing library that supports multi-dimensional arrays and “brings the computational power of languages like C and Fortran to Python.” The only problem, it doesn’t natively support GPUs. That’s where JAX comes in. JAX supports CPUs, GPUs, and TPUs.     

Project Background

  • Tool: JAX
  • Author: Google
  • Initial Release: September 2017
  • Type: NumPy on the CPU, GPU, and TPU
  • License: Apache 2.0
  • Github: jax with 14k+ stars
  • Contributors: 300+ 
  • Hardware: CPUs, GPUs, and TPUs


  • Python and NumPy functions
  • Loops, branches, recursions, and closures
  • Derivatives
  • Supports backpropagation using the grad function


  • Speeds up code significantly using a special compiler called Accelerated Linear Algebra (XLA)
  • Fuses operations together
  • Uses XLA to compile and run NumPy on GPUs and TPUs
  • just-in-time compile Python functions into XLA optimized kernels
  • pma enables the replication of computations across several GPU cores at once
  • Supports composable transformations: grad, jit, pmap, and vmap 
Scroll to Top