torchax
  • What is torchax

User Guide

  • Get Started
  • How it works

Tutorials

  • Training a PyTorch model with JAX
  • Distributed arrays and automatic parallelization
    • Intro and a quick example
    • describes how array values are laid out in memory across devices
      • Sharding basics, and the subclass
    • Computation follows data sharding and is automatically parallelized
      • When explicit shardings disagree, JAX errors
    • Constraining shardings of intermediates in ted code
    • Examples: neural networks
      • 8-way batch data parallelism
      • 4-way batch data parallelism and 2-way model tensor parallelism
torchax
  • Tutorials
  • Distributed arrays and automatic parallelization

Distributed arrays and automatic parallelization¶

This tutorial is based on the JAX tutorial on distributed arrays: https://docs.jax.dev/en/latest/the-training-cookbook.html

Most Cell is one to one mapping to the JAX tutorial.

In [1]:
Copied!
!pip install termcolor
!pip install termcolor
Requirement already satisfied: termcolor in /Users/hanq/homebrew/Caskroom/miniconda/base/envs/py13/lib/python3.13/site-packages (3.1.0)
In [2]:
Copied!
from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp

jax.config.update('jax_num_cpu_devices', 8)
print(jax.devices())
from typing import Optional import numpy as np import jax import jax.numpy as jnp jax.config.update('jax_num_cpu_devices', 8) print(jax.devices())
[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
In [3]:
Copied!
import torchax as tx
tx.enable_globally()
import torchax.interop
import torch
import torchax as tx tx.enable_globally() import torchax.interop import torch
WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.
WARNING:root:Duplicate op registration for aten.__and__

⚠️ WARNING: The notebook requires 8 devices to run.

In [4]:
Copied!
if len(jax.local_devices()) < 8:
  raise Exception("Notebook requires 8 devices to run")
if len(jax.local_devices()) < 8: raise Exception("Notebook requires 8 devices to run")

Intro and a quick example¶

By reading this tutorial notebook, you'll learn about jax.Array, a unified datatype for representing arrays, even with physical storage spanning multiple devices. You'll also learn about how using jax.Arrays together with jax.jit can provide automatic compiler-based parallelization.

Before we think step by step, here's a quick example. First, we'll create a jax.Array sharded across multiple devices:

In [5]:
Copied!
from jax.sharding import PartitionSpec as P, NamedSharding
from jax.sharding import PartitionSpec as P, NamedSharding
In [6]:
Copied!
# Create a Sharding object to distribute a value across devices:
mesh = jax.make_mesh((4, 2), ('x', 'y'))
# Create a Sharding object to distribute a value across devices: mesh = jax.make_mesh((4, 2), ('x', 'y'))
In [7]:
Copied!
# Create an array of random values:
# x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = torch.randn((8192, 8192), device='jax')


# and use jax.device_put to distribute it across devices:
# apply_jax_ applies a function to the inner jax Array
y = x.apply_jax(jax.device_put, NamedSharding(mesh, P('x', 'y')))

# This line makes visualize_array_sharding into a torch function, so it can take torchax's Tensor
# instead of jax arrays
visualize_array_sharding = tx.interop.torch_view(jax.debug.visualize_array_sharding)

visualize_array_sharding(y)
# Create an array of random values: # x = jax.random.normal(jax.random.key(0), (8192, 8192)) x = torch.randn((8192, 8192), device='jax') # and use jax.device_put to distribute it across devices: # apply_jax_ applies a function to the inner jax Array y = x.apply_jax(jax.device_put, NamedSharding(mesh, P('x', 'y'))) # This line makes visualize_array_sharding into a torch function, so it can take torchax's Tensor # instead of jax arrays visualize_array_sharding = tx.interop.torch_view(jax.debug.visualize_array_sharding) visualize_array_sharding(y)
                        
   CPU 0       CPU 1    
                        
                        
   CPU 2       CPU 3    
                        
                        
   CPU 4       CPU 5    
                        
                        
   CPU 6       CPU 7    
                        

Next, we'll apply a computation to it and visualize how the result values are stored across multiple devices too:

In [8]:
Copied!
z = torch.sin(y)
visualize_array_sharding(z)
z = torch.sin(y) visualize_array_sharding(z)
                        
   CPU 0       CPU 1    
                        
                        
   CPU 2       CPU 3    
                        
                        
   CPU 4       CPU 5    
                        
                        
   CPU 6       CPU 7    
                        

The evaluation of the jnp.sin application was automatically parallelized across the devices on which the input values (and output values) are stored:

In [9]:
Copied!
# `x` is present on a single device
# .jax returns the inner jax array of the tensor
%timeit -n 5 -r 5 torch.sin(x).jax().block_until_ready()
# `x` is present on a single device # .jax returns the inner jax array of the tensor %timeit -n 5 -r 5 torch.sin(x).jax().block_until_ready()
144 ms ± 13.5 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
In [10]:
Copied!
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 torch.sin(y).jax().block_until_ready()
# `y` is sharded across 8 devices. %timeit -n 5 -r 5 torch.sin(y).jax().block_until_ready()
68.6 ms ± 1.25 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)

Now let's look at each of these pieces in more detail!

Sharding describes how array values are laid out in memory across devices¶

Sharding basics, and the NamedSharding subclass¶

To parallelize computation across multiple devices, we first must lay out input data across multiple devices.

In JAX, Sharding objects describe distributed memory layouts. They can be used with jax.device_put to produce a value with distributed layout.

For example, here's a value with a single-device Sharding:

In [11]:
Copied!
import jax
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = tx.interop.torch_view(x) # this is another way to create a torchax Tensor
import jax x = jax.random.normal(jax.random.key(0), (8192, 8192)) x = tx.interop.torch_view(x) # this is another way to create a torchax Tensor
In [12]:
Copied!
visualize_array_sharding(x)
visualize_array_sharding(x)
                         
                         
                         
                         
                         
          CPU 0          
                         
                         
                         
                         
                         

Here, we're using the jax.debug.visualize_array_sharding function to show where the value x is stored in memory. All of x is stored on a single device, so the visualization is pretty boring!

But we can shard x across multiple devices by using jax.device_put and a Sharding object. First, we make a numpy.ndarray of Devices using jax.make_mesh, which takes hardware topology into account for the Device order:

In [13]:
Copied!
from jax.sharding import Mesh, PartitionSpec, NamedSharding

P = PartitionSpec

mesh = jax.make_mesh((4, 2), ('a', 'b'))
y = x.apply_jax(jax.device_put, NamedSharding(mesh, P('a', 'b')))
visualize_array_sharding(y)
from jax.sharding import Mesh, PartitionSpec, NamedSharding P = PartitionSpec mesh = jax.make_mesh((4, 2), ('a', 'b')) y = x.apply_jax(jax.device_put, NamedSharding(mesh, P('a', 'b'))) visualize_array_sharding(y)
                        
   CPU 0       CPU 1    
                        
                        
   CPU 2       CPU 3    
                        
                        
   CPU 4       CPU 5    
                        
                        
   CPU 6       CPU 7    
                        

We can define a helper function to make things simpler:

In [14]:
Copied!
default_mesh = jax.make_mesh((4, 2), ('a', 'b'))

def mesh_sharding(
    pspec: PartitionSpec, mesh: Optional[Mesh] = None,
  ) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)
default_mesh = jax.make_mesh((4, 2), ('a', 'b')) def mesh_sharding( pspec: PartitionSpec, mesh: Optional[Mesh] = None, ) -> NamedSharding: if mesh is None: mesh = default_mesh return NamedSharding(mesh, pspec)
In [15]:
Copied!
y = x.apply_jax(jax.device_put, mesh_sharding(P('a', 'b')))
visualize_array_sharding(y)
y = x.apply_jax(jax.device_put, mesh_sharding(P('a', 'b'))) visualize_array_sharding(y)
                        
   CPU 0       CPU 1    
                        
                        
   CPU 2       CPU 3    
                        
                        
   CPU 4       CPU 5    
                        
                        
   CPU 6       CPU 7    
                        

Here, we use P('a', 'b') to express that the first and second axes of x should be sharded over the device mesh axes 'a' and 'b', respectively. We can easily switch to P('b', 'a') to shard the axes of x over different devices:

In [16]:
Copied!
y = x.apply_jax(jax.device_put, mesh_sharding(P('b', 'a')))
visualize_array_sharding(y)
y = x.apply_jax(jax.device_put, mesh_sharding(P('b', 'a'))) visualize_array_sharding(y)
                                    
                                    
  CPU 0    CPU 2    CPU 4    CPU 6  
                                    
                                    
                                    
                                    
                                    
  CPU 1    CPU 3    CPU 5    CPU 7  
                                    
                                    
                                    
In [17]:
Copied!
# This `None` means that `x` is not sharded on its second dimension,
# and since the Mesh axis name 'b' is not mentioned, shards are
# replicated across it.
y = x.apply_jax(jax.device_put, mesh_sharding(P('a', None)))
visualize_array_sharding(y)
# This `None` means that `x` is not sharded on its second dimension, # and since the Mesh axis name 'b' is not mentioned, shards are # replicated across it. y = x.apply_jax(jax.device_put, mesh_sharding(P('a', None))) visualize_array_sharding(y)
                         
         CPU 0,1         
                         
                         
         CPU 2,3         
                         
                         
         CPU 4,5         
                         
                         
         CPU 6,7         
                         

Here, because P('a', None) doesn't mention the Mesh axis name 'b', we get replication over the axis 'b'. The None here is just acting as a placeholder to line up against the second axis of the value x, without expressing sharding over any mesh axis. (As a shorthand, trailing Nones can be omitted, so that P('a', None) means the same thing as P('a'). But it doesn't hurt to be explicit!)

To shard only over the second axis of x, we can use a None placeholder in the PartitionSpec:

In [18]:
Copied!
y = x.apply_jax(jax.device_put, mesh_sharding(P(None, 'b')))
visualize_array_sharding(y)
y = x.apply_jax(jax.device_put, mesh_sharding(P(None, 'b'))) visualize_array_sharding(y)
                        
                        
                        
                        
                        
CPU 0,2,4,6 CPU 1,3,5,7 
                        
                        
                        
                        
                        
In [19]:
Copied!
y = x.apply_jax(jax.device_put, mesh_sharding(P(None, 'a')))
visualize_array_sharding(y)
y = x.apply_jax(jax.device_put, mesh_sharding(P(None, 'a'))) visualize_array_sharding(y)
                                    
                                    
                                    
                                    
                                    
 CPU 0,1  CPU 2,3  CPU 4,5  CPU 6,7 
                                    
                                    
                                    
                                    
                                    

For a fixed mesh, we can even partition one logical axis of x over multiple device mesh axes:

In [20]:
Copied!
y = x.apply_jax(jax.device_put, mesh_sharding(P(('a', 'b'), None)))
visualize_array_sharding(y)
y = x.apply_jax(jax.device_put, mesh_sharding(P(('a', 'b'), None))) visualize_array_sharding(y)
          CPU 0          
                         
          CPU 1          
                         
          CPU 2          
                         
          CPU 3          
                         
          CPU 4          
                         
          CPU 5          
                         
          CPU 6          
                         
          CPU 7          
                         

Using NamedSharding makes it easy to define a device mesh once and give its axes names, then just refer to those names in PartitionSpecs for each device_put as needed.

Computation follows data sharding and is automatically parallelized¶

With sharded input data, the compiler can give us parallel computation. In particular, functions decorated with jax.jit can operate over sharded arrays without copying data onto a single device. Instead, computation follows sharding: based on the sharding of the input data, the compiler decides shardings for intermediates and output values, and parallelizes their evaluation, even inserting communication operations as necessary.

For example, the simplest computation is an elementwise one:

In [21]:
Copied!
mesh = jax.make_mesh((4, 2), ('a', 'b'))
mesh = jax.make_mesh((4, 2), ('a', 'b'))
In [22]:
Copied!
# apply_jax_ is like apply_jax but is inplace (reuses the python reference of tensor wrapper)
x.apply_jax_(jax.device_put, NamedSharding(mesh, P('a', 'b')))
print('input sharding:')
visualize_array_sharding(x)

y = torch.sin(x)
print('output sharding:')
visualize_array_sharding(y)
# apply_jax_ is like apply_jax but is inplace (reuses the python reference of tensor wrapper) x.apply_jax_(jax.device_put, NamedSharding(mesh, P('a', 'b'))) print('input sharding:') visualize_array_sharding(x) y = torch.sin(x) print('output sharding:') visualize_array_sharding(y)
input sharding:
                        
   CPU 0       CPU 1    
                        
                        
   CPU 2       CPU 3    
                        
                        
   CPU 4       CPU 5    
                        
                        
   CPU 6       CPU 7    
                        
output sharding:
                        
   CPU 0       CPU 1    
                        
                        
   CPU 2       CPU 3    
                        
                        
   CPU 4       CPU 5    
                        
                        
   CPU 6       CPU 7    
                        

Here for the elementwise operation torch.sin the compiler chose the output sharding to be the same as the input. Moreover, the compiler automatically parallelized the computation, so that each device computed its output shard from its input shard in parallel.

In other words, even though we wrote the torch.sin computation as if a single machine were to execute it, the compiler splits up the computation for us and executes it on multiple devices.

We can do the same for more than just elementwise operations too. Consider a matrix multiplication with sharded inputs:

In [23]:
Copied!
y = x.apply_jax(jax.device_put, NamedSharding(mesh, P('a', None)))
z = x.apply_jax(jax.device_put, NamedSharding(mesh, P(None, 'b')))
print('lhs sharding:')
visualize_array_sharding(y)
print('rhs sharding:')
visualize_array_sharding(z)

w = torch.matmul(y, z)
print('out sharding:')
visualize_array_sharding(w)
y = x.apply_jax(jax.device_put, NamedSharding(mesh, P('a', None))) z = x.apply_jax(jax.device_put, NamedSharding(mesh, P(None, 'b'))) print('lhs sharding:') visualize_array_sharding(y) print('rhs sharding:') visualize_array_sharding(z) w = torch.matmul(y, z) print('out sharding:') visualize_array_sharding(w)
lhs sharding:
                         
         CPU 0,1         
                         
                         
         CPU 2,3         
                         
                         
         CPU 4,5         
                         
                         
         CPU 6,7         
                         
rhs sharding:
                        
                        
                        
                        
                        
CPU 0,2,4,6 CPU 1,3,5,7 
                        
                        
                        
                        
                        
out sharding:
                        
   CPU 0       CPU 1    
                        
                        
   CPU 2       CPU 3    
                        
                        
   CPU 4       CPU 5    
                        
                        
   CPU 6       CPU 7    
                        

Here the compiler chose the output sharding so that it could maximally parallelize the computation: without needing communication, each device already has the input shards it needs to compute its output shard.

How can we be sure it's actually running in parallel? We can do a simple timing experiment:

In [24]:
Copied!
x_single = x.apply_jax(jax.device_put, jax.devices()[0])
visualize_array_sharding(x_single)
x_single = x.apply_jax(jax.device_put, jax.devices()[0]) visualize_array_sharding(x_single)
                         
                         
                         
                         
                         
          CPU 0          
                         
                         
                         
                         
                         
In [25]:
Copied!
torch.allclose(torch.matmul(x_single, x_single).cpu(),
            torch.matmul(y, z).cpu())
torch.allclose(torch.matmul(x_single, x_single).cpu(), torch.matmul(y, z).cpu())
/Users/hanq/git/qihqi/torchax/torchax/ops/mappings.py:83: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:209.)
  res = torch.from_numpy(numpy.asarray(x))
Out[25]:
True
In [26]:
Copied!
%timeit -n 5 -r 5 torch.matmul(x_single, x_single).jax().block_until_ready()
%timeit -n 5 -r 5 torch.matmul(x_single, x_single).jax().block_until_ready()
1.89 s ± 20.5 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
In [27]:
Copied!
%timeit -n 5 -r 5 torch.matmul(y, z).apply_jax(jax.block_until_ready)
%timeit -n 5 -r 5 torch.matmul(y, z).apply_jax(jax.block_until_ready)
1.93 s ± 120 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)

Even copying a sharded Array produces a result with the sharding of the input:

In [28]:
Copied!
w_copy = w.apply_jax(jnp.copy)
visualize_array_sharding(w_copy)
w_copy = w.apply_jax(jnp.copy) visualize_array_sharding(w_copy)
                        
   CPU 0       CPU 1    
                        
                        
   CPU 2       CPU 3    
                        
                        
   CPU 4       CPU 5    
                        
                        
   CPU 6       CPU 7    
                        

So computation follows data placement: when we explicitly shard data with jax.device_put, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of JAX's policy of following explicit device placement.

When explicit shardings disagree, JAX errors¶

But what if two arguments to a computation are explicitly placed on different sets of devices, or with incompatible device orders? In these ambiguous cases, an error is raised:

In [29]:
Copied!
import textwrap
from termcolor import colored

def print_exception(e):
  name = colored(f'{type(e).__name__}', 'red', force_color=True)
  print(textwrap.fill(f'{name}: {str(e)}'))
import textwrap from termcolor import colored def print_exception(e): name = colored(f'{type(e).__name__}', 'red', force_color=True) print(textwrap.fill(f'{name}: {str(e)}'))
In [30]:
Copied!
sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x'))
sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x'))

y = x.apply_jax(jax.device_put, sharding1)
z = x.apply_jax(jax.device_put, sharding2)
try: y + z
except ValueError as e: print_exception(e)
sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x')) sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x')) y = x.apply_jax(jax.device_put, sharding1) z = x.apply_jax(jax.device_put, sharding2) try: y + z except ValueError as e: print_exception(e)
ValueError: Received incompatible devices for jitted
computation. Got argument x of add with shape float32[8192,8192] and
device ids [0, 1, 2, 3] on platform CPU and argument y of add with
shape float32[8192,8192] and device ids [4, 5, 6, 7] on platform CPU
In [31]:
Copied!
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]

sharding1 = NamedSharding(Mesh(devices, 'x'), P('x'))
sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x'))

y = x.apply_jax(jax.device_put, sharding1)
z = x.apply_jax(jax.device_put, sharding2)
try: y + z
except ValueError as e: print_exception(e)
devices = jax.devices() permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]] sharding1 = NamedSharding(Mesh(devices, 'x'), P('x')) sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x')) y = x.apply_jax(jax.device_put, sharding1) z = x.apply_jax(jax.device_put, sharding2) try: y + z except ValueError as e: print_exception(e)
ValueError: Received incompatible devices for jitted
computation. Got argument x of add with shape float32[8192,8192] and
device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform CPU and argument y of
add with shape float32[8192,8192] and device ids [0, 1, 2, 3, 6, 7, 4,
5] on platform CPU

We say arrays that have been explicitly placed or sharded with jax.device_put are committed to their device(s), and so won't be automatically moved. See the device placement FAQ for more information.

When arrays are not explicitly placed or sharded with jax.device_put, they are placed uncommitted on the default device. Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices.

For example, the output of jnp.zeros, jnp.arange, and jnp.array are uncommitted:

In [32]:
Copied!
y = x.apply_jax(jax.device_put, sharding1)
y + torch.ones_like(y)
y + torch.arange(0, y.nelement(), device='jax').reshape(y.shape)
print('no error!')
y = x.apply_jax(jax.device_put, sharding1) y + torch.ones_like(y) y + torch.arange(0, y.nelement(), device='jax').reshape(y.shape) print('no error!')
no error!

Constraining shardings of intermediates in jitted code¶

While the compiler will attempt to decide how a function's intermediate values and outputs should be sharded, we can also give it hints using jax.lax.with_sharding_constraint. Using jax.lax.with_sharding_constraint is much like jax.device_put, except we use it inside staged-out (i.e. jit-decorated) functions:

In [33]:
Copied!
mesh = jax.make_mesh((4, 2), ('x', 'y'))
mesh = jax.make_mesh((4, 2), ('x', 'y'))
In [34]:
Copied!
x = torch.randn((8192, 8192), device='jax')
x = torch.randn((8192, 8192), device='jax')
In [35]:
Copied!
@tx.interop.jax_jit
def f(x):
  x = x + 1
  #y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x')))
  # x.shard_ is a short hand for x.apply_jax_(jax.lax.with_sharding_constraint,...)
  x.shard_(NamedSharding(mesh, P('y', 'x')))
  return x
@tx.interop.jax_jit def f(x): x = x + 1 #y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x'))) # x.shard_ is a short hand for x.apply_jax_(jax.lax.with_sharding_constraint,...) x.shard_(NamedSharding(mesh, P('y', 'x'))) return x
In [36]:
Copied!
visualize_array_sharding(x)
y = f(x)
visualize_array_sharding(y)
visualize_array_sharding(x) y = f(x) visualize_array_sharding(y)
                         
                         
                         
                         
                         
          CPU 0          
                         
                         
                         
                         
                         
                                    
                                    
  CPU 0    CPU 2    CPU 4    CPU 6  
                                    
                                    
                                    
                                    
                                    
  CPU 1    CPU 3    CPU 5    CPU 7  
                                    
                                    
                                    
In [37]:
Copied!
@tx.interop.jax_jit
def f(x):
  x = x + 1
  y = x.shard_(NamedSharding(mesh, P()))
  return y
@tx.interop.jax_jit def f(x): x = x + 1 y = x.shard_(NamedSharding(mesh, P())) return y
In [38]:
Copied!
visualize_array_sharding(x)
y = f(x)
visualize_array_sharding(y)
visualize_array_sharding(x) y = f(x) visualize_array_sharding(y)
                         
                         
                         
                         
                         
          CPU 0          
                         
                         
                         
                         
                         

By adding with_sharding_constraint, we've constrained the sharding of the output. In addition to respecting the annotation on a particular intermediate, the compiler will use annotations to decide shardings for other values.

It's often a good practice to annotate the outputs of computations, for example based on how the values are ultimately consumed.

Examples: neural networks¶

⚠️ WARNING: The following is meant to be a simple demonstration of automatic sharding propagation with jax.Array, but it may not reflect best practices for real examples. For instance, real examples may require more use of with_sharding_constraint.

We can use jax.device_put and jax.jit's computation-follows-sharding features to parallelize computation in neural networks. Here are some simple examples, based on this basic neural network:

In [44]:
Copied!
import jax
import jax.numpy as jnp
import jax import jax.numpy as jnp
In [45]:
Copied!
# def predict(params, inputs):
#   for W, b in params:
#     outputs = jnp.dot(inputs, W) + b
#     inputs = jnp.maximum(outputs, 0)
#   return outputs

# def loss(params, batch):
#   inputs, targets = batch
#   predictions = predict(params, inputs)
#   return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))

class Model(torch.nn.Module):

    def __init__(self, layer_sizes):
        super().__init__()

        self.layers = torch.nn.ModuleList(
            [torch.nn.Linear(in_, out)
             for in_, out in zip(layer_sizes[:-1], layer_sizes[1:])])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
            x = torch.maximum(x, torch.tensor(0, device=x.device))
        return x
# def predict(params, inputs): # for W, b in params: # outputs = jnp.dot(inputs, W) + b # inputs = jnp.maximum(outputs, 0) # return outputs # def loss(params, batch): # inputs, targets = batch # predictions = predict(params, inputs) # return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1)) class Model(torch.nn.Module): def __init__(self, layer_sizes): super().__init__() self.layers = torch.nn.ModuleList( [torch.nn.Linear(in_, out) for in_, out in zip(layer_sizes[:-1], layer_sizes[1:])]) def forward(self, x): for layer in self.layers: x = layer(x) x = torch.maximum(x, torch.tensor(0, device=x.device)) return x
In [46]:
Copied!
# loss_jit = jax.jit(loss)
# gradfun = jax.jit(jax.grad(loss))
# loss_jit = jax.jit(loss) # gradfun = jax.jit(jax.grad(loss))
In [47]:
Copied!
# def init_layer(key, n_in, n_out):
#   k1, k2 = jax.random.split(key)
#   W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
#   b = jax.random.normal(k2, (n_out,))
#   return W, b

# def init_model(key, layer_sizes, batch_size):
#   key, *keys = jax.random.split(key, len(layer_sizes))
#   params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

#   key, *keys = jax.random.split(key, 3)
#   inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
#   targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

#   return params, (inputs, targets)

layer_sizes = [784, 1024, 1024, 1024, 10]
batch_size = 8192

# params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
# def init_layer(key, n_in, n_out): # k1, k2 = jax.random.split(key) # W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in) # b = jax.random.normal(k2, (n_out,)) # return W, b # def init_model(key, layer_sizes, batch_size): # key, *keys = jax.random.split(key, len(layer_sizes)) # params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:])) # key, *keys = jax.random.split(key, 3) # inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0])) # targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1])) # return params, (inputs, targets) layer_sizes = [784, 1024, 1024, 1024, 10] batch_size = 8192 # params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
In [48]:
Copied!
model = Model(layer_sizes)
model.to('jax')
inputs = torch.randn((batch_size, layer_sizes[0]), device='jax')
model = Model(layer_sizes) model.to('jax') inputs = torch.randn((batch_size, layer_sizes[0]), device='jax')
In [49]:
Copied!
model(inputs)
model(inputs)
/Users/hanq/git/qihqi/torchax/torchax/ops/jtorch.py:71: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return jnp.array(
Out[49]:
Tensor(<class 'jaxlib._jax.ArrayImpl'> [[0.         0.01995892 0.04525738 ... 0.10564147 0.01692204 0.        ]
 [0.         0.         0.02429873 ... 0.02905991 0.         0.02695976]
 [0.         0.00898021 0.0093673  ... 0.05738227 0.02055857 0.02600054]
 ...
 [0.         0.03838623 0.0349848  ... 0.08870234 0.0068823  0.        ]
 [0.         0.         0.         ... 0.07888345 0.03313747 0.        ]
 [0.         0.         0.         ... 0.05596108 0.01192383 0.        ]])
In [50]:
Copied!
params = model.state_dict() 

# make a function that takes weight as input
def pure_model_fun(weights, inputs):
    return torch.func.functional_call(model, weights, args=(inputs, ))

torch_loss = torch.nn.MSELoss()

def loss_fun(weight, batch):
    inputs, label = batch
    res = pure_model_fun(weight, inputs)
    return torch_loss(res, label)


grad_fn = tx.interop.jax_value_and_grad(loss_fun)
grad_fn_jit = tx.interop.jax_jit(grad_fn)
loss_jit = tx.interop.jax_jit(loss_fun)
params = model.state_dict() # make a function that takes weight as input def pure_model_fun(weights, inputs): return torch.func.functional_call(model, weights, args=(inputs, )) torch_loss = torch.nn.MSELoss() def loss_fun(weight, batch): inputs, label = batch res = pure_model_fun(weight, inputs) return torch_loss(res, label) grad_fn = tx.interop.jax_value_and_grad(loss_fun) grad_fn_jit = tx.interop.jax_jit(grad_fn) loss_jit = tx.interop.jax_jit(loss_fun)

8-way batch data parallelism¶

In [51]:
Copied!
mesh = jax.make_mesh((8,), ('batch',))
mesh = jax.make_mesh((8,), ('batch',))
In [52]:
Copied!
sharding = NamedSharding(mesh, P('batch'))
replicated_sharding = NamedSharding(mesh, P())
sharding = NamedSharding(mesh, P('batch')) replicated_sharding = NamedSharding(mesh, P())
In [53]:
Copied!
batch = torch.randn((batch_size, layer_sizes[0]), device='jax'), torch.randn((batch_size, layer_sizes[-1]), device='jax')

# jax device put also works on pytrees
jax_device_put = tx.interop.torch_view(jax.device_put)

batch = jax_device_put(batch, sharding)
params = jax_device_put(params, replicated_sharding)
batch = torch.randn((batch_size, layer_sizes[0]), device='jax'), torch.randn((batch_size, layer_sizes[-1]), device='jax') # jax device put also works on pytrees jax_device_put = tx.interop.torch_view(jax.device_put) batch = jax_device_put(batch, sharding) params = jax_device_put(params, replicated_sharding)
In [54]:
Copied!
loss_jit(params, batch)
loss_jit(params, batch)
Out[54]:
Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034581)
In [55]:
Copied!
step_size = 1e-3

import optax

optimizer = optax.sgd(step_size)

opt_state = tx.interop.call_jax(optimizer.init, params)


for i in range(5):
  loss, grads = grad_fn_jit(params, batch)
  updates, opt_state = tx.interop.call_jax(optimizer.update, grads, opt_state)
  params = tx.interop.call_jax(optax.apply_updates, params, updates)
  print(i, 'loss is', loss)

print(loss_jit(params, batch))
step_size = 1e-3 import optax optimizer = optax.sgd(step_size) opt_state = tx.interop.call_jax(optimizer.init, params) for i in range(5): loss, grads = grad_fn_jit(params, batch) updates, opt_state = tx.interop.call_jax(optimizer.update, grads, opt_state) params = tx.interop.call_jax(optax.apply_updates, params, updates) print(i, 'loss is', loss) print(loss_jit(params, batch))
0 loss is Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034581)
1 loss is Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034552)
2 loss is Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034525)
3 loss is Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034498)
4 loss is Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034469)
Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034441)
In [56]:
Copied!
%timeit -n 5 -r 5 grad_fn_jit(params, batch)[0].jax().block_until_ready()
%timeit -n 5 -r 5 grad_fn_jit(params, batch)[0].jax().block_until_ready()
263 ms ± 7.32 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
In [57]:
Copied!
batch_single = jax_device_put(batch, jax.devices()[0])
params_single = jax_device_put(params, jax.devices()[0])
batch_single = jax_device_put(batch, jax.devices()[0]) params_single = jax_device_put(params, jax.devices()[0])
In [58]:
Copied!
%timeit -n 5 -r 5 grad_fn_jit(params_single, batch_single)[0].jax().block_until_ready()
%timeit -n 5 -r 5 grad_fn_jit(params_single, batch_single)[0].jax().block_until_ready()
248 ms ± 6.91 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)

4-way batch data parallelism and 2-way model tensor parallelism¶

In [59]:
Copied!
mesh = jax.make_mesh((4, 2), ('batch', 'model'))
mesh = jax.make_mesh((4, 2), ('batch', 'model'))
In [60]:
Copied!
batch = jax_device_put(batch, NamedSharding(mesh, P('batch', None)))
visualize_array_sharding(batch[0])
visualize_array_sharding(batch[1])
batch = jax_device_put(batch, NamedSharding(mesh, P('batch', None))) visualize_array_sharding(batch[0]) visualize_array_sharding(batch[1])
         
 CPU 0,1 
         
         
 CPU 2,3 
         
         
 CPU 4,5 
         
         
 CPU 6,7 
         
         
 CPU 0,1 
         
         
 CPU 2,3 
         
         
 CPU 4,5 
         
         
 CPU 6,7 
         
In [61]:
Copied!
replicated_sharding = NamedSharding(mesh, P())
replicated_sharding = NamedSharding(mesh, P())
In [62]:
Copied!
params.keys()
params.keys()
Out[62]:
odict_keys(['layers.0.weight', 'layers.0.bias', 'layers.1.weight', 'layers.1.bias', 'layers.2.weight', 'layers.2.bias', 'layers.3.weight', 'layers.3.bias'])
In [63]:
Copied!
# (W1, b1), (W2, b2), (W3, b3), (W4, b4) = params

# W1 = jax.device_put(W1, replicated_sharding)
# b1 = jax.device_put(b1, replicated_sharding)

# W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))
# b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))

# W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))
# b3 = jax.device_put(b3, replicated_sharding)

# W4 = jax.device_put(W4, replicated_sharding)
# b4 = jax.device_put(b4, replicated_sharding)

# params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)

name_to_sharding = {
    'layers.0.weight': replicated_sharding, 
    'layers.0.bias': replicated_sharding, 
    'layers.1.weight': NamedSharding(mesh, P('model')), # column parallel 
    'layers.1.bias': NamedSharding(mesh, P('model')), 
    'layers.2.weight': NamedSharding(mesh, P(None, 'model')),
    'layers.2.bias': replicated_sharding, 
    'layers.3.weight': replicated_sharding, 
    'layers.3.bias': replicated_sharding
}

for name, tensor in params.items():
    tensor.apply_jax_(jax.device_put, name_to_sharding[name])
    
# (W1, b1), (W2, b2), (W3, b3), (W4, b4) = params # W1 = jax.device_put(W1, replicated_sharding) # b1 = jax.device_put(b1, replicated_sharding) # W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model'))) # b2 = jax.device_put(b2, NamedSharding(mesh, P('model'))) # W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None))) # b3 = jax.device_put(b3, replicated_sharding) # W4 = jax.device_put(W4, replicated_sharding) # b4 = jax.device_put(b4, replicated_sharding) # params = (W1, b1), (W2, b2), (W3, b3), (W4, b4) name_to_sharding = { 'layers.0.weight': replicated_sharding, 'layers.0.bias': replicated_sharding, 'layers.1.weight': NamedSharding(mesh, P('model')), # column parallel 'layers.1.bias': NamedSharding(mesh, P('model')), 'layers.2.weight': NamedSharding(mesh, P(None, 'model')), 'layers.2.bias': replicated_sharding, 'layers.3.weight': replicated_sharding, 'layers.3.bias': replicated_sharding } for name, tensor in params.items(): tensor.apply_jax_(jax.device_put, name_to_sharding[name])
In [64]:
Copied!
visualize_array_sharding(params['layers.1.weight'])
visualize_array_sharding(params['layers.1.weight'])
                         
                         
       CPU 0,2,4,6       
                         
                         
                         
                         
                         
       CPU 1,3,5,7       
                         
                         
                         
In [65]:
Copied!
visualize_array_sharding(params['layers.2.weight'])
visualize_array_sharding(params['layers.2.weight'])
                        
                        
                        
                        
                        
CPU 0,2,4,6 CPU 1,3,5,7 
                        
                        
                        
                        
                        
In [66]:
Copied!
print(loss_jit(params, batch))
print(loss_jit(params, batch))
Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034441)
In [67]:
Copied!
step_size = 1e-3

import optax

optimizer = optax.sgd(step_size)

opt_state = tx.interop.call_jax(optimizer.init, params)


for i in range(5):
  loss, grads = grad_fn_jit(params, batch)
  updates, opt_state = tx.interop.call_jax(optimizer.update, grads, opt_state)
  params = tx.interop.call_jax(optax.apply_updates, params, updates)
  print(i, 'loss is', loss)

print(loss_jit(params, batch))
step_size = 1e-3 import optax optimizer = optax.sgd(step_size) opt_state = tx.interop.call_jax(optimizer.init, params) for i in range(5): loss, grads = grad_fn_jit(params, batch) updates, opt_state = tx.interop.call_jax(optimizer.update, grads, opt_state) params = tx.interop.call_jax(optax.apply_updates, params, updates) print(i, 'loss is', loss) print(loss_jit(params, batch))
0 loss is Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034441)
1 loss is Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034413)
2 loss is Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034385)
3 loss is Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034357)
4 loss is Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034331)
Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034302)
In [68]:
Copied!
print(loss_jit(params, batch))
print(loss_jit(params, batch))
Tensor(<class 'jaxlib._jax.ArrayImpl'> 1.0034302)
In [69]:
Copied!
visualize_array_sharding(params['layers.1.weight'])
visualize_array_sharding(params['layers.2.weight'])
visualize_array_sharding(params['layers.1.weight']) visualize_array_sharding(params['layers.2.weight'])
                         
                         
       CPU 0,2,4,6       
                         
                         
                         
                         
                         
       CPU 1,3,5,7       
                         
                         
                         
                        
                        
                        
                        
                        
CPU 0,2,4,6 CPU 1,3,5,7 
                        
                        
                        
                        
                        
In [70]:
Copied!
%timeit -n 5 -r 5 grad_fn_jit(params, batch)[0].jax().block_until_ready()
%timeit -n 5 -r 5 grad_fn_jit(params, batch)[0].jax().block_until_ready()
320 ms ± 8.59 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
Previous

Built with MkDocs using a theme provided by Read the Docs.
« Previous