Skip to content

Commit 2ab3292

Browse files
committed
PatternMatcher supports matching list-typed args
[ghstack-poisoned]
1 parent 0efd4e9 commit 2ab3292

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

test/fx/test_subgraph_rewriter.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,3 +840,24 @@ def num_repalcement_node_found(traced):
840840
[second_input_is_scalar])
841841
self.assertEqual(len(matches), 1)
842842
self.assertEqual(num_repalcement_node_found(traced), 1)
843+
844+
def test_matching_pattern_with_list_type_arg(self):
845+
class M(torch.nn.Module):
846+
def forward(self, x):
847+
return torch.ops.aten._reshape_alias_copy.default(x, [1, 2], [3, 4])
848+
849+
def pattern(x, arg0, arg1):
850+
return torch.ops.aten._reshape_alias_copy.default(x, arg0, arg1)
851+
852+
def replacement(x, arg0, arg1):
853+
return torch.ops.aten._reshape_alias_copy.default(x, arg1, arg0)
854+
855+
traced = symbolic_trace(M())
856+
matches = subgraph_rewriter.replace_pattern(traced, pattern, replacement)
857+
858+
self.assertEqual(len(matches), 1)
859+
860+
self.assertExpectedInline(traced.code.strip(), """\
861+
def forward(self, x):
862+
_reshape_alias_copy_default_1 = torch.ops.aten._reshape_alias_copy.default(x, [3, 4], [1, 2]); x = None
863+
return _reshape_alias_copy_default_1""") # noqa: B950

torch/fx/passes/utils/matcher_utils.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from torch.fx.graph import Graph
55
from torch.fx.node import Node
66
from torch.fx._compatibility import compatibility
7-
import torch.utils._pytree as pytree
87
from typing import Dict, List, Set, Any
98
import logging
109
import os
@@ -106,12 +105,11 @@ def _nodes_are_equal(self, pn: Node, gn: Node) -> bool:
106105
def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool:
107106
# `lookup` represents all the nodes in `original_graph`
108107
# that are part of `pattern`
109-
lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items()}
110-
for gn, pn in lookup.items():
111-
# Placeholders can be used by other nodes in the graphs
112-
if pn.op == "placeholder":
113-
continue
114108

109+
# Placeholders can be used by other nodes in the graphs
110+
lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items() if pn.op != "placeholder"}
111+
112+
for gn, pn in lookup.items():
115113
# nodes returned by output are allowed to be used in other areas of the graph
116114
if pn in self.pattern_returning_nodes:
117115
continue
@@ -188,8 +186,20 @@ def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool:
188186
# match for `gn`
189187
match_found = True
190188

191-
pn_flatten_args, _ = pytree.tree_flatten(pn.args)
192-
gn_flatten_args, _ = pytree.tree_flatten(gn.args)
189+
def flatten_args(args) -> List[Any]:
190+
# Recursively flatten args
191+
result : List[Any] = []
192+
for arg in args:
193+
# flatten the list, if only it's a list/tuple of nodes
194+
if isinstance(arg, (list, tuple)) and len(arg) > 0 and isinstance(arg[0], Node):
195+
result.extend(flatten_args(arg))
196+
else:
197+
result.append(arg)
198+
199+
return result
200+
201+
pn_flatten_args = flatten_args(pn.args)
202+
gn_flatten_args = flatten_args(gn.args)
193203

194204
if pn.kwargs.keys() == gn.kwargs.keys():
195205
for key in pn.kwargs.keys():

torch/fx/subgraph_rewriter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,10 @@ def _replace_pattern(
245245
assert len(match.placeholder_nodes) == len(replacement_placeholders)
246246
val_map: Dict[Node, Node] = {}
247247
for rn, gn in zip(replacement_placeholders, match.placeholder_nodes):
248-
val_map[rn] = match_changed_node.get(gn, gn)
248+
if isinstance(gn, Node):
249+
val_map[rn] = match_changed_node.get(gn, gn)
250+
else:
251+
val_map[rn] = gn
249252

250253
# Copy the replacement graph over
251254
user_nodes: Set[Node] = set()

0 commit comments

Comments
 (0)