Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom actor and critic network #1985

Closed
4 tasks done
krishdotn1 opened this issue Aug 2, 2024 · 4 comments
Closed
4 tasks done

Custom actor and critic network #1985

krishdotn1 opened this issue Aug 2, 2024 · 4 comments
Labels
check the checklist You have checked the required items in the checklist but you didn't do what is written... question Further information is requested RTFM Answer is the documentation

Comments

@krishdotn1
Copy link

❓ Question

can anyone explain to me how can I change the default actor and critic network to my network?
I have done this. Step by step of implementation:

  1. created a custom network
  2. def _build_mlp_extractor(self) -> None:
    self.mlp_extractor = CustomNetwork(self.features_dim)
    with my custom network.
    is it enough to run stable-baselines3 as default run ?

Checklist

@krishdotn1 krishdotn1 added the question Further information is requested label Aug 2, 2024
@araffin araffin added RTFM Answer is the documentation check the checklist You have checked the required items in the checklist but you didn't do what is written... labels Aug 2, 2024
@fracapuano
Copy link
Contributor

Hey @krishdotn1 👋

Check out this ⬇️ -- it directly contains the answer to your question 😊

from typing import Callable, Dict, List, Optional, Tuple, Type, Union

from gymnasium import spaces
import torch as th
from torch import nn

from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy


class CustomNetwork(nn.Module):
    """
    Custom network for policy and value function.
    It receives as input the features extracted by the features extractor.

    :param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
    :param last_layer_dim_pi: (int) number of units for the last layer of the policy network
    :param last_layer_dim_vf: (int) number of units for the last layer of the value network
    """

    def __init__(
        self,
        feature_dim: int,
        last_layer_dim_pi: int = 64,
        last_layer_dim_vf: int = 64,
    ):
        super().__init__()

        # IMPORTANT:
        # Save output dimensions, used to create the distributions
        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf

        # Policy network
        self.policy_net = nn.Sequential(
            nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
        )
        # Value network
        self.value_net = nn.Sequential(
            nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
        )

    def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        """
        :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """
        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features: th.Tensor) -> th.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: th.Tensor) -> th.Tensor:
        return self.value_net(features)


class CustomActorCriticPolicy(ActorCriticPolicy):
    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Callable[[float], float],
        *args,
        **kwargs,
    ):
        # Disable orthogonal initialization
        kwargs["ortho_init"] = False
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            # Pass remaining arguments to base class
            *args,
            **kwargs,
        )


    def _build_mlp_extractor(self) -> None:
        self.mlp_extractor = CustomNetwork(self.features_dim)


model = PPO(CustomActorCriticPolicy, "CartPole-v1", verbose=1)
model.learn(5000)

You should just subclass ActorCriticPolicy (or any other policy you would need based on your problem--hard to tell without the necessary context) and add your custom network as mlp_extractor. Feel free to share more details if not clear, happy to help further 🤗

@krishdotn1
Copy link
Author

Thank you @fracapuano.
I have done same thing it works well but when I use SubprocVecEnv it never start training. just show Using Cuda device and loading.
P.S I'm using LNN model for critic and actor.

@fracapuano
Copy link
Contributor

fracapuano commented Aug 5, 2024

Great to hear that you are following the doc!
Would you be able to upload a minimal example to reproduce your issue here?

  • I have never used SubprocVecEnv with a custom model, but my understanding of this is that it should not change much (mind confirming @araffin)?

Thank you!

@araffin
Copy link
Member

araffin commented Aug 13, 2024

@fracapuano thanks for helping out =)

I have never used SubprocVecEnv with a custom model, but my understanding of this is that it should not change much

yes, SubprocVecEnv should not influence the result, it only spawns multiple process to collect data in parallel.
If something goes wrong as that step, it's probably a limitation from the env.

@araffin araffin closed this as completed Aug 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
check the checklist You have checked the required items in the checklist but you didn't do what is written... question Further information is requested RTFM Answer is the documentation
Projects
None yet
Development

No branches or pull requests

3 participants