Skip to content

Commit

Permalink
Soft Actor Critic (SAC) Model (#627)
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ authored Sep 8, 2021
1 parent 3f6b122 commit f6a7e98
Show file tree
Hide file tree
Showing 11 changed files with 761 additions and 5 deletions.
7 changes: 3 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `EMNISTDataModule`, `BinaryEMNISTDataModule`, and `BinaryEMNIST` dataset ([#676](https://github.com/PyTorchLightning/lightning-bolts/pull/676))

- Added Soft Actor Critic (SAC) Model [#627](https://github.com/PyTorchLightning/lightning-bolts/pull/627))

- Added Advantage Actor-Critic (A2C) Model [#598](https://github.com/PyTorchLightning/lightning-bolts/pull/598))
- Added `EMNISTDataModule`, `BinaryEMNISTDataModule`, and `BinaryEMNIST` dataset ([#676](https://github.com/PyTorchLightning/lightning-bolts/pull/676))

- Added Advantage Actor-Critic (A2C) Model [#598](https://github.com/PyTorchLightning/lightning-bolts/pull/598))

- Added Torch ORT Callback [#720](https://github.com/PyTorchLightning/lightning-bolts/pull/720))


- Added SparseML Callback [#724](https://github.com/PyTorchLightning/lightning-bolts/pull/724))


### Changed

- Changed the default values `pin_memory=False`, `shuffle=False` and `num_workers=16` to `pin_memory=True`, `shuffle=True` and `num_workers=0` of datamodules ([#701](https://github.com/PyTorchLightning/lightning-bolts/pull/701))
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
76 changes: 76 additions & 0 deletions docs/source/reinforce_learn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -764,3 +764,79 @@ Example::

.. autoclass:: pl_bolts.models.rl.AdvantageActorCritic
:noindex:

--------------


Soft Actor Critic (SAC)
^^^^^^^^^^^^^^^^^^^^^^^

Soft Actor Critic model introduced in `Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor <https://arxiv.org/abs/1801.01290>`__
Paper authors: Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, Sergey Levine

Original implementation by: `Jason Wang <https://github.com/blahBlahhhJ>`_

Soft Actor Critic (SAC) is a powerful actor critic algorithm in reinforcement learning. Unlike A2C, SAC's policy outputs a
special continuous distribution for actions, and its critic estimates the Q value instead of the state value, which
means it now takes in not only states but also actions. The new actor allows SAC to support continuous action tasks such
as controlling robots, and the new critic allows SAC to support off-policy learning which is more sample efficient.

The actor has a new objective to maximize entropy to encourage exploration while maximizing the expected rewards.
The critic uses two separate Q functions to "mitigate positive bias" during training by picking the minimum of the
two as the predicted Q value.

Since SAC is off-policy, its algorithm's training step is quite similar to DQN:

1. Initialize one policy network, two Q networks, and two corresponding target Q networks.
2. Run 1 step using action sampled from policy and store the transition into the replay buffer.

.. math::
a \sim tanh(N(\mu_\pi(s), \sigma_\pi(s)))
3. Sample transitions (states, actions, rewards, dones, next states) from the replay buffer.

.. math::
s, a, r, d, s' \sim B
4. Compute actor loss and update policy network.

.. math::
J_\pi = \frac1n\sum_i(\log\pi(\pi(a | s_i) | s_i) - Q_{min}(s_i, \pi(a | s_i)))
5. Compute Q target

.. math::
target_i = r_i + (1 - d_i) \gamma (\min_i Q_{target,i}(s'_i, \pi(a', s'_i)) - log\pi(\pi(a | s'_i) | s'_i))
5. Compute critic loss and update Q network..

.. math::
J_{Q_i} = \frac1n \sum_i(Q_i(s_i, a_i) - target_i)^2
4. Soft update the target Q network using a weighted sum of itself and the Q network.

.. math::
Q_{target,i} := \tau Q_{target,i} + (1-\tau) Q_i
SAC Benefits
~~~~~~~~~~~~~~~~~~~

- More sample efficient due to off-policy training

- Supports continuous action space

SAC Results
~~~~~~~~~~~~~~~~

.. image:: _images/rl_benchmark/pendulum_sac_results.jpg
:width: 300
:alt: SAC Results

Example::
from pl_bolts.models.rl import SAC
sac = SAC("Pendulum-v0")
trainer = Trainer()
trainer.fit(sac)

.. autoclass:: pl_bolts.models.rl.SAC
:noindex:
2 changes: 2 additions & 0 deletions pl_bolts/models/rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN
from pl_bolts.models.rl.per_dqn_model import PERDQN
from pl_bolts.models.rl.reinforce_model import Reinforce
from pl_bolts.models.rl.sac_model import SAC
from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient

__all__ = [
Expand All @@ -15,5 +16,6 @@
"NoisyDQN",
"PERDQN",
"Reinforce",
"SAC",
"VanillaPolicyGradient",
]
45 changes: 45 additions & 0 deletions pl_bolts/models/rl/common/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,48 @@ def __call__(self, states: Tensor, device: str) -> List[int]:
actions = [np.random.choice(len(prob), p=prob) for prob in prob_np]

return actions


class SoftActorCriticAgent(Agent):
"""Actor-Critic based agent that returns a continuous action based on the policy."""

def __call__(self, states: Tensor, device: str) -> List[float]:
"""Takes in the current state and returns the action based on the agents policy.
Args:
states: current state of the environment
device: the device used for the current batch
Returns:
action defined by policy
"""
if not isinstance(states, list):
states = [states]

if not isinstance(states, Tensor):
states = torch.tensor(states, device=device)

dist = self.net(states)
actions = [a for a in dist.sample().cpu().numpy()]

return actions

def get_action(self, states: Tensor, device: str) -> List[float]:
"""Get the action greedily (without sampling)
Args:
states: current state of the environment
device: the device used for the current batch
Returns:
action defined by policy
"""
if not isinstance(states, list):
states = [states]

if not isinstance(states, Tensor):
states = torch.tensor(states, device=device)

actions = [self.net.get_action(states).cpu().numpy()]

return actions
62 changes: 62 additions & 0 deletions pl_bolts/models/rl/common/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Distributions used in some continuous RL algorithms."""
import torch


class TanhMultivariateNormal(torch.distributions.MultivariateNormal):
"""The distribution of X is an affine of tanh applied on a normal distribution.
X = action_scale * tanh(Z) + action_bias
Z ~ Normal(mean, variance)
"""

def __init__(self, action_bias, action_scale, **kwargs):
super().__init__(**kwargs)

self.action_bias = action_bias
self.action_scale = action_scale

def rsample_with_z(self, sample_shape=torch.Size()):
"""Samples X using reparametrization trick with the intermediate variable Z.
Returns:
Sampled X and Z
"""
z = super().rsample()
return self.action_scale * torch.tanh(z) + self.action_bias, z

def log_prob_with_z(self, value, z):
"""Computes the log probability of a sampled X.
Refer to the original paper of SAC for more details in equation (20), (21)
Args:
value: the value of X
z: the value of Z
Returns:
Log probability of the sample
"""
value = (value - self.action_bias) / self.action_scale
z_logprob = super().log_prob(z)
correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1)
return z_logprob - correction

def rsample_and_log_prob(self, sample_shape=torch.Size()):
"""Samples X and computes the log probability of the sample.
Returns:
Sampled X and log probability
"""
z = super().rsample()
z_logprob = super().log_prob(z)
value = torch.tanh(z)
correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1)
return self.action_scale * value + self.action_bias, z_logprob - correction

def rsample(self, sample_shape=torch.Size()):
fz, z = self.rsample_with_z(sample_shape)
return fz

def log_prob(self, value):
value = (value - self.action_bias) / self.action_scale
z = torch.log(1 + value) / 2 - torch.log(1 - value) / 2
return self.log_prob_with_z(value, z)
62 changes: 61 additions & 1 deletion pl_bolts/models/rl/common/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

import numpy as np
import torch
from torch import Tensor, nn
from torch import FloatTensor, Tensor, nn
from torch.distributions import Categorical, Normal
from torch.nn import functional as F

from pl_bolts.models.rl.common.distributions import TanhMultivariateNormal


class CNN(nn.Module):
"""Simple MLP network."""
Expand Down Expand Up @@ -84,6 +86,64 @@ def forward(self, input_x):
return self.net(input_x.float())


class ContinuousMLP(nn.Module):
"""MLP network that outputs continuous value via Gaussian distribution."""

def __init__(
self,
input_shape: Tuple[int],
n_actions: int,
hidden_size: int = 128,
action_bias: int = 0,
action_scale: int = 1,
):
"""
Args:
input_shape: observation shape of the environment
n_actions: dimension of actions in the environment
hidden_size: size of hidden layers
action_bias: the center of the action space
action_scale: the scale of the action space
"""
super().__init__()
self.action_bias = action_bias
self.action_scale = action_scale

self.shared_net = nn.Sequential(
nn.Linear(input_shape[0], hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU()
)
self.mean_layer = nn.Linear(hidden_size, n_actions)
self.logstd_layer = nn.Linear(hidden_size, n_actions)

def forward(self, x: FloatTensor) -> TanhMultivariateNormal:
"""Forward pass through network. Calculates the action distribution.
Args:
x: input to network
Returns:
action distribution
"""
x = self.shared_net(x.float())
batch_mean = self.mean_layer(x)
logstd = torch.clamp(self.logstd_layer(x), -20, 2)
batch_scale_tril = torch.diag_embed(torch.exp(logstd))
return TanhMultivariateNormal(
action_bias=self.action_bias, action_scale=self.action_scale, loc=batch_mean, scale_tril=batch_scale_tril
)

def get_action(self, x: FloatTensor) -> Tensor:
"""Get the action greedily (without sampling)
Args:
x: input to network
Returns:
mean action
"""
x = self.shared_net(x.float())
batch_mean = self.mean_layer(x)
return self.action_scale * torch.tanh(batch_mean) + self.action_bias


class ActorCriticMLP(nn.Module):
"""MLP network with heads for actor and critic."""

Expand Down
Loading

0 comments on commit f6a7e98

Please sign in to comment.