Sedpack with the MNIST Dataset
Find full code samples at docs/tutorials/quick_start/.
Let us give an overview of working with the MNIST dataset. This is a toy example in the sense that given the limited size requirements of this dataset one can easily load the whole dataset into RAM. On the other hand it is deeply familiar for many people interested in machine learning and allows us to demonstrate the sedpack usage.
Download the Dataset
Downloading the dataset is supported by many machine learning frameworks.
import keras
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# In this example we still use Keras to download the dataset.import keras
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
Create a Sedpack Dataset
We walk through the process of saving MNIST data in the Sedpack format. This allows us to showcase basic features of Sedpack.
Define Metadata
Information about the dataset is an important part of the whole dataset. We
represent our metadata using Pydantic objects
which are then serialized into the JSON
format. In most cases whenever there is some metadata there is also
custom_metadata
field which can contain any information which will be
serialized and then deserialized from JSON.
from sedpack.io import Metadata
# General info about the datasetmetadata = Metadata( description="MNIST dataset in the sedpack format", dataset_license=""" Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset, which is a derivative work from original NIST datasets. MNIST dataset is made available under the terms of the Creative Commons Attribution-Share Alike 3.0 license. """, custom_metadata={ "list of authors": ["Yann LeCun", "Corinna Cortes"], },)
Define Attributes
A dataset consists of examples. All examples have the same attributes. An
attribute is a piece of data being saved. In the case of the MNIST dataset
there is the image of 28 by 28 grayscale pixels and the corresponding class (0
to 9 denoting which digit is represented by the image). An Attribute
class
saves the name, shape, and data type of each of these pieces of information.
A list of Attribute
s gives us the saved_data_description
.
from sedpack.io import DatasetStructure, Attribute
dataset_structure = DatasetStructure( saved_data_description=[ Attribute( name="input", shape=(28, 28), dtype="float32", ), Attribute( name="digit", shape=(), dtype="uint8", ), ], examples_per_shard=256, compression="LZ4", shard_file_type="fb", hash_checksum_algorithms=("sha256",),)
We have specified some more properties of the dataset structure.
A Shard
is a file containing one or more examples. By setting
examples_per_shard
we limit how many examples should go to one shard. A
common practice is targetting the order of 100MB files. Obviously this is out
of reach for us now since the whole dataset will be around 12MB. For larger
datasets the tradeoff is that the fewer examples in a shard there are the
easier it is to iterate in a shuffled way. On the other hand to utilize the
storage medium (HDD, SSD, NAS, …) we prefer larger files. Note that the exact
tradeoff depends on the storage medium characteristics and is not discovered
automatically.
Another potential benefit of storing multiple examples in a single file is the
compression. The compression
parameter sets which compression algorithm to
use for the whole shard file. If there are more similar examples then the
compression ratio is likely to be larger.
Finally all shards are stored in a given file type. The library supports the following:
- “tfrec”: stands for TFRecord is a Protocol Buffer based storage supported by TensorFlow.
- “fb”: stands for FlatBuffer a more light-weight version of Protocol Buffer format.
- “npz”: stands for NumPy format.
All of these formats provide the possibility of being parsed by other languages (however the level of support varies).
Finally we can specify the set of hash_checksum_algorithms
which form
checksums for all saved files — both the shard files (data) and the JSON
(metadata) files. Thus effectively forming a Merkle
tree over the whole dataset.
Create the Dataset
Now actually creating the dataset is straightforward.
from sedpack.io import Dataset
dataset = Dataset.create( path="Datasets/MNIST", # All files are stored here metadata=metadata, dataset_structure=dataset_structure,)
The previous command will create the directory “MNIST” under “Datasets” and put
a “dataset_info.json” in it. In case this file already exists an exception
sedpack.io.errors.DatasetExistsError
is raised so that your older dataset is
not overwritten by accident.
Writing the Examples
We have created the dataset. Now we need to write the data in. When downloading
the MNIST data we notice that there is x_train
, y_train
, x_test
, and
y_test
:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# In this example we still use Keras to download the dataset.(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
The x
or y
part denotes if those are the inputs or outputs (x
is an array
of images, y
is an array of the digit classes). The train
and test
denote
which split
the data belongs to. For the purpose of building a model we often
split the data into:
- train set used to build the model (e.g., determine the weights of a neural network),
- validation, also called test, set used to observe the building process (e.g., observe how well the network generalizes during training),
- test, also called holdout, set which is used at the very end of the process to determine the generalization performance of our test.
We choose the train / test / holdout names (as opposed to train / validation / test). The important part is that one needs to be very careful not to mix data belonging to different sets. This is because training and evaluating on the same data would confuse memorization for generalization. For more information on this see any classical book, e.g., The Elements of Statistical Learning by Jerome H. Friedman, Robert Tibshirani, and Trevor Hastie.
An additional detail is that we do not call write_example
on the dataset
object itself but on a DatasetFiller
object. This acts as a context manager
which ensures that all data gets saved. A benefit of using a DatasetFiller
is
that we can create more of those to write in a multithreaded way if we choose
so.
The full code is as follows:
# DatasetFiller makes sure that all shard files are written properly# when exiting the context.with dataset.filler() as dataset_filler: # Determine which data are in the holdout (test) for i in tqdm(range(len(x_test)), desc="holdout"): dataset_filler.write_example( values={ "input": x_test[i], "digit": y_test[i], }, split="holdout", )
# Randomly assign 10% of validation and the rest is training assert len(x_train) == len(y_train) train_indices: list[int] = list(range(len(x_train))) random.shuffle(train_indices) validation_split_position: int = int(len(x_train) * 0.1) for index_position, index in enumerate( tqdm(train_indices, desc='train and val')): split = "test" if index_position < validation_split_position else "train" dataset_filler.write_example( values={ "input": x_train[index], "digit": y_train[index], }, split=split, )
Intermezzo to Iterate
One can load the dataset and iterate the examples easily. The loading could be omitted since we just created the dataset.
# Load the dataset (can be omitted since we just created it)dataset = Dataset("Datasets/MNIST")
for example in dataset.as_numpy_iterator(split="train", repeat=False, shuffle=0): print(example)
We can see that example
is a dictionary containing attribute_name: attribute_data
.
Training a Neural Network
The neural network can be according to your favourite framework’s tutorial (or your own):
MNIST tutorial of Flax
Load
One can load the dataset and use as_tfdataset
for Keras/TensorFlow training.
It is often the case that your training loop needs other format of data than an
iterator of examples.
from sedpack.io import Dataset
dataset = Dataset(args.dataset_directory) # Load the dataset
Batching
examples together is a common practice during stochastic gradient
descent. Instead of
computing the gradient for the whole dataset or just one example we compute it
for several randomly chosen examples at a time (for a batch
or for a
mini-batch
). This allows better utilization of parallelization and thus of
the hardware (GPU / TPU) as well as it allows us to do more robust gradients.
We thus set the parameter batch_size
to whatever we wish. The batch dimension
is always the first dimension.
batch_size: int = 128
The data might not have been saved in the exact format the training loop
needs. For instance we have saved the classes as small integers when the
training loop prefers those encoded as
one-hot. Moreover it might be
useful to clearly distinguish which are the inputs and which are the
outputs (return a tuple of (inputs, outputs)
). All of these and more can
be done by a simple transformation function which takes an example (a
dictionary of values) and returns whatever form one pleases.
# ExampleT: TypeAlias of dict[str, sedpack.io.types.AttributeValueT]def process_record(rec: ExampleT) -> Any: # One hot encoding output = tf.one_hot(rec["digit"], 10) # Return tuple (input, output) return rec["input"], output
The data might not have been saved in the exact format the training loop
needs. For instance we have saved the classes as small integers np.uint8
when the training loop prefers those encoded as jnp.int32
. Moreover it
could be beneficial to reshape the inputs. All of these and more can be
done by a simple transformation function which takes an example (a
dictionary of values) and returns whatever form one pleases.
# ExampleT: TypeAlias of dict[str, sedpack.io.types.AttributeValueT]def process_batch(d): """Turn the NumPy arrays into JAX arrays and reshape the input to have a channel. """ batch_size: int = d["input"].shape[0] return { "input": jnp.array(d["input"]).reshape(batch_size, 28, 28, 1), "digit": jnp.array(d["digit"], jnp.int32), }
Loading the train and test (validation) split is then as follows:
# Load train and validation splits of the datasettrain_data = dataset.as_tfdataset( "train", batch_size=batch_size, process_record=process_record,)validation_data = dataset.as_tfdataset( "test", # validation split batch_size=batch_size, process_record=process_record,)
train_data = dataset.as_tfdataset( "train", batch_size=batch_size, shuffle=1_000, ) validation_data = dataset.as_tfdataset( "test", # validation split batch_size=batch_size, shuffle=1_000, repeat=False, )
This has been the historical API. We are working on benchmarking our own pipeline written in Rust. When that API gets mature it will be the default choice.
Train
And we train:
model = get_model()
steps_per_epoch = 100epochs = 10history = model.fit( train_data, steps_per_epoch=steps_per_epoch, epochs=epochs, validation_data=validation_data, validation_steps=steps_per_epoch // 10,)
We followed the tutorial and wrote a custom training loop.
for step, batch in enumerate(tqdm(train_data)): batch = process_batch(batch) train_step(model, optimizer, metrics, batch)