Skip to content

Commit

Permalink
Merge pull request werner-duvaud#29 from xuxiyang1993/master
Browse files Browse the repository at this point in the history
IS weights for prioritized replay
  • Loading branch information
werner-duvaud committed Mar 28, 2020
2 parents bdbf703 + 43688f7 commit 4a20f90
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
1 change: 1 addition & 0 deletions games/lunarlander.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self):
# Prioritized Replay
self.PER = True
self.PER_alpha = 0.5
self.PER_beta = 1.0

# Exponential learning rate schedule
self.lr_init = 0.005 # Initial learning rate
Expand Down
20 changes: 13 additions & 7 deletions replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,18 @@ def get_batch(self):
[],
[]
)
total_samples = sum((len(game_history.priorities) for game_history in self.buffer))
weight_batch = []
for _ in range(self.config.batch_size):
game_index, game_history = self.sample_game(self.buffer)
game_pos = self.sample_position(game_history)
game_index, game_history, game_prob = self.sample_game(self.buffer)
game_pos, pos_prob = self.sample_position(game_history)
index_batch.append([game_index, game_pos])
weight_batch.append((total_samples * game_prob * pos_prob) ** (-self.config.PER_beta))

values, rewards, policies, actions = self.make_target(
game_history, game_pos
)

index_batch.append([game_index, game_pos])
observation_batch.append(game_history.observation_history[game_pos])
action_batch.append(actions)
value_batch.append(values)
Expand All @@ -57,7 +60,8 @@ def get_batch(self):
# value_batch: batch, num_unroll_steps+1
# reward_batch: batch, num_unroll_steps+1
# policy_batch: batch, num_unroll_steps+1, len(action_space)
return index_batch, (observation_batch, action_batch, value_batch, reward_batch, policy_batch)
weight_batch = numpy.array(weight_batch) / max(weight_batch)
return index_batch, (weight_batch, observation_batch, action_batch, value_batch, reward_batch, policy_batch)

def sample_game(self, buffer):
"""
Expand All @@ -68,8 +72,9 @@ def sample_game(self, buffer):
game_probs = numpy.array(self.game_priorities) / sum(self.game_priorities)
game_index_candidates = numpy.arange(0, len(self.buffer), dtype=int)
game_index = numpy.random.choice(game_index_candidates, p=game_probs)
game_prob = game_probs[game_index]

return game_index, self.buffer[game_index]
return game_index, self.buffer[game_index], game_prob

def sample_position(self, game_history):
"""
Expand All @@ -79,8 +84,9 @@ def sample_position(self, game_history):
position_probs = numpy.array(game_history.priorities) / sum(game_history.priorities)
position_index_candidates = numpy.arange(0, len(position_probs), dtype=int)
position_index = numpy.random.choice(position_index_candidates, p=position_probs)
position_prob = position_probs[position_index]

return position_index
return position_index, position_prob

def update_priorities(self, priorities, index_info):

Expand All @@ -94,7 +100,7 @@ def update_priorities(self, priorities, index_info):
numpy.put(self.buffer[game_index].priorities, range(start_index, end_index), priority)

# update game priorities
self.game_priorities[game_index] = numpy.mean(self.buffer[game_index].priorities)
self.game_priorities[game_index] = numpy.max(self.buffer[game_index].priorities) # option: mean, sum, max

self.max_recorded_game_priority = numpy.max(self.game_priorities)

Expand Down
16 changes: 8 additions & 8 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def update_weights(self, batch):
"""

(
weight_batch,
observation_batch,
action_batch,
target_value,
Expand All @@ -78,6 +79,7 @@ def update_weights(self, batch):
priorities = numpy.zeros_like(target_value_scalar)

device = next(self.model.parameters()).device
weight_batch = torch.tensor(weight_batch).float().to(device)
observation_batch = torch.tensor(observation_batch).float().to(device)
action_batch = torch.tensor(action_batch).float().to(device).unsqueeze(-1)
target_value = torch.tensor(target_value).float().to(device)
Expand Down Expand Up @@ -119,6 +121,7 @@ def update_weights(self, batch):
_,
current_policy_loss,
) = self.loss_function(
weight_batch,
value.squeeze(-1),
reward.squeeze(-1),
policy_logits,
Expand All @@ -137,6 +140,7 @@ def update_weights(self, batch):
current_reward_loss,
current_policy_loss,
) = self.loss_function(
weight_batch,
value.squeeze(-1),
reward.squeeze(-1),
policy_logits,
Expand Down Expand Up @@ -205,16 +209,12 @@ def scalar_to_support(x, support_size):

@staticmethod
def loss_function(
value, reward, policy_logits, target_value, target_reward, target_policy
weight_batch, value, reward, policy_logits, target_value, target_reward, target_policy
):
# Cross-entropy seems to have a better convergence than MSE
value_loss = (-target_value * torch.nn.LogSoftmax(dim=1)(value)).sum(1).mean()
reward_loss = (
(-target_reward * torch.nn.LogSoftmax(dim=1)(reward)).sum(1).mean()
)
policy_loss = (
(-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits)).sum(1).mean()
)
value_loss = ((-target_value * torch.nn.LogSoftmax(dim=1)(value)).sum(1) * weight_batch).mean()
reward_loss = ((-target_reward * torch.nn.LogSoftmax(dim=1)(reward)).sum(1) * weight_batch).mean()
policy_loss = ((-target_policy * torch.nn.LogSoftmax(dim=1)(policy_logits)).sum(1) * weight_batch).mean()
return value_loss, reward_loss, policy_loss

@staticmethod
Expand Down

0 comments on commit 4a20f90

Please sign in to comment.