Training a PyTorch model with JAX¶
Introduction¶
This tutorial notebook is adapted from https://docs.pytorch.org/tutorials/beginner/introyt/trainingyt.html
It will keep the most PyTorch code unchanged (especially the model definition),
and will replace the standard PyTorch train loop (loss.backward()
+ optimizer.step()
pattern)
with a JAX train loop (jax.grad
followed by optax.apply_updates
).
The rest of the tutorial, such as data loading, print loss etc. are kept as close to the original as possible.
Dataset and DataLoader¶
The Dataset
and DataLoader
classes encapsulate the process of
pulling your data from storage and exposing it to your training loop in
batches.
The Dataset
is responsible for accessing and processing single
instances of data.
The DataLoader
pulls instances of data from the Dataset
(either
automatically or with a sampler that you define), collects them in
batches, and returns them for consumption by your training loop. The
DataLoader
works with all kinds of datasets, regardless of the type of
data they contain.
For this tutorial, we'll be using the Fashion-MNIST dataset provided by
TorchVision. We use torchvision.transforms.Normalize()
to zero-center
and normalize the distribution of the image tile content, and download
both training and validation data splits.
# Optional: install dependencies
!pip install matplotlib torch torchax jax optax tensorboard
Requirement already satisfied: matplotlib in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (3.10.7) Requirement already satisfied: torch in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (2.8.0) Requirement already satisfied: torchax in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (0.0.6) Requirement already satisfied: jax in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (0.7.2) Requirement already satisfied: optax in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (0.2.6) Requirement already satisfied: tensorboard in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (2.20.0) Requirement already satisfied: contourpy>=1.0.1 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (1.3.3) Requirement already satisfied: cycler>=0.10 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (4.60.1) Requirement already satisfied: kiwisolver>=1.3.1 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (1.4.9) Requirement already satisfied: numpy>=1.23 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (2.3.2) Requirement already satisfied: packaging>=20.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (25.0) Requirement already satisfied: pillow>=8 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (11.3.0) Requirement already satisfied: pyparsing>=3 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (3.2.5) Requirement already satisfied: python-dateutil>=2.7 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from matplotlib) (2.9.0.post0) Requirement already satisfied: filelock in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from torch) (3.19.1) Requirement already satisfied: typing-extensions>=4.10.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from torch) (4.15.0) Requirement already satisfied: setuptools in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from torch) (78.1.1) Requirement already satisfied: sympy>=1.13.3 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from torch) (1.14.0) Requirement already satisfied: networkx in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from torch) (3.5) Requirement already satisfied: jinja2 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from torch) (3.1.6) Requirement already satisfied: fsspec in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from torch) (2025.7.0) Requirement already satisfied: jaxlib<=0.7.2,>=0.7.2 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from jax) (0.7.2) Requirement already satisfied: ml_dtypes>=0.5.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from jax) (0.5.3) Requirement already satisfied: opt_einsum in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from jax) (3.4.0) Requirement already satisfied: scipy>=1.13 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from jax) (1.16.1) Requirement already satisfied: absl-py>=0.7.1 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from optax) (2.3.1) Requirement already satisfied: chex>=0.1.87 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from optax) (0.1.91) Requirement already satisfied: grpcio>=1.48.2 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from tensorboard) (1.75.1) Requirement already satisfied: markdown>=2.6.8 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from tensorboard) (3.9) Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from tensorboard) (6.32.1) Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from tensorboard) (0.7.2) Requirement already satisfied: werkzeug>=1.0.1 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from tensorboard) (3.1.3) Requirement already satisfied: toolz>=1.0.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from chex>=0.1.87->optax) (1.0.0) Requirement already satisfied: six>=1.5 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from python-dateutil>=2.7->matplotlib) (1.17.0) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from sympy>=1.13.3->torch) (1.3.0) Requirement already satisfied: MarkupSafe>=2.1.1 in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (from werkzeug>=1.0.1->tensorboard) (3.0.2)
# For tips on running notebooks in Google Colab, see
# https://docs.pytorch.org/tutorials/beginner/colab
%matplotlib inline
import torch
import torchvision
import torchvision.transforms as transforms
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)
# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))
Training set has 60000 instances Validation set has 10000 instances
As always, let's visualize the data as a sanity check:
import matplotlib.pyplot as plt
import numpy as np
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(training_loader)
images, labels = next(dataiter)
# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))
Coat Pullover Bag Bag
The Model¶
The model we'll use in this example is a variant of LeNet-5 - it should be familiar if you've watched the previous videos in this series.
import torch.nn as nn
import torch.nn.functional as F
# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
def __init__(self):
super(GarmentClassifier, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = GarmentClassifier()
model(images)
tensor([[ 0.0147, -0.0986, -0.0197, -0.0880, -0.0228, -0.0379, 0.0820, -0.0990, -0.0172, -0.0708], [ 0.0056, -0.0992, -0.0195, -0.0978, -0.0150, -0.0213, 0.0677, -0.0953, -0.0219, -0.0732], [ 0.0166, -0.1112, -0.0197, -0.1034, -0.0170, -0.0254, 0.0676, -0.1010, -0.0152, -0.0621], [ 0.0213, -0.1105, -0.0365, -0.1026, -0.0205, -0.0142, 0.0628, -0.1015, -0.0243, -0.0444]], grad_fn=<AddmmBackward0>)
Loss Function¶
For this example, we'll be using a cross-entropy loss. For demonstration purposes, we'll create batches of dummy output and label values, run them through the loss function, and examine the result.
loss_fn = torch.nn.CrossEntropyLoss()
# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])
print(dummy_outputs)
print(dummy_labels)
loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))
tensor([[0.0188, 0.1162, 0.7957, 0.0443, 0.0191, 0.1256, 0.8698, 0.1108, 0.4106, 0.8890], [0.6096, 0.3726, 0.2039, 0.0199, 0.5399, 0.1214, 0.0381, 0.5662, 0.7744, 0.9680], [0.2135, 0.7618, 0.2944, 0.8146, 0.8824, 0.8905, 0.8575, 0.8906, 0.0485, 0.0220], [0.4641, 0.1055, 0.9248, 0.7413, 0.3650, 0.5146, 0.3190, 0.7306, 0.5251, 0.5850]]) tensor([1, 5, 3, 7]) Total loss for this batch: 2.370086431503296
Move model to 'jax' device¶
import torchax
torchax.enable_globally()
model.to('jax')
images = images.to('jax')
dummy_labels = dummy_labels.to('jax')
WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed. WARNING:root:Duplicate op registration for aten.__and__
import optax
start_learning_rate = 1e-3
optimizer = optax.adam(start_learning_rate)
print(optimizer)
GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x14234ca40>, update=<function chain.<locals>.update_fn at 0x14234cae0>)
The Training Loop¶
Below, we have a function that performs one training epoch.
First, let's articulate what the training step does.
At each training step, we first evaluate the model. the Model is a
function that maps the (weights, input data)
to prediction
.
$$ model: (weights, input) \mapsto pred $$
In PyTorch, we can use torch.func.functional_call to call a model with weights passed in as a paramter.
The loss is a function that takes the prediction, the label to a real number representing the loss:
$$ loss: (pred, label) \mapsto loss $$
To train the model, we a glorified Gradient Descent (in this case Adam), so we need to have another function that represent the gradient of the loss with respect of weights.
$$ \frac {d loss} {d weights}$$
Finally, the train_step
itself is a function that takes (weights, optimizer_state, input_data) to
(updated weights, and updated optimizer_states).
We can spell out the individual components of a train loop, and use Python to assemble them together:
weights = model.state_dict()
def run_model_and_loss(weights, inputs, labels):
# First call the model with passed in weights
output = torch.func.functional_call(model, weights, args=(inputs, ))
loss = loss_fn(output, labels)
return loss
run_model_and_loss(model.state_dict(), images, dummy_labels)
Tensor(<class 'jaxlib._jax.ArrayImpl'> 2.3484843)
Now let's define the gradient function of it. In JAX, one would use jax.jit
.
However, jax.jit
need to take a JAX function (function that takes jax.Array as inputs and outputs) as
argument, and here run_model_and_loss
takes torch.Tensor as inputs / outputs.
One way to solve this issue is to use jax_view
from the torchax.interop module
jax_view
converts a torch function to a jax function.
torchax
has common JAX functions wrapped in the [so they work with torch-functions as well.
in this case, we will use jax_value_and_grad
.
from torchax.interop import jax_view
import jax
grad_fn_jax = jax.grad( jax_view(run_model_and_loss))
grad_fn_jax(jax_view(weights), jax_view(images), jax_view(dummy_labels)).keys()
odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
Note that above grad_fn_jax
is the gradient of jax_view(run_model_and_loss)
and is a jax function.
if instead we wish to make a it into a torch function, we can use torch_view
on it and it will
become a function that takes torch tensors and returns torch tensors.
In fact, the pattern of calling, jax_view
+ jax.value_and_grad
+ torch_view
is common enough that
we provided this very wraper as torchax.interop.jax_value_and_grad
below
grad_fn = torchax.interop.jax_value_and_grad(run_model_and_loss)
Now let's assemble the train loop:
# Initialize optimizer
from torchax.interop import call_jax
# Initialize optimizer, we need to call optimizer.init, but
# it is a JAX-function (function that takes jax arrays as input),
# so we use call_jax to pass it torch values:
opt_state = call_jax(optimizer.init, weights)
def train_one_epoch(epoch_index, tb_writer):
global weights
global opt_state
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
inputs = inputs.to('jax')
labels = labels.to('jax')
# compute gradients
loss, gradients = grad_fn(weights, inputs, labels)
# compute updates
updates, opt_state = call_jax(optimizer.update, gradients, opt_state)
#apply updates
weights = call_jax(optax.apply_updates, weights, updates)
# Gather data and report
running_loss += loss.cpu().item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
if i > 2000:
break
# NOTE: make it run faster for CI
return last_loss
The above will work, however, the grad / optimizer update / apply update is pretty standard; so we have a helper to do exactly that make_train_step
Now let's use that instead.
Having a variable for the function of one training step also allows us to compile it with jax.jit
.
Here we use interop.jax_jit
which just wraps jax.jit
with torch_view
and pass kwargs verbatim to the
underlying jax.jit
as below.
We can optionally donate the weight and optmizer state, so XLA can issue in-place updates for those 2.
import functools
from torchax.train import make_train_step
# the calling convention to make_train_step is the model_fn
# takes weights (trainable params) and buffers (non-trainable params)
# separately. because jax.jit will compute gradients wrt the first arg.
def model_fn(weights, buffers, data):
return torch.func.functional_call(model, (weights, buffers), data)
one_step = make_train_step(
model_fn=model_fn,
loss_fn=loss_fn,
optax_optimizer=optimizer)
# def one_step(weights, opt_state, inputs, labels):
# # compute gradients
# loss, gradients = grad_fn(weights, inputs, labels)
# # compute updates
# updates, opt_state = call_jax(optimizer.update, gradients, opt_state)
# #apply updates
# weights = call_jax(optax.apply_updates, weights, updates)
# return loss, weights, opt_state
one_step = torchax.interop.jax_jit(one_step, kwargs_for_jax_jit={'donate_argnums': (0, 2)})
def train_one_epoch(epoch_index, tb_writer):
global weights
global opt_state
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
inputs = inputs.to('jax')
labels = labels.to('jax')
loss, weights, opt_state = one_step(weights, {}, opt_state, inputs, labels)
# Gather data and report
running_loss += loss.cpu().item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
if i > 2000:
break
# NOTE: make it run faster for CI
return last_loss
Per-Epoch Activity¶
There are a couple of things we'll want to do once per epoch:
- Perform validation by checking our relative loss on a set of data that was not used for training, and report this
- Save a copy of the model
Here, we'll do our reporting in TensorBoard. This will require going to the command line to start TensorBoard, and opening it in another browser tab.
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 2
best_vloss = 1_000_000.
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
vinputs = vinputs.to('jax')
vlabels = vlabels.to('jax')
model.load_state_dict(weights) # put the trained weight back to test it
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
if i > 1000:
break
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
epoch_number + 1)
writer.flush()
epoch_number += 1
EPOCH 1:
/Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py:5943: UserWarning: Explicitly requested dtype int64 requested in arange is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. return _arange(start, stop=stop, step=step, dtype=dtype, /Users/hanq/git/qihqi/torchax/torchax/ops/mappings.py:83: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:209.) res = torch.from_numpy(numpy.asarray(x))
batch 1000 loss: 1.0007757039815188 batch 2000 loss: 0.6775612101629377 LOSS train 0.6775612101629377 valid Tensor(<class 'jaxlib._jax.ArrayImpl'> 0.6449958) EPOCH 2: batch 1000 loss: 0.6143077593529597 batch 2000 loss: 0.543319463224354 LOSS train 0.543319463224354 valid Tensor(<class 'jaxlib._jax.ArrayImpl'> 0.55330473)
Save the model checkpoint¶
Currently torch.save
(which is based on Pickle) are not able to save tensors on 'jax' device.
Because JAX arrays cannot be pickled.
So now we have 2 strategies for saving:
- convert the tensors on jax devices to plain JAX arrays; then use flax.checkpoint to save the data. You will get an JAX-style checkpoint (directory) if you do so.
- convert the tensors from jax devices to CPU torch.Tensor, then use
torch.save
; you will get a regular pickle based checkpoint if you do so.
We recommend 1. and we have provided wrapper in torchax.save_checkpoint
that does exactly this.
import os
import orbax.checkpoint as ocp
ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
model_path = ckpt_dir / 'state'
torchax.save_checkpoint(weights, model_path, step=1)
WARNING:absl:[process=0][thread=MainThread][operation_id=1] _SignalingThread.join() waiting for signals ([]) blocking the main thread will slow down blocking save times. This is likely due to main thread calling result() on a CommitFuture.
!find /tmp/my-checkpoints/
/tmp/my-checkpoints/ /tmp/my-checkpoints/state /tmp/my-checkpoints/state/checkpoint_1 /tmp/my-checkpoints/state/checkpoint_1/_sharding /tmp/my-checkpoints/state/checkpoint_1/_METADATA /tmp/my-checkpoints/state/checkpoint_1/_CHECKPOINT_METADATA /tmp/my-checkpoints/state/checkpoint_1/array_metadatas /tmp/my-checkpoints/state/checkpoint_1/array_metadatas/process_0 /tmp/my-checkpoints/state/checkpoint_1/manifest.ocdbt /tmp/my-checkpoints/state/checkpoint_1/d /tmp/my-checkpoints/state/checkpoint_1/d/a3ddd2e5f397d67e91789ba879d4dd7f /tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0 /tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0/manifest.ocdbt /tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0/d /tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0/d/d93c12ce6141615627e44d574d6c7277 /tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0/d/7ea645b957f87cd9df883e2d0f5205bf /tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0/d/da7cbbf95d8fa1f2b7ae34daa713bef3 /tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0/d/d1c362243563d8b2b92c1bcceee0ddaf /tmp/my-checkpoints/state/checkpoint_1/ocdbt.process_0/d/bcb90d6f6ab75dc2b0cfbf48b9301ed1
/Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. pid, fd = os.forkpty()
You can also produce a torch pickle based checkpoint by moving the state_dict to CPU
You can do so with
cpu_state_dict = jax.tree.map(lambda a: a.jax(), weights)
torch.save(cpu_state_dict, ckpt_dir / 'torch_checkpoint.pkl')
!ls /tmp/my-checkpoints/
state torch_checkpoint.pkl