Skip to content

Commit

Permalink
Add support for pre and post linear modules in create_mlp (#1975)
Browse files Browse the repository at this point in the history
* Add support for pre and post linear modules in `create_mlp`

* Disable mypy for python 3.8

* Reformat toml file

* Update docstring

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Add some comments

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
  • Loading branch information
araffin and qgallouedec committed Jul 22, 2024
1 parent 1a69fc8 commit 000544c
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 61 deletions.
72 changes: 37 additions & 35 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ name: CI

on:
push:
branches: [ master ]
branches: [master]
pull_request:
branches: [ master ]
branches: [master]

jobs:
build:
Expand All @@ -23,38 +23,40 @@ jobs:
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
pip install .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
pip install .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
# Do not run for python 3.8 (mypy internal error)
if: matrix.python-version != '3.8'
- name: Test with pytest
run: |
make pytest
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
Changelog
==========

Release 2.4.0a5 (WIP)
Release 2.4.0a6 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^

New Features:
^^^^^^^^^^^^^
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)

Bug Fixes:
^^^^^^^^^^
Expand Down
32 changes: 17 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ ignore = ["B028", "RUF013"]

[tool.ruff.lint.per-file-ignores]
# Default implementation in abstract methods
"./stable_baselines3/common/callbacks.py"= ["B027"]
"./stable_baselines3/common/noise.py"= ["B027"]
"./stable_baselines3/common/callbacks.py" = ["B027"]
"./stable_baselines3/common/noise.py" = ["B027"]
# ClassVar, implicit optional check not needed for tests
"./tests/*.py"= ["RUF012", "RUF013"]
"./tests/*.py" = ["RUF012", "RUF013"]


[tool.ruff.lint.mccabe]
Expand All @@ -37,33 +37,35 @@ exclude = """(?x)(

[tool.pytest.ini_options]
# Deterministic ordering for tests; useful for pytest-xdist.
env = [
"PYTHONHASHSEED=0"
]
env = ["PYTHONHASHSEED=0"]

filterwarnings = [
# Tensorboard warnings
"ignore::DeprecationWarning:tensorboard",
# Gymnasium warnings
"ignore::UserWarning:gymnasium",
# tqdm warning about rich being experimental
"ignore:rich is experimental"
"ignore:rich is experimental",
]
markers = [
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')",
]

[tool.coverage.run]
disable_warnings = ["couldnt-parse"]
branch = false
omit = [
"tests/*",
"setup.py",
# Require graphical interface
"stable_baselines3/common/results_plotter.py",
# Require ffmpeg
"stable_baselines3/common/vec_env/vec_video_recorder.py",
"tests/*",
"setup.py",
# Require graphical interface
"stable_baselines3/common/results_plotter.py",
# Require ffmpeg
"stable_baselines3/common/vec_env/vec_video_recorder.py",
]

[tool.coverage.report]
exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"]
exclude_lines = [
"pragma: no cover",
"raise NotImplementedError()",
"if typing.TYPE_CHECKING:",
]
53 changes: 44 additions & 9 deletions stable_baselines3/common/torch_layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple, Type, Union
from typing import Dict, List, Optional, Tuple, Type, Union

import gymnasium as gym
import torch as th
Expand All @@ -14,7 +14,7 @@ class BaseFeaturesExtractor(nn.Module):
"""
Base class that represents a features extractor.
:param observation_space:
:param observation_space: The observation space of the environment
:param features_dim: Number of features extracted.
"""

Expand All @@ -26,6 +26,7 @@ def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None:

@property
def features_dim(self) -> int:
"""The number of features that the extractor outputs."""
return self._features_dim


Expand All @@ -34,7 +35,7 @@ class FlattenExtractor(BaseFeaturesExtractor):
Feature extract that flatten the input.
Used as a placeholder when feature extraction is not needed.
:param observation_space:
:param observation_space: The observation space of the environment
"""

def __init__(self, observation_space: gym.Space) -> None:
Expand All @@ -52,7 +53,7 @@ class NatureCNN(BaseFeaturesExtractor):
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529-533.
:param observation_space:
:param observation_space: The observation space of the environment
:param features_dim: Number of features extracted.
This corresponds to the number of unit for the last layer.
:param normalized_image: Whether to assume that the image is already normalized
Expand Down Expand Up @@ -113,13 +114,15 @@ def create_mlp(
activation_fn: Type[nn.Module] = nn.ReLU,
squash_output: bool = False,
with_bias: bool = True,
pre_linear_modules: Optional[List[Type[nn.Module]]] = None,
post_linear_modules: Optional[List[Type[nn.Module]]] = None,
) -> List[nn.Module]:
"""
Create a multi layer perceptron (MLP), which is
a collection of fully-connected layers each followed by an activation function.
:param input_dim: Dimension of the input vector
:param output_dim:
:param output_dim: Dimension of the output (last layer, for instance, the number of actions)
:param net_arch: Architecture of the neural net
It represents the number of units per layer.
The length of this list is the number of layers.
Expand All @@ -128,20 +131,52 @@ def create_mlp(
:param squash_output: Whether to squash the output using a Tanh
activation function
:param with_bias: If set to False, the layers will not learn an additive bias
:return:
:param pre_linear_modules: List of nn.Module to add before the linear layers.
These modules should maintain the input tensor dimension (e.g. BatchNorm).
The number of input features is passed to the module's constructor.
Compared to post_linear_modules, they are used before the output layer (output_dim > 0).
:param post_linear_modules: List of nn.Module to add after the linear layers
(and before the activation function). These modules should maintain the input
tensor dimension (e.g. Dropout, LayerNorm). They are not used after the
output layer (output_dim > 0). The number of input features is passed to
the module's constructor.
:return: The list of layers of the neural network
"""

pre_linear_modules = pre_linear_modules or []
post_linear_modules = post_linear_modules or []

modules = []
if len(net_arch) > 0:
modules = [nn.Linear(input_dim, net_arch[0], bias=with_bias), activation_fn()]
else:
modules = []
# BatchNorm maintains input dim
for module in pre_linear_modules:
modules.append(module(input_dim))

modules.append(nn.Linear(input_dim, net_arch[0], bias=with_bias))

# LayerNorm, Dropout maintain output dim
for module in post_linear_modules:
modules.append(module(net_arch[0]))

modules.append(activation_fn())

for idx in range(len(net_arch) - 1):
for module in pre_linear_modules:
modules.append(module(net_arch[idx]))

modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias))

for module in post_linear_modules:
modules.append(module(net_arch[idx + 1]))

modules.append(activation_fn())

if output_dim > 0:
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
# Only add BatchNorm before output layer
for module in pre_linear_modules:
modules.append(module(last_layer_dim))

modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias))
if squash_output:
modules.append(nn.Tanh())
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a5
2.4.0a6
56 changes: 56 additions & 0 deletions tests/test_custom_policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
import torch as th
import torch.nn as nn

from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike
from stable_baselines3.common.torch_layers import create_mlp


@pytest.mark.parametrize(
Expand Down Expand Up @@ -62,3 +64,57 @@ def test_tf_like_rmsprop_optimizer():
def test_dqn_custom_policy():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300)


def test_create_mlp():
net = create_mlp(4, 2, net_arch=[16, 8], squash_output=True)
# We cannot compare the network directly because the modules have different ids
# assert net == [nn.Linear(4, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2),
# nn.Tanh()]
assert len(net) == 6
assert isinstance(net[0], nn.Linear)
assert net[0].in_features == 4
assert net[0].out_features == 16
assert isinstance(net[1], nn.ReLU)
assert isinstance(net[2], nn.Linear)
assert isinstance(net[4], nn.Linear)
assert net[4].in_features == 8
assert net[4].out_features == 2
assert isinstance(net[5], nn.Tanh)

# Linear network
net = create_mlp(4, -1, net_arch=[])
assert net == []

# No output layer, with custom activation function
net = create_mlp(6, -1, net_arch=[8], activation_fn=nn.Tanh)
# assert net == [nn.Linear(6, 8), nn.Tanh()]
assert len(net) == 2
assert isinstance(net[0], nn.Linear)
assert net[0].in_features == 6
assert net[0].out_features == 8
assert isinstance(net[1], nn.Tanh)

# Using pre-linear and post-linear modules
pre_linear = [nn.BatchNorm1d]
post_linear = [nn.LayerNorm]
net = create_mlp(6, 2, net_arch=[8, 12], pre_linear_modules=pre_linear, post_linear_modules=post_linear)
# assert net == [nn.BatchNorm1d(6), nn.Linear(6, 8), nn.LayerNorm(8), nn.ReLU()
# nn.BatchNorm1d(6), nn.Linear(8, 12), nn.LayerNorm(12), nn.ReLU(),
# nn.BatchNorm1d(12), nn.Linear(12, 2)] # Last layer does not have post_linear
assert len(net) == 10
assert isinstance(net[0], nn.BatchNorm1d)
assert net[0].num_features == 6
assert isinstance(net[1], nn.Linear)
assert isinstance(net[2], nn.LayerNorm)
assert isinstance(net[3], nn.ReLU)
assert isinstance(net[4], nn.BatchNorm1d)
assert isinstance(net[5], nn.Linear)
assert net[5].in_features == 8
assert net[5].out_features == 12
assert isinstance(net[6], nn.LayerNorm)
assert isinstance(net[7], nn.ReLU)
assert isinstance(net[8], nn.BatchNorm1d)
assert isinstance(net[-1], nn.Linear)
assert net[-1].in_features == 12
assert net[-1].out_features == 2

0 comments on commit 000544c

Please sign in to comment.