Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Add a next_observations field to RolloutBufferSamples #1328

Closed
1 task done
euanong opened this issue Feb 12, 2023 · 1 comment · May be fixed by #1329
Closed
1 task done

[Feature Request] Add a next_observations field to RolloutBufferSamples #1328

euanong opened this issue Feb 12, 2023 · 1 comment · May be fixed by #1329
Labels
duplicate This issue or pull request already exists enhancement New feature or request

Comments

@euanong
Copy link

euanong commented Feb 12, 2023

🚀 Feature

When sampling from a RolloutBuffer, we return RolloutBufferSamples containing tensors of observations, actions etc.

def _get_samples(
self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
data = (
self.observations[batch_inds],

It would be nice if RolloutBufferSamples could also contain a batch of next observations (alongside a mask that, for each observation, tells us whether that observation has a successor).

Motivation

I'm implementing an RL pipeline in which I extend PPO with a custom loss. For this custom loss, I need access to (observation, next observation) pairs.

In the PPO implementation

# train for n_epochs epochs
for epoch in range(self.n_epochs):
approx_kl_divs = []
# Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(self.batch_size):
actions = rollout_data.actions

each batch of rollout data over which we compute the PPO loss is a RolloutBufferSample -- and, as these consist of a random subset of observations from the RolloutBuffer, we do not have enough information to compute the next observation for each observation in the batch.

Pitch

I have already implemented this feature and submitted it as a PR [to be linked after submission].

Alternatives

Alternatively, we could return the indices of the sampled elements with respect to the original buffer. While this may allow for more general buffer manipulation, this feels less pleasant to use.

Additional context

No response

Checklist

  • I have checked that there is no similar issue in the repo
@araffin
Copy link
Member

araffin commented Feb 13, 2023

I have checked that there is no similar issue in the repo

Duplicate of #1273

@araffin araffin closed this as not planned Won't fix, can't repro, duplicate, stale Feb 13, 2023
@araffin araffin added the duplicate This issue or pull request already exists label Feb 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
duplicate This issue or pull request already exists enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants