tianshou.data.data_buffer

Base class

Batch set

class tianshou.data.data_buffer.batch_set.BatchSet(nstep=1)[source]

Bases: tianshou.data.data_buffer.base.DataBufferBase

Class for batched dataset as used in on-policy algorithms, where a batch of data is first collected with the current policy, several optimization steps are then conducted on this batch of data and the data are then discarded and collected again.

Parameters:nstep – An int defaulting to 1. The number of timesteps to lookahead for temporal difference computation. Only continuous data pieces longer than this number or already terminated ones are considered valid data points.
add(frame)[source]

Adds one frame of data to the buffer.

Parameters:frame – A tuple of (observation, action, reward, done_flag).
clear()[source]

Empties the data buffer and prepares to collect a new batch of data.

sample(batch_size)[source]

Performs uniform random sampling on self.index. For simplicity, we do random sampling with replacement for now with time O(batch_size). Fastest sampling without replacement seems to have to be of time O(batch_size * log(num_episodes)).

Parameters:batch_size – An int. The size of the minibatch.
Returns:A list of list of the sampled indexes. Episodes without sampled data points correspond to empty sub-lists.
statistics(discount_factor=0.99)[source]

Computes and prints out the statistics (e.g., discounted returns, undiscounted returns) in the batch set. This is useful when policies are optimized by on-policy algorithms, so the current data in the batch set directly reflect the performance of the current policy.

Parameters:discount_factor – Optional. A float in range \([0, 1]\) defaulting to 0.99. The discount factor to compute discounted returns.

Replay buffer base

Vanilla replay buffer

class tianshou.data.data_buffer.vanilla.VanillaReplayBuffer(capacity, nstep=1)[source]

Bases: tianshou.data.data_buffer.replay_buffer_base.ReplayBufferBase

Class for vanilla replay buffer as used in (Mnih, et al., 2015). Frames are always continuous in temporal order. They are only removed from the beginning and added at the tail. This continuity in self.data could be exploited only in vanilla replay buffer.

Parameters:
  • capacity – An int. The capacity of the buffer.
  • nstep – An int defaulting to 1. The number of timesteps to lookahead for temporal difference computation. Only continuous data pieces longer than this number or already terminated ones are considered valid data points.
add(frame)[source]

Adds one frame of data to the buffer.

Parameters:frame – A tuple of (observation, action, reward, done_flag).
clear()

Empties the data buffer, usually used in batch set but not in replay buffer.

remove()[source]

Removes data from the buffer until self.size <= self.capacity.

sample(batch_size)[source]

Performs uniform random sampling on self.index. For simplicity, we do random sampling with replacement for now with time O(batch_size). Fastest sampling without replacement seems to have to be of time O(batch_size * log(num_episodes)).

Parameters:batch_size – An int. The size of the minibatch.
Returns:A list of list of the sampled indexes. Episodes without sampled data points correspond to empty sub-lists.