Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(examples): TorchRL - MAML integration #12

Merged
merged 66 commits into from
Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
72c0709
init
vmoens May 10, 2022
c57d607
remove checks
vmoens May 10, 2022
34e24ab
device casting
vmoens May 10, 2022
475fa6c
GAE
vmoens May 12, 2022
0ebd589
GAE bf
vmoens May 12, 2022
6471cc4
TDLambdaEstimate
vmoens May 13, 2022
aa3f0e6
TDLambdaEstimate update
vmoens May 13, 2022
62a0d1c
pbar for regular algo
vmoens May 13, 2022
34fb175
pbar for regular algo
vmoens May 13, 2022
4c88832
SerialEnv
vmoens May 13, 2022
5d32bb1
selected keys
vmoens May 13, 2022
44f6b51
selected keys
vmoens May 13, 2022
aa7ce2e
timeit
vmoens May 13, 2022
55d2a8e
amend
vmoens May 14, 2022
65efa02
amend
vmoens May 14, 2022
ce473f1
fix: resolve conflicts
Benjamin-eecs Jul 20, 2022
87d14aa
merge: resolve conflicts
Benjamin-eecs Jul 20, 2022
7154b5d
merge: resolve conflicts
Benjamin-eecs Jul 20, 2022
eeb815a
merge: update missing files
Benjamin-eecs Jul 20, 2022
71f75cc
fix: correct MAML-RL implementation
Benjamin-eecs Jul 25, 2022
b918213
fix: sync with torchrl newest version
Benjamin-eecs Jul 25, 2022
7d540fd
fix: sync with torchrl newest version
Benjamin-eecs Jul 25, 2022
d493241
merge: resolve conflicts
Benjamin-eecs Jul 25, 2022
35fb0b7
merge: resolve conflicts
Benjamin-eecs Jul 25, 2022
4171196
test: add run.sh
Benjamin-eecs Jul 25, 2022
1d03e62
fix: correct torchrl env device error, working with GPU
Benjamin-eecs Jul 25, 2022
69e43e6
fix: correct torchrl env device error, working with GPU
Benjamin-eecs Jul 25, 2022
7f06f3e
test: add save fn for testing
Benjamin-eecs Jul 25, 2022
7d67c5f
Merge remote-tracking branch 'origin' into torchrl_new
Benjamin-eecs Jul 27, 2022
62722cd
fix: update maml examples inner lr
Benjamin-eecs Jul 27, 2022
4f0379a
Merge remote-tracking branch 'origin' into torchrl_new
Benjamin-eecs Aug 1, 2022
381910f
merge: resolve conflict
Benjamin-eecs Aug 3, 2022
8fe6dfc
Merge remote-tracking branch 'origin' into torchrl_new
Benjamin-eecs Aug 3, 2022
9c6701b
feat(examples): add torchrl MAML-RL results
Benjamin-eecs Aug 3, 2022
30dd7f7
Merge remote-tracking branch 'origin' into torchrl_new
Benjamin-eecs Aug 6, 2022
609df80
fix: update torchrl inner lr to 0.1
Benjamin-eecs Aug 6, 2022
d25a8a3
fix(examples): add device flag for torchrl MAML-RL example
Benjamin-eecs Aug 6, 2022
3f9c8fa
fix(example): update torchrl MAML performance with lr=0.1
Benjamin-eecs Aug 6, 2022
a6f75c2
fix(example): add plot.py for debugging
Benjamin-eecs Aug 6, 2022
d5cfe6b
independent seed for each env
vmoens Aug 6, 2022
53ad3d2
update seeding
vmoens Aug 7, 2022
739544b
docs: update contributing.rst
XuehaiPan Aug 11, 2022
ca6c44d
Merge remote-tracking branch 'origin' into torchrl_new
Benjamin-eecs Aug 11, 2022
8f2de3f
merge: sync with upstream
Benjamin-eecs Aug 11, 2022
fa1856c
fix: revert back to original torchrl results
Benjamin-eecs Aug 11, 2022
f66f717
Merge branch 'main' into torchrl_new
XuehaiPan Aug 11, 2022
b9c79c7
style: reformat code
XuehaiPan Aug 11, 2022
e707135
Merge branch 'main' into torchrl_new
XuehaiPan Aug 11, 2022
d0da29d
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Aug 11, 2022
e1b25b6
merge: resolve conflicts
Benjamin-eecs Aug 29, 2022
52eb07e
merge: resolve conflicts
Benjamin-eecs Aug 29, 2022
b03424b
fix: update gae
Benjamin-eecs Aug 29, 2022
cf66401
merge: resolve conflicts
Benjamin-eecs Aug 29, 2022
88b98cf
fix: torchrl MAML-RL extract_state_dict
Benjamin-eecs Aug 29, 2022
00878da
fix: delete log
Benjamin-eecs Aug 29, 2022
d726f62
fix: correct exploration mode
Benjamin-eecs Aug 30, 2022
69f71b7
fix: correct exploration mode
Benjamin-eecs Aug 30, 2022
c3e9286
revert: exploration mode
Benjamin-eecs Aug 31, 2022
89e8a83
fix: update torchrl inner lr
Benjamin-eecs Aug 31, 2022
6612d2a
fix: update torchrl inner lr
Benjamin-eecs Aug 31, 2022
a5924af
fix: update torchrl MAML-RL results
Benjamin-eecs Sep 1, 2022
c5ed460
fix: update torchrl MAML-RL results
Benjamin-eecs Sep 1, 2022
6e77107
fix: update torchrl MAML-RL results
Benjamin-eecs Sep 1, 2022
c4c2295
fix: pass lint
Benjamin-eecs Sep 1, 2022
df429bc
fix: pass lint
Benjamin-eecs Sep 1, 2022
64cdaba
doc: update contributor list
Benjamin-eecs Sep 1, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add MAML example with TorchRL integration by [@vmoens](https://github.com/vmoens) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#12](https://github.com/metaopt/TorchOpt/pull/12).
- Add optional argument `params` to update function in gradient transformations by [@XuehaiPan](https://github.com/XuehaiPan) in [#65](https://github.com/metaopt/torchopt/pull/65).
- Add option `weight_decay` option to optimizers by [@XuehaiPan](https://github.com/XuehaiPan) in [#65](https://github.com/metaopt/torchopt/pull/65).
- Add option `maximize` option to optimizers by [@XuehaiPan](https://github.com/XuehaiPan) in [#64](https://github.com/metaopt/torchopt/pull/64).
Expand Down
1 change: 1 addition & 0 deletions docs/source/developer/contributor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ Contributor
We always welcome contributions to help make TorchOpt better. Below is an incomplete list of our contributors (find more on `this page <https://github.com/metaopt/torchopt/graphs/contributors>`_).

* Yao Fu (`future-xy <https://github.com/future-xy>`_)
* Vincent Moens (`vmoens <https://github.com/vmoens>`_)
2 changes: 2 additions & 0 deletions examples/MAML-RL/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Specify the seed to train.
```bash
### Run MAML
python maml.py --seed 1
### Run torchrl MAML implementation
python maml_torchrl.py --seed 1
```

## Results
Expand Down
9 changes: 2 additions & 7 deletions examples/MAML-RL/helpers/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# https://github.com/tristandeleu/pytorch-maml-rl
# ==============================================================================

import torch
import torch.nn as nn
from torch.distributions import Categorical

Expand All @@ -27,12 +26,8 @@ class CategoricalMLPPolicy(nn.Module):
with discrete action spaces (eg. `TabularMDPEnv`).
"""

def __init__(
self,
input_size,
output_size,
):
super(CategoricalMLPPolicy, self).__init__()
def __init__(self, input_size, output_size):
super().__init__()
self.torso = nn.Sequential(
nn.Linear(input_size, 32),
nn.ReLU(),
Expand Down
88 changes: 88 additions & 0 deletions examples/MAML-RL/helpers/policy_torchrl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import torch
import torch.nn as nn
from torch.distributions import Categorical
from torchrl.modules import (
ActorValueOperator,
OneHotCategorical,
ProbabilisticActor,
TensorDictModule,
ValueOperator,
)


class Backbone(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.torso = nn.Sequential(
nn.Linear(input_size, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
)

def forward(self, inputs, params=None):
embedding = self.torso(inputs)
return embedding


class CategoricalSubNet(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.policy_head = nn.Linear(32, output_size)

def forward(self, embedding, params=None):
logits = self.policy_head(embedding)
return logits


class ValueSubNet(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.value_head = nn.Linear(32, 1)

def forward(self, embedding, params=None):
value = self.value_head(embedding)
return value


class ActorCritic(ActorValueOperator):
def __init__(self, input_size, output_size):
super().__init__(
TensorDictModule(
spec=None,
module=Backbone(input_size, output_size),
in_keys=['observation'],
out_keys=['hidden'],
),
ProbabilisticActor(
spec=None,
module=TensorDictModule(
CategoricalSubNet(input_size, output_size),
in_keys=['hidden'],
out_keys=['logits'],
),
distribution_class=OneHotCategorical,
return_log_prob=False,
dist_param_keys=['logits'],
out_key_sample=['action'],
),
ValueOperator(
module=ValueSubNet(input_size, output_size),
in_keys=['hidden'],
),
)
7 changes: 5 additions & 2 deletions examples/MAML-RL/helpers/tabular_mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ class TabularMDPEnv(gym.Env):
Learning", 2016 (https://arxiv.org/abs/1611.02779)
"""

def __init__(self, num_states, num_actions, max_episode_steps, seed, task={}):
super(TabularMDPEnv, self).__init__()
def __init__(self, num_states, num_actions, max_episode_steps, seed, task=None):
super().__init__()

task = task or {}

self.max_episode_steps = max_episode_steps
self.num_states = num_states
self.num_actions = num_actions
Expand Down
Binary file modified examples/MAML-RL/maml.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/MAML-RL/maml_torchrl.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading