Skip to content

Replay Buffer#

smart_control.reinforcement_learning.replay_buffer.replay_buffer #

Reinforcement learning replay buffers.

ReplayBufferManager #

ReplayBufferManager(data_spec, capacity, checkpoint_dir, sequence_length=2)

Manager for creating and interacting with Reverb replay buffers.

This class simplifies the setup, interaction, and checkpointing of Reverb replay buffers for reinforcement learning agents. It provides methods to create a new buffer, add data, sample from the buffer, and save/restore buffer state.

clear #

clear() -> None

Clear all data from the replay buffer.

close #

close() -> None

Close the replay buffer server and clean up resources.

create_replay_buffer #

create_replay_buffer()

Create the replay buffer.

get_dataset #

get_dataset(
    batch_size: int = 64, num_steps: Optional[int] = None
) -> tf.data.Dataset

Get a TensorFlow dataset for sampling from the replay buffer.

Parameters:

Name Type Description Default
batch_size int

Number of sequences to sample in each batch.

64
num_steps Optional[int]

Number of steps to sample for each sequence. If None, defaults to sequence_length.

None

Returns:

Type Description
Dataset

A TensorFlow dataset that samples from the replay buffer.

Raises:

Type Description
RuntimeError

If the replay buffer has not been initialized yet.

get_replay_buffer_and_observer #

get_replay_buffer_and_observer() -> Tuple[
    reverb_replay_buffer.ReverbReplayBuffer,
    reverb_utils.ReverbAddTrajectoryObserver,
]

Get the replay buffer and observer.

Creates them if not already initialized.

Returns:

Type Description
Tuple[ReverbReplayBuffer, ReverbAddTrajectoryObserver]

A tuple of (replay_buffer, observer).

load_replay_buffer #

load_replay_buffer() -> Tuple[
    reverb_replay_buffer.ReverbReplayBuffer,
    reverb_utils.ReverbAddTrajectoryObserver,
]

Load an existing replay buffer from a saved checkpoint.

This method reconstructs the replay buffer, server, and observer based on the saved state in the checkpoint directory.

Returns:

Type Description
Tuple[ReverbReplayBuffer, ReverbAddTrajectoryObserver]

A tuple of (replay_buffer, observer).

num_frames #

num_frames() -> int

Get the current number of frames in the replay buffer.

Returns:

Type Description
int

The number of frames currently in the buffer.