Skip to content

Commit

Permalink
Merge pull request #184 from reginald-mclean/API_Update
Browse files Browse the repository at this point in the history
Updating Generic Wrappers to the new Gym API
  • Loading branch information
jjshoots authored Sep 19, 2022
2 parents f885b9e + 5fae23f commit 0635181
Showing 1 changed file with 38 additions and 16 deletions.
54 changes: 38 additions & 16 deletions supersuit/generic_wrappers/frame_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def step(self, action):
total_reward = 0.0

for x in range(num_skips):
obs, rew, done, info = super().step(action)
obs, rew, term, trunc, info = super().step(action)
total_reward += rew
if done:
if term or trunc:
break

return obs, total_reward, done, info
return obs, total_reward, term, trunc, info


class StepAltWrapper(BaseWrapper):
Expand All @@ -45,7 +45,8 @@ def __init__(self, env, num_frames):
def reset(self, seed=None, return_info=False, options=None):
super().reset(seed=seed, options=options)
self.agents = self.env.agents[:]
self.dones = make_defaultdict({agent: False for agent in self.agents})
self.terminations = make_defaultdict({agent: False for agent in self.agents})
self.truncations = make_defaultdict({agent: False for agent in self.agents})
self.rewards = make_defaultdict({agent: 0.0 for agent in self.agents})
self._cumulative_rewards = make_defaultdict(
{agent: 0.0 for agent in self.agents}
Expand All @@ -63,7 +64,10 @@ def observe(self, agent):

def step(self, action):
self._has_updated = True
if self.dones[self.agent_selection]:
if (
self.terminations[self.agent_selection]
or self.truncations[self.agent_selection]
):
if self.env.agents and self.agent_selection == self.env.agent_selection:
self.env.step(None)
self._was_done_step(action)
Expand All @@ -75,20 +79,28 @@ def step(self, action):
self.old_actions[cur_agent] = action
while self.old_actions[self.env.agent_selection] is not None:
step_agent = self.env.agent_selection
if step_agent in self.env.dones:
if step_agent in self.env.terminations or self.env.trunctations:
# reward = self.env.rewards[step_agent]
# done = self.env.dones[step_agent]
# info = self.env.infos[step_agent]
observe, reward, done, info = self.env.last(observe=False)
observe, reward, term, trunc, info = self.env.last(observe=False)
action = self.old_actions[step_agent]
self.env.step(action)

for agent in self.env.agents:
self.rewards[agent] += self.env.rewards[agent]
self.infos[self.env.agent_selection] = info
while self.env.agents and self.env.dones[self.env.agent_selection]:
while self.env.agents and (
self.env.trunctations[self.env.agent_selection]
or self.env.trunctations[self.env.agent_selection]
):
done_agent = self.env.agent_selection
self.dones[done_agent] = True
self.truncations[done_agent] = self.env.trunctations[
self.env.trunctations
]
self.terminations[done_agent] = self.env.terminations[
self.env.terminations
]
self._final_observations[done_agent] = self.env.observe(done_agent)
self.env.step(None)
step_agent = self.env.agent_selection
Expand All @@ -99,7 +111,8 @@ def step(self, action):

my_agent_set = set(self.agents)
for agent in self.env.agents:
self.dones[agent] = self.env.dones[agent]
self.terminations[agent] = self.env.terminations[agent]
self.truncations[agent] = self.env.trunctations[agent]
self.infos[agent] = self.env.infos[agent]
if agent not in my_agent_set:
self.agents.append(agent)
Expand All @@ -122,16 +135,18 @@ def step(self, action):
orig_agents = set(action.keys())

total_reward = make_defaultdict({agent: 0.0 for agent in self.agents})
total_dones = {}
total_terminations = {}
total_truncations = {}
total_infos = {}
total_obs = {}

for x in range(num_skips):
obs, rews, done, info = super().step(action)
obs, rews, term, trunc, info = super().step(action)

for agent, rew in rews.items():
total_reward[agent] += rew
total_dones[agent] = done[agent]
total_truncations[agent] = trunc[agent]
total_terminations[agent] = term[agent]
total_infos[agent] = info[agent]
total_obs[agent] = obs[agent]

Expand All @@ -142,7 +157,7 @@ def step(self, action):
), "parallel environments that use frame_skip_v0 must provide a `default_action` argument for steps between an agent being generated and an agent taking its first step"
action[agent] = self.default_action

if all(done.values()):
if all(term.values()) or all(trunc.values()):
break

# delete any values created by agents which were
Expand All @@ -151,12 +166,19 @@ def step(self, action):
for agent in list(total_reward):
if agent not in final_agents and agent not in orig_agents:
del total_reward[agent]
del total_dones[agent]
del total_truncations[agent]
del total_terminations[agent]
del total_infos[agent]
del total_obs[agent]

self.agents = self.env.agents[:]
return total_obs, total_reward, total_dones, total_infos
return (
total_obs,
total_reward,
total_terminations,
total_truncations,
total_infos,
)


frame_skip_v0 = WrapperChooser(
Expand Down

0 comments on commit 0635181

Please sign in to comment.