Skip to content

Commit

Permalink
fix: Bugfix for str_to_net(...) (#12)
Browse files Browse the repository at this point in the history
When extracting the child modules of an `nn.Sequential` instance, the parser used by `str_to_net(...)` was traversing in a recursive manner (which is undesired). The newly proposed code here uses the iterable nature of `nn.Sequential` to correctly extract the child modules of `nn.Sequential`.

When working with modules which are composed of other modules, `str_to_net(...)` was erroneously adding the children modules twice into the resulting `nn.Sequential`. For example, when the network string contains `RecurrentNet` (`RecurrentNet` being a module that wraps a child `RNN` module) the resulting `nn.Sequential` was ending up with the `RecurrentNet` _and_ its child `RNN` appended erroneously. This proposal aims to fix this issue.
  • Loading branch information
engintoklu committed Aug 31, 2022
1 parent c6471fc commit 6efb628
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/evotorch/neuroevolution/net/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __repr__(self) -> str:

def submodules(a: nn.Module) -> list:
if isinstance(a, nn.Sequential):
return [module for i, module in enumerate(a.modules()) if i >= 1]
return [module for module in a]
else:
return [a]

Expand Down

0 comments on commit 6efb628

Please sign in to comment.