Getting Started

Installation

First install torch CPU:

# On Linux.
pip install torch --index-url https://download.pytorch.org/whl/cpu

# Or on Mac.
pip install torch

Then install JAX for the accelerator you want to use:

# On Google Cloud TPU.
pip install -U jax[tpu]

# Or, on GPU machines.
pip install -U jax[cuda12]

# Or, on Linux CPU machines or Macs.
pip install -U jax

Finally install torchax:

pip install torchax

You can also install from source if you prefer the lastest torchax:

pip install git+https://github.com/google/torchax.git@main

Adopt JAX with ease

Suppose we have this toy model in PyTorch, and we want to run this model with JAX.

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

m = MyModel()
inputs = torch.randn(3, 3, 28, 28, device='cuda')
res = m(inputs)

Instead of rewriting the above in JAX (say, with NeuralNetwork libraries like flax or equinox). One can run the above model in JAX with these changes:

import torch
import torch.nn as nn
+ import torchax
+ torchax.enable_globally()

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

m = MyModel()
- inputs = torch.randn(3, 3, 28, 28, device='cuda')
+ inputs = torch.randn(3, 3, 28, 28, device='jax')
+ m.to('jax')
res = m(inputs)

That is it! You are now using JAX as the backend to execute the above model.

Now a bit of explanation:

import torchax
torchax.enable_globally()

The 2 lines above enables torchax and allows us to capture the PyTorch operators that we are running.

Then, you can use a jax device:

inputs = torch.randn(3, 3, 28, 28, device='jax')
m = MyModel().to('jax')

Here, the jax behaves like another PyTorch device, like cuda device on GPUs.

Is it really running on JAX?

One to to see and believe that it's actually running JAX, one can verify by capturing JAX profiler traces and see the JAX math running:

import torch
import torch.nn as nn
import torchax
torchax.enable_globally()

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

m = MyModel()
inputs = torch.randn(3, 3, 28, 28, device='jax')
m.to('jax')

import jax
with jax.profiler.trace('/tmp/jax-trace', create_perfetto_link=True):
    res = m(inputs)

Running the above (using a Google TPU VM that I happen to have) yields:

image

We can see we have JAX ops running on the accelerator.

Compiling with jax.jit

Running JAX code through torchax front-end this way, while runs fine, is not very fast. We can see from the above in the profiler, there are a lots of time spend in compiling XLA operations.

For people familiar with JAX, you get JAX's most performance benefits through compiling. In fact, as it pointed in this github post, the perfornace is day and night. So even though we succeeded in running the model on JAX, but we are running JAX's eager model. Now let's compile the model to unlock the performance benefits of JAX:

import torch
import torch.nn as nn
import torchax
torchax.enable_globally()

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

m = MyModel()
inputs = torch.randn(3, 3, 28, 28, device='jax')
m.to('jax')

m_jitted = torchax.compile(m)

import jax
import time

for i in range(3)
    start = time.perf_counter()
    res = m(inputs)
    res.apply_jax_(jax.block_until_ready)
    end = time.perf_counter()
    print(f'iteration {i} took {end - start}s')

with jax.profiler.trace('/tmp/jax-trace', create_perfetto_link=True):
    res = m(inputs)
    res.apply_jax_(jax.block_until_ready)

The output is

iteration 0 took 0.06629770500876475s
iteration 1 took 0.00022945000091567636s
iteration 2 took 0.00016760900325607508s

We can see that the first call to the model is slower, because it's compile time plus runtime. The second and third run are just the runtime and is much faster.

Looking the profiler output, we can also see that the ops are now fused.

image{.align-center}

because compiling the model (using jax.jit, which is using XLA compiler under the hood) allows XLA fuse operations for better performance.

For running a more serious model, here are a 3-series blog on running HuggingFace's version of Llama-3 7B model on TPUs, using Tensor parallel on 8 chips: