Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MultiDiscrete/MultiBinary action spaces #5

Closed
araffin opened this issue May 9, 2020 · 7 comments · Fixed by #13
Closed

Add support for MultiDiscrete/MultiBinary action spaces #5

araffin opened this issue May 9, 2020 · 7 comments · Fixed by #13
Labels
enhancement New feature or request
Milestone

Comments

@araffin
Copy link
Member

araffin commented May 9, 2020

  • distributions.py need to be updated (and maybe ppo/a2c) with MultiCategorical and Bernoulli distributions

  • the envs from identity_env.py should help to create tests

@rolandgvc is working on it

@araffin araffin added the enhancement New feature or request label May 9, 2020
@araffin araffin added this to the v1.0 milestone May 9, 2020
@araffin
Copy link
Member Author

araffin commented May 9, 2020

Flattened seems the easiest and cleanest way, no?

@Miffyli
Copy link
Collaborator

Miffyli commented May 9, 2020

Hmm where did the question go disappear? o: .I would like to comment on this after taking a good look at how it could be done. A multi-discrete distribution returning flattened output seems bit convoluted at first sight, but I will comment better when there is an example code suggestion.

@rolandgvc
Copy link
Contributor

rolandgvc commented May 9, 2020

@Miffyli deleted it before @araffin answered, as I though it was obvious =) here is the suggestion:

class MultiCategoricalDistribution(Distribution):
    """
    MultiCategorical distribution for multi discrete actions.
    
    :param action_dims: ([int]) List of sizes of discrete action spaces.
    """
    def __init__(self, action_dims: [int]):
        super(MultiCategoricalDistribution, self).__init__()
        self.action_dims = action_dims
        self.distributions = None

    def proba_distribution_net(self, latent_dim: int) -> nn.Module:
        """
        Create the layer that represents the distribution:
        it will be the logits of the Categorical distribution.
        You can then get probabilities using a softmax.

        :param latent_dim: (int) Dimension of the last layer
            of the policy network (before the action layer)
        :return: (nn.Linear)
        """
        action_logits = nn.Linear(latent_dim, np.sum(self.action_dims))
        return action_logits 

    def proba_distribution(self, action_logits: th.Tensor) -> 'MultiCategoricalDistribution':
        reshaped_logits = action_logits.split(self.action_dims)
        self.distributions = [Categorical(logits=l) for l in reshaped_logits]
        return self

@rolandgvc
Copy link
Contributor

rolandgvc commented May 9, 2020

Progress so far (tested):

class MultiCategoricalDistribution(Distribution):
    """
    MultiCategorical distribution for multi discrete actions.

    :param action_dims: ([int]) List of sizes of discrete action spaces.
    """
    def __init__(self, action_dims: [int]):
        super(MultiCategoricalDistribution, self).__init__()
        self.action_dims = action_dims
        self.distributions = None

    def proba_distribution_net(self, latent_dim: int) -> nn.Module:
        """
        Create the layer that represents the distribution:
        it will be the logits of the Categorical distribution.
        You can then get probabilities using a softmax.

        :param latent_dim: (int) Dimension of the last layer
            of the policy network (before the action layer)
        :return: (nn.Linear)
        """
        action_logits = nn.Linear(latent_dim, np.sum(self.action_dims))
        return action_logits 

    def proba_distribution(self, action_logits: th.Tensor) -> 'MultiCategoricalDistribution':
        reshaped_logits = th.split(action_logits, self.action_dims, dim=1)
        self.distributions = [Categorical(logits=l) for l in reshaped_logits]
        return self

    def mode(self) -> th.Tensor:
        return th.stack([th.argmax(d.probs, dim=1) for d in self.distributions])

    def sample(self) -> th.Tensor:
        return th.stack([d.sample() for d in self.distributions])

    def entropy(self) -> th.Tensor:
        return sum([d.entropy() for d in self.distributions])

    def actions_from_params(self, action_logits: th.Tensor,
                            deterministic: bool = False) -> th.Tensor:
        # Update the proba distribution
        self.proba_distribution(action_logits)
        return self.get_actions(deterministic=deterministic)

    def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        actions = self.actions_from_params(action_logits)
        log_prob = self.log_prob(actions)
        return actions, log_prob

    def log_prob(self, actions: th.Tensor) -> th.Tensor:
        return sum(d.log_prob(x) for d, x in zip(self.distributions, th.unbind(actions)))

Let me know if you have any design suggestions @Miffyli @araffin

@araffin
Copy link
Member Author

araffin commented May 9, 2020

I think it would be better if you open a draft pull request ;)

@rolandgvc
Copy link
Contributor

rolandgvc commented May 10, 2020

@araffin can we just use the .shape attribute for multi spaces here?

def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]:
    """
    Get the shape of the observation (useful for the buffers).

    :param observation_space: (spaces.Space)
    :return: (Tuple[int, ...])
    """
    if isinstance(observation_space, spaces.Box):
        return observation_space.shape
    elif isinstance(observation_space, spaces.Discrete):
        # Observation is an int
        return 1
    elif isinstance(observation_space, spaces.MultiDiscrete):
        return observation_space.shape
    elif isinstance(observation_space, spaces.MultiBinary):
        return observation_space.shape
    else:
        raise NotImplementedError()

@araffin
Copy link
Member Author

araffin commented May 10, 2020

@rolandgvc this is for another issue (this one: #4 ) no?

Looking at the source, this won't work as obs.shape is not defined for MultiBinary, and same for multi discrete

As Gym is not documented, I really recommend to read the source code.

Shunian-Chen pushed a commit to Shunian-Chen/AIPI530 that referenced this issue Nov 14, 2021
Off-Policy State Dependent Exploration
araffin added a commit that referenced this issue Feb 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants