-
Notifications
You must be signed in to change notification settings - Fork 9.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Adding Distributed RL Model Example #692
base: main
Are you sure you want to change the base?
Conversation
In this example, the RL model is distributed across one agent and multiple observers. Each observer has a replicated submodel of the RL model which all connect to the submodel on the agent. During training, this example uses distributed autograd to set gradients for all submodels. Then, it uses RPC calls to collect gradients from all observers to the agent, sums those gradients, applies the gradient to the local dummy model on agent, and then broadcast the model parameters back to the observers to update their models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @jiajunshen
This is a quick example showing how to use RPC to build distributed RL models. I didn't aim for efficiency or code simplicity in this WIP version. But hopefully, it could demonstrate the ideas.
x = self.affine1(x) | ||
x = self.dropout(x) | ||
x = F.relu(x) | ||
return self.affine2(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
each observer applies four layers in the forward pass
self.rewards = [] | ||
|
||
def forward(self, action_scores): | ||
return F.softmax(action_scores, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
each agent only applies a softmax.
self.agent_rref = RRef(self) | ||
self.rewards = {} | ||
self.saved_log_probs = {} | ||
self.ob_policy = ObserverPolicy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The agent also creates a dummy ObserverPolicy so that it can use the same optimizer to update all model parameters. Note that this ObserverPolicy never participates in the forward or backward pass. It is only used to apply the summed gradients.
|
||
grads = [fut.wait() for fut in futs] | ||
grads = [*zip(*grads)] | ||
grads = [sum(grad) for grad in grads] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the above few lines just sums the grads from all observers.
# set grads for agent model | ||
ctx_grads = dist_autograd.get_gradients(ctx_id) | ||
for p in self.policy.parameters(): | ||
p.grad = ctx_grads[p] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then, we set the grad field for all parameters. (note that the ob_policy is a dummy model, which is only useful for grad updates)
for ob_rref in self.ob_rrefs: | ||
futs.append(_async_remote_method(Observer.update_model, ob_rref, ob_params)) | ||
for fut in futs: | ||
fut.wait() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
params on both the dummy model and the AgentPolicy are updated. Now broadcast the dummy model params to all observers to perform updates there.
for fut in futs: | ||
fut.wait() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the above few steps can be replaced by c10d gather/scatter, and wrapped into your own version of distributed optimizer. In that way, it will be more efficient and look better.
self.policy = ObserverPolicy() | ||
|
||
def get_gradients(self, ctx_id): | ||
all_grads = dist_autograd.get_gradients(ctx_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After the distributed backward pass, the grads for the ObserverPolicy
model lives in dist_autograd.get_gradients(ctx_id)
on the observer. Retrieve them and give them to the agent.
In this example, the RL model is distributed across one agent and
multiple observers. Each observer has a replicated submodel of
the RL model which all connect to the submodel on the agent.
During training, this example uses distributed autograd to set
gradients for all submodels. Then, it uses RPC calls to collect
gradients from all observers to the agent, sums those gradients,
applies the gradient to the local dummy model on agent, and then
broadcast the model parameters back to the observers to update
their models.