diff --git a/CHANGELOG.md b/CHANGELOG.md index e6486a5de5cf9..d4a848bd4d822 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added Reinforcement Learning - Deep Q-network (DQN) lightning example ([#1232](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232)) - Added support for hierarchical `dict` ([#1152](https://github.com/PyTorchLightning/pytorch-lightning/pull/1152)) - Added `TrainsLogger` class ([#1122](https://github.com/PyTorchLightning/pytorch-lightning/pull/1122)) - Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946)) diff --git a/pl_examples/domain_templates/reinforse_learn_Qnet.py b/pl_examples/domain_templates/reinforse_learn_Qnet.py new file mode 100644 index 0000000000000..4585c108d5cfb --- /dev/null +++ b/pl_examples/domain_templates/reinforse_learn_Qnet.py @@ -0,0 +1,360 @@ +""" +# Deep Reinforcement Learning: Deep Q-network (DQN) + +this example is based off https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On- +Second-Edition/blob/master/Chapter06/02_dqn_pong.py + +The template illustrates using Lightning for Reinforcement Learning. The example builds a basic DQN using the +classic CartPole environment. + +to run the template just run: +python dqn.py + +After ~1500 steps, you will see the total_reward hitting the max score of 200. Open up tensor boards to +see the metrics. + +tensorboard --logdir default +""" + +import pytorch_lightning as pl + +from typing import Tuple, List + +import argparse +from collections import OrderedDict, deque, namedtuple + +import gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from torch.utils.data.dataset import IterableDataset + + +class DQN(nn.Module): + """ + Simple MLP network + + Args: + obs_size: observation/state size of the environment + n_actions: number of discrete actions available in the environment + hidden_size: size of hidden layers + """ + + def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128): + super(DQN, self).__init__() + self.net = nn.Sequential( + nn.Linear(obs_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, n_actions) + ) + + def forward(self, x): + return self.net(x.float()) + + +# Named tuple for storing experience steps gathered in training +Experience = namedtuple( + 'Experience', field_names=['state', 'action', 'reward', + 'done', 'new_state']) + + +class ReplayBuffer: + """ + Replay Buffer for storing past experiences allowing the agent to learn from them + + Args: + capacity: size of the buffer + """ + + def __init__(self, capacity: int) -> None: + self.buffer = deque(maxlen=capacity) + + def __len__(self) -> None: + return len(self.buffer) + + def append(self, experience: Experience) -> None: + """ + Add experience to the buffer + + Args: + experience: tuple (state, action, reward, done, new_state) + """ + self.buffer.append(experience) + + def sample(self, batch_size: int) -> Tuple: + indices = np.random.choice(len(self.buffer), batch_size, replace=False) + states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices]) + + return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32), + np.array(dones, dtype=np.bool), np.array(next_states)) + + +class RLDataset(IterableDataset): + """ + Iterable Dataset containing the ExperienceBuffer + which will be updated with new experiences during training + + Args: + buffer: replay buffer + sample_size: number of experiences to sample at a time + """ + + def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None: + self.buffer = buffer + self.sample_size = sample_size + + def __iter__(self) -> Tuple: + states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size) + for i in range(len(dones)): + yield states[i], actions[i], rewards[i], dones[i], new_states[i] + + +class Agent: + """ + Base Agent class handeling the interaction with the environment + + Args: + env: training environment + replay_buffer: replay buffer storing experiences + """ + + def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None: + self.env = env + self.replay_buffer = replay_buffer + self.reset() + self.state = self.env.reset() + + def reset(self) -> None: + """ Resents the environment and updates the state""" + self.state = self.env.reset() + + def get_action(self, net: nn.Module, epsilon: float, device: str) -> int: + """ + Using the given network, decide what action to carry out + using an epsilon-greedy policy + + Args: + net: DQN network + epsilon: value to determine likelihood of taking a random action + device: current device + + Returns: + action + """ + if np.random.random() < epsilon: + action = self.env.action_space.sample() + else: + state = torch.tensor([self.state]) + + if device not in ['cpu']: + state = state.cuda(device) + + q_values = net(state) + _, action = torch.max(q_values, dim=1) + action = int(action.item()) + + return action + + @torch.no_grad() + def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') -> Tuple[float, bool]: + """ + Carries out a single interaction step between the agent and the environment + + Args: + net: DQN network + epsilon: value to determine likelihood of taking a random action + device: current device + + Returns: + reward, done + """ + + action = self.get_action(net, epsilon, device) + + # do step in the environment + new_state, reward, done, _ = self.env.step(action) + + exp = Experience(self.state, action, reward, done, new_state) + + self.replay_buffer.append(exp) + + self.state = new_state + if done: + self.reset() + return reward, done + + +class DQNLightning(pl.LightningModule): + """ Basic DQN Model """ + + def __init__(self, hparams: argparse.Namespace) -> None: + super().__init__() + self.hparams = hparams + + self.env = gym.make(self.hparams.env) + obs_size = self.env.observation_space.shape[0] + n_actions = self.env.action_space.n + + self.net = DQN(obs_size, n_actions) + self.target_net = DQN(obs_size, n_actions) + + self.buffer = ReplayBuffer(self.hparams.replay_size) + self.agent = Agent(self.env, self.buffer) + self.total_reward = 0 + self.episode_reward = 0 + self.populate(self.hparams.warm_start_steps) + + def populate(self, steps: int = 1000) -> None: + """ + Carries out several random steps through the environment to initially fill + up the replay buffer with experiences + + Args: + steps: number of random steps to populate the buffer with + """ + for i in range(steps): + self.agent.play_step(self.net, epsilon=1.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Passes in a state x through the network and gets the q_values of each action as an output + + Args: + x: environment state + + Returns: + q values + """ + output = self.net(x) + return output + + def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """ + Calculates the mse loss using a mini batch from the replay buffer + + Args: + batch: current mini batch of replay data + + Returns: + loss + """ + states, actions, rewards, dones, next_states = batch + + state_action_values = self.net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1) + + with torch.no_grad(): + next_state_values = self.target_net(next_states).max(1)[0] + next_state_values[dones] = 0.0 + next_state_values = next_state_values.detach() + + expected_state_action_values = next_state_values * self.hparams.gamma + rewards + + return nn.MSELoss()(state_action_values, expected_state_action_values) + + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict: + """ + Carries out a single step through the environment to update the replay buffer. + Then calculates loss based on the minibatch recieved + + Args: + batch: current mini batch of replay data + nb_batch: batch number + + Returns: + Training loss and log metrics + """ + device = self.get_device(batch) + epsilon = max(self.hparams.eps_end, self.hparams.eps_start - + self.global_step + 1 / self.hparams.eps_last_frame) + + # step through environment with agent + reward, done = self.agent.play_step(self.net, epsilon, device) + self.episode_reward += reward + + # calculates training loss + loss = self.dqn_mse_loss(batch) + + if self.trainer.use_dp or self.trainer.use_ddp2: + loss = loss.unsqueeze(0) + + if done: + self.total_reward = self.episode_reward + self.episode_reward = 0 + + # Soft update of target network + if self.global_step % self.hparams.sync_rate == 0: + self.target_net.load_state_dict(self.net.state_dict()) + + log = {'total_reward': torch.tensor(self.total_reward).to(device), + 'reward': torch.tensor(reward).to(device), + 'steps': torch.tensor(self.global_step).to(device)} + + return OrderedDict({'loss': loss, 'log': log, 'progress_bar': log}) + + def configure_optimizers(self) -> List[Optimizer]: + """ Initialize Adam optimizer""" + optimizer = optim.Adam(self.net.parameters(), lr=self.hparams.lr) + return [optimizer] + + def __dataloader(self) -> DataLoader: + """Initialize the Replay Buffer dataset used for retrieving experiences""" + dataset = RLDataset(self.buffer, self.hparams.episode_length) + dataloader = DataLoader(dataset=dataset, + batch_size=self.hparams.batch_size, + sampler=None + ) + return dataloader + + def train_dataloader(self) -> DataLoader: + """Get train loader""" + return self.__dataloader() + + def get_device(self, batch) -> str: + """Retrieve device currently being used by minibatch""" + return batch[0].device.index if self.on_gpu else 'cpu' + + +def main(hparams) -> None: + model = DQNLightning(hparams) + + trainer = pl.Trainer( + gpus=1, + distributed_backend='dp', + early_stop_callback=False, + val_check_interval=100 + ) + + trainer.fit(model) + + +if __name__ == '__main__': + torch.manual_seed(0) + np.random.seed(0) + + parser = argparse.ArgumentParser() + parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") + parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") + parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag") + parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") + parser.add_argument("--sync_rate", type=int, default=10, + help="how many frames do we update the target network") + parser.add_argument("--replay_size", type=int, default=1000, + help="capacity of the replay buffer") + parser.add_argument("--warm_start_size", type=int, default=1000, + help="how many samples do we use to fill our buffer at the start of training") + parser.add_argument("--eps_last_frame", type=int, default=1000, + help="what frame should epsilon stop decaying") + parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon") + parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon") + parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode") + parser.add_argument("--max_episode_reward", type=int, default=200, + help="max episode reward in the environment") + parser.add_argument("--warm_start_steps", type=int, default=1000, + help="max episode reward in the environment") + + args = parser.parse_args() + + main(args) diff --git a/pl_examples/requirements.txt b/pl_examples/requirements.txt index d9f4c0d808165..24506bbba7964 100644 --- a/pl_examples/requirements.txt +++ b/pl_examples/requirements.txt @@ -1 +1,2 @@ -torchvision>=0.4.0 \ No newline at end of file +torchvision>=0.4.0 +gym>=0.17.0 \ No newline at end of file