What is torchax

torchax is a PyTorch frontend for JAX. It gives JAX the ability to author JAX programs using familiar PyTorch syntax. It also provides JAX-Pytorch interoperability, meaning, one can mix JAX & Pytorch syntax together when authoring ML programs, and run it in every hardware JAX can run.

With torchax, you can:

  • Call JAX functions from PyTorch, passing in jax.Arrays.
  • Call PyTorch functions from JAX, passing in torch.Tensors.
  • Use JAX features like jax.grad, optax, and Auto-Parallelisation to train PyTorch models.
  • Use a PyTorch model as a feature extractor with a JAX model.
  • Run PyTorch code on hardwares where JAX is supported, such as Google TPUs, with minimal code changes.

Installation

First install torch CPU:

# On Linux.
pip install torch --index-url https://download.pytorch.org/whl/cpu

# Or on Mac.
pip install torch

Then install JAX for the accelerator you want to use:

# On Google Cloud TPU.
pip install -U jax[tpu]

# Or, on GPU machines.
pip install -U jax[cuda12]

# Or, on Linux CPU machines or Macs.
pip install -U jax

Or, follow a method on https://docs.jax.dev/en/latest/installation.html to install Jax.

Finally install torchax:

pip install torchax

You can also install from source if you prefer the lastest torchax:

pip install git+https://github.com/google/torchax.git@main

Note that for now we don't automatically install torch or jax when installing torchax because we want to expose the option of picking a version to the users.