Skip to content

Commit

Permalink
Fix issues in pytorch#132 (pytorch#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
jansel committed Apr 19, 2022
1 parent dedd9fa commit 222696b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
26 changes: 26 additions & 0 deletions tests/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,3 +925,29 @@ def test_reformer_sorting(self):
)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, ifdyn(28, 14))

def test_recursive_map(self):
# https://github.com/facebookresearch/torchdynamo/issues/132
def _recursive_map(struct, batch_dim=0):
for k, v in struct.items():
if v is not None:
if isinstance(v, dict):
_recursive_map(v)
else:
struct[k] = v

def toy_example(a, b, v):
x = a / (torch.abs(a) + 1)
if v is not None:
_recursive_map(v)
return x * b

cnt = torchdynamo.testing.CompileCounter()
with torchdynamo.optimize(cnt):
toy_example(
torch.randn(10),
torch.randn(10),
{"layer0": {"memory_keys": torch.randn(10)}},
)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 4)
8 changes: 5 additions & 3 deletions torchdynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def COMPARE_OP(self, inst):
BaseListVariable,
UserDefinedVariable,
BaseUserFunctionVariable,
ConstDictVariable,
),
)
and isinstance(right, ConstantVariable)
Expand Down Expand Up @@ -1031,7 +1032,9 @@ def __init__(

# TODO(jansel): figure out why the following is needed for detectron2_maskrcnn
for val in self.symbolic_locals.values():
if isinstance(val, (ListIteratorVariable, BaseListVariable)):
if isinstance(
val, (ListIteratorVariable, BaseListVariable, ConstDictVariable)
):
self.output.guards.update(val.guards)

self._freevars_ids = dict()
Expand Down Expand Up @@ -1160,6 +1163,7 @@ def inline_call_(parent, func, args, kwargs):
assert tracer.symbolic_result.as_python_constant() is None
return ListIteratorVariable(
tracer.generated_items,
mutable_local=MutableLocal(),
**VariableTracker.propagate(tracer.symbolic_result),
)
else:
Expand Down Expand Up @@ -1188,8 +1192,6 @@ def __init__(
)
self.symbolic_result = None
self.closure_cells = closure_cells
# self.funcvar = funcvar
# self.parent = parent

def STORE_DEREF(self, inst):
if inst.argval in self.closure_cells:
Expand Down

0 comments on commit 222696b

Please sign in to comment.