Skip to content

Conversation

@r-barnes
Copy link
Contributor

@r-barnes r-barnes commented Sep 6, 2022

Summary:
This diff introduces a set of changes that makes it possible for the host to get assertions from CUDA devices. This includes the introduction of

CUDA_KERNEL_ASSERT2

A preprocessor macro to be used within a CUDA kernel that, upon an assertion failure, writes the assertion message, file, line number, and possibly other information to UVM (Managed memory). Once this is done, the original assertion is triggered, which places the GPU in a Bad State requiring recovery. In my tests, data written to UVM appears there before the GPU reaches the Bad State and is still accessible from the host after the GPU is in this state.

Messages are written to a multi-message buffer which can, in theory, hold many assertion failures. I've done this as a precaution in case there are several, but I don't actually know whether that is possible and a simpler design which holds only a single message may well be all that is necessary.

TORCH_DSA_KERNEL_ARGS

This preprocess macro is added as an argument to a kernel function's signature. It expands to supply the standardized names of all the arguments needed by C10_CUDA_COMMUNICATING_KERNEL_ASSERTION to handle device-side assertions. This includes, eg, the name of the pointer to the UVM memory the assertion would be written to. This macro abstracts the arguments so there is a single point of change if the system needs to be modified.

c10::cuda::get_global_cuda_kernel_launch_registry()

This host-side function returns a singleton object that manages the host's part of the device-side assertions. Upon allocation, the singleton allocates sufficient UVM (Managed) memory to hold information about several device-side assertion failures. The singleton also provides methods for getting the current traceback (used to identify when a kernel was launched). To avoid consuming all the host's memory the singleton stores launches in a circular buffer; a unique "generation number" is used to ensure that kernel launch failures map to their actual launch points (in the case that the circular buffer wraps before the failure is detected).

TORCH_DSA_KERNEL_LAUNCH

This host-side preprocessor macro replaces the standard

kernel_name<<<blocks, threads, shmem, stream>>>(args)

invocation with

TORCH_DSA_KERNEL_LAUNCH(blocks, threads, shmem, stream, args);

Internally, it fetches the UVM (Managed) pointer and generation number from the singleton and append these to the standard argument list. It also checks to ensure the kernel launches correctly. This abstraction on kernel launches can be modified to provide additional safety/logging.

c10::cuda::c10_retrieve_device_side_assertion_info
This host-side function checks, when called, that no kernel assertions have occurred. If one has. It then raises an exception with:

  1. Information (file, line number) of what kernel was launched.
  2. Information (file, line number, message) about the device-side assertion
  3. Information (file, line number) about where the failure was detected.

Checking for device-side assertions

Device-side assertions are most likely to be noticed by the host when a CUDA API call such as cudaDeviceSynchronize is made and fails with a cudaError_t indicating

CUDA error: device-side assert triggered CUDA kernel errors

Therefore, we rewrite C10_CUDA_CHECK() to include a call to c10_retrieve_device_side_assertion_info(). To make the code cleaner, most of the logic of C10_CUDA_CHECK() is now contained within a new function c10_cuda_check_implementation() to which C10_CUDA_CHECK passes the preprocessor information about filenames, function names, and line numbers. (In C++20 we can use std::source_location to eliminate macros entirely!)

Notes on special cases

  • Multiple assertions from the same block are recorded
  • Multiple assertions from different blocks are recorded
  • Launching kernels from many threads on many streams seems to be handled correctly
  • If two process are using the same GPU and one of the processes fails with a device-side assertion the other process continues without issue
  • X Multiple assertions from separate kernels on different streams seem to be recorded, but we can't reproduce the test condition
  • X Multiple assertions from separate devices should be all be shown upon exit, but we've been unable to generate a test that produces this condition

Differential Revision: D37621532

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Sep 6, 2022

🔗 Helpful links

❌ 6 New Failures, 2 Base Failures

As of commit 0c1508614c (more details on the Dr. CI page):

Expand to see more
  • 6/8 failures introduced in this PR
  • 2/8 broken upstream at merge base 166dec7 on Sep 06 from 3:08pm to 6:23pm

🕵️ 6 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / win-vs2019-cuda11.6-py3 / build (1/6)

Step: "Build" (full log | diagnosis details)

2022-09-06T23:10:29.1742659Z FAILED: bin/c10_cuda.dll lib/c10_cuda.lib
2022-09-06T23:10:28.2144359Z [4534/6419] Linking CXX executable bin\c10_typeid_test.exe
2022-09-06T23:10:28.2554166Z [4535/6419] Running C++/Python protocol buffer compiler on C:/actions-runner/_work/pytorch/pytorch/caffe2/proto/caffe2.proto
2022-09-06T23:10:28.2678865Z [4536/6419] Generating ../../../torch/utils/data/datapipes/datapipe.pyi
2022-09-06T23:10:28.2679448Z Generating Python interface file 'datapipe.pyi'...
2022-09-06T23:10:28.3039669Z [4537/6419] Stringify NVFUSER runtime source file
2022-09-06T23:10:28.3768595Z [4538/6419] Linking CXX executable bin\c10_intrusive_ptr_benchmark.exe
2022-09-06T23:10:28.3906830Z [4539/6419] Stringify NVFUSER runtime source file
2022-09-06T23:10:28.6191543Z [4540/6419] Building CXX object caffe2\proto\CMakeFiles\Caffe2_PROTO.dir\torch.pb.cc.obj
2022-09-06T23:10:28.6299852Z [4541/6419] Building CXX object caffe2\proto\CMakeFiles\Caffe2_PROTO.dir\caffe2.pb.cc.obj
2022-09-06T23:10:29.1742346Z [4542/6419] Linking CXX shared library bin\c10_cuda.dll
2022-09-06T23:10:29.1742659Z FAILED: bin/c10_cuda.dll lib/c10_cuda.lib 
2022-09-06T23:10:29.1744225Z cmd.exe /C "cd . && C:\Jenkins\Miniconda3\Library\bin\cmake.exe -E vs_link_dll --intdir=c10\cuda\CMakeFiles\c10_cuda.dir --rc=C:\PROGRA~2\WI3CF2~1\10\bin\100190~1.0\x64\rc.exe --mt=C:\PROGRA~2\WI3CF2~1\10\bin\100190~1.0\x64\mt.exe --manifests  -- C:\PROGRA~2\MICROS~2\2019\BUILDT~1\VC\Tools\MSVC\1428~1.293\bin\Hostx64\x64\link.exe /nologo c10\cuda\CMakeFiles\c10_cuda.dir\CUDAStream.cpp.obj c10\cuda\CMakeFiles\c10_cuda.dir\CUDAFunctions.cpp.obj c10\cuda\CMakeFiles\c10_cuda.dir\CUDAMiscFunctions.cpp.obj c10\cuda\CMakeFiles\c10_cuda.dir\CUDACachingAllocator.cpp.obj c10\cuda\CMakeFiles\c10_cuda.dir\impl\CUDAGuardImpl.cpp.obj c10\cuda\CMakeFiles\c10_cuda.dir\impl\CUDATest.cpp.obj  /out:bin\c10_cuda.dll /implib:lib\c10_cuda.lib /pdb:bin\c10_cuda.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO  lib\c10.lib  "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.6\lib\x64\cudart_static.lib"  kernel32.lib user32.lib gdi32.lib winspool.lib shell32.lib ole32.lib oleaut32.lib uuid.lib comdlg32.lib advapi32.lib  && cd ."
2022-09-06T23:10:29.1747094Z LINK: command "C:\PROGRA~2\MICROS~2\2019\BUILDT~1\VC\Tools\MSVC\1428~1.293\bin\Hostx64\x64\link.exe /nologo c10\cuda\CMakeFiles\c10_cuda.dir\CUDAStream.cpp.obj c10\cuda\CMakeFiles\c10_cuda.dir\CUDAFunctions.cpp.obj c10\cuda\CMakeFiles\c10_cuda.dir\CUDAMiscFunctions.cpp.obj c10\cuda\CMakeFiles\c10_cuda.dir\CUDACachingAllocator.cpp.obj c10\cuda\CMakeFiles\c10_cuda.dir\impl\CUDAGuardImpl.cpp.obj c10\cuda\CMakeFiles\c10_cuda.dir\impl\CUDATest.cpp.obj /out:bin\c10_cuda.dll /implib:lib\c10_cuda.lib /pdb:bin\c10_cuda.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO lib\c10.lib C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.6\lib\x64\cudart_static.lib kernel32.lib user32.lib gdi32.lib winspool.lib shell32.lib ole32.lib oleaut32.lib uuid.lib comdlg32.lib advapi32.lib /MANIFEST /MANIFESTFILE:bin\c10_cuda.dll.manifest" failed (exit code 1120) with the following output:
2022-09-06T23:10:29.1748516Z    Creating library lib\c10_cuda.lib and object lib\c10_cuda.exp
2022-09-06T23:10:29.1749244Z CUDATest.cpp.obj : error LNK2001: unresolved external symbol "void __cdecl c10::cuda::c10_cuda_check_implementation(class std::basic_string<char,struct std::char_traits<char>,class std::allocator<char> >,class std::basic_string<char,struct std::char_traits<char>,class std::allocator<char> >,int,bool)" (?c10_cuda_check_implementation@cuda@c10@@YAXV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@0H_N@Z)
2022-09-06T23:10:29.1750293Z CUDAStream.cpp.obj : error LNK2001: unresolved external symbol "void __cdecl c10::cuda::c10_cuda_check_implementation(class std::basic_string<char,struct std::char_traits<char>,class std::allocator<char> >,class std::basic_string<char,struct std::char_traits<char>,class std::allocator<char> >,int,bool)" (?c10_cuda_check_implementation@cuda@c10@@YAXV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@0H_N@Z)
2022-09-06T23:10:29.1751340Z CUDAFunctions.cpp.obj : error LNK2001: unresolved external symbol "void __cdecl c10::cuda::c10_cuda_check_implementation(class std::basic_string<char,struct std::char_traits<char>,class std::allocator<char> >,class std::basic_string<char,struct std::char_traits<char>,class std::allocator<char> >,int,bool)" (?c10_cuda_check_implementation@cuda@c10@@YAXV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@0H_N@Z)
2022-09-06T23:10:29.1752430Z CUDACachingAllocator.cpp.obj : error LNK2001: unresolved external symbol "void __cdecl c10::cuda::c10_cuda_check_implementation(class std::basic_string<char,struct std::char_traits<char>,class std::allocator<char> >,class std::basic_string<char,struct std::char_traits<char>,class std::allocator<char> >,int,bool)" (?c10_cuda_check_implementation@cuda@c10@@YAXV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@0H_N@Z)
2022-09-06T23:10:29.1753501Z CUDAGuardImpl.cpp.obj : error LNK2001: unresolved external symbol "void __cdecl c10::cuda::c10_cuda_check_implementation(class std::basic_string<char,struct std::char_traits<char>,class std::allocator<char> >,class std::basic_string<char,struct std::char_traits<char>,class std::allocator<char> >,int,bool)" (?c10_cuda_check_implementation@cuda@c10@@YAXV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@0H_N@Z)
2022-09-06T23:10:29.1754248Z bin\c10_cuda.dll : fatal error LNK1120: 1 unresolved externals
2022-09-06T23:10:29.2640303Z [4543/6419] Linking CXX shared library bin\caffe2_nvrtc.dll

See GitHub Actions build pull / linux-bionic-cuda11.3-py3.7-clang9 / build (2/6)

Step: "Build" (full log | diagnosis details)

2022-09-06T23:04:01.6956908Z ../../../lib/libc1..._traits, std::allocator >, int, bool)'
2022-09-06T23:04:01.5186006Z Writing /var/lib/jenkins/workspace/build/third_party/onnx/onnx/onnx_onnx_torch-ml.proto3
2022-09-06T23:04:01.5197373Z [ 43%] �[34m�[1mRunning C++/Python protocol buffer compiler on /var/lib/jenkins/workspace/caffe2/proto/caffe2.proto�[0m
2022-09-06T23:04:01.5206567Z Writing /var/lib/jenkins/workspace/build/third_party/onnx/onnx/onnx-ml.pb.h
2022-09-06T23:04:01.5207154Z generating /var/lib/jenkins/workspace/build/third_party/onnx/onnx/onnx_pb.py
2022-09-06T23:04:01.5353851Z [ 43%] �[34m�[1mRunning C++ protocol buffer compiler on /var/lib/jenkins/workspace/build/third_party/onnx/onnx/onnx_onnx_torch-ml.proto�[0m
2022-09-06T23:04:01.5684950Z [ 43%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/qc8-dwconv/gen/up8x9-minmax-fp32-sse2-mul16-add16.c.o�[0m
2022-09-06T23:04:01.5739613Z [ 43%] �[32m�[1mLinking CXX executable ../../../bin/c10_cuda_CUDATest�[0m
2022-09-06T23:04:01.6147116Z [ 43%] �[32mBuilding CXX object caffe2/proto/CMakeFiles/Caffe2_PROTO.dir/torch.pb.cc.o�[0m
2022-09-06T23:04:01.6702451Z [ 43%] �[32mBuilding CXX object third_party/tensorpipe/tensorpipe/CMakeFiles/tensorpipe_cuda.dir/common/ibv.cc.o�[0m
2022-09-06T23:04:01.6869722Z [ 43%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/qc8-dwconv/gen/up8x9-minmax-fp32-sse2-mul16.c.o�[0m
2022-09-06T23:04:01.6956908Z ../../../lib/libc10_cuda.so: undefined reference to `c10::cuda::c10_cuda_check_implementation(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, bool)'
2022-09-06T23:04:01.6983089Z clang: �[0;1;31merror: �[0mlinker command failed with exit code 1 (use -v to see invocation)�[0m
2022-09-06T23:04:01.7018864Z c10/cuda/test/CMakeFiles/c10_cuda_CUDATest.dir/build.make:101: recipe for target 'bin/c10_cuda_CUDATest' failed
2022-09-06T23:04:01.7019425Z make[2]: *** [bin/c10_cuda_CUDATest] Error 1
2022-09-06T23:04:01.7041156Z CMakeFiles/Makefile2:5719: recipe for target 'c10/cuda/test/CMakeFiles/c10_cuda_CUDATest.dir/all' failed
2022-09-06T23:04:01.7041854Z make[1]: *** [c10/cuda/test/CMakeFiles/c10_cuda_CUDATest.dir/all] Error 2
2022-09-06T23:04:01.7051770Z make[1]: *** Waiting for unfinished jobs....
2022-09-06T23:04:01.7137388Z [ 43%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/qc8-dwconv/gen/up8x25-minmax-fp32-sse2-mul16-add16.c.o�[0m
2022-09-06T23:04:01.7503298Z [ 43%] Built target gen_onnx_proto
2022-09-06T23:04:01.7593766Z [ 43%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/qc8-dwconv/gen/up8x25-minmax-fp32-sse2-mul16.c.o�[0m
2022-09-06T23:04:01.8093084Z [ 43%] �[32mBuilding CXX object caffe2/proto/CMakeFiles/Caffe2_PROTO.dir/caffe2.pb.cc.o�[0m

See GitHub Actions build pull / linux-bionic-cuda11_6-py3_10-gcc7-deploy / build (3/6)

Step: "Build" (full log | diagnosis details)

2022-09-06T23:06:21.9223712Z �[0m�[1m�[31mERROR...eof ((socklen_t)))\n ^\n" }
2022-09-06T23:06:21.9217981Z �[0m�[1m�[31mERROR�[0m 2022-09-06T23:03:34Z: sccache::server: Compilation failed: Output { status: ExitStatus(unix_wait_status(256)), stdout: "", stderr: "conftest.c: In function 'main':\nconftest.c:340:2: error: 'struct sockaddr' has no member named 'sa_len'\n x.sa_len = 0;\n  ^\n" }
2022-09-06T23:06:21.9218312Z 
2022-09-06T23:06:21.9219116Z �[0m�[1m�[31mERROR�[0m 2022-09-06T23:03:37Z: sccache::server: Compilation failed: Output { status: ExitStatus(unix_wait_status(256)), stdout: "", stderr: "conftest.c: In function 'main':\nconftest.c:374:10: error: 'RTLD_MEMBER' undeclared (first use in this function); did you mean 'RTLD_NEXT'?\n   (void) RTLD_MEMBER;\n          ^~~~~~~~~~~\n          RTLD_NEXT\nconftest.c:374:10: note: each undeclared identifier is reported only once for each function it appears in\n" }
2022-09-06T23:06:21.9219603Z 
2022-09-06T23:06:21.9220412Z �[0m�[1m�[31mERROR�[0m 2022-09-06T23:03:37Z: sccache::server: Compilation failed: Output { status: ExitStatus(unix_wait_status(256)), stdout: "", stderr: "conftest.c:369:9: error: unknown type name 'not'\n         not a universal capable compiler\n         ^~~\nconftest.c:369:15: error: expected '=', ',', ';', 'asm' or '__attribute__' before 'universal'\n         not a universal capable compiler\n               ^~~~~~~~~\nconftest.c:369:15: error: unknown type name 'universal'\n" }
2022-09-06T23:06:21.9220895Z 
2022-09-06T23:06:21.9221698Z �[0m�[1m�[31mERROR�[0m 2022-09-06T23:03:37Z: sccache::server: Compilation failed: Output { status: ExitStatus(unix_wait_status(256)), stdout: "", stderr: "conftest.c: In function 'main':\nconftest.c:375:4: error: unknown type name 'not'; did you mean 'ino_t'?\n    not big endian\n    ^~~\n    ino_t\nconftest.c:375:12: error: expected '=', ',', ';', 'asm' or '__attribute__' before 'endian'\n    not big endian\n            ^~~~~~\n" }
2022-09-06T23:06:21.9222141Z 
2022-09-06T23:06:21.9222769Z �[0m�[1m�[31mERROR�[0m 2022-09-06T23:03:38Z: sccache::server: Compilation failed: Output { status: ExitStatus(unix_wait_status(256)), stdout: "", stderr: "conftest.c: In function 'main':\nconftest.c:386:4: error: 'struct stat' has no member named 'st_mtimespec'; did you mean 'st_mtim'?\n st.st_mtimespec.tv_nsec = 1;\n    ^~~~~~~~~~~~\n    st_mtim\n" }
2022-09-06T23:06:21.9223153Z 
2022-09-06T23:06:21.9223712Z �[0m�[1m�[31mERROR�[0m 2022-09-06T23:03:40Z: sccache::server: Compilation failed: Output { status: ExitStatus(unix_wait_status(256)), stdout: "", stderr: "conftest.c: In function 'main':\nconftest.c:410:24: error: expected expression before ')' token\n if (sizeof ((socklen_t)))\n                        ^\n" }
2022-09-06T23:06:21.9224054Z 
2022-09-06T23:06:21.9224307Z + echo '=========== If your build fails, please take a look at the log above for possible reasons ==========='
2022-09-06T23:06:21.9224681Z =========== If your build fails, please take a look at the log above for possible reasons ===========
2022-09-06T23:06:21.9224977Z + sccache --show-stats
2022-09-06T23:06:21.9240356Z Compile requests                   8738
2022-09-06T23:06:21.9240704Z Compile requests executed          6657
2022-09-06T23:06:21.9241021Z Cache hits                         6395
2022-09-06T23:06:21.9241293Z Cache hits (C/C++)                 6394
2022-09-06T23:06:21.9241585Z Cache hits (CUDA)                     1
2022-09-06T23:06:21.9241863Z Cache misses                        191

See GitHub Actions build pull / linux-bionic-cuda11.6-py3.10-gcc7 / build (4/6)

Step: "Build" (full log | diagnosis details)

2022-09-06T23:03:09.3369818Z collect2: error: ld returned 1 exit status
2022-09-06T23:03:08.8902588Z [ 45%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/math/roundu-sse41.c.o�[0m
2022-09-06T23:03:08.8920834Z [ 45%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/math/roundz-sse41.c.o�[0m
2022-09-06T23:03:08.9749111Z [ 45%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/qc8-dwconv/gen/up8x9-minmax-fp32-sse41-mul16-add16.c.o�[0m
2022-09-06T23:03:08.9815961Z [ 45%] �[32mBuilding CXX object c10/cuda/test/CMakeFiles/c10_cuda_CUDATest.dir/impl/CUDATest.cpp.o�[0m
2022-09-06T23:03:09.0236862Z [ 45%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/qc8-dwconv/gen/up8x9-minmax-fp32-sse41-mul16.c.o�[0m
2022-09-06T23:03:09.0723333Z Compiling  all_gather.cu                       > /var/lib/jenkins/workspace/build/nccl/obj/collectives/device/all_gather_max_bf16.o
2022-09-06T23:03:09.1225877Z [ 45%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/qc8-dwconv/gen/up8x9-minmax-fp32-sse41-mul32.c.o�[0m
2022-09-06T23:03:09.2386452Z [ 45%] �[32m�[1mLinking CXX executable ../../../bin/c10_cuda_CUDATest�[0m
2022-09-06T23:03:09.2706783Z [ 45%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/qc8-dwconv/gen/up8x25-minmax-fp32-sse41-mul16-add16.c.o�[0m
2022-09-06T23:03:09.3368878Z ../../../lib/libc10_cuda.so: undefined reference to `c10::cuda::c10_cuda_check_implementation(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, bool)'
2022-09-06T23:03:09.3369818Z collect2: error: ld returned 1 exit status
2022-09-06T23:03:09.3391383Z c10/cuda/test/CMakeFiles/c10_cuda_CUDATest.dir/build.make:101: recipe for target 'bin/c10_cuda_CUDATest' failed
2022-09-06T23:03:09.3391946Z make[2]: *** [bin/c10_cuda_CUDATest] Error 1
2022-09-06T23:03:09.3395665Z CMakeFiles/Makefile2:5719: recipe for target 'c10/cuda/test/CMakeFiles/c10_cuda_CUDATest.dir/all' failed
2022-09-06T23:03:09.3396280Z make[1]: *** [c10/cuda/test/CMakeFiles/c10_cuda_CUDATest.dir/all] Error 2
2022-09-06T23:03:09.3396557Z make[1]: *** Waiting for unfinished jobs....
2022-09-06T23:03:09.3478126Z [ 45%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/qc8-dwconv/gen/up8x25-minmax-fp32-sse41-mul16.c.o�[0m
2022-09-06T23:03:09.3835655Z Compiling  all_gather.cu                       > /var/lib/jenkins/workspace/build/nccl/obj/collectives/device/all_gather_premulsum_i8.o
2022-09-06T23:03:09.4275360Z [ 45%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/qc8-dwconv/gen/up8x25-minmax-fp32-sse41-mul32.c.o�[0m
2022-09-06T23:03:09.5784648Z [ 45%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/qc8-dwconv/gen/up16x9-minmax-fp32-sse41-mul16-add16.c.o�[0m
2022-09-06T23:03:09.6666691Z [ 45%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/qc8-dwconv/gen/up16x9-minmax-fp32-sse41-mul16.c.o�[0m

See GitHub Actions build pull / linux-jammy-cuda11.6-cudnn8-py3.8-clang12 / build (5/6)

Step: "Build" (full log | diagnosis details)

2022-09-06T23:03:44.5861408Z /usr/bin/ld: ../....._traits, std::allocator >, int, bool)'
2022-09-06T23:03:44.0890021Z [ 46%] Built target tensorpipe_cuda
2022-09-06T23:03:44.1006508Z [ 46%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/f16-f32-vcvt/gen/vcvt-sse41-int16-x32.c.o�[0m
2022-09-06T23:03:44.1277425Z [ 46%] �[32m�[1mLinking CXX static library ../../../lib/libkineto.a�[0m
2022-09-06T23:03:44.1709011Z [ 46%] Built target kineto
2022-09-06T23:03:44.1797342Z [ 46%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/f16-f32-vcvt/gen/vcvt-sse41-int32-x8.c.o�[0m
2022-09-06T23:03:44.2216405Z [ 46%] �[32mBuilding CXX object c10/cuda/test/CMakeFiles/c10_cuda_CUDATest.dir/impl/CUDATest.cpp.o�[0m
2022-09-06T23:03:44.2896540Z [ 46%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/f16-f32-vcvt/gen/vcvt-sse41-int32-x16.c.o�[0m
2022-09-06T23:03:44.4029772Z [ 46%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/f16-f32-vcvt/gen/vcvt-sse41-int32-x24.c.o�[0m
2022-09-06T23:03:44.4338906Z [ 46%] �[32m�[1mLinking CXX executable ../../../bin/c10_cuda_CUDATest�[0m
2022-09-06T23:03:44.5113890Z [ 46%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/f16-f32-vcvt/gen/vcvt-sse41-int32-x32.c.o�[0m
2022-09-06T23:03:44.5861408Z /usr/bin/ld: ../../../lib/libc10_cuda.so: undefined reference to `c10::cuda::c10_cuda_check_implementation(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, bool)'
2022-09-06T23:03:44.5895544Z clang: �[0;1;31merror: �[0m�[1mlinker command failed with exit code 1 (use -v to see invocation)�[0m
2022-09-06T23:03:44.5940014Z gmake[2]: *** [c10/cuda/test/CMakeFiles/c10_cuda_CUDATest.dir/build.make:102: bin/c10_cuda_CUDATest] Error 1
2022-09-06T23:03:44.5941716Z gmake[1]: *** [CMakeFiles/Makefile2:5666: c10/cuda/test/CMakeFiles/c10_cuda_CUDATest.dir/all] Error 2
2022-09-06T23:03:44.5943071Z gmake[1]: *** Waiting for unfinished jobs....
2022-09-06T23:03:44.6031127Z [ 46%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/f32-f16-vcvt/gen/vcvt-sse41-x8.c.o�[0m
2022-09-06T23:03:44.6132336Z Compiling  reduce.cu                           > /var/lib/jenkins/workspace/build/nccl/obj/collectives/device/reduce_min_i64.o
2022-09-06T23:03:44.7005662Z [ 46%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/f32-f16-vcvt/gen/vcvt-sse41-x16.c.o�[0m
2022-09-06T23:03:44.7885108Z [ 46%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/f32-f16-vcvt/gen/vcvt-sse41-x24.c.o�[0m
2022-09-06T23:03:44.8762706Z [ 46%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/f32-f16-vcvt/gen/vcvt-sse41-x32.c.o�[0m
2022-09-06T23:03:44.9733030Z [ 46%] �[32mBuilding C object confu-deps/XNNPACK/CMakeFiles/all_microkernels.dir/src/f32-prelu/gen/sse41-2x4.c.o�[0m

See GitHub Actions build pull / linux-bionic-cuda11.6-py3.10-gcc7-bazel-test / build-and-test (6/6)

Step: "Build" (full log | diagnosis details)

2022-09-06T23:01:08.9008050Z gcc: error: language cu not recognized
2022-09-06T23:01:08.4294890Z Unzipped bazel-out/k8-fastbuild/bin/mnist/t10k-images-idx3-ubyte.gz ...
2022-09-06T23:01:08.4295275Z Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz ...
2022-09-06T23:01:08.4295300Z 
2022-09-06T23:01:08.4295446Z 0% |                                                                | 0%
2022-09-06T23:01:08.4295609Z 0% |################################################################| 100%
2022-09-06T23:01:08.4295985Z Unzipped bazel-out/k8-fastbuild/bin/mnist/t10k-labels-idx1-ubyte.gz ...
2022-09-06T23:01:08.9006693Z �[31m�[1mERROR: �[0m/var/lib/jenkins/workspace/c10/cuda/test/BUILD.bazel:4:15: Compiling c10/cuda/test/impl/CUDAAssertionsTest.cu failed: (Exit 1): gcc failed: error executing command /opt/cache/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer '-std=c++11' -MD -MF ... (remaining 65 argument(s) skipped)
2022-09-06T23:01:08.9007384Z 
2022-09-06T23:01:08.9007582Z Use --sandbox_debug to see verbose messages from the sandbox
2022-09-06T23:01:08.9007842Z gcc: error: language cu not recognized
2022-09-06T23:01:08.9008050Z gcc: error: language cu not recognized
2022-09-06T23:01:09.0709835Z �[32mINFO: �[0mElapsed time: 150.384s, Critical Path: 28.32s
2022-09-06T23:01:09.0710534Z �[32mINFO: �[0m1878 processes: 302 internal, 1 local, 1575 processwrapper-sandbox.
2022-09-06T23:01:09.0711144Z �[31m�[1mFAILED:�[0m Build did NOT complete successfully
2022-09-06T23:01:09.0750384Z �[31m�[1mFAILED:�[0m Build did NOT complete successfully
2022-09-06T23:01:09.0833916Z �[0m+ sccache_epilogue
2022-09-06T23:01:09.0834486Z + echo '::group::Sccache Compilation Log'
2022-09-06T23:01:09.0835426Z ##[group]Sccache Compilation Log
2022-09-06T23:01:09.0835991Z + echo '=================== sccache compilation log ==================='
2022-09-06T23:01:09.0836443Z =================== sccache compilation log ===================
2022-09-06T23:01:09.0837211Z + python /var/lib/jenkins/workspace/.jenkins/pytorch/print_sccache_log.py /var/lib/jenkins/sccache_error.log

🚧 2 fixed upstream failures:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D37621532

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D37621532

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 7, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/84609

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 19934bd:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@soumith
Copy link
Contributor

soumith commented Sep 8, 2022

this is pretty cool!

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D37621532

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D37621532

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D37621532

@facebook-github-bot
Copy link
Contributor

@r-barnes has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D37621532

r-barnes added a commit to r-barnes/pytorch that referenced this pull request Dec 6, 2022
Summary:
This diff introduces a set of changes that makes it possible for the host to get assertions from CUDA devices. This includes the introduction of

**`CUDA_KERNEL_ASSERT2`**

A preprocessor macro to be used within a CUDA kernel that, upon an assertion failure, writes the assertion message, file, line number, and possibly other information to UVM (Managed memory). Once this is done, the original assertion is triggered, which places the GPU in a Bad State requiring recovery. In my tests, data written to UVM appears there before the GPU reaches the Bad State and is still accessible from the host after the GPU is in this state.

Messages are written to a multi-message buffer which can, in theory, hold many assertion failures. I've done this as a precaution in case there are several, but I don't actually know whether that is possible and a simpler design which holds only a single message may well be all that is necessary.

**`TORCH_DSA_KERNEL_ARGS`**

This preprocess macro is added as an _argument_ to a kernel function's signature. It expands to supply the standardized names of all the arguments needed by `C10_CUDA_COMMUNICATING_KERNEL_ASSERTION` to handle device-side assertions. This includes, eg, the name of the pointer to the UVM memory the assertion would be written to. This macro abstracts the arguments so there is a single point of change if the system needs to be modified.

**`c10::cuda::get_global_cuda_kernel_launch_registry()`**

This host-side function returns a singleton object that manages the host's part of the device-side assertions. Upon allocation, the singleton allocates sufficient UVM (Managed) memory to hold information about several device-side assertion failures. The singleton also provides methods for getting the current traceback (used to identify when a kernel was launched). To avoid consuming all the host's memory the singleton stores launches in a circular buffer; a unique "generation number" is used to ensure that kernel launch failures map to their actual launch points (in the case that the circular buffer wraps before the failure is detected).

**`TORCH_DSA_KERNEL_LAUNCH`**

This host-side preprocessor macro replaces the standard
```
kernel_name<<<blocks, threads, shmem, stream>>>(args)
```
invocation with
```
TORCH_DSA_KERNEL_LAUNCH(blocks, threads, shmem, stream, args);
```
Internally, it fetches the UVM (Managed) pointer and generation number from the singleton and append these to the standard argument list. It also checks to ensure the kernel launches correctly. This abstraction on kernel launches can be modified to provide additional safety/logging.

**`c10::cuda::c10_retrieve_device_side_assertion_info`**
This host-side function checks, when called, that no kernel assertions have occurred. If one has. It then raises an exception with:
1. Information (file, line number) of what kernel was launched.
2. Information (file, line number, message) about the device-side assertion
3. Information (file, line number) about where the failure was detected.

**Checking for device-side assertions**

Device-side assertions are most likely to be noticed by the host when a CUDA API call such as `cudaDeviceSynchronize` is made and fails with a `cudaError_t` indicating
> CUDA error: device-side assert triggered CUDA kernel errors

Therefore, we rewrite `C10_CUDA_CHECK()` to include a call to `c10_retrieve_device_side_assertion_info()`. To make the code cleaner, most of the logic of `C10_CUDA_CHECK()` is now contained within a new function `c10_cuda_check_implementation()` to which `C10_CUDA_CHECK` passes the preprocessor information about filenames, function names, and line numbers. (In C++20 we can use `std::source_location` to eliminate macros entirely!)

# Notes on special cases

* Multiple assertions from the same block are recorded
* Multiple assertions from different blocks are recorded
* Launching kernels from many threads on many streams seems to be handled correctly
* If two process are using the same GPU and one of the processes fails with a device-side assertion the other process continues without issue
* X Multiple assertions from separate kernels on different streams seem to be recorded, but we can't reproduce the test condition
* X Multiple assertions from separate devices should be all be shown upon exit, but we've been unable to generate a test that produces this condition

Pull Request resolved: pytorch#84609

Reviewed By: ezyang

Differential Revision: D37621532

Pulled By: r-barnes

fbshipit-source-id: eacd53618c190f6d76caf2ab3928dfd68d92a85e
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D37621532

r-barnes added a commit to r-barnes/pytorch that referenced this pull request Dec 7, 2022
Summary:
This diff introduces a set of changes that makes it possible for the host to get assertions from CUDA devices. This includes the introduction of

**`CUDA_KERNEL_ASSERT2`**

A preprocessor macro to be used within a CUDA kernel that, upon an assertion failure, writes the assertion message, file, line number, and possibly other information to UVM (Managed memory). Once this is done, the original assertion is triggered, which places the GPU in a Bad State requiring recovery. In my tests, data written to UVM appears there before the GPU reaches the Bad State and is still accessible from the host after the GPU is in this state.

Messages are written to a multi-message buffer which can, in theory, hold many assertion failures. I've done this as a precaution in case there are several, but I don't actually know whether that is possible and a simpler design which holds only a single message may well be all that is necessary.

**`TORCH_DSA_KERNEL_ARGS`**

This preprocess macro is added as an _argument_ to a kernel function's signature. It expands to supply the standardized names of all the arguments needed by `C10_CUDA_COMMUNICATING_KERNEL_ASSERTION` to handle device-side assertions. This includes, eg, the name of the pointer to the UVM memory the assertion would be written to. This macro abstracts the arguments so there is a single point of change if the system needs to be modified.

**`c10::cuda::get_global_cuda_kernel_launch_registry()`**

This host-side function returns a singleton object that manages the host's part of the device-side assertions. Upon allocation, the singleton allocates sufficient UVM (Managed) memory to hold information about several device-side assertion failures. The singleton also provides methods for getting the current traceback (used to identify when a kernel was launched). To avoid consuming all the host's memory the singleton stores launches in a circular buffer; a unique "generation number" is used to ensure that kernel launch failures map to their actual launch points (in the case that the circular buffer wraps before the failure is detected).

**`TORCH_DSA_KERNEL_LAUNCH`**

This host-side preprocessor macro replaces the standard
```
kernel_name<<<blocks, threads, shmem, stream>>>(args)
```
invocation with
```
TORCH_DSA_KERNEL_LAUNCH(blocks, threads, shmem, stream, args);
```
Internally, it fetches the UVM (Managed) pointer and generation number from the singleton and append these to the standard argument list. It also checks to ensure the kernel launches correctly. This abstraction on kernel launches can be modified to provide additional safety/logging.

**`c10::cuda::c10_retrieve_device_side_assertion_info`**
This host-side function checks, when called, that no kernel assertions have occurred. If one has. It then raises an exception with:
1. Information (file, line number) of what kernel was launched.
2. Information (file, line number, message) about the device-side assertion
3. Information (file, line number) about where the failure was detected.

**Checking for device-side assertions**

Device-side assertions are most likely to be noticed by the host when a CUDA API call such as `cudaDeviceSynchronize` is made and fails with a `cudaError_t` indicating
> CUDA error: device-side assert triggered CUDA kernel errors

Therefore, we rewrite `C10_CUDA_CHECK()` to include a call to `c10_retrieve_device_side_assertion_info()`. To make the code cleaner, most of the logic of `C10_CUDA_CHECK()` is now contained within a new function `c10_cuda_check_implementation()` to which `C10_CUDA_CHECK` passes the preprocessor information about filenames, function names, and line numbers. (In C++20 we can use `std::source_location` to eliminate macros entirely!)

# Notes on special cases

* Multiple assertions from the same block are recorded
* Multiple assertions from different blocks are recorded
* Launching kernels from many threads on many streams seems to be handled correctly
* If two process are using the same GPU and one of the processes fails with a device-side assertion the other process continues without issue
* X Multiple assertions from separate kernels on different streams seem to be recorded, but we can't reproduce the test condition
* X Multiple assertions from separate devices should be all be shown upon exit, but we've been unable to generate a test that produces this condition

Pull Request resolved: pytorch#84609

Reviewed By: ezyang

Differential Revision: D37621532

Pulled By: r-barnes

fbshipit-source-id: cfa5410f58773c6a88f236c827da3a32b4295c96
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D37621532

Summary:
This diff introduces a set of changes that makes it possible for the host to get assertions from CUDA devices. This includes the introduction of

**`CUDA_KERNEL_ASSERT2`**

A preprocessor macro to be used within a CUDA kernel that, upon an assertion failure, writes the assertion message, file, line number, and possibly other information to UVM (Managed memory). Once this is done, the original assertion is triggered, which places the GPU in a Bad State requiring recovery. In my tests, data written to UVM appears there before the GPU reaches the Bad State and is still accessible from the host after the GPU is in this state.

Messages are written to a multi-message buffer which can, in theory, hold many assertion failures. I've done this as a precaution in case there are several, but I don't actually know whether that is possible and a simpler design which holds only a single message may well be all that is necessary.

**`TORCH_DSA_KERNEL_ARGS`**

This preprocess macro is added as an _argument_ to a kernel function's signature. It expands to supply the standardized names of all the arguments needed by `C10_CUDA_COMMUNICATING_KERNEL_ASSERTION` to handle device-side assertions. This includes, eg, the name of the pointer to the UVM memory the assertion would be written to. This macro abstracts the arguments so there is a single point of change if the system needs to be modified.

**`c10::cuda::get_global_cuda_kernel_launch_registry()`**

This host-side function returns a singleton object that manages the host's part of the device-side assertions. Upon allocation, the singleton allocates sufficient UVM (Managed) memory to hold information about several device-side assertion failures. The singleton also provides methods for getting the current traceback (used to identify when a kernel was launched). To avoid consuming all the host's memory the singleton stores launches in a circular buffer; a unique "generation number" is used to ensure that kernel launch failures map to their actual launch points (in the case that the circular buffer wraps before the failure is detected).

**`TORCH_DSA_KERNEL_LAUNCH`**

This host-side preprocessor macro replaces the standard
```
kernel_name<<<blocks, threads, shmem, stream>>>(args)
```
invocation with
```
TORCH_DSA_KERNEL_LAUNCH(blocks, threads, shmem, stream, args);
```
Internally, it fetches the UVM (Managed) pointer and generation number from the singleton and append these to the standard argument list. It also checks to ensure the kernel launches correctly. This abstraction on kernel launches can be modified to provide additional safety/logging.

**`c10::cuda::c10_retrieve_device_side_assertion_info`**
This host-side function checks, when called, that no kernel assertions have occurred. If one has. It then raises an exception with:
1. Information (file, line number) of what kernel was launched.
2. Information (file, line number, message) about the device-side assertion
3. Information (file, line number) about where the failure was detected.

**Checking for device-side assertions**

Device-side assertions are most likely to be noticed by the host when a CUDA API call such as `cudaDeviceSynchronize` is made and fails with a `cudaError_t` indicating
> CUDA error: device-side assert triggered CUDA kernel errors

Therefore, we rewrite `C10_CUDA_CHECK()` to include a call to `c10_retrieve_device_side_assertion_info()`. To make the code cleaner, most of the logic of `C10_CUDA_CHECK()` is now contained within a new function `c10_cuda_check_implementation()` to which `C10_CUDA_CHECK` passes the preprocessor information about filenames, function names, and line numbers. (In C++20 we can use `std::source_location` to eliminate macros entirely!)

# Notes on special cases

* Multiple assertions from the same block are recorded
* Multiple assertions from different blocks are recorded
* Launching kernels from many threads on many streams seems to be handled correctly
* If two process are using the same GPU and one of the processes fails with a device-side assertion the other process continues without issue
* X Multiple assertions from separate kernels on different streams seem to be recorded, but we can't reproduce the test condition
* X Multiple assertions from separate devices should be all be shown upon exit, but we've been unable to generate a test that produces this condition

Pull Request resolved: pytorch#84609

Reviewed By: ezyang

Differential Revision: D37621532

Pulled By: r-barnes

fbshipit-source-id: efdfc57d6a1fa6dadfe30e693157e4cc040c7191
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D37621532

@facebook-github-bot
Copy link
Contributor

@r-barnes has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@r-barnes
Copy link
Contributor Author

r-barnes commented Dec 8, 2022

@pytorchbot merge -f "Internal changes is incorrect."

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator

Details for Dev Infra team Raised by workflow job

@r-barnes
Copy link
Contributor Author

r-barnes commented Dec 8, 2022

@pytorchbot merge -f "No internal changes"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator

Details for Dev Infra team Raised by workflow job

@r-barnes
Copy link
Contributor Author

r-barnes commented Dec 8, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Summary:
This diff introduces a set of changes that makes it possible for the host to get assertions from CUDA devices. This includes the introduction of

**`CUDA_KERNEL_ASSERT2`**

A preprocessor macro to be used within a CUDA kernel that, upon an assertion failure, writes the assertion message, file, line number, and possibly other information to UVM (Managed memory). Once this is done, the original assertion is triggered, which places the GPU in a Bad State requiring recovery. In my tests, data written to UVM appears there before the GPU reaches the Bad State and is still accessible from the host after the GPU is in this state.

Messages are written to a multi-message buffer which can, in theory, hold many assertion failures. I've done this as a precaution in case there are several, but I don't actually know whether that is possible and a simpler design which holds only a single message may well be all that is necessary.

**`TORCH_DSA_KERNEL_ARGS`**

This preprocess macro is added as an _argument_ to a kernel function's signature. It expands to supply the standardized names of all the arguments needed by `C10_CUDA_COMMUNICATING_KERNEL_ASSERTION` to handle device-side assertions. This includes, eg, the name of the pointer to the UVM memory the assertion would be written to. This macro abstracts the arguments so there is a single point of change if the system needs to be modified.

**`c10::cuda::get_global_cuda_kernel_launch_registry()`**

This host-side function returns a singleton object that manages the host's part of the device-side assertions. Upon allocation, the singleton allocates sufficient UVM (Managed) memory to hold information about several device-side assertion failures. The singleton also provides methods for getting the current traceback (used to identify when a kernel was launched). To avoid consuming all the host's memory the singleton stores launches in a circular buffer; a unique "generation number" is used to ensure that kernel launch failures map to their actual launch points (in the case that the circular buffer wraps before the failure is detected).

**`TORCH_DSA_KERNEL_LAUNCH`**

This host-side preprocessor macro replaces the standard
```
kernel_name<<<blocks, threads, shmem, stream>>>(args)
```
invocation with
```
TORCH_DSA_KERNEL_LAUNCH(blocks, threads, shmem, stream, args);
```
Internally, it fetches the UVM (Managed) pointer and generation number from the singleton and append these to the standard argument list. It also checks to ensure the kernel launches correctly. This abstraction on kernel launches can be modified to provide additional safety/logging.

**`c10::cuda::c10_retrieve_device_side_assertion_info`**
This host-side function checks, when called, that no kernel assertions have occurred. If one has. It then raises an exception with:
1. Information (file, line number) of what kernel was launched.
2. Information (file, line number, message) about the device-side assertion
3. Information (file, line number) about where the failure was detected.

**Checking for device-side assertions**

Device-side assertions are most likely to be noticed by the host when a CUDA API call such as `cudaDeviceSynchronize` is made and fails with a `cudaError_t` indicating
> CUDA error: device-side assert triggered CUDA kernel errors

Therefore, we rewrite `C10_CUDA_CHECK()` to include a call to `c10_retrieve_device_side_assertion_info()`. To make the code cleaner, most of the logic of `C10_CUDA_CHECK()` is now contained within a new function `c10_cuda_check_implementation()` to which `C10_CUDA_CHECK` passes the preprocessor information about filenames, function names, and line numbers. (In C++20 we can use `std::source_location` to eliminate macros entirely!)

# Notes on special cases

* Multiple assertions from the same block are recorded
* Multiple assertions from different blocks are recorded
* Launching kernels from many threads on many streams seems to be handled correctly
* If two process are using the same GPU and one of the processes fails with a device-side assertion the other process continues without issue
* X Multiple assertions from separate kernels on different streams seem to be recorded, but we can't reproduce the test condition
* X Multiple assertions from separate devices should be all be shown upon exit, but we've been unable to generate a test that produces this condition

Differential Revision: D37621532

Pull Request resolved: pytorch#84609
Approved by: https://github.com/ezyang, https://github.com/malfet
@r-barnes r-barnes mentioned this pull request Jun 6, 2023
@ppwwyyxx
Copy link
Collaborator

Why is this feature behind a compile-time flag TORCH_USE_CUDA_DSA that's default disabled? Does it have a significant runtime overhead even when not enabled at runtime, but just built in compile-time ?

The release not enabling this flag by default means that most users are not benefiting from this great feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request cla signed fb-exported Merged release notes: cuda release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.