torchax
  • What is torchax

User Guide

  • Get Started
  • How it works

Tutorials

  • Training a PyTorch model with JAX
    • Introduction
    • Dataset and DataLoader
  • Distributed arrays and automatic parallelization
torchax
  • Tutorials
  • Training a PyTorch model with JAX

Training a PyTorch model with JAX¶

Introduction¶

This tutorial notebook is adapted from https://docs.pytorch.org/tutorials/beginner/introyt/trainingyt.html

It will keep the most PyTorch code unchanged (especially the model definition), and will replace the standard PyTorch train loop (loss.backward() + optimizer.step() pattern) with a JAX train loop (jax.grad followed by optax.apply_updates).

The rest of the tutorial, such as data loading, print loss etc. are kept as close to the original as possible.

Dataset and DataLoader¶

The Dataset and DataLoader classes encapsulate the process of pulling your data from storage and exposing it to your training loop in batches.

The Dataset is responsible for accessing and processing single instances of data.

The DataLoader pulls instances of data from the Dataset (either automatically or with a sampler that you define), collects them in batches, and returns them for consumption by your training loop. The DataLoader works with all kinds of datasets, regardless of the type of data they contain.

For this tutorial, we'll be using the Fashion-MNIST dataset provided by TorchVision. We use torchvision.transforms.Normalize() to zero-center and normalize the distribution of the image tile content, and download both training and validation data splits.

In [1]:
Copied!
# Optional: install dependencies
!pip install matplotlib torch torchax jax optax tensorboard
# Optional: install dependencies !pip install matplotlib torch torchax jax optax tensorboard
Requirement already satisfied: matplotlib in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (3.10.7)
Requirement already satisfied: torch in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (2.8.0)
Requirement already satisfied: torchax in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (0.0.6)
Requirement already satisfied: jax in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (0.7.2)
Requirement already satisfied: optax in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (0.2.6)
Requirement already satisfied: tensorboard in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (2.20.0)
Requirement already satisfied: contourpy>=1.0.1 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (4.60.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (1.4.9)
Requirement already satisfied: numpy>=1.23 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (2.3.2)
Requirement already satisfied: packaging>=20.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (25.0)
Requirement already satisfied: pillow>=8 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (11.3.0)
Requirement already satisfied: pyparsing>=3 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (3.2.5)
Requirement already satisfied: python-dateutil>=2.7 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: filelock in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from torch) (3.19.1)
Requirement already satisfied: typing-extensions>=4.10.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from torch) (4.15.0)
Requirement already satisfied: setuptools in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from torch) (78.1.1)
Requirement already satisfied: sympy>=1.13.3 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from torch) (1.14.0)
Requirement already satisfied: networkx in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from torch) (3.5)
Requirement already satisfied: jinja2 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from torch) (3.1.6)
Requirement already satisfied: fsspec in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from torch) (2025.7.0)
Requirement already satisfied: jaxlib<=0.7.2,>=0.7.2 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from jax) (0.7.2)
Requirement already satisfied: ml_dtypes>=0.5.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from jax) (0.5.3)
Requirement already satisfied: opt_einsum in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from jax) (3.4.0)
Requirement already satisfied: scipy>=1.13 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from jax) (1.16.1)
Requirement already satisfied: absl-py>=0.7.1 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from optax) (2.3.1)
Requirement already satisfied: chex>=0.1.87 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from optax) (0.1.91)
Requirement already satisfied: grpcio>=1.48.2 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from tensorboard) (1.75.1)
Requirement already satisfied: markdown>=2.6.8 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from tensorboard) (3.9)
Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from tensorboard) (6.32.1)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from tensorboard) (0.7.2)
Requirement already satisfied: werkzeug>=1.0.1 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from tensorboard) (3.1.3)
Requirement already satisfied: toolz>=1.0.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from chex>=0.1.87->optax) (1.0.0)
Requirement already satisfied: six>=1.5 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from python-dateutil>=2.7->matplotlib) (1.17.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from sympy>=1.13.3->torch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.1.1 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from werkzeug>=1.0.1->tensorboard) (3.0.2)
In [2]:
Copied!
# For tips on running notebooks in Google Colab, see
# https://docs.pytorch.org/tutorials/beginner/colab
%matplotlib inline
# For tips on running notebooks in Google Colab, see # https://docs.pytorch.org/tutorials/beginner/colab %matplotlib inline
In [3]:
Copied!
import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
import torch import torchvision import torchvision.transforms as transforms # PyTorch TensorBoard support from torch.utils.tensorboard import SummaryWriter from datetime import datetime transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # Create datasets for training & validation, download if necessary training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True) validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True) # Create data loaders for our datasets; shuffle for training, not for validation training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True) validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False) # Class labels classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot') # Report split sizes print('Training set has {} instances'.format(len(training_set))) print('Validation set has {} instances'.format(len(validation_set)))
Training set has 60000 instances
Validation set has 10000 instances

As always, let's visualize the data as a sanity check:

In [4]:
Copied!
import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(classes[labels[j]] for j in range(4)))
import matplotlib.pyplot as plt import numpy as np # Helper function for inline image display def matplotlib_imshow(img, one_channel=False): if one_channel: img = img.mean(dim=0) img = img / 2 + 0.5 # unnormalize npimg = img.numpy() if one_channel: plt.imshow(npimg, cmap="Greys") else: plt.imshow(np.transpose(npimg, (1, 2, 0))) dataiter = iter(training_loader) images, labels = next(dataiter) # Create a grid from the images and show them img_grid = torchvision.utils.make_grid(images) matplotlib_imshow(img_grid, one_channel=True) print(' '.join(classes[labels[j]] for j in range(4)))
Coat  Pullover  Bag  Bag
No description has been provided for this image

The Model¶

The model we'll use in this example is a variant of LeNet-5 - it should be familiar if you've watched the previous videos in this series.

In [5]:
Copied!
import torch.nn as nn
import torch.nn.functional as F

# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = GarmentClassifier()
import torch.nn as nn import torch.nn.functional as F # PyTorch models inherit from torch.nn.Module class GarmentClassifier(nn.Module): def __init__(self): super(GarmentClassifier, self).__init__() self.conv1 = nn.Conv2d(1, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 4 * 4, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 4 * 4) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model = GarmentClassifier()
In [6]:
Copied!
model(images)
model(images)
Out[6]:
tensor([[ 0.0147, -0.0986, -0.0197, -0.0880, -0.0228, -0.0379,  0.0820, -0.0990,
         -0.0172, -0.0708],
        [ 0.0056, -0.0992, -0.0195, -0.0978, -0.0150, -0.0213,  0.0677, -0.0953,
         -0.0219, -0.0732],
        [ 0.0166, -0.1112, -0.0197, -0.1034, -0.0170, -0.0254,  0.0676, -0.1010,
         -0.0152, -0.0621],
        [ 0.0213, -0.1105, -0.0365, -0.1026, -0.0205, -0.0142,  0.0628, -0.1015,
         -0.0243, -0.0444]], grad_fn=<AddmmBackward0>)

Loss Function¶

For this example, we'll be using a cross-entropy loss. For demonstration purposes, we'll create batches of dummy output and label values, run them through the loss function, and examine the result.

In [7]:
Copied!
loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
loss_fn = torch.nn.CrossEntropyLoss() # NB: Loss functions expect data in batches, so we're creating batches of 4 # Represents the model's confidence in each of the 10 classes for a given input dummy_outputs = torch.rand(4, 10) # Represents the correct class among the 10 being tested dummy_labels = torch.tensor([1, 5, 3, 7]) print(dummy_outputs) print(dummy_labels) loss = loss_fn(dummy_outputs, dummy_labels) print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.0188, 0.1162, 0.7957, 0.0443, 0.0191, 0.1256, 0.8698, 0.1108, 0.4106,
         0.8890],
        [0.6096, 0.3726, 0.2039, 0.0199, 0.5399, 0.1214, 0.0381, 0.5662, 0.7744,
         0.9680],
        [0.2135, 0.7618, 0.2944, 0.8146, 0.8824, 0.8905, 0.8575, 0.8906, 0.0485,
         0.0220],
        [0.4641, 0.1055, 0.9248, 0.7413, 0.3650, 0.5146, 0.3190, 0.7306, 0.5251,
         0.5850]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.370086431503296

Move model to 'jax' device¶

In [8]:
Copied!
import torchax
torchax.enable_globally()
model.to('jax')
images = images.to('jax')
dummy_labels = dummy_labels.to('jax')
import torchax torchax.enable_globally() model.to('jax') images = images.to('jax') dummy_labels = dummy_labels.to('jax')
WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.
WARNING:root:Duplicate op registration for aten.__and__

Optimizer¶

For this example, we'll be using simple optax optimizer.

In [9]:
Copied!
import optax
start_learning_rate = 1e-3
optimizer = optax.adam(start_learning_rate) 
import optax start_learning_rate = 1e-3 optimizer = optax.adam(start_learning_rate)
In [10]:
Copied!
print(optimizer)
print(optimizer)
GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x14234ca40>, update=<function chain.<locals>.update_fn at 0x14234cae0>)

The Training Loop¶

Below, we have a function that performs one training epoch.

First, let's articulate what the training step does.

At each training step, we first evaluate the model. the Model is a function that maps the (weights, input data) to prediction.

$$ model: (weights, input) \mapsto pred $$

In PyTorch, we can use torch.func.functional_call to call a model with weights passed in as a paramter.

The loss is a function that takes the prediction, the label to a real number representing the loss:

$$ loss: (pred, label) \mapsto loss $$

To train the model, we a glorified Gradient Descent (in this case Adam), so we need to have another function that represent the gradient of the loss with respect of weights.

$$ \frac {d loss} {d weights}$$

Finally, the train_step itself is a function that takes (weights, optimizer_state, input_data) to (updated weights, and updated optimizer_states).

We can spell out the individual components of a train loop, and use Python to assemble them together:

In [11]:
Copied!
weights = model.state_dict()

def run_model_and_loss(weights, inputs, labels):
    # First call the model with passed in weights
    output = torch.func.functional_call(model, weights, args=(inputs, ))
    loss = loss_fn(output, labels)
    return loss
weights = model.state_dict() def run_model_and_loss(weights, inputs, labels): # First call the model with passed in weights output = torch.func.functional_call(model, weights, args=(inputs, )) loss = loss_fn(output, labels) return loss
In [12]:
Copied!
run_model_and_loss(model.state_dict(), images, dummy_labels)
run_model_and_loss(model.state_dict(), images, dummy_labels)
Out[12]:
Tensor(<class 'jaxlib._jax.ArrayImpl'> 2.3484843)

Now let's define the gradient function of it. In JAX, one would use jax.jit. However, jax.jit need to take a JAX function (function that takes jax.Array as inputs and outputs) as argument, and here run_model_and_loss takes torch.Tensor as inputs / outputs.

One way to solve this issue is to use jax_view from the torchax.interop module

jax_view converts a torch function to a jax function.

torchax has common JAX functions wrapped in the [so they work with torch-functions as well. in this case, we will use jax_value_and_grad.

In [13]:
Copied!
from torchax.interop import jax_view
import jax

grad_fn_jax = jax.grad( jax_view(run_model_and_loss))

grad_fn_jax(jax_view(weights), jax_view(images), jax_view(dummy_labels)).keys()
from torchax.interop import jax_view import jax grad_fn_jax = jax.grad( jax_view(run_model_and_loss)) grad_fn_jax(jax_view(weights), jax_view(images), jax_view(dummy_labels)).keys()
Out[13]:
odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

Note that above grad_fn_jax is the gradient of jax_view(run_model_and_loss) and is a jax function.

if instead we wish to make a it into a torch function, we can use torch_view on it and it will become a function that takes torch tensors and returns torch tensors.

In fact, the pattern of calling, jax_view + jax.value_and_grad + torch_view is common enough that we provided this very wraper as torchax.interop.jax_value_and_grad below

In [14]:
Copied!
grad_fn = torchax.interop.jax_value_and_grad(run_model_and_loss)
grad_fn = torchax.interop.jax_value_and_grad(run_model_and_loss)

Now let's assemble the train loop:

In [15]:
Copied!
# Initialize optimizer
from torchax.interop import call_jax

# Initialize optimizer, we need to call optimizer.init, but
# it is a JAX-function (function that takes jax arrays as input),
# so we use call_jax to pass it torch values:

opt_state = call_jax(optimizer.init, weights)


def train_one_epoch(epoch_index, tb_writer):
    global weights
    global opt_state
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data
        inputs = inputs.to('jax')
        labels = labels.to('jax')

        # compute gradients
        loss, gradients = grad_fn(weights, inputs, labels)
        # compute updates
        updates, opt_state = call_jax(optimizer.update, gradients, opt_state)
        #apply updates
        weights = call_jax(optax.apply_updates, weights, updates)
        
        # Gather data and report
        running_loss += loss.cpu().item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.
        if i > 2000: 
            break
            # NOTE: make it run faster for CI

    return last_loss
# Initialize optimizer from torchax.interop import call_jax # Initialize optimizer, we need to call optimizer.init, but # it is a JAX-function (function that takes jax arrays as input), # so we use call_jax to pass it torch values: opt_state = call_jax(optimizer.init, weights) def train_one_epoch(epoch_index, tb_writer): global weights global opt_state running_loss = 0. last_loss = 0. # Here, we use enumerate(training_loader) instead of # iter(training_loader) so that we can track the batch # index and do some intra-epoch reporting for i, data in enumerate(training_loader): # Every data instance is an input + label pair inputs, labels = data inputs = inputs.to('jax') labels = labels.to('jax') # compute gradients loss, gradients = grad_fn(weights, inputs, labels) # compute updates updates, opt_state = call_jax(optimizer.update, gradients, opt_state) #apply updates weights = call_jax(optax.apply_updates, weights, updates) # Gather data and report running_loss += loss.cpu().item() if i % 1000 == 999: last_loss = running_loss / 1000 # loss per batch print(' batch {} loss: {}'.format(i + 1, last_loss)) tb_x = epoch_index * len(training_loader) + i + 1 tb_writer.add_scalar('Loss/train', last_loss, tb_x) running_loss = 0. if i > 2000: break # NOTE: make it run faster for CI return last_loss

The above will work, however, the grad / optimizer update / apply update is pretty standard; so we have a helper to do exactly that make_train_step

Now let's use that instead.

Having a variable for the function of one training step also allows us to compile it with jax.jit. Here we use interop.jax_jit which just wraps jax.jit with torch_view and pass kwargs verbatim to the underlying jax.jit as below.

We can optionally donate the weight and optmizer state, so XLA can issue in-place updates for those 2.

In [16]:
Copied!
import functools
from torchax.train import make_train_step


# the calling convention to make_train_step is the model_fn
# takes weights (trainable params) and buffers (non-trainable params)
# separately. because jax.jit will compute gradients wrt the first arg.
def model_fn(weights, buffers, data):
    return torch.func.functional_call(model, (weights, buffers), data)


one_step = make_train_step(
    model_fn=model_fn,
    loss_fn=loss_fn,
    optax_optimizer=optimizer)


# def one_step(weights, opt_state, inputs, labels):
#             # compute gradients
#     loss, gradients = grad_fn(weights, inputs, labels)
#         # compute updates
#     updates, opt_state = call_jax(optimizer.update, gradients, opt_state)
#         #apply updates
#     weights = call_jax(optax.apply_updates, weights, updates)
#     return loss, weights, opt_state

one_step = torchax.interop.jax_jit(one_step, kwargs_for_jax_jit={'donate_argnums': (0, 2)})


def train_one_epoch(epoch_index, tb_writer):
    global weights
    global opt_state
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data
        inputs = inputs.to('jax')
        labels = labels.to('jax')

        loss, weights, opt_state = one_step(weights, {}, opt_state, inputs, labels) 
        # Gather data and report
        running_loss += loss.cpu().item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.
        if i > 2000: 
            break
            # NOTE: make it run faster for CI

    return last_loss
import functools from torchax.train import make_train_step # the calling convention to make_train_step is the model_fn # takes weights (trainable params) and buffers (non-trainable params) # separately. because jax.jit will compute gradients wrt the first arg. def model_fn(weights, buffers, data): return torch.func.functional_call(model, (weights, buffers), data) one_step = make_train_step( model_fn=model_fn, loss_fn=loss_fn, optax_optimizer=optimizer) # def one_step(weights, opt_state, inputs, labels): # # compute gradients # loss, gradients = grad_fn(weights, inputs, labels) # # compute updates # updates, opt_state = call_jax(optimizer.update, gradients, opt_state) # #apply updates # weights = call_jax(optax.apply_updates, weights, updates) # return loss, weights, opt_state one_step = torchax.interop.jax_jit(one_step, kwargs_for_jax_jit={'donate_argnums': (0, 2)}) def train_one_epoch(epoch_index, tb_writer): global weights global opt_state running_loss = 0. last_loss = 0. # Here, we use enumerate(training_loader) instead of # iter(training_loader) so that we can track the batch # index and do some intra-epoch reporting for i, data in enumerate(training_loader): # Every data instance is an input + label pair inputs, labels = data inputs = inputs.to('jax') labels = labels.to('jax') loss, weights, opt_state = one_step(weights, {}, opt_state, inputs, labels) # Gather data and report running_loss += loss.cpu().item() if i % 1000 == 999: last_loss = running_loss / 1000 # loss per batch print(' batch {} loss: {}'.format(i + 1, last_loss)) tb_x = epoch_index * len(training_loader) + i + 1 tb_writer.add_scalar('Loss/train', last_loss, tb_x) running_loss = 0. if i > 2000: break # NOTE: make it run faster for CI return last_loss

Per-Epoch Activity¶

There are a couple of things we'll want to do once per epoch:

  • Perform validation by checking our relative loss on a set of data that was not used for training, and report this
  • Save a copy of the model

Here, we'll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.

In [17]:
Copied!
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 2

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            vinputs = vinputs.to('jax')
            vlabels = vlabels.to('jax')
            model.load_state_dict(weights) # put the trained weight back to test it
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

            if i > 1000:
                break

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()


    epoch_number += 1
# Initializing in a separate cell so we can easily add more epochs to the same run timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp)) epoch_number = 0 EPOCHS = 2 best_vloss = 1_000_000. for epoch in range(EPOCHS): print('EPOCH {}:'.format(epoch_number + 1)) # Make sure gradient tracking is on, and do a pass over the data model.train(True) avg_loss = train_one_epoch(epoch_number, writer) running_vloss = 0.0 # Set the model to evaluation mode, disabling dropout and using population # statistics for batch normalization. model.eval() # Disable gradient computation and reduce memory consumption. with torch.no_grad(): for i, vdata in enumerate(validation_loader): vinputs, vlabels = vdata vinputs = vinputs.to('jax') vlabels = vlabels.to('jax') model.load_state_dict(weights) # put the trained weight back to test it voutputs = model(vinputs) vloss = loss_fn(voutputs, vlabels) running_vloss += vloss if i > 1000: break avg_vloss = running_vloss / (i + 1) print('LOSS train {} valid {}'.format(avg_loss, avg_vloss)) # Log the running loss averaged per batch # for both training and validation writer.add_scalars('Training vs. Validation Loss', { 'Training' : avg_loss, 'Validation' : avg_vloss }, epoch_number + 1) writer.flush() epoch_number += 1
EPOCH 1:
/Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py:5943: UserWarning: Explicitly requested dtype int64 requested in arange is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return _arange(start, stop=stop, step=step, dtype=dtype,
/Users/hanq/git/qihqi/torchax/torchax/ops/mappings.py:83: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:209.)
  res = torch.from_numpy(numpy.asarray(x))
  batch 1000 loss: 1.0007757039815188
  batch 2000 loss: 0.6775612101629377
LOSS train 0.6775612101629377 valid Tensor(<class 'jaxlib._jax.ArrayImpl'> 0.6449958)
EPOCH 2:
  batch 1000 loss: 0.6143077593529597
  batch 2000 loss: 0.543319463224354
LOSS train 0.543319463224354 valid Tensor(<class 'jaxlib._jax.ArrayImpl'> 0.55330473)

Save the model checkpoint¶

Currently torch.save (which is based on Pickle) are not able to save tensors on 'jax' device. Because JAX arrays cannot be pickled.

So now we have 2 strategies for saving:

  1. convert the tensors on jax devices to plain JAX arrays; then use flax.checkpoint to save the data. You will get an JAX-style checkpoint (directory) if you do so.
  2. convert the tensors from jax devices to CPU torch.Tensor, then use torch.save; you will get a regular pickle based checkpoint if you do so.

We recommend 1. and we have provided wrapper in torchax.save_checkpoint that does exactly this.

In [18]:
Copied!
import os
import orbax.checkpoint as ocp
ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
model_path = ckpt_dir / 'state'
torchax.save_checkpoint(weights, model_path, step=1)
import os import orbax.checkpoint as ocp ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/') model_path = ckpt_dir / 'state' torchax.save_checkpoint(weights, model_path, step=1)
WARNING:absl:[process=0][thread=MainThread][operation_id=1] _SignalingThread.join() waiting for signals ([]) blocking the main thread will slow down blocking save times. This is likely due to main thread calling result() on a CommitFuture.
In [19]:
Copied!
!find /tmp/my-checkpoints/
!find /tmp/my-checkpoints/
/tmp/my-checkpoints/
/tmp/my-checkpoints/state
/tmp/my-checkpoints/state/checkpoint_1
/tmp/my-checkpoints/state/checkpoint_1/_sharding
/tmp/my-checkpoints/state/checkpoint_1/_METADATA
/tmp/my-checkpoints/state/checkpoint_1/_CHECKPOINT_METADATA
/tmp/my-checkpoints/state/checkpoint_1/array_metadatas
/tmp/my-checkpoints/state/checkpoint_1/array_metadatas/process_0
/tmp/my-checkpoints/state/checkpoint_1/manifest.ocdbt
/tmp/my-checkpoints/state/checkpoint_1/d
/tmp/my-checkpoints/state/checkpoint_1/d/a3ddd2e5f397d67e91789ba879d4dd7f
/tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0
/tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0/manifest.ocdbt
/tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0/d
/tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0/d/d93c12ce6141615627e44d574d6c7277
/tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0/d/7ea645b957f87cd9df883e2d0f5205bf
/tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0/d/da7cbbf95d8fa1f2b7ae34daa713bef3
/tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0/d/d1c362243563d8b2b92c1bcceee0ddaf
/tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0/d/bcb90d6f6ab75dc2b0cfbf48b9301ed1
/Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid, fd = os.forkpty()

You can also produce a torch pickle based checkpoint by moving the state_dict to CPU

You can do so with

In [20]:
Copied!
cpu_state_dict = jax.tree.map(lambda a: a.jax(), weights)
cpu_state_dict = jax.tree.map(lambda a: a.jax(), weights)
In [21]:
Copied!
torch.save(cpu_state_dict, ckpt_dir / 'torch_checkpoint.pkl')
torch.save(cpu_state_dict, ckpt_dir / 'torch_checkpoint.pkl')
In [1]:
Copied!
!ls /tmp/my-checkpoints/
!ls /tmp/my-checkpoints/
state                torch_checkpoint.pkl
In [ ]:
Copied!

In [ ]:
Copied!

In [ ]:
Copied!

In [ ]:
Copied!

Previous Next

Built with MkDocs using a theme provided by Read the Docs.
« Previous Next »