Skip to content

Commit

Permalink
Add multi-env support
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Jul 17, 2024
1 parent 5f3cba9 commit b5ce091
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 61 deletions.
89 changes: 38 additions & 51 deletions sbx/common/prioritized_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from gymnasium import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.type_aliases import ReplayBufferSamples
from stable_baselines3.common.utils import get_linear_fn
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize


Expand All @@ -41,6 +40,8 @@ def __init__(self, capacity: int, reduce_op: Callable, neutral_element: float) -
"""
assert capacity > 0 and capacity & (capacity - 1) == 0, f"Capacity must be positive and a power of 2, not {capacity}"
self._capacity = capacity
# First index is the root, leaf nodes are in [capacity, 2 * capacity - 1].
# For each parent node i, left child has index [2 * i], right child [2 * i + 1]
self._values = np.full(2 * capacity, neutral_element)
self._reduce_op = reduce_op
self.neutral_element = neutral_element
Expand Down Expand Up @@ -97,8 +98,10 @@ def __setitem__(self, idx: np.ndarray, val: np.ndarray) -> None:
:param idx: index of the value to be updated
:param val: new value
"""
# assert np.all(0 <= idx < self._capacity), f"Trying to set item outside capacity: {idx}"
# Indices of the leafs
indices = idx + self._capacity
# Update the leaf nodes and then the related nodes
self._values[indices] = val
if isinstance(indices, int):
indices = np.array([indices])
Expand Down Expand Up @@ -153,8 +156,7 @@ def find_prefixsum_idx(self, prefixsum: np.ndarray) -> np.ndarray:
if isinstance(prefixsum, float):
prefixsum = np.array([prefixsum])
assert 0 <= np.min(prefixsum)
assert np.max(prefixsum) <= self.sum() + 1e-5
assert isinstance(prefixsum[0], float)
# assert np.max(prefixsum) <= self.sum() + 1e-5

indices = np.ones(len(prefixsum), dtype=int)
should_continue = np.ones(len(prefixsum), dtype=bool)
Expand Down Expand Up @@ -227,8 +229,6 @@ def __init__(
device: Union[th.device, str] = "auto",
n_envs: int = 1,
alpha: float = 0.5,
beta: float = 0.4,
final_beta: float = 1.0,
optimize_memory_usage: bool = False,
min_priority: float = 1e-6,
):
Expand All @@ -238,7 +238,7 @@ def __init__(
assert optimize_memory_usage is False, "PrioritizedReplayBuffer doesn't support optimize_memory_usage=True"

# TODO: add support for multi env
assert n_envs == 1, "PrioritizedReplayBuffer doesn't support n_envs > 1"
# assert n_envs == 1, "PrioritizedReplayBuffer doesn't support n_envs > 1"

# Find the next power of 2 for the buffer size
power_of_two = int(np.ceil(np.log2(buffer_size)))
Expand All @@ -249,26 +249,12 @@ def __init__(

self._alpha = alpha

# Track the training progress remaining (from 1 to 0)
# this is used to update beta
self._current_progress_remaining = 1.0

# TODO: move beta schedule to the DQN algorithm
self._inital_beta = beta
self._final_beta = final_beta
self.beta_schedule = get_linear_fn(
self._inital_beta,
self._final_beta,
end_fraction=1.0,
)

self._sum_tree = SumSegmentTree(tree_capacity)
self._min_tree = MinSegmentTree(tree_capacity)

@property
def beta(self) -> float:
# Linear schedule
return self.beta_schedule(self._current_progress_remaining)
# Flatten the indices from the buffer to store them in the sum tree
# Replay buffer: (idx, env_idx)
# Sum tree: idx * self.n_envs + env_idx
self.env_offsets = np.arange(self.n_envs)

def add(
self,
Expand All @@ -289,14 +275,14 @@ def add(
:param done: Whether the episode was finished after the transition to be stored.
:param infos: Eventual information given by the environment.
"""
# store transition index with maximum priority in sum tree
self._sum_tree[self.pos] = self._max_priority**self._alpha
self._min_tree[self.pos] = self._max_priority**self._alpha
# Store transition index with maximum priority in sum tree
self._sum_tree[self.pos * self.n_envs + self.env_offsets] = self._max_priority**self._alpha
self._min_tree[self.pos * self.n_envs + self.env_offsets] = self._max_priority**self._alpha

# store transition in the buffer
# Store transition in the buffer
super().add(obs, next_obs, action, reward, done, infos)

def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
def sample(self, batch_size: int, beta: float, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
"""
Sample elements from the prioritized replay buffer.
Expand All @@ -305,41 +291,46 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB
to normalize the observations/rewards when sampling
:return: a batch of sampled experiences from the buffer.
"""
assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires."
# assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires."

# priorities = np.zeros((batch_size, 1))
# sample_indices = np.zeros(batch_size, dtype=np.uint32)

# TODO: check how things are sampled in the original implementation

sample_indices = self._sample_proportional(batch_size)
leaf_nodes_indices = sample_indices
leaf_nodes_indices = self._sample_proportional(batch_size)
# Convert the leaf nodes indices to buffer indices
# Replay buffer: (idx, env_idx)
# Sum tree: idx * self.n_envs + env_idx
buffer_indices = leaf_nodes_indices // self.n_envs
env_indices = leaf_nodes_indices % self.n_envs

# probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha
# where p_i > 0 is the priority of transition i.
# probs = priorities / self.tree.total_sum
probabilities = self._sum_tree[sample_indices] / self._sum_tree.sum()
probabilities = self._sum_tree[leaf_nodes_indices] / self._sum_tree.sum()

# Importance sampling weights.
# All weights w_i were scaled so that max_i w_i = 1.
# weights = (self.size() * probs + 1e-7) ** -self.beta
# min_probability = self._min_tree.min() / self._sum_tree.sum()
# max_weight = (min_probability * self.size()) ** (-self.beta)
# weights = (probabilities * self.size()) ** (-self.beta) / max_weight
weights = (probabilities * self.size()) ** (-self.beta)
weights = (probabilities * self.size()) ** (-beta)
weights = weights / weights.max()

# TODO: add proper support for multi env
# env_indices = np.random.randint(0, high=self.n_envs, size=(batch_size,))
env_indices = np.zeros(batch_size, dtype=np.uint32)
next_obs = self._normalize_obs(self.next_observations[sample_indices, env_indices, :], env)
# env_indices = np.zeros(batch_size, dtype=np.uint32)
next_obs = self._normalize_obs(self.next_observations[buffer_indices, env_indices, :], env)

batch = (
self._normalize_obs(self.observations[sample_indices, env_indices, :], env),
self.actions[sample_indices, env_indices, :],
self._normalize_obs(self.observations[buffer_indices, env_indices, :], env),
self.actions[buffer_indices, env_indices, :],
next_obs,
self.dones[sample_indices],
self.rewards[sample_indices],
# Only use dones that are not due to timeouts
# deactivated by default (timeouts is initialized as an array of False)
(self.dones[buffer_indices, env_indices] * (1 - self.timeouts[buffer_indices, env_indices])).reshape(-1, 1),
self._normalize_reward(self.rewards[buffer_indices, env_indices].reshape(-1, 1), env),
weights,
)
return ReplayBufferSamples(*tuple(map(self.to_torch, batch)), leaf_nodes_indices) # type: ignore[arg-type,call-arg]
Expand All @@ -358,28 +349,24 @@ def _sample_proportional(self, batch_size: int) -> np.ndarray:
return self._sum_tree.find_prefixsum_idx(priorities_sum)

# def update_priorities(self, indices: np.ndarray, priorities: np.ndarray) -> None:
def update_priorities(self, indices: np.ndarray, priorities: np.ndarray, progress_remaining: float) -> None:
def update_priorities(self, leaf_nodes_indices: np.ndarray, priorities: np.ndarray) -> None:
"""
Update priorities of sampled transitions.
:param leaf_nodes_indices: Indices of the sampled transitions.
:param td_errors: New priorities, td error in the case of
:param priorities: New priorities, td error in the case of
proportional prioritized replay buffer.
"""
# TODO: move beta to the DQN algorithm
# Update beta schedule
self._current_progress_remaining = progress_remaining

# TODO: double check that all samples are updated
# priorities = np.abs(td_errors) + self.min_priority
priorities += self._min_priority
# assert len(indices) == len(priorities)
assert np.min(priorities) > 0
assert np.min(indices) >= 0
assert np.max(indices) < self.buffer_size
assert np.min(leaf_nodes_indices) >= 0
assert np.max(leaf_nodes_indices) < self.buffer_size
# TODO: check if we need to add the min_priority here
# priorities = (np.abs(td_errors) + self.min_priority) ** self.alpha
self._sum_tree[indices] = priorities**self._alpha
self._min_tree[indices] = priorities**self._alpha
self._sum_tree[leaf_nodes_indices] = priorities**self._alpha
self._min_tree[leaf_nodes_indices] = priorities**self._alpha
# Update max priority for new samples
self._max_priority = max(self._max_priority, np.max(priorities))
36 changes: 26 additions & 10 deletions sbx/per_dqn/per_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import optax
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn

from sbx.common.prioritized_replay_buffer import PrioritizedReplayBuffer
from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState
Expand Down Expand Up @@ -39,6 +40,8 @@ def __init__(
exploration_fraction: float = 0.1,
exploration_initial_eps: float = 1.0,
exploration_final_eps: float = 0.05,
initial_beta: float = 0.4,
final_beta: float = 1.0,
optimize_memory_usage: bool = False, # Note: unused but to match SB3 API
# max_grad_norm: float = 10,
train_freq: Union[int, Tuple[int, str]] = 4,
Expand Down Expand Up @@ -77,6 +80,19 @@ def __init__(
_init_setup_model=_init_setup_model,
)

self._inital_beta = initial_beta
self._final_beta = final_beta
self.beta_schedule = get_linear_fn(
self._inital_beta,
self._final_beta,
end_fraction=1.0,
)

@property
def beta(self) -> float:
# Linear schedule
return self.beta_schedule(self._current_progress_remaining)

def learn(
self,
total_timesteps: int,
Expand All @@ -97,7 +113,7 @@ def learn(

def train(self, batch_size: int, gradient_steps: int) -> None:
# Sample all at once for efficiency (so we can jit the for loop)
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)
data = self.replay_buffer.sample(batch_size * gradient_steps, self.beta, env=self._vec_normalize_env)
# Convert to numpy
data = ReplayBufferSamplesNp(
data.observations.numpy(),
Expand All @@ -121,7 +137,7 @@ def train(self, batch_size: int, gradient_steps: int) -> None:
"info": {
"critic_loss": jnp.array([0.0]),
"qf_mean_value": jnp.array([0.0]),
"td_error": jnp.zeros_like(data.rewards),
"priorities": jnp.zeros_like(data.rewards),
},
}

Expand All @@ -137,12 +153,12 @@ def train(self, batch_size: int, gradient_steps: int) -> None:
self.policy.qf_state = update_carry["qf_state"]
qf_loss_value = update_carry["info"]["critic_loss"]
qf_mean_value = update_carry["info"]["qf_mean_value"] / gradient_steps
td_error = update_carry["info"]["td_error"]
priorities = update_carry["info"]["priorities"]

# Update priorities, they will be proportional to the td error
# Note: compared to the original implementation, we update
# the priorities after all the gradient steps
self.replay_buffer.update_priorities(data.leaf_nodes_indices, td_error, self._current_progress_remaining)
self.replay_buffer.update_priorities(data.leaf_nodes_indices, priorities)

self._n_updates += gradient_steps
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
Expand Down Expand Up @@ -179,24 +195,24 @@ def weighted_huber_loss(params):
# Retrieve the q-values for the actions from the replay buffer
current_q_values = jnp.take_along_axis(current_q_values, replay_actions, axis=1)
# TD error in absolute value, to update priorities
td_error = jnp.abs(current_q_values - target_q_values)
priorities = jnp.abs(current_q_values - target_q_values)
# Weighted Huber loss using importance sampling weights
loss = (sampling_weights * optax.huber_loss(current_q_values, target_q_values)).mean()
return loss, (current_q_values.mean(), td_error.flatten())
return loss, (current_q_values.mean(), priorities.flatten())

(qf_loss_value, (qf_mean_value, td_error)), grads = jax.value_and_grad(weighted_huber_loss, has_aux=True)(
(qf_loss_value, (qf_mean_value, priorities)), grads = jax.value_and_grad(weighted_huber_loss, has_aux=True)(
qf_state.params
)
qf_state = qf_state.apply_gradients(grads=grads)

return qf_state, (qf_loss_value, qf_mean_value, td_error)
return qf_state, (qf_loss_value, qf_mean_value, priorities)

@staticmethod
@jax.jit
def _train(carry, indices):
data = carry["data"]

qf_state, (qf_loss_value, qf_mean_value, td_error) = PERDQN.update_qnetwork(
qf_state, (qf_loss_value, qf_mean_value, priorities) = PERDQN.update_qnetwork(
carry["gamma"],
carry["qf_state"],
observations=data.observations[indices],
Expand All @@ -210,6 +226,6 @@ def _train(carry, indices):
carry["qf_state"] = qf_state
carry["info"]["critic_loss"] += qf_loss_value
carry["info"]["qf_mean_value"] += qf_mean_value
carry["info"]["td_error"] = carry["info"]["td_error"].at[indices].set(td_error)
carry["info"]["priorities"] = carry["info"]["priorities"].at[indices].set(priorities)

return carry, None

0 comments on commit b5ce091

Please sign in to comment.