Add a compile-time flag to trigger verbose logging for device-side asserts#166171
Add a compile-time flag to trigger verbose logging for device-side asserts#166171drdarshan wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166171
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f2d7805 with merge base 2df2c31 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@drdarshan has exported this pull request. If you are a Meta employee, you can view the originating Diff in D85185987. |
|
+1 from me, based on our extensive discussions about this. Just 1 minor suggestion. |
|
@pytorchbot merge |
Merge failedReason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR! Details for Dev Infra teamRaised by workflow job |
…serts (pytorch#166171) Summary: Using `CUDA_KERNEL_ASSERT_PRINTF` inside kernels allows us to log invalid values to the console (that can be in turn used to surface _hopefully_ more clearer error messages). This does have an impact in the number of registers needed for the values being logged (I confirmed via diffing PTX that there is no other impact relative to using `__assert_fail`) To avoid causing perf bottlenecks, this change adds a compile-time switch to enable more verbose errors in some of the common kernels that cause DSAs. There is also a Buck config that can be used to configure this switch more conveniently. ## Alternatives considered I considered making the behavior of `CUDA_KERNEL_ASSERT_PRINTF` controllable via a compile-time macro instead of writing another wrapper for it but there are kernels where the extra register pressure is not as severe and in those cases, having more useful error messages by default is pretty useful. Test Plan: ## Simple Python Driver: ``` # scatter_errors.py import torch def main() -> None: a = torch.rand(128, device="cuda:0") idx = torch.randint(0, 128, (100,), device="cuda:0") idx[0] = 9999 b = torch.scatter(a, 0, idx, 555.0) print(b) ``` When running normally via: ``` $ buck2 run @//mode/opt :scatter_errors ``` we see the followng DSA message: ``` fbcode/caffe2/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:410: operator(): block: [0,0,0], thread: [0,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ``` Running via: ``` $ buck2 run @//mode/opt -c fbcode.c10_enable_verbose_assert=1 :scatter_errors ``` however produces: ``` [CUDA_KERNEL_ASSERT] fbcode/caffe2/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:410: operator(): block: [0,0,0], thread: [0,0,0]: Assertion failed: `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"`: Expected 0 <= idx_dim < index_size (128), but got idx_dim = 9999 ``` Reviewed By: ngimel Differential Revision: D85185987
2842e3e to
f2d7805
Compare
|
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…serts (#166171) Summary: Using `CUDA_KERNEL_ASSERT_PRINTF` inside kernels allows us to log invalid values to the console (that can be in turn used to surface _hopefully_ more clearer error messages). This does have an impact in the number of registers needed for the values being logged (I confirmed via diffing PTX that there is no other impact relative to using `__assert_fail`) To avoid causing perf bottlenecks, this change adds a compile-time switch to enable more verbose errors in some of the common kernels that cause DSAs. There is also a Buck config that can be used to configure this switch more conveniently. ## Alternatives considered I considered making the behavior of `CUDA_KERNEL_ASSERT_PRINTF` controllable via a compile-time macro instead of writing another wrapper for it but there are kernels where the extra register pressure is not as severe and in those cases, having more useful error messages by default is pretty useful. Test Plan: ## Simple Python Driver: ``` # scatter_errors.py import torch def main() -> None: a = torch.rand(128, device="cuda:0") idx = torch.randint(0, 128, (100,), device="cuda:0") idx[0] = 9999 b = torch.scatter(a, 0, idx, 555.0) print(b) ``` When running normally via: ``` $ buck2 run @//mode/opt :scatter_errors ``` we see the followng DSA message: ``` fbcode/caffe2/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:410: operator(): block: [0,0,0], thread: [0,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ``` Running via: ``` $ buck2 run @//mode/opt -c fbcode.c10_enable_verbose_assert=1 :scatter_errors ``` however produces: ``` [CUDA_KERNEL_ASSERT] fbcode/caffe2/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:410: operator(): block: [0,0,0], thread: [0,0,0]: Assertion failed: `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"`: Expected 0 <= idx_dim < index_size (128), but got idx_dim = 9999 ``` Differential Revision: D85185987 Pull Request resolved: #166171 Approved by: https://github.com/ngimel
…serts (pytorch#166171) Summary: Using `CUDA_KERNEL_ASSERT_PRINTF` inside kernels allows us to log invalid values to the console (that can be in turn used to surface _hopefully_ more clearer error messages). This does have an impact in the number of registers needed for the values being logged (I confirmed via diffing PTX that there is no other impact relative to using `__assert_fail`) To avoid causing perf bottlenecks, this change adds a compile-time switch to enable more verbose errors in some of the common kernels that cause DSAs. There is also a Buck config that can be used to configure this switch more conveniently. ## Alternatives considered I considered making the behavior of `CUDA_KERNEL_ASSERT_PRINTF` controllable via a compile-time macro instead of writing another wrapper for it but there are kernels where the extra register pressure is not as severe and in those cases, having more useful error messages by default is pretty useful. Test Plan: ## Simple Python Driver: ``` # scatter_errors.py import torch def main() -> None: a = torch.rand(128, device="cuda:0") idx = torch.randint(0, 128, (100,), device="cuda:0") idx[0] = 9999 b = torch.scatter(a, 0, idx, 555.0) print(b) ``` When running normally via: ``` $ buck2 run @//mode/opt :scatter_errors ``` we see the followng DSA message: ``` fbcode/caffe2/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:410: operator(): block: [0,0,0], thread: [0,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ``` Running via: ``` $ buck2 run @//mode/opt -c fbcode.c10_enable_verbose_assert=1 :scatter_errors ``` however produces: ``` [CUDA_KERNEL_ASSERT] fbcode/caffe2/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:410: operator(): block: [0,0,0], thread: [0,0,0]: Assertion failed: `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"`: Expected 0 <= idx_dim < index_size (128), but got idx_dim = 9999 ``` Differential Revision: D85185987 Pull Request resolved: pytorch#166171 Approved by: https://github.com/ngimel
Summary:
Using
CUDA_KERNEL_ASSERT_PRINTFinside kernels allows us to log invalid values to the console (that can be in turn used to surface hopefully more clearer error messages).This does have an impact in the number of registers needed for the values being logged (I confirmed via diffing PTX that there is no other impact relative to using
__assert_fail)To avoid causing perf bottlenecks, this change adds a compile-time switch to enable more verbose errors in some of the common kernels that cause DSAs. There is also a Buck config that can be used to configure this switch more conveniently.
Alternatives considered
I considered making the behavior of
CUDA_KERNEL_ASSERT_PRINTFcontrollable via a compile-time macro instead of writing another wrapper for it but there are kernels where the extra register pressure is not as severe and in those cases, having more useful error messages by default is pretty useful.Test Plan:
Simple Python Driver:
When running normally via:
we see the followng DSA message:
Running via:
however produces:
Differential Revision: D85185987