Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions test/fx/test_subgraph_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,3 +840,24 @@ def num_repalcement_node_found(traced):
[second_input_is_scalar])
self.assertEqual(len(matches), 1)
self.assertEqual(num_repalcement_node_found(traced), 1)

def test_matching_pattern_with_list_type_arg(self):
class M(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._reshape_alias_copy.default(x, [1, 2], [3, 4])

def pattern(x, arg0, arg1):
return torch.ops.aten._reshape_alias_copy.default(x, arg0, arg1)

def replacement(x, arg0, arg1):
return torch.ops.aten._reshape_alias_copy.default(x, arg1, arg0)

traced = symbolic_trace(M())
matches = subgraph_rewriter.replace_pattern(traced, pattern, replacement)

self.assertEqual(len(matches), 1)

self.assertExpectedInline(traced.code.strip(), """\
def forward(self, x):
_reshape_alias_copy_default_1 = torch.ops.aten._reshape_alias_copy.default(x, [3, 4], [1, 2]); x = None
return _reshape_alias_copy_default_1""") # noqa: B950
26 changes: 18 additions & 8 deletions torch/fx/passes/utils/matcher_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torch.fx.graph import Graph
from torch.fx.node import Node
from torch.fx._compatibility import compatibility
import torch.utils._pytree as pytree
from typing import Dict, List, Set, Any
import logging
import os
Expand Down Expand Up @@ -106,12 +105,11 @@ def _nodes_are_equal(self, pn: Node, gn: Node) -> bool:
def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool:
# `lookup` represents all the nodes in `original_graph`
# that are part of `pattern`
lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items()}
for gn, pn in lookup.items():
# Placeholders can be used by other nodes in the graphs
if pn.op == "placeholder":
continue

# Placeholders can be used by other nodes in the graphs
lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items() if pn.op != "placeholder"}

for gn, pn in lookup.items():
# nodes returned by output are allowed to be used in other areas of the graph
if pn in self.pattern_returning_nodes:
continue
Expand Down Expand Up @@ -188,8 +186,20 @@ def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool:
# match for `gn`
match_found = True

pn_flatten_args, _ = pytree.tree_flatten(pn.args)
gn_flatten_args, _ = pytree.tree_flatten(gn.args)
def flatten_args(args) -> List[Any]:
# Recursively flatten args
result : List[Any] = []
for arg in args:
# flatten the list, if only it's a list/tuple of nodes
if isinstance(arg, (list, tuple)) and len(arg) > 0 and isinstance(arg[0], Node):
result.extend(flatten_args(arg))
else:
result.append(arg)

return result

pn_flatten_args = flatten_args(pn.args)
gn_flatten_args = flatten_args(gn.args)

if pn.kwargs.keys() == gn.kwargs.keys():
for key in pn.kwargs.keys():
Expand Down
5 changes: 4 additions & 1 deletion torch/fx/subgraph_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,10 @@ def _replace_pattern(
assert len(match.placeholder_nodes) == len(replacement_placeholders)
val_map: Dict[Node, Node] = {}
for rn, gn in zip(replacement_placeholders, match.placeholder_nodes):
val_map[rn] = match_changed_node.get(gn, gn)
if isinstance(gn, Node):
val_map[rn] = match_changed_node.get(gn, gn)
else:
val_map[rn] = gn

# Copy the replacement graph over
user_nodes: Set[Node] = set()
Expand Down