Skip to content

Commit

Permalink
Dont pairwise check unfusable nodes in scheduler (pytorch#136682)
Browse files Browse the repository at this point in the history
Gives 8% wall time speedup on n=1000 benchmark in pytorch#136429

Pull Request resolved: pytorch#136682
Approved by: https://github.com/ezyang, https://github.com/jansel, https://github.com/shunting314
  • Loading branch information
eellison authored and pytorchmergebot committed Sep 26, 2024
1 parent 0b62ebf commit aa56f80
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2672,6 +2672,8 @@ def check_all_pairs(nodes: List[BaseSchedulerNode]) -> None:

buffer_names_grouping = collections.defaultdict(list)
for node in nodes:
if self.unfusable_node(node):
continue
for buf in node.used_buffer_names():
buffer_names_grouping[buf].append(node)
for node_grouping in buffer_names_grouping.values():
Expand Down Expand Up @@ -2895,6 +2897,15 @@ def has_shared_data_after_reordering_loop(

return self.score_fusion_memory(node1, node2) > 0

def unfusable_node(self, node: BaseSchedulerNode) -> bool:
"""
Is this node unfusable under any conditions.
"""
return (
isinstance(node, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
and not node.is_template()
)

def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool:
"""
Determine if it is possible to combine node1 and node2 into a
Expand Down

0 comments on commit aa56f80

Please sign in to comment.