JAX

JAX is a popular open-source library that makes NumPy run faster by using Accelerated Linear Algebra (XLA) to compile and run NumPy on GPUs. NumPy is a powerful scientific computing library thatย supportsย multi-dimensional arrays and “brings the computational power of languages like C and Fortran to Python.” The only problem, it doesn’t natively support GPUs. That’s where JAX comes in. JAX supports CPUs, GPUs, and TPUs.ย  ย  ย 

Project Background

  • Tool:ย JAX
  • Author: Google
  • Initial Release: September 2017
  • Type: NumPy on the CPU, GPU, and TPU
  • License: Apache 2.0
  • Github:ย jaxย with 14k+ stars
  • Contributors: 300+ย 
  • Hardware: CPUs, GPUs, and TPUs

Differentiates

  • Pythonย and NumPy functions
  • Loops, branches, recursions, and closures
  • Derivatives
  • Supports backpropagation using the grad function

Features

  • Speedsย up code significantly using a special compiler called Accelerated Linear Algebra (XLA)
  • Fuses operations together
  • Uses XLA to compile and run NumPy on GPUs and TPUs
  • just-in-time compile Python functions into XLA optimized kernels
  • pma enables the replication of computations across several GPU cores at once
  • Supports composable transformations: grad, jit, pmap, and vmapย 
Scroll to Top