Skip to content

Commit

Permalink
Make MultiOptimizer serializable
Browse files Browse the repository at this point in the history
The current implementation of MultiOptimizer puts grads and vars
inside the optimizer_specs variable. Since the grads and vars are
not serializable, any model that uses MultiOptimizer class cannot
be saved or checkpointed. This PR will only extract the necessary
optimizer_specs that are useful for re-instantiating the
MultiOptimizer instance in get_config() method, leading to a
serializable MultiOptimizer implementation.
  • Loading branch information
JackWindows committed May 31, 2022
1 parent 3bbf711 commit 2529b5c
Showing 1 changed file with 7 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,13 @@ def apply_gradients(self, grads_and_vars, **kwargs):

def get_config(self):
config = super(MultiOptimizer, self).get_config()
config.update({"optimizer_specs": self.optimizer_specs})
optimizer_specs_without_gv = []
for optimizer_spec in self.optimizer_specs:
optimizer_specs_without_gv.append({
"optimizer": optimizer_spec["optimizer"],
"weights": optimizer_spec["weights"]
})
config.update({"optimizer_specs": optimizer_specs_without_gv})
return config

@classmethod
Expand Down

0 comments on commit 2529b5c

Please sign in to comment.