Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Temporal Convolutional network #1984

Open
2 tasks done
tty666 opened this issue Jul 30, 2024 · 1 comment
Open
2 tasks done

[Feature Request] Temporal Convolutional network #1984

tty666 opened this issue Jul 30, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@tty666
Copy link

tty666 commented Jul 30, 2024

🚀 Feature

Hello guys,
After watching this video :
https://www.youtube.com/watch?v=WoLlZLdoEQk
I had the idea to extend the NatureCNN to NatureCTN1D this way :

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()
class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, activation_fn, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.activation1 = activation_fn()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.activation2 = activation_fn()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.activation1, self.dropout1,
                                    self.conv2, self.chomp2, self.activation2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.activation = activation_fn()

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.activation(out + res)
class NatureTCN1D(BaseFeaturesExtractor):
    def __init__(
        self,
        observation_space: gym.spaces.Box,
        features_dim: int = 256,
        dataset_dim: int = 21,
        activation_fn=nn.SiLU,
        dropout=0.2,
    ) -> None:
        super().__init__(observation_space, features_dim)
        self.dataset_dim = dataset_dim
        n_input_channels = self.dataset_dim
        
        self.tcn = nn.Sequential(
            TemporalBlock(n_input_channels, 32, kernel_size=5, stride=1, dilation=1, padding=(5-1) * 1, activation_fn=activation_fn, dropout=dropout),
            TemporalBlock(32, 64, kernel_size=7, stride=1, dilation=2, padding=(7-1) * 2, activation_fn=activation_fn, dropout=dropout),
            TemporalBlock(64, 128, kernel_size=3, stride=1, dilation=4, padding=(3-1) * 4, activation_fn=activation_fn, dropout=dropout),
            nn.Flatten(),
        )

        with th.no_grad():
            sample_observation = th.as_tensor(observation_space.sample()[None, :, :self.dataset_dim]).float()
            sample_observation = sample_observation.permute(0, 2, 1)
            n_flatten = self.tcn(sample_observation).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.SiLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        observations = observations[:, :, :self.dataset_dim]
        observations = observations.permute(0, 2, 1)
        return self.linear(self.tcn(observations))
class CombinedExtractorT1D(BaseFeaturesExtractor):
    """
    Combined features extractor for Box observation spaces.
    The input is fed through two separate submodules (CNN1D and Flatten),
    the output features are concatenated and fed through additional MLP network ("combined").

    :param observation_space: The observation space
    :param cnn_output_dim: Number of features to output from the CNN1D submodule(s). Defaults to
        256 to avoid exploding network sizes.
    """

    def __init__(
        self,
        observation_space: spaces.Box,
        cnn_output_dim: int = 256,
        dataset_dim: int = 5,
    ) -> None:
        assert isinstance(observation_space, spaces.Box), (
            "NatureCNN1D must be used with a gym.spaces.Box ",
            f"observation space, not {observation_space}",
        )
        super().__init__(observation_space, features_dim=1)
        # We assume CxL inputs (channels first)
        assert len(observation_space.shape) == 2, (
            "You should use NatureCNN1D only with 2D inputs (channels, length)"
        )
        self.cnn_extractor = NatureTCN1D(observation_space, features_dim=cnn_output_dim, dataset_dim=dataset_dim)
        self.raw_extractor = nn.Flatten()

        cnn_output_size = cnn_output_dim
        raw_output_size = get_flattened_obs_dim(observation_space)

        # Update the features dim manually
        self._features_dim = cnn_output_size + raw_output_size

    def forward(self, observations: th.Tensor) -> th.Tensor:
        cnn_encoded = self.cnn_extractor(observations)
        raw_encoded = self.raw_extractor(observations)

        return th.cat([raw_encoded, cnn_encoded], dim=1)
class ActorCriticTCN1DPolicy(ActorCriticPolicy):
    """
    CNN policy class for actor-critic algorithms (has both policy and value prediction).
    Used by A2C, PPO and the likes.

    :param observation_space: Observation space
    :param action_space: Action space
    :param lr_schedule: Learning rate schedule (could be constant)
    :param net_arch: The specification of the policy and value networks.
    :param activation_fn: Activation function
    :param ortho_init: Whether to use or not orthogonal initialization
    :param use_sde: Whether to use State Dependent Exploration or not
    :param log_std_init: Initial value for the log standard deviation
    :param full_std: Whether to use (n_features x n_actions) parameters
        for the std instead of only (n_features,) when using gSDE
    :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
        a positive standard deviation (cf paper). It allows to keep variance
        above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
    :param squash_output: Whether to squash the output using a tanh function,
        this allows to ensure boundaries when using gSDE.
    :param features_extractor_class: Features extractor to use.
    :param features_extractor_kwargs: Keyword arguments
        to pass to the features extractor.
    :param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
    :param normalize_images: Whether to normalize images or not,
        dividing by 255.0 (True by default)
    :param optimizer_class: The optimizer to use,
        ``th.optim.Adam`` by default
    :param optimizer_kwargs: Additional keyword arguments,
        excluding the learning rate, to pass to the optimizer
    """

    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Schedule,
        net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
        activation_fn: Type[nn.Module] = nn.Tanh,
        ortho_init: bool = True,
        use_sde: bool = False,
        log_std_init: float = 0.0,
        full_std: bool = True,
        use_expln: bool = False,
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = NatureTCN1D,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        share_features_extractor: bool = True,
        normalize_images: bool = True,
        optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
    ):
        # Ajout d'une valeur par défaut si optimizer_kwargs est None
        # if optimizer_kwargs is None:
        #     optimizer_kwargs = {}
        #     optimizer_kwargs["eps"] = 1e-5
        #print(f"optimizer_kwargs: {optimizer_kwargs}")  # Debug message
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            net_arch,
            activation_fn,
            ortho_init,
            use_sde,
            log_std_init,
            full_std,
            use_expln,
            squash_output,
            features_extractor_class,
            features_extractor_kwargs,
            share_features_extractor,
            normalize_images,
            optimizer_class,
            optimizer_kwargs,
        )
class MultiInputActorCriticPolicyT1D(ActorCriticPolicy):
    """
    MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction).
    Used by A2C, PPO and the likes.

    :param observation_space: Observation space (Tuple)
    :param action_space: Action space
    :param lr_schedule: Learning rate schedule (could be constant)
    :param net_arch: The specification of the policy and value networks.
    :param activation_fn: Activation function
    :param ortho_init: Whether to use or not orthogonal initialization
    :param use_sde: Whether to use State Dependent Exploration or not
    :param log_std_init: Initial value for the log standard deviation
    :param full_std: Whether to use (n_features x n_actions) parameters
        for the std instead of only (n_features,) when using gSDE
    :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
        a positive standard deviation (cf paper). It allows to keep variance
        above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
    :param squash_output: Whether to squash the output using a tanh function,
        this allows to ensure boundaries when using gSDE.
    :param features_extractor_class: Uses the CombinedExtractor
    :param features_extractor_kwargs: Keyword arguments
        to pass to the features extractor.
    :param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
    :param normalize_images: Whether to normalize images or not,
        dividing by 255.0 (True by default)
    :param optimizer_class: The optimizer to use,
        ``th.optim.Adam`` by default
    :param optimizer_kwargs: Additional keyword arguments,
        excluding the learning rate, to pass to the optimizer
    """

    def __init__(
        self,
        observation_space: spaces.Box,
        action_space: spaces.Space,
        lr_schedule: Schedule,
        net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
        activation_fn: Type[nn.Module] = nn.Tanh,
        ortho_init: bool = True,
        use_sde: bool = False,
        log_std_init: float = 0.0,
        full_std: bool = True,
        use_expln: bool = False,
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractorT1D,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        share_features_extractor: bool = True,
        normalize_images: bool = True,
        optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
    ):
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            net_arch,
            activation_fn,
            ortho_init,
            use_sde,
            log_std_init,
            full_std,
            use_expln,
            squash_output,
            features_extractor_class,
            features_extractor_kwargs,
            share_features_extractor,
            normalize_images,
            optimizer_class,
            optimizer_kwargs,
        )

It's a quick addition I am pretty sure I could work more on it...
But maybe it's a good addition and sometimes replacement for LSTM/RNN ?
I am using SiLU in my context but in a more "general" way ReLU could be used as activation function ...
What do you think Should I propose it as a pull request for the contrib repo ? Or it doesn't make any sense for you ?

Motivation

The ReccurentPPO wasn't fitting my need so I did some research about other possibilities ...

Pitch

No response

Alternatives

No response

Additional context

No response

Checklist

  • I have checked that there is no similar issue in the repo
  • If I'm requesting a new feature, I have proposed alternatives
@tty666 tty666 added the enhancement New feature or request label Jul 30, 2024
@araffin
Copy link
Member

araffin commented Jul 31, 2024

But maybe it's a good addition and sometimes replacement for LSTM/RNN ?

that would be more for SB3 contrib I guess.
And without any benchmark, it's hard to say if it's a good addition.
For instance, for recurrent PPO: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4

(the gain is marginal with respect to frame stacking on several envs but it is substantial on others like lunar lander without velocity)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants