-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Introduce a match filter for SubgraphRewriter #86430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86430
Note: Links to docs will display an error until the docs builds have been completed. ✅ No Failures, 1 PendingAs of commit d88b8c8: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
||
| # Filter out matches that don't match the filter | ||
| if match_filter: | ||
| _matches = [m for m in _matches if match_filter(m, original_graph, pattern_graph)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need original_graph? I don't know how original_graph will be helpful in filtering the match since it won't know where to look
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to keep the function signature general, so that I won't need to come back to this and add another field.
My guess is that original_graph might be useful for filtering out the matches, if it need to check the surrounding nodes of the match.
torch/fx/subgraph_rewriter.py
Outdated
|
|
||
|
|
||
| def _replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable, | ||
| match_filter: Optional[Callable[["InternalMatch", Graph, Graph], bool]] = None) -> List[Match]: # type: ignore[name-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please fix the lint
jerryzh168
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! lg overall, had a question about argument for filter
could you add a Summary and Test for the PR, would be useful for context and future reference, something like this: #86338 (comment) |
[ghstack-poisoned]
Updated. Thanks for reviewing this! |
|
@pytorchbot merge -g |
Merge startedYour change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Hey @SherlockNoMad. |
Summary: This PR introduces an interface for user defined function that filters the matches in SubgraphRewriter. The function will have the following signature. callable(match: InternalMatch, original_graph: Graph, pattern_graph: Graph) -> bool This filter is applied after SubgraphMatcher returns the matches, and before replacement takes place. Pull Request resolved: #86430 Approved by: https://github.com/jerryzh168 Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/455b873919d928a073eb2d60e07d1c5b2de2d6c6 Reviewed By: seemethere Differential Revision: D40196820 Pulled By: seemethere fbshipit-source-id: 3c854843094f79c6b23501b513fa029e5adf60a8
Stack from ghstack (oldest at bottom):
This PR introduces an interface for user defined function that filters the matches in SubgraphRewriter. The function will have the following signature.
callable(match: InternalMatch, original_graph: Graph, pattern_graph: Graph) -> bool
This filter is applied after SubgraphMatcher returns the matches, and before replacement takes place.