Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: autoreset wrappers #223

Merged
merged 5 commits into from
Mar 8, 2024
Merged

fix: autoreset wrappers #223

merged 5 commits into from
Mar 8, 2024

Conversation

sash-a
Copy link
Collaborator

@sash-a sash-a commented Feb 12, 2024

There was an issue with the autoreset wrappers: they never showed you the final timestep. This is an issue if the final timestep is a truncation (discount = 1, timestep.last() = true), then we'd want this observation in order to get the next value for our value target, but currently this would not be possible.

Gym fixes does it in the same way I am proposing as can be seen here they place the final observation in infos. What this PR does is always place either the current observation or the terminal observation in timestep.extras["real_next_obs"]

Here's how I'm currently using for SAC (and it's working well there):

    def step(
        action: Array, obs: Observation, env_state: State, buffer_state: BufferState
    ) -> Tuple[Array, State, BufferState, Dict]:
        """Given an action, step the environment and add to the buffer."""
        env_state, timestep = jax.vmap(env.step)(env_state, action)
        next_obs = timestep.observation
        rewards = timestep.reward
        terms = ~(timestep.discount).astype(bool)
        infos = timestep.extras

        real_next_obs = infos["real_next_obs"]

        transition = Transition(obs, action, rewards, terms, real_next_obs)
        buffer_state = rb.add(buffer_state, transition)

        return next_obs, env_state, buffer_state, infos["episode_metrics"]

@sash-a sash-a self-assigned this Feb 12, 2024
jumanji/wrappers.py Outdated Show resolved Hide resolved
@clement-bonnet
Copy link
Collaborator

Hi Sasha,
This issue seems a bit related to #106.
Returning the reset state/observation instead of the terminal state/observation when auto-resetting has always been the desired feature. This is because none of the Jumanji environments uses truncation, so one does not need the terminal state to train an actor-critic agent.
Now, if a user implements a new jumanji environments using the Environment abstraction and other tools from Jumanji, including truncation, one may want to use the truncated state/observation in their own training loop, which seems to be your use case, right? Passing it to the extras seems legit to me. 🙌

@sash-a
Copy link
Collaborator Author

sash-a commented Mar 5, 2024

Hi Sasha, This issue seems a bit related to #106. Returning the reset state/observation instead of the terminal state/observation when auto-resetting has always been the desired feature. This is because none of the Jumanji environments uses truncation, so one does not need the terminal state to train an actor-critic agent. Now, if a user implements a new jumanji environments using the Environment abstraction and other tools from Jumanji, including truncation, one may want to use the truncated state/observation in their own training loop, which seems to be your use case, right? Passing it to the extras seems legit to me. 🙌

Yup this is exactly the use case!

jumanji/wrappers.py Outdated Show resolved Hide resolved
jumanji/wrappers.py Outdated Show resolved Hide resolved
@sash-a sash-a merged commit ce8b873 into instadeepai:main Mar 8, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants