Skip to content

Commit

Permalink
feat(examples): TorchRL - MAML integration (#12)
Browse files Browse the repository at this point in the history
Co-authored-by: Benjamin-eecs <benjaminliu.eecs@gmail.com>
  • Loading branch information
vmoens and Benjamin-eecs authored Sep 1, 2022
1 parent 5242e66 commit 0bc346c
Show file tree
Hide file tree
Showing 10 changed files with 355 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add MAML example with TorchRL integration by [@vmoens](https://github.com/vmoens) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#12](https://github.com/metaopt/TorchOpt/pull/12).
- Add optional argument `params` to update function in gradient transformations by [@XuehaiPan](https://github.com/XuehaiPan) in [#65](https://github.com/metaopt/torchopt/pull/65).
- Add option `weight_decay` option to optimizers by [@XuehaiPan](https://github.com/XuehaiPan) in [#65](https://github.com/metaopt/torchopt/pull/65).
- Add option `maximize` option to optimizers by [@XuehaiPan](https://github.com/XuehaiPan) in [#64](https://github.com/metaopt/torchopt/pull/64).
Expand Down
1 change: 1 addition & 0 deletions docs/source/developer/contributor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ Contributor
We always welcome contributions to help make TorchOpt better. Below is an incomplete list of our contributors (find more on `this page <https://github.com/metaopt/torchopt/graphs/contributors>`_).

* Yao Fu (`future-xy <https://github.com/future-xy>`_)
* Vincent Moens (`vmoens <https://github.com/vmoens>`_)
2 changes: 2 additions & 0 deletions examples/MAML-RL/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Specify the seed to train.
```bash
### Run MAML
python maml.py --seed 1
### Run torchrl MAML implementation
python maml_torchrl.py --seed 1
```

## Results
Expand Down
9 changes: 2 additions & 7 deletions examples/MAML-RL/helpers/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# https://github.com/tristandeleu/pytorch-maml-rl
# ==============================================================================

import torch
import torch.nn as nn
from torch.distributions import Categorical

Expand All @@ -27,12 +26,8 @@ class CategoricalMLPPolicy(nn.Module):
with discrete action spaces (eg. `TabularMDPEnv`).
"""

def __init__(
self,
input_size,
output_size,
):
super(CategoricalMLPPolicy, self).__init__()
def __init__(self, input_size, output_size):
super().__init__()
self.torso = nn.Sequential(
nn.Linear(input_size, 32),
nn.ReLU(),
Expand Down
88 changes: 88 additions & 0 deletions examples/MAML-RL/helpers/policy_torchrl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import torch
import torch.nn as nn
from torch.distributions import Categorical
from torchrl.modules import (
ActorValueOperator,
OneHotCategorical,
ProbabilisticActor,
TensorDictModule,
ValueOperator,
)


class Backbone(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.torso = nn.Sequential(
nn.Linear(input_size, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
)

def forward(self, inputs, params=None):
embedding = self.torso(inputs)
return embedding


class CategoricalSubNet(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.policy_head = nn.Linear(32, output_size)

def forward(self, embedding, params=None):
logits = self.policy_head(embedding)
return logits


class ValueSubNet(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.value_head = nn.Linear(32, 1)

def forward(self, embedding, params=None):
value = self.value_head(embedding)
return value


class ActorCritic(ActorValueOperator):
def __init__(self, input_size, output_size):
super().__init__(
TensorDictModule(
spec=None,
module=Backbone(input_size, output_size),
in_keys=['observation'],
out_keys=['hidden'],
),
ProbabilisticActor(
spec=None,
module=TensorDictModule(
CategoricalSubNet(input_size, output_size),
in_keys=['hidden'],
out_keys=['logits'],
),
distribution_class=OneHotCategorical,
return_log_prob=False,
dist_param_keys=['logits'],
out_key_sample=['action'],
),
ValueOperator(
module=ValueSubNet(input_size, output_size),
in_keys=['hidden'],
),
)
7 changes: 5 additions & 2 deletions examples/MAML-RL/helpers/tabular_mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ class TabularMDPEnv(gym.Env):
Learning", 2016 (https://arxiv.org/abs/1611.02779)
"""

def __init__(self, num_states, num_actions, max_episode_steps, seed, task={}):
super(TabularMDPEnv, self).__init__()
def __init__(self, num_states, num_actions, max_episode_steps, seed, task=None):
super().__init__()

task = task or {}

self.max_episode_steps = max_episode_steps
self.num_states = num_states
self.num_actions = num_actions
Expand Down
Binary file modified examples/MAML-RL/maml.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/MAML-RL/maml_torchrl.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 0bc346c

Please sign in to comment.