[MPS][Inductor] Add _print_Where method to MetalExprPrinter for ReflectionPad support#169648
[MPS][Inductor] Add _print_Where method to MetalExprPrinter for ReflectionPad support#169648lingebeng wants to merge 2 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/169648
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 32c113e with merge base 7375582 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@lingebeng Let's wait for CI, but something tells me you'll have to uncomment one of |
|
@pytorchbot merge -f "Lint + MPS are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ctionPad support (pytorch#169648) The `MetalExprPrinter` class was missing a `_print_Where()` method, causing sympy `Where` expressions to be emitted as raw `Where(...)` function calls in generated Metal Shading Language (MSL) code. Since `Where` is not a standard MSL function,compilation failed with "undeclared identifier 'Where'" error. This PR adds the missing `_print_Where()` method to convert `Where(condition, true_val, false_val)` to Metal's C-style ternary operator `condition ? true_val : false_val`, following the same pattern as `CppPrinter`. Fixes pytorch#169643 Pull Request resolved: pytorch#169648 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
…ctionPad support (#169648) The `MetalExprPrinter` class was missing a `_print_Where()` method, causing sympy `Where` expressions to be emitted as raw `Where(...)` function calls in generated Metal Shading Language (MSL) code. Since `Where` is not a standard MSL function,compilation failed with "undeclared identifier 'Where'" error. This PR adds the missing `_print_Where()` method to convert `Where(condition, true_val, false_val)` to Metal's C-style ternary operator `condition ? true_val : false_val`, following the same pattern as `CppPrinter`. Fixes #169643 Pull Request resolved: #169648 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
The
MetalExprPrinterclass was missing a_print_Where()method, causing sympyWhereexpressions to be emitted as rawWhere(...)function calls in generated Metal Shading Language (MSL) code. SinceWhereis not a standard MSL function,compilation failed with "undeclared identifier 'Where'" error.This PR adds the missing
_print_Where()method to convertWhere(condition, true_val, false_val)to Metal's C-style ternary operatorcondition ? true_val : false_val, following the same pattern asCppPrinter.Fixes #169643
cc @kulinseth @malfet @DenisVieriu97 @jhavukainen @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo