Skip to content

Optimizing SNR, CPA, and template attacks performance using JAX and sedpack

In the previous SNR tutorial we have seen that JAX could improve execution speed. The question of this tutorial is just how far these improvements can go. We start with revisiting SNR. Then we focus on CPA and template attacks and explain why there is even more room for potential improvements. We give the reader the necessary background knowledge of GPU to be able to argue when these accelerators are likely to provide a speed-up. After finishing this tutorial you should be able to implement and optimize additional methods based on your needs (e.g., higher order attacks). For the side channel attack background we give relevant links to tutorials and assume some proficiency there.

We note that the idea to use GPU acceleration to speed up CPA or template attacks is by far not new. See for instance GPU Assisted Side-Channel-Evaluation and First-Order and Higher-Order Power Analysis: Computational Approaches and Aspects.

You can find the resulting Python scripts in the SCA tutorials directory on GitHub.

For our experiments we will use Quickstart: How to think in JAX which is a NumPy inspired library for array manipulation. Two advantages are high performance array operations and automatic differentiation. For this tutorial we will heavily use the former. Please check the installation instructions and the list of JAX — supported platforms which includes NVidia CUDA, AMD GPU ROCm, TPU, Intel GPU, and CPU.

We will use the Sedpack - Scalable and efficient data packing dataset format.

Building on the previous SNR tutorial the winner code was running in 15 seconds on our tinyAES dataset and was as follows:

# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
# Welford's algorithm
@jax.jit
def jax_update(existing_aggregate, new_trace):
"""For a given value `new_trace`, compute the new `count`, new `mean`, and
new `squared_deltas`. The variables have the following meaning:
- `mean` accumulates the mean of the entire dataset,
- `squared_deltas` aggregates the squared distance from the mean,
- `count` aggregates the number of samples seen so far.
"""
(count, mean, squared_deltas) = existing_aggregate
count += 1
delta = new_trace - mean
mean += delta / count
updated_delta = new_trace - mean
squared_deltas += delta * updated_delta
return (count, mean, squared_deltas)
def jax_get_initial_aggregate(trace_len: int):
dtype = jnp.float32
count = jnp.array(0, dtype=dtype)
mean = jnp.zeros(trace_len, dtype=dtype)
squared_deltas = jnp.zeros(trace_len, dtype=dtype)
return (count, mean, squared_deltas)
def jax_finalize(existing_aggregate):
"""Retrieve the mean and variance from an aggregate.
"""
(count, mean, squared_deltas) = existing_aggregate
assert count >= 2
(mean, variance) = (mean, squared_deltas / count)
return (mean, variance)

And the SNR loop as follows:

def snr_jax(dataset_path: Path, ap_name: str) -> None:
"""Compute SNR using NumPy.
"""
# Load the dataset
dataset = Dataset(dataset_path)
# We know that trace1 is the first.
trace_len: int = dataset.dataset_structure.saved_data_description[0].shape[0]
leakage_to_aggregate = {
i: jax_get_initial_aggregate(trace_len=trace_len) for i in range(9)
}
split = "test"
for example in tqdm(
dataset.as_numpy_iterator(
split=split,
repeat=False,
shuffle=0,
),
desc=f"[JAX] Computing SNR over {split}",
total=dataset._dataset_info.splits[split].number_of_examples,
):
current_leakage = int(example[ap_name][0]).bit_count()
leakage_to_aggregate[current_leakage] = jax_update(leakage_to_aggregate[current_leakage], example["trace1"],)
results = {
leakage: jax_finalize(aggregate) for leakage, aggregate in leakage_to_aggregate.items()
}
# Find out which class is the most common.
most_common_leakage = 0
most_common_count = 0
for leakage, (count, _mean, _squared_deltas) in leakage_to_aggregate.items():
if count >= most_common_count:
most_common_leakage = leakage
most_common_count = count
signals = np.array([mean for mean, _variance in results.values()])
return 20 * np.log(np.var(signals, axis=0) / results[most_common_leakage][1])

Recently we implemented Rust bindings for the FlatBuffers shard file type. These result in speed improvements (and we plan to continue this trend).

Just by changing the iteration from as_numpy_iterator to as_numpy_iterator_rust we get from 15s to 6s. This is thanks to a recent improvement in the sedpack library. The strategy of writing modular software and guiding the improvements through profiling allows us to benefit from library improvements (here in sedpack, later in JAX). We could have chosen the way of writing everything in C++ (or any other low-level language) and optimizing each and every piece. In theory that would give us the best performance. However moving our computation on a GPU would be much harder. Here we tell the story of how our modular optimization was done.

A side note on the API naming: we plan to simplify the user-facing API letting as_numpy_iterator choose the most optimal implementation automatically. Either with lower-level functions being available or slowly deprecated before version 1.0.

CPU is usually faster for single core processing. On the other hand GPUs excel in the scenario when we need to apply the same operation to a large amount of data, e.g., computing with matrices. The often overlooked aspect of using GPUs is the cost of memory transfers from CPU or RAM to the dedicated GPU memory. The situation is a little different with integrated GPUs which share RAM with the CPU. Some of the optimizations mentioned in the following text will definitely benefit integrated GPUs but we do not have an integrated GPU to run our tests. If you are interested in running such experiments feel free to let us know your results.

How do we tackle the communication overhead? One way of dealing with it is operation fusion and smart placement of data. Imagine we want to compute the Euclidean norm of a vector. The naive way to do this is to:

## Very naive Euclidean norm computation:
# load our array to GPU memory
# square each element
# retrieve squared results to RAM
# load it back to GPU memory
# compute sum
# retrieve back single number from GPU memory
# compute square root (on CPU)

Ideally we would not move data there and back and even better we would compute the sum right after the squaring. The creators of JAX could have added another operation for Euclidean norm, then another couple for activation functions (see for instance Activation functions), etc. The more scalable approach is compiling the Python code using the OpenXLA Project. JAX does this and much more for us automatically allowing us to focus on writing Python code.

Equipped with this knowledge we might not expect that large speedup for SNR since the amount of computation is asymptotically the same as the size of data transferred to GPU. The situation will be different for CPA and template attacks since the computation is larger than the amount of data. But we try our best with SNR before moving forward.

Batching examples is another way to combat the cost of memory transfers. The idea is to transfer larger amounts of memory at once. This way the overhead connected to the transfer itself is amortized. We explore this in the following sub-section.

If you feel curious feel free to see the documentation for your accelerator of choice. Let us name just a couple of resources by JAX: How to Think About TPUs | How To Scale Your Model and How to Think About GPUs | How To Scale Your Model and by Nvidia: An Even Easier Introduction to CUDA (Updated) | NVIDIA Technical Blog and CUDA C++ Programming Guide.

The sedpack library gives us as_numpy_iterator_rust_batched (again API is subject to renaming). To be used as follows:

for batch in dataset.as_numpy_iterator_rust_batched(
split=split,
repeat=False,
shuffle=0,
batch_size=batch_size,
):
print(batch)

The output of the previous code is a dictionary of the form {attribute_name: batched_values} where the first dimension is the batch dimension. Thus if our trace length was 80,000 and our batch size is 100 then batch[“trace1”] is a NumPy array of shape (100, 80_000). Similarly for other attributes. The only exception is the last batch in the case when batch size does not divide the total number of examples.

Our choice of representation – one aggregate per hypothesis value – is not well suited for batch processing on GPU. In NumPy we would just add a dimension corresponding to the possible hypothesis values, index, and update. For JAX to offer the capabilities it does, its designers had to make certain decisions. The ones we meet now are in-place updates. Luckily there are well documented mechanisms to work around these limitations The Sharp Bits 🔪 — JAX documentation.

import numpy as np
import jax.numpy as jnp
np_array = np.arange(10)
np_array[3] += 2
print(np_array) # [0 1 2 5 4 5 6 7 8 9]
jax_array = jnp.arange(10)
# jax_array[3] += 2 # Illegal expression -> TypeError
jax_array = jax_array.at[3].add(2)
print(jax_array) # [0 1 2 5 4 5 6 7 8 9]

We have already seen Just-in-time compilation — JAX documentation in action where it made our code run faster. From implementing neural networks you might remember the Automatic differentiation — JAX documentation. The last one is Automatic vectorization — JAX documentation which makes code run on higher dimensional arguments. The beauty of these and other JAX operations is that they compose – that is one can for example jit a vmap-ed function and then take its gradient.

For our batching example we will use the jax.lax.scan function. The reason being that we need to update our intermediate results one example at a time. The function scan expects a function taking two parameters, the first is the carry information and the second is the element along the leading axis. This function returns the next carry value and the immediate state.

# A toy example of using jax.lax.scan to compute Euclidean norm.
import jax
import jax.numpy as jnp
def f(carry: jnp.float32, x: jnp.float32) -> (jnp.float32, jnp.float32):
x_squared = x ** 2
carry += x_squared
return (carry, x_squared)
@jax.jit
def euclidean_norm(vec) -> jnp.float32:
norm_squared, _squared_coordinates = jax.lax.scan(
f, # the function
0.0, # the initial carry
vec, # scan over the first dimension of this vector
)
# We drop _squared_coordinates
return jnp.sqrt(norm_squared)
vec = jnp.arange(10)
print(euclidean_norm(vec)) # 16.881943
# Make sure the computation is correct
v = np.array(vec)
assert euclidean_norm(vec) == np.sqrt((v ** 2).sum())

Can you see the similarity between an RNN layer, the recurrent cell, last state, and return state of the Base RNN layer?

Pytree intermezzo to keep the code more readable

Section titled “Pytree intermezzo to keep the code more readable”

We have seen that jax.lax.scan takes a function of two parameters where we want to pass three. We could pass a tuple of two variables as a single argument. We are going to use a dictionary or Named Tuples — typing documentation instances. We could also pass a dictionary or a completely custom class provided we satisfy JAX’s very mild expectations of their API. For more information see Pytrees — JAX documentation.

Our data structures look like follows:

class SnrAggregate(NamedTuple):
"""A pytree representing state of online SNR computation for all possible
leakage values together.
Attributes:
count (ArrayLike): Number of seen traces, shape
(different_leakage_values,).
mean (ArrayLike): Running mean of the trace, shape
(different_leakage_values, trace_len).
squared_deltas (ArrayLike): Running sum of squared distances from the
mean, shape (different_leakage_values, trace_len).
"""
count: ArrayLike
mean: ArrayLike
squared_deltas: ArrayLike
class UpdateData(NamedTuple):
"""A pytree representing the current update.
Attributes:
leakage_value (jnp.int32): The leakage value. Assumed to be in
range(different_leakage_values), see SnrAggregate.
trace (ArrayLike): The trace for this example, shape (trace_len,).
"""
leakage_value: jnp.int32
trace: ArrayLike

One of the advantages of using jax.lax.scan is that we can write the update function almost without having to think about batching (just returning the additional output, for instance the value zero). Then using scan gives us the batched version for free. We can see that in the function jax_update where we write it for a single update data (trace and leakage value) and then we scan this over a batch of examples seamlessly:

def jax_update(
aggregate: SnrAggregate,
data: UpdateData,
) -> SnrAggregate:
"""For a given update of trace and leakage_value update the aggregate
(single example, not batched). Returns the aggregate update and the total
number of updates to be directly usable with jax.lax.scan.
"""
count = aggregate.count.at[data.leakage_value].add(1)
delta = data.trace - aggregate.mean[data.leakage_value]
mean = aggregate.mean.at[data.leakage_value].add(delta /
count[data.leakage_value])
updated_delta = data.trace - mean[data.leakage_value]
squared_deltas = aggregate.squared_deltas.at[data.leakage_value].add(
delta * updated_delta)
return (
SnrAggregate(
count=count,
mean=mean,
squared_deltas=squared_deltas,
),
count, # To be usable with jax.lax.scan.
)
@jax.jit
def jax_update_b(
aggregate: SnrAggregate,
leakage_values: ArrayLike,
new_traces: ArrayLike,
) -> SnrAggregate:
"""Batched version of jax_update without returning the count.
Args:
aggregate (SnrAggregate): The current state.
leakage_values (ArrayLike): The leakage values of shape (batch_size,).
Each of those is in range(different_leakage_values).
new_traces (ArrayLike): The batch of traces of shape (batch_size,
trace_len).
Returns: the final SnrAggregate as if updating batch_size times using
jax_update (and forgetting count).
"""
new_aggregate, _ = jax.lax.scan(
jax_update,
aggregate,
UpdateData(
leakage_value=leakage_values,
trace=new_traces,
),
)
return new_aggregate

Then we just combine our results similarly to the previous tutorial:

# SNR computation for batched examples
def snr_jax_batched(
dataset_path: Path,
ap_name: str,
) -> npt.NDArray[np.float32]:
"""Compute SNR using NumPy.
"""
# Load the dataset
dataset = Dataset(dataset_path)
# We know that trace1 is the first.
trace_len: int = dataset.dataset_structure.saved_data_description[0].shape[0]
leakage_aggregate = jax_get_initial_aggregate(
trace_len=trace_len,
different_leakage_values=9, # Hamming weight
)
split: SplitT = "test"
byte_index: int = 0
batch_size: int = 64
for example in tqdm(
dataset.as_numpy_iterator_rust_batched(
split=split,
repeat=False,
shuffle=0,
batch_size=batch_size,
),
desc=f"[JAX] Computing SNR in batches over {split}",
total=dataset.dataset_info.splits[split].number_of_examples //
batch_size,
):
current_leakage = jnp.bitwise_count(example[ap_name][:, byte_index])
leakage_aggregate = jax_update_b(
leakage_aggregate,
current_leakage,
example["trace1"],
)
finalized = jax_finalize(leakage_aggregate)
results = {
int(i): (finalized[0][i], finalized[1][i])
for i in range(finalized[0].shape[0])
}
# Find out which class is the most common.
most_common_leakage = int(jnp.argmax(leakage_aggregate[0]))
signals = np.array([mean for mean, _ in results.values()])
return np.array(
20 * np.log(np.var(signals, axis=0) / results[most_common_leakage][1]),
dtype=np.float32,
)

You might notice that when running the SNR computation twice the second run is faster. One explanation could point to the file-system cache. Another is that the JIT is just-in-time compiler meaning the first time a jit-ted function is called there is some overhead proportional to the function source code. More precisely, each time the function is called with a new shape the optimal XLA representation is recomputed. Luckily our batches are always of the same size possibly with the exception of the last one. Thus the compilation cost can be often neglected. This is not the case for benchmarking. We have chosen to count the cost of JIT-compilation in the overall time since that is the more realistic scenario in practice. Even with this the whole SNR computation is under 3s for the batched version, roughly three times faster than the non-batched version.

Please read the Tutorial — Correlation Power Analysis, NewAE wiki which we use as a base for our experiments. We generally follow the mentioned tutorial and hold the following information for each of the 16 byte indexes:

dtype = jnp.float32
single_byte_index_state = {
"d": jnp.zeros(1, dtype=jnp.int64),
"sum_h_t": jnp.zeros((different_target_secrets, trace_len), dtype=dtype),
"sum_h": jnp.zeros(different_target_secrets, dtype=dtype),
"sum_hh": jnp.zeros(different_target_secrets, dtype=dtype),
"sum_t": jnp.zeros(trace_len, dtype=dtype),
"sum_tt": jnp.zeros(trace_len, dtype=dtype),
}

You might wonder why we keep our state in an untyped dictionary instead of a named tuple. One of the reasons is that this is a tutorial and that we want to show more possibilities. The other reason is that we will change the update function to work with all byte indices at once without the need to touch its code. That will also change the data representation and the author of this tutorial finds it nicer to use a dictionary for this purpose. But first let us begin with single byte index update:

@jax.jit
def r_update(
state: dict[str, ArrayLike],
data: UpdateData,
) -> (dict[str, ArrayLike], jnp.int32):
"""Update the CPA aggregate state.
"""
# Check the dimensions if debugging. This will work even across vmaps, jit,
# scan, etc.
assert data.trace.shape == state["sum_t"].shape
assert data.hypothesis.shape == state["sum_h"].shape
# D (so far)
d = state["d"] + 1
# i indexes the hypothesis possible values
# j indexes the time dimension
# \sum_{d=1}^{D} h_{d,i} t_{d,j}
sum_h_t = state["sum_h_t"] + jnp.einsum("i,j->ij", data.hypothesis,
data.trace)
# \sum_{d=1}^{D} h_{d, i}
sum_h = state["sum_h"] + data.hypothesis
# \sum_{d=1}^{D} t_{d, j}
sum_t = state["sum_t"] + data.trace
# \sum_{d=1}^{D} h_{d, i}^2
sum_hh = state["sum_hh"] + data.hypothesis**2
# \sum_{d=1}^{D} t_{d, j}^2
sum_tt = state["sum_tt"] + data.trace**2
return (
{
"d": d,
"sum_h_t": sum_h_t,
"sum_h": sum_h,
"sum_hh": sum_hh,
"sum_t": sum_t,
"sum_tt": sum_tt,
},
d, # jax.lax.scan trick known from SNR
)

This code resembles NumPy code quite a lot. All we need is to supply the traces and hypothesis values set to the bit count of plaintext xor key guesses. One note is that our dataset does not have a constant key which means that we simulate an all zeros key by pretending that the plaintext is equal to the real plaintext xor the real key. We also provide an experiment with 256 examples where we are sure to have a constant key (due to the way the dataset has been captured) which is enough to see a significant leakage but too fast to benchmark.

How much of the code can we reuse to attack all byte indices at once? And if we are able to do that, would it not mean that we would keep 16 copies of sum_t and sum_tt instead of one copy? First of all the asymptotically most important is sum_h_t since that is of shape (256, trace_len). We nevertheless optimize the sum_t to be stored just once just to show the available tools. First we change our initial state to accommodate for all 16 byte indices:

dtype = jnp.float32
multibyte_state = {
"d": jnp.zeros(1, dtype=jnp.int64),
"sum_h_t": jnp.zeros(
(num_byte_indexes, different_target_secrets, trace_len),
dtype=dtype),
"sum_h": jnp.zeros(
(num_byte_indexes, different_target_secrets),
dtype=dtype),
"sum_hh": jnp.zeros(
(num_byte_indexes, different_target_secrets),
dtype=dtype),
"sum_t": jnp.zeros(trace_len, dtype=dtype),
"sum_tt": jnp.zeros(trace_len, dtype=dtype),
}

Now one might expect that we need to change the r_update function. An easier way to do that is using Automatic vectorization — JAX documentation to our advantage. Namely the vmap transform allows us to add a dimension to a function. You are strongly encouraged to read the official jax.vmap documentation. The main trick is that we are able not only to vectorize a function over parameters but also to specify which inputs or outputs are not to be vectorized since we do not need that (such as trace or sum_t). The final code might look scary but the official documentation provides simpler examples first:

state_vmap = {
"d": None, # counts the number of seen examples
"sum_h_t": 0, # each byte index separately
"sum_h": 0, # each byte index separately
"sum_hh": 0, # each byte index separately
"sum_t": None, # would be the same
"sum_tt": None, # would be the same
}
r_update_multi_index = jax.jit(
jax.vmap(
r_update, # the function to be vectorized
in_axes=(
state_vmap,
UpdateData(
trace=None, # single trace for all byte indices
hypothesis=0, # byte indices times hypothesis
),
),
out_axes=(
state_vmap, # same shape input output
None, # jax.lax.scan does not matter
),
))

And finally we use our r_update_multi_index inside a scan:

for example in tqdm(
dataset.as_numpy_iterator_rust_batched(
split=split,
repeat=False,
shuffle=0,
batch_size=batch_size,
),
desc=f"[JAX] Computing CPA over batches of {split}",
total=dataset.dataset_info.splits[split].number_of_examples //
batch_size,
):
# Since the dataset was not created using a constant key we need to
# simulate. Our leakage model is the Hamming weight of S-BOX inputs
# which had high SNR. When we would be running with constant secret key
# our simulated_plaintext would equal the real plaintext (which we
# assume to know). With changes of the key we could simulate all zero
# key by setting simulated_plaintext to plaintext ^ key.
simulated_plaintext = example["plaintext"] ^ example["key"]
hypothesis = jnp.bitwise_count(
simulated_plaintext.reshape(-1, num_byte_indexes, 1) ^
jnp.arange(256, dtype=jnp.uint8).reshape(1, 1, 256))
aggregate, _ = jax.lax.scan(
r_update_multi_index,
aggregate,
UpdateData(
trace=example["trace1"][:, :trace_len],
hypothesis=hypothesis,
),
)

Overall computing CPA for all 16 byte indices with batching seems to take roughly two to nine times the time of computing it for a single index without batching. The tradeoff seems to depend on the trace length. You might wonder what happens for a single index with batching or for all indices without batching. Those are great questions which can be only answered by experiments.

Template attacks are a popular form of profiling side channel attacks introduced by [1]. One needs to compute the covariance matrix whose size is quadratic in the number of points of interest (time points from a trace). There is thus very little memory communication compared to the amount of compute needed.

[1] Chari, Suresh, Josyula R. Rao, and Pankaj Rohatgi. “Template attacks.” International workshop on cryptographic hardware and embedded systems. Berlin, Heidelberg: Springer Berlin Heidelberg, 2002.

Our code is based on Tutorial B8 Profiling Attacks (Manual Template Attack) - ChipWhisperer Wiki and Template Attacks - ChipWhisperer Wiki. Template attacks assume that a trace can be expressed as an ideal trace (a function of the inputs and intermediates only) plus a gaussian noise. Thus we will need to estimate parameters (mean and covariance matrix) of a Multivariate normal distribution - Wikipedia. We refer the reader to the NewAE tutorials for more information.

Templates are usually short (20-50, maybe 100) and the more POIs the more traces are needed. We scale way beyond that to ten thousand points-of-interest. This is still computationally feasible on a consumer-grade GPU (depending on the amount of GPU memory, the number is for Nvidia RTX4090). At the same time, large templates are hardly of any practical use since we would need tremendous amounts of profiling data (at least quadratic in the number of points-of-interest times linear in the number of potential leakages). Nevertheless the speedup will still be significant for batches of smaller numbers of points-of-interest. In our example we will use 10 points of interest since that is the amount of data we have for the tiny AES dataset.

We compute the covariance matrix online. Let be the points of interest we cut from the -th trace. In our example we take points to which are around the peak of SNR from our previous tutorial for the Hamming weight of the first byte of the S-BOX input. That is in the -th example the trace has length 10. For each different leakage value (0 to 9) we keep the following: the number of traces belonging to the leakage (single integer), their mean (of length 10), and the covariance matrix and update them as follows:

# Profiling phase for a single leakage value template:
n = prev_n + 1
dm = trace - prev_mean # distance from the current mean
mean = prev_mean + (dm / n)
# Add the vector product of dm with itself to the covariance matrix
C = prev_C + jnp.einsum("i,j->ij", dm, dm)

The Tutorial B8 Profiling Attacks (Manual Template Attack) - ChipWhisperer Wiki is using SciPy to compute the probability density function for a multivariate normal probability density. We provide manual computation here to avoid an additional dependency. It is probably a good idea to use scipy.stats.multivariate_normal — SciPy v1.16.2 Manual the logpdf function from SciPy in the production code since we are unlikely to get any meaningful speedup during our attack phase.

We do not provide formal profiling in this section. One of the reasons is that the profiling phase finishes in under 5 seconds. When we tried using roughly 10,000 points of interest the profiling phase was still done in two minutes but the covariance matrix was singular and determining the probability density with a singular covariance matrix is outside of the scope of this tutorial. With 20,000 points-of-interest the JAX library raised an out of memory exception (due to allocating an array of 400,000,000 floats on our GPU).

Empirically batching speeds up the profiling phase roughly twice compared to profiling example by example (in the range of default values provided by the script – 10 points of interest).

Our implementation in the example script builds templates only for the first byte of the S-BOX input. One could easily vmap that to receive an implementation attacking all byte indices at once. Provided the leakage would be somewhat in the same places or one would cut the traces inside the update function. We invite the interested reader to attempt this.

Running multiple evaluations in parallel on a single GPU

Section titled “Running multiple evaluations in parallel on a single GPU”

Frameworks such as JAX or TensorFlow often allocate the whole GPU memory and use their own allocators. Whereas this has a lot of benefits, it prevents us from running several experiments in parallel even when the hardware would be capable of doing that. Luckily, it is rather easy to change this behaviour using a flag. One can read more about it in the official GPU memory allocation — JAX documentation.

If one is using TensorFlow shard files (“tfrec”) they would also benefit from iterating inside a CPU context using: with tf.device(“/CPU”): as noted by Use a GPU | TensorFlow Core.

One particularity is that JAX is by default using float32 instead of float64 for the speed and compatibility reasons. One can read more about this decision and how to enable float64 in the official documentation: Default dtypes and the X64 flag — JAX documentation. Another choice would be to use more numerically precise methods such as the Kahan summation algorithm - Wikipedia. Changing a config and then profiling is very easy. Changing the algorithm is a little more complicated but also a good exercise in using the JAX ecosystem.

Throughout our tutorial series we focused on the input of the S-BOX in the first round of AES. Both that and its Hamming weight can be computed very easily given the plaintext and key. However since our representation is based on NumPy arrays one can supply a very efficient implementation to obtain more complicated attack points.

We have not done any higher order attacks in this tutorial. These might be added later based on our needs. If you want to benefit from them sooner feel free to open a pull-request on the SCAAML: Side Channel Attacks Assisted with Machine Learning repository.

Using tuples to represent several values is convenient for a small number of such variables. To keep our code clean we could use dictionaries or many more structures as described by: Pytrees — JAX documentation. We have chosen NamedTuple subclasses and dictionaries. However one could ask themselves if these data-structures are not adding additional overhead. We leave this as an exercise for the interested reader. If you choose to pursue this path make sure to try with and without batching and with and without JIT-compiling to get the full picture. And do not forget to jax.block_until_ready.

While profiling for their own use-cases one should also not forget to watch the overall PC usage (e.g., using htop) and your GPU utilization (e.g., using nvidia-smi -l). And use the Profiling computation — JAX documentation resources.