Skip to content

Commit d13b678

Browse files
Revert "[fx][subgraph_rewriter] Change match_filter to be a List in replace_pattern_with_filters (#87257)"
This reverts commit 5865083. Reverted #87257 on behalf of https://github.com/weiwangmeta due to breaking internal builds/BC-breaking change
1 parent fc21b9d commit d13b678

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

test/fx/test_subgraph_rewriter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ def gemm_bias_mul_replacement_with_c(a, b, bias, c):
773773

774774
self.assertEqual(repalcement_node_found, 2)
775775

776-
def test_replace_pattern_with_filters(self):
776+
def test_replace_pattern_with_filter(self):
777777
class M(torch.nn.Module):
778778
def __init__(self):
779779
super().__init__()
@@ -833,10 +833,10 @@ def num_repalcement_node_found(traced):
833833

834834
# match with filter, should find 1 match
835835
traced = symbolic_trace(M())
836-
matches = subgraph_rewriter.replace_pattern_with_filters(
836+
matches = subgraph_rewriter.replace_pattern_with_filter(
837837
traced,
838838
BinaryOpScalarReLUPattern,
839839
BinaryOpScalarReLUReplacement,
840-
[second_input_is_scalar])
840+
second_input_is_scalar)
841841
self.assertEqual(len(matches), 1)
842842
self.assertEqual(num_repalcement_node_found(traced), 1)

torch/fx/subgraph_rewriter.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Callable, Dict, List, NamedTuple, Optional, Set
99
import torch
1010

11-
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters']
11+
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filter']
1212

1313
@compatibility(is_backward_compatible=True)
1414
class Match(NamedTuple):
@@ -185,11 +185,11 @@ def forward(self, x, w1, w2):
185185

186186
# Experimental API, not backward compatible
187187
@compatibility(is_backward_compatible=False)
188-
def replace_pattern_with_filters(
188+
def replace_pattern_with_filter(
189189
gm: GraphModule,
190190
pattern: Callable,
191191
replacement: Callable,
192-
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]], # type: ignore[name-defined]
192+
match_filter: Callable[["InternalMatch", Graph, Graph], bool], # type: ignore[name-defined]
193193
) -> List[Match]:
194194
"""
195195
See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
@@ -200,21 +200,18 @@ def replace_pattern_with_filters(
200200
definition of InternalMatch.
201201
"""
202202

203-
return _replace_pattern(gm, pattern, replacement, match_filters)
203+
return _replace_pattern(gm, pattern, replacement, match_filter)
204204

205205

206206
def _replace_pattern(
207207
gm: GraphModule,
208208
pattern: Callable,
209209
replacement: Callable,
210-
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None # type: ignore[name-defined]
210+
match_filter: Optional[Callable[["InternalMatch", Graph, Graph], bool]] = None # type: ignore[name-defined]
211211
) -> List[Match]:
212212

213213
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch
214214

215-
if match_filters is None:
216-
match_filters = []
217-
218215
# Get the graphs for `gm`, `pattern`, `replacement`
219216
original_graph: Graph = gm.graph
220217
pattern_graph: Graph = symbolic_trace(pattern).graph
@@ -225,11 +222,8 @@ def _replace_pattern(
225222
_matches: List[InternalMatch] = matcher.match(original_graph)
226223

227224
# Filter out matches that don't match the filter
228-
_matches = [
229-
m for m in _matches
230-
if all(match_filter(m, original_graph, pattern_graph)
231-
for match_filter in match_filters)
232-
]
225+
if match_filter:
226+
_matches = [m for m in _matches if match_filter(m, original_graph, pattern_graph)]
233227

234228
replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"]
235229

0 commit comments

Comments
 (0)