Skip to content

Commit

Permalink
Model Speedup Refactor (#3462)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-ningxin committed Jul 14, 2021
1 parent 5b99b59 commit 7eedec4
Show file tree
Hide file tree
Showing 13 changed files with 2,218 additions and 1,701 deletions.
3 changes: 0 additions & 3 deletions docs/en_US/Compression/CompressionReference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,6 @@ Topology Utilities
.. autoclass:: nni.compression.pytorch.utils.shape_dependency.GroupDependency
:members:

.. autoclass:: nni.compression.pytorch.utils.mask_conflict.CatMaskPadding
:members:

.. autoclass:: nni.compression.pytorch.utils.mask_conflict.GroupMaskConflict
:members:

Expand Down
14 changes: 12 additions & 2 deletions nni/common/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ def __init__(self, model=None, dummy_input=None, traced_model=None):
def _trace(self, model, dummy_input):
training = model.training
model.eval()
self.trace = torch.jit.trace(model, dummy_input)
kw_args = {}
if torch.__version__ >= '1.6.0':
# only pytorch with version greater than 1.6.0 has the strict option
kw_args['strict'] = False
self.trace = torch.jit.trace(model, dummy_input, **kw_args)
torch._C._jit_pass_inline(self.trace.graph)
model.train(training)

Expand Down Expand Up @@ -247,6 +251,7 @@ class TorchModuleGraph(TorchGraph):
def __init__(self, model=None, dummy_input=None, traced_model=None):
super().__init__(model, dummy_input, traced_model)
self.global_count = 0
self.reused_module = set()
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph()
self._extract_auxiliary_info()

Expand Down Expand Up @@ -390,9 +395,12 @@ def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
outputs.append(output_name)
else:
outputs.append(output_name)
unique_outputs = list(set(outputs))
# remove the dumplicated output names
unique_outputs.sort(key=outputs.index)

nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=list(inputs), outputs=list(outputs))
node_group, inputs=list(inputs), outputs=unique_outputs)
return nodepy

def _extract_cat_info(self, node_group, cpp_node):
Expand Down Expand Up @@ -724,6 +732,8 @@ def _build_graph(self):
unique_name = module_name
if use_count > 0:
unique_name = module_name + '.%d' % use_count
self.reused_module.add(unique_name)
self.reused_module.add(module_name)
node_group = self._expand_module_node(
node, module_name, unique_name, module_to_type[module_name],
node_cpps, input_to_node, output_to_node, 'module')
Expand Down
Loading

0 comments on commit 7eedec4

Please sign in to comment.