JAX
JAX is an open-source numerical computing library from Google that combines NumPy-style array programming with automatic differentiation and just-in-time compilation, used to train large-scale machine learning models on GPUs and TPUs.
JAX is an open-source numerical computing library developed by Google and first released in 2018. It brings together a familiar NumPy-style programming interface with three powerful transformations: automatic differentiation, just-in-time (JIT) compilation, and automatic vectorisation. JAX has become a key framework for training state-of-the-art foundation models, used by organisations including Google DeepMind, Anthropic, xAI, and Apple.
JAX is built on a functional programming model. Rather than mutating state in place, programs are written as pure mathematical functions that JAX can analyse and transform. This design is what allows the library to apply its transformations composably and reliably.
Core transformations
The function grad performs automatic differentiation, returning a new function that computes the gradient of the original. Because it composes, higher-order derivatives are obtained by applying it repeatedly. The function jit compiles Python functions into optimised machine code using the XLA (Accelerated Linear Algebra) compiler, often yielding large speed-ups by fusing operations and reducing overhead. The function vmap automatically vectorises a function so that it operates over batches without the developer writing explicit loops, and pmap and related tools distribute computation across multiple devices.
These transformations run unchanged on central processing units, GPUs, and Google's Tensor Processing Units, making JAX especially well suited to large-scale training on TPU pods.
The JAX AI stack
JAX itself is deliberately minimal, providing the numerical core while higher-level libraries supply the conveniences needed for full machine-learning workflows. Google has packaged these into the JAX AI Stack, an end-to-end platform co-designed with Cloud TPUs. Flax provides a flexible API for authoring neural-network models. Optax offers composable gradient-processing and optimisation transformations. Orbax handles asynchronous, distributed checkpointing so that long training runs survive hardware failures. Newer additions announced in 2025 include Metrax for efficient evaluation metrics and JAX-Privacy 1.0 for differentially private training pipelines.
Position relative to PyTorch and TensorFlow
JAX occupies a distinct niche. Where PyTorch emphasises an imperative, easy-to-debug style and TensorFlow emphasises production deployment, JAX targets high-performance research and large-scale training where its functional purity and XLA compilation deliver strong performance. The trade-off is a steeper learning curve and a smaller, though rapidly growing, ecosystem. Its functional approach also extends beyond deep learning into scientific computing, physics simulation, and other fields that benefit from composable differentiation.
| Transformation | Purpose |
|----------------|---------|
| grad | Automatic differentiation |
| jit | Just-in-time compilation via XLA |
| vmap | Automatic batching / vectorisation |
| pmap | Parallelism across devices |
References
- Bradbury, J., et al. (2018). JAX: Composable Transformations of Python and NumPy Programs. Google.
- Google Developers Blog. (2025). Building Production AI on Google Cloud TPUs with JAX. developers.googleblog.com.
- Google Research. (2025). Differentially Private Machine Learning at Scale with JAX-Privacy. research.google.
- Google Cloud Blog. (2025). A Guide to JAX for PyTorch Developers. cloud.google.com.