Skip to content

Commit

Permalink
Example: Simple RL example using DQN/Lightning (#1232)
Browse files Browse the repository at this point in the history
* Example: Simple RL example using DQN/Lightning

* DQN RL Agent using Lightning

* Uses Iterable Dataset for Replay Buffer

* Buffer is populated by agent as training is carried out, updating the
dataset

* Applied autopep8 fixes

* * Updated line length from 120 to 110

* Update pl_examples/domain_templates/dqn.py

simplify get_device method

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pl_examples/domain_templates/dqn.py

Re-ordered imports

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* CI: split tests-examples (#990)

* CI: split tests-examples

* tests without template

* comment depends

* CircleCI typo

* add doctest

* update test req.

* CI tests

* setup macOS

* longer train

* lover pred acc

* fix model

* rename default model

* lower tests acc

* typo

* imports

* fix test optimizer

* update calls

* fix Win

* lower Drone image

* fix call

* pytorch image

* fix test

* add dev image

* add dev image

* update image

* drone volume

* lint

* update test notes

* rename tests/models >> tests/base

* group models

* conftest

* optim imports

* typos

* fix import

* fix tests

* install AMP

* tests

* fix import

* Clean up

* added module docstring

* renamed variables to be more descriptive

* Added missing docstrings and type annotations

* Added gym to example requirements

* Added note to changelog

* updated example image

* update types

* rename script

* Update CHANGELOG.md

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* another rename

* Disable validation when val_percent_check=0 (#1251)

* fix disable validation

* add test

* update changelog

* update docs for val_percent_check

* make "fast training" docs consistent

* calling self.forward() -> self() (#1211)

* self.forward() -> self()

* update changelog

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Fix requirements-extra.txt Trains package to release version (#1229)

* Fix requirement-extra use released Trains package

* Update README.md add Trains and links to the external Visualization section

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Remove unnecessary parameters to super() in documentation and source code (#1240)

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* update deprecation warning (#1258)

* update docs for progress bat values (#1253)

* lower timeouts for inactive issues (#1250)

* update contrib list (#1241)

Co-authored-by: William Falcon <waf2107@columbia.edu>

* Fix outdated docs (#1227)

* Fix typo (#1224)

* drop unused Tox (#1242)

* system info (#1234)

* system info

* update big info

* test script

* update config

* rename script

* import path

* Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194)

* Example: Simple RL example using DQN/Lightning

* DQN RL Agent using Lightning

* Uses Iterable Dataset for Replay Buffer

* Buffer is populated by agent as training is carried out, updating the
dataset

* Applied autopep8 fixes

* * Updated line length from 120 to 110

* Update pl_examples/domain_templates/dqn.py

simplify get_device method

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pl_examples/domain_templates/dqn.py

Re-ordered imports

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Clean up

* added module docstring

* renamed variables to be more descriptive

* Added missing docstrings and type annotations

* Added gym to example requirements

* Added note to changelog

* update types

* rename script

* Update CHANGELOG.md

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* another rename

Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch>
Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com>
Co-authored-by: Tyler Yep <tyep@stanford.edu>
Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com>
Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
  • Loading branch information
10 people committed Mar 28, 2020
1 parent 4e0d0ab commit dab3b96
Show file tree
Hide file tree
Showing 3 changed files with 363 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added Reinforcement Learning - Deep Q-network (DQN) lightning example ([#1232](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232))
- Added support for hierarchical `dict` ([#1152](https://github.com/PyTorchLightning/pytorch-lightning/pull/1152))
- Added `TrainsLogger` class ([#1122](https://github.com/PyTorchLightning/pytorch-lightning/pull/1122))
- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
Expand Down
360 changes: 360 additions & 0 deletions pl_examples/domain_templates/reinforse_learn_Qnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,360 @@
"""
# Deep Reinforcement Learning: Deep Q-network (DQN)
this example is based off https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-
Second-Edition/blob/master/Chapter06/02_dqn_pong.py
The template illustrates using Lightning for Reinforcement Learning. The example builds a basic DQN using the
classic CartPole environment.
to run the template just run:
python dqn.py
After ~1500 steps, you will see the total_reward hitting the max score of 200. Open up tensor boards to
see the metrics.
tensorboard --logdir default
"""

import pytorch_lightning as pl

from typing import Tuple, List

import argparse
from collections import OrderedDict, deque, namedtuple

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset


class DQN(nn.Module):
"""
Simple MLP network
Args:
obs_size: observation/state size of the environment
n_actions: number of discrete actions available in the environment
hidden_size: size of hidden layers
"""

def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
super(DQN, self).__init__()
self.net = nn.Sequential(
nn.Linear(obs_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, n_actions)
)

def forward(self, x):
return self.net(x.float())


# Named tuple for storing experience steps gathered in training
Experience = namedtuple(
'Experience', field_names=['state', 'action', 'reward',
'done', 'new_state'])


class ReplayBuffer:
"""
Replay Buffer for storing past experiences allowing the agent to learn from them
Args:
capacity: size of the buffer
"""

def __init__(self, capacity: int) -> None:
self.buffer = deque(maxlen=capacity)

def __len__(self) -> None:
return len(self.buffer)

def append(self, experience: Experience) -> None:
"""
Add experience to the buffer
Args:
experience: tuple (state, action, reward, done, new_state)
"""
self.buffer.append(experience)

def sample(self, batch_size: int) -> Tuple:
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])

return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32),
np.array(dones, dtype=np.bool), np.array(next_states))


class RLDataset(IterableDataset):
"""
Iterable Dataset containing the ExperienceBuffer
which will be updated with new experiences during training
Args:
buffer: replay buffer
sample_size: number of experiences to sample at a time
"""

def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
self.buffer = buffer
self.sample_size = sample_size

def __iter__(self) -> Tuple:
states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)
for i in range(len(dones)):
yield states[i], actions[i], rewards[i], dones[i], new_states[i]


class Agent:
"""
Base Agent class handeling the interaction with the environment
Args:
env: training environment
replay_buffer: replay buffer storing experiences
"""

def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
self.env = env
self.replay_buffer = replay_buffer
self.reset()
self.state = self.env.reset()

def reset(self) -> None:
""" Resents the environment and updates the state"""
self.state = self.env.reset()

def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:
"""
Using the given network, decide what action to carry out
using an epsilon-greedy policy
Args:
net: DQN network
epsilon: value to determine likelihood of taking a random action
device: current device
Returns:
action
"""
if np.random.random() < epsilon:
action = self.env.action_space.sample()
else:
state = torch.tensor([self.state])

if device not in ['cpu']:
state = state.cuda(device)

q_values = net(state)
_, action = torch.max(q_values, dim=1)
action = int(action.item())

return action

@torch.no_grad()
def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') -> Tuple[float, bool]:
"""
Carries out a single interaction step between the agent and the environment
Args:
net: DQN network
epsilon: value to determine likelihood of taking a random action
device: current device
Returns:
reward, done
"""

action = self.get_action(net, epsilon, device)

# do step in the environment
new_state, reward, done, _ = self.env.step(action)

exp = Experience(self.state, action, reward, done, new_state)

self.replay_buffer.append(exp)

self.state = new_state
if done:
self.reset()
return reward, done


class DQNLightning(pl.LightningModule):
""" Basic DQN Model """

def __init__(self, hparams: argparse.Namespace) -> None:
super().__init__()
self.hparams = hparams

self.env = gym.make(self.hparams.env)
obs_size = self.env.observation_space.shape[0]
n_actions = self.env.action_space.n

self.net = DQN(obs_size, n_actions)
self.target_net = DQN(obs_size, n_actions)

self.buffer = ReplayBuffer(self.hparams.replay_size)
self.agent = Agent(self.env, self.buffer)
self.total_reward = 0
self.episode_reward = 0
self.populate(self.hparams.warm_start_steps)

def populate(self, steps: int = 1000) -> None:
"""
Carries out several random steps through the environment to initially fill
up the replay buffer with experiences
Args:
steps: number of random steps to populate the buffer with
"""
for i in range(steps):
self.agent.play_step(self.net, epsilon=1.0)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Passes in a state x through the network and gets the q_values of each action as an output
Args:
x: environment state
Returns:
q values
"""
output = self.net(x)
return output

def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
"""
Calculates the mse loss using a mini batch from the replay buffer
Args:
batch: current mini batch of replay data
Returns:
loss
"""
states, actions, rewards, dones, next_states = batch

state_action_values = self.net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)

with torch.no_grad():
next_state_values = self.target_net(next_states).max(1)[0]
next_state_values[dones] = 0.0
next_state_values = next_state_values.detach()

expected_state_action_values = next_state_values * self.hparams.gamma + rewards

return nn.MSELoss()(state_action_values, expected_state_action_values)

def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict:
"""
Carries out a single step through the environment to update the replay buffer.
Then calculates loss based on the minibatch recieved
Args:
batch: current mini batch of replay data
nb_batch: batch number
Returns:
Training loss and log metrics
"""
device = self.get_device(batch)
epsilon = max(self.hparams.eps_end, self.hparams.eps_start -
self.global_step + 1 / self.hparams.eps_last_frame)

# step through environment with agent
reward, done = self.agent.play_step(self.net, epsilon, device)
self.episode_reward += reward

# calculates training loss
loss = self.dqn_mse_loss(batch)

if self.trainer.use_dp or self.trainer.use_ddp2:
loss = loss.unsqueeze(0)

if done:
self.total_reward = self.episode_reward
self.episode_reward = 0

# Soft update of target network
if self.global_step % self.hparams.sync_rate == 0:
self.target_net.load_state_dict(self.net.state_dict())

log = {'total_reward': torch.tensor(self.total_reward).to(device),
'reward': torch.tensor(reward).to(device),
'steps': torch.tensor(self.global_step).to(device)}

return OrderedDict({'loss': loss, 'log': log, 'progress_bar': log})

def configure_optimizers(self) -> List[Optimizer]:
""" Initialize Adam optimizer"""
optimizer = optim.Adam(self.net.parameters(), lr=self.hparams.lr)
return [optimizer]

def __dataloader(self) -> DataLoader:
"""Initialize the Replay Buffer dataset used for retrieving experiences"""
dataset = RLDataset(self.buffer, self.hparams.episode_length)
dataloader = DataLoader(dataset=dataset,
batch_size=self.hparams.batch_size,
sampler=None
)
return dataloader

def train_dataloader(self) -> DataLoader:
"""Get train loader"""
return self.__dataloader()

def get_device(self, batch) -> str:
"""Retrieve device currently being used by minibatch"""
return batch[0].device.index if self.on_gpu else 'cpu'


def main(hparams) -> None:
model = DQNLightning(hparams)

trainer = pl.Trainer(
gpus=1,
distributed_backend='dp',
early_stop_callback=False,
val_check_interval=100
)

trainer.fit(model)


if __name__ == '__main__':
torch.manual_seed(0)
np.random.seed(0)

parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag")
parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
parser.add_argument("--sync_rate", type=int, default=10,
help="how many frames do we update the target network")
parser.add_argument("--replay_size", type=int, default=1000,
help="capacity of the replay buffer")
parser.add_argument("--warm_start_size", type=int, default=1000,
help="how many samples do we use to fill our buffer at the start of training")
parser.add_argument("--eps_last_frame", type=int, default=1000,
help="what frame should epsilon stop decaying")
parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon")
parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon")
parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode")
parser.add_argument("--max_episode_reward", type=int, default=200,
help="max episode reward in the environment")
parser.add_argument("--warm_start_steps", type=int, default=1000,
help="max episode reward in the environment")

args = parser.parse_args()

main(args)
3 changes: 2 additions & 1 deletion pl_examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
torchvision>=0.4.0
torchvision>=0.4.0
gym>=0.17.0

1 comment on commit dab3b96

@Borda
Copy link
Member

@Borda Borda commented on dab3b96 Mar 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we please next time correct the commit message, this generated plenty of false cross-references and all contributors to these issues are notified with this new commit...

Please sign in to comment.