88from typing import Callable , Dict , List , NamedTuple , Optional , Set
99import torch
1010
11- __all__ = ['Match' , 'replace_pattern' , 'replace_pattern_with_filters ' ]
11+ __all__ = ['Match' , 'replace_pattern' , 'replace_pattern_with_filter ' ]
1212
1313@compatibility (is_backward_compatible = True )
1414class Match (NamedTuple ):
@@ -185,11 +185,11 @@ def forward(self, x, w1, w2):
185185
186186# Experimental API, not backward compatible
187187@compatibility (is_backward_compatible = False )
188- def replace_pattern_with_filters (
188+ def replace_pattern_with_filter (
189189 gm : GraphModule ,
190190 pattern : Callable ,
191191 replacement : Callable ,
192- match_filters : List [ Callable [["InternalMatch" , Graph , Graph ], bool ] ], # type: ignore[name-defined]
192+ match_filter : Callable [["InternalMatch" , Graph , Graph ], bool ], # type: ignore[name-defined]
193193) -> List [Match ]:
194194 """
195195 See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
@@ -200,21 +200,18 @@ def replace_pattern_with_filters(
200200 definition of InternalMatch.
201201 """
202202
203- return _replace_pattern (gm , pattern , replacement , match_filters )
203+ return _replace_pattern (gm , pattern , replacement , match_filter )
204204
205205
206206def _replace_pattern (
207207 gm : GraphModule ,
208208 pattern : Callable ,
209209 replacement : Callable ,
210- match_filters : List [Callable [["InternalMatch" , Graph , Graph ], bool ]] = None # type: ignore[name-defined]
210+ match_filter : Optional [Callable [["InternalMatch" , Graph , Graph ], bool ]] = None # type: ignore[name-defined]
211211) -> List [Match ]:
212212
213213 from torch .fx .passes .utils .matcher_utils import SubgraphMatcher , InternalMatch
214214
215- if match_filters is None :
216- match_filters = []
217-
218215 # Get the graphs for `gm`, `pattern`, `replacement`
219216 original_graph : Graph = gm .graph
220217 pattern_graph : Graph = symbolic_trace (pattern ).graph
@@ -225,11 +222,8 @@ def _replace_pattern(
225222 _matches : List [InternalMatch ] = matcher .match (original_graph )
226223
227224 # Filter out matches that don't match the filter
228- _matches = [
229- m for m in _matches
230- if all (match_filter (m , original_graph , pattern_graph )
231- for match_filter in match_filters )
232- ]
225+ if match_filter :
226+ _matches = [m for m in _matches if match_filter (m , original_graph , pattern_graph )]
233227
234228 replacement_placeholders = [n for n in replacement_graph .nodes if n .op == "placeholder" ]
235229
0 commit comments