< All Topics
Print

JAX

JAX is a popular open-source library developed by Google that combines the functionality of Autograd, and XLA (Accelerated Linear Algebra), which are libraries for function differentiation (e.g., get higher derivatives automatically, supports forward- and reverse-mode differentiation) and for compiling and running Machine Learning (ML) operations on GPUs and TPUs.[1]

It has the same Python structure for writing code and enables composable transforms in the form of “jax.transform.” Developers can call functions to differentiate, vectorize, parallelize, compile, and write SPDM programs.

JAX was developed to accelerate Machine Learning codes by making NumPy run faster and more efficiently. NumPy is a Python library focused on scientific computation that supports multi-dimensional arrays and “brings the computational power of languages like C and Fortran to Python.”

It is an ongoing project supported by over 450 contributors with 50+ releases and continuous reporting of issues and bugs. If you find some bugs, check the JAX’s documentation to know how to report them.[2]

JAX Main Features

  • As accelerated Numpy: JAX is a numerical computation library for Python designed to be a drop-in replacement for NumPy. Like NumPy, It allows for the manipulation of arrays and matrices of data and provides a wide range of mathematical operations and functions. However, JAX also includes several features specifically designed to accelerate the performance of numerical computations (e.g., JIT). Additionally, It supports hardware acceleration on various platforms, including CPUs, GPUs, and TPUs, which can further improve performance.
  • Just InTime Compilation: JIT in JAX refers to the ability to dynamically compile and execute Python functions to machine code at runtime. It allows JAX to optimize the performance of the function and make it run faster. JIT compilation is achieved through the XLA library, which can optimize the function’s computation and memory access patterns to improve performance. Additionally, JAX includes support for hardware acceleration on a variety of platforms, including CPUs, GPUs, and TPUs, which can further improve performance.
  • Automatic Differentiation: JAX offers functions that make it easy to compute higher-order derivatives by making the functions themselves differentiable. It also enables forward- and reverse-mode differentiation of numerical functions through the function transformations jax.jacfwd and jax.jacrev.
  • Automatic Vectorization: JAX incorporates vectorization capabilities through the use of the jax.vmap transformation. It allows you to generate a vectorized implementation of a function automatically.
  • Pseudo Random Numbers: The jax.random package offers multiple methods for deterministic generation of sequences of pseudorandom numbers. It supports a wide range of probability distributions, including uniform, normal, and beta distributions, and more specialized distributions such as the exponential and Poisson distributions.

Quick Installation Guide

You can install the library through the Python Package Index (PyPI), the official third-party software repository for Python.

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

Note: On Windows, these pip installations may fail without any warning; therefore, it’s recommended to use Windows Subsystem for Linux (WSL) instead.

To import the library and its modules, use the commands illustrated below:

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

To see more details about the installation procedure, check the its Quickstart guide. 

JAX vs. NumPy

JAX is a library for high-performance ML research, while NumPy is a library for scientific computing. It is built on top of NumPy and can be used as a drop-in replacement for it, but it also includes additional features such as automatic differentiation and GPU acceleration.

The jax.numpy package offers most of the functionality of numpy, allowing similar operations. For example:

import numpy as np
x_np = np.linspace(0, 10, 1000)
y_np = 2 * (np.cos(x_np))^2

import jax.numpy as jnp
x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * (jnp.cos(x_jnp))^2

But we must highlight one critical difference. It is that arrays created through JAX are immutable. In other words, you can not modify an array’s internal value, as it is typically done with NumPy.

Remember that the JAX’s primary goal is to speed up the array processing. To circumvent that problem, JAX provides an indexed update syntax, which returns an updated copy of the array:

import jax.numpy as jnp
x = jnp.arange(10)
y = x.at[0].set(100)
print(x)
print(y)
<strong>[0 1 2 3 4 5 6 7 8 9]
[100  1  2  3  4  5  6  7  8  9]</strong>

Highlights

Project Background

  • Tool: JAX
  • Author: Google
  • Initial Release: September 2017
  • Type: Accelerator for NumPy on the CPU, GPU, and TPU
  • License: Apache 2.0
  • Languages: Python, C++ and others
  • Github: JAX
  • Hardware: CPUs, GPUs, and TPUs

Differentiates

  • Python and NumPy functions
  • Loops, branches, recursions, and closures
  • Derivatives
  • Supports backpropagation using the grad function

Main Features

  • Speeds up code significantly using a special compiler called Accelerated Linear Algebra (XLA)
  • Fuses operations together
  • Uses XLA to compile and run NumPy on GPUs and TPUs
  • just-in-time compile Python functions into XLA optimized kernels
  • pma enables the replication of computations across several GPU cores at once
  • Supports composable transformations: grad, jit, pmap, and vmap 

Projects and Libraries Using JAX

  • Flax: A deep learning library that takes advantage of JAX’s powerful features to provide a high-level interface for building neural networks. 3.9k+ stars.
  • Haiku: A library that provides a simple and a composable set of abstractions for Machine Learning research. DeepMind built it on top of JAX. 2.3k+ stars.
  • NumPyro: It is a library focused on providing probabilistic models to provide a NumPy backend for Pyro. 1.1k+ stars.
  • Networks Learning High Frequency Functions. A study that uses JAX and Neural Tangents for training networks and calculating neural tangent kernels.[3]
  • JAXNS: A high-performance nested sampling package built on top of JAX to be several orders of magnitude faster that similar nested sampling models.[4]

Community Benchmarks

  • 21,500 Stars
  • 2,000 Forks
  • 450+ Code contributors
  • 50+ Releases
  • Source: GitHub

Releases

  • V0.4.1 (1-2023)
  • V0.3.25 (11-15-2022)
  • V0.3.24 (11-4-2022) 
  • V0.3.24 (10-232022)
  • Source: releases GitHub

References

[1] Documentation, 2023.

[2] Google, JAX– github, 2023.

[3] Tancik, M., Srinivasan, P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., … & Ng, R. (2020). Fourier features let networks learn high frequency functions in low dimensional domains. Advances in Neural Information Processing Systems, 33, 7537-7547.

[4] Albert, J. G. (2020). JAXNS: a high-performance nested sampling package based on JAX. arXiv preprint arXiv:2012.15286.

Was this article helpful?
0 out of 5 stars
5 Stars 0%
4 Stars 0%
3 Stars 0%
2 Stars 0%
1 Stars 0%
5
Please Share Your Feedback
How Can We Improve This Article?
Table of Contents
Scroll to Top