-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PPO doesn't work with MultiDiscrete observation space #1836
Comments
The simplest way around is to flatten the observation space. from gymnasium.wrappers import FlattenObservation
env = FlattenObservation(CustomEnv()) |
Thank you very much for your answer. env1 = CustomEnv()
env1.observation_space.shape
env1.observation_space.sample()
env2 = FlattenObservation(CustomEnv())
env2.observation_space.shape
env2.observation_space.sample() The two shapes and the results of the samples are different: in the case of env1 we have a shape of (2,2), in the case of env2 we have (10,). A question naturally arises: are there differences in the performance of an algorithm depending on the way I represent the observation (in this case, a flattened or not flattened observation) ? |
Indeed, it's different from what I expected too. It seems that flatten in the multi-discrete case works in a very counter-intuitive way (at least for me). As far as I can see, there's no wrapper that allows this, so you'll have to create your own wrapper: from gymnasium import ObservationWrapper
class FlattenMultiDiscrete(ObservationWrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = MultiDiscrete(env.observation_space.nvec.flatten())
def observation(self, observation):
return observation.flatten()
env = FlattenMultiDiscrete(CustomEnv()) |
Note: the env checker must be updated to warn users that we don't support multi-dim multi discrete and propose a fix (the one from @qgallouedec ). |
Thank you very much for the answer. |
Please let it open until the env checker is updated :) |
🐛 Bug
I am implementing a simple custom environment for using PPO with MultiDiscrete observation space.
It works if I use MultiDiscrete([ 5, 2, 2 ]), but when it becomes a multidimensional array it fails. In the code I attach I am using the MultiDiscrete observation given as example in https://gymnasium.farama.org/api/spaces/fundamental/#gymnasium.spaces.MultiDiscrete .
Code example
Relevant log output / Error message
System Info
Checklist
The text was updated successfully, but these errors were encountered: