Skip to content

Conversation

@peterbell10
Copy link
Collaborator

Fixes #21821

This follows @ngimel's suggestion to manually synchronize MAGMA calls with the current stream. This is handled automatically with MagmaStreamSyncGuard.

I think for the functions with _batched variants we could possibly avoid synchronisation by using a batch of size 1 since these have a magma_queue_t argument. However, I presume there's a reason it wasn't written like that in the first place.

I also figured out why porting to aten "magically fixed" torch.svd. The magma functions for svd all take host arrays as input and output. The ATen port uses blocking copy_s which fully synchronize the operation. On the other hand, the THC functions use cudaMemcpy which doesn't synchronize with streams created with cudaStreamNonBlocking (which aten does). The fix is to use cudaMemcpyAsync and cudaStreamSynchronize, the same as copy_ does internally:

AT_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
AT_CUDA_CHECK(cudaStreamSynchronize(stream));

I'm not sure how to test these changes as I wasn't able to reproduce any of the stream sync issues. Possibly a mixture of non-determinism and because some of these functions are implicitly synchronous anyway.

@dr-ci
Copy link

dr-ci bot commented Apr 14, 2020

💊 Build failures summary and remediations

As of commit 428bdde (more details on the Dr. CI page):


  • 3/3 failures possibly* introduced in this PR
    • 1/3 non-CircleCI failure(s)

🕵️ 2 new failures recognized by patterns

The following build failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_build (1/2)

Step: "Build" (full log | pattern match details | 🔁 rerun) <confirmed not flaky by 2 failures>

Apr 22 23:32:45 torch_xla/csrc/aten_xla_type_default.cpp:5:10: fatal error: 'torch/library.h' file not found
ackages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/torch_util.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/torch_util.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
6/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/random.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/random.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
ges/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/ir_dump_util.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/ir_dump_util.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
site-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/pooling.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/pooling.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
kages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/tensor_impl.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/tensor_impl.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
kages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/python_util.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/python_util.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
6/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/device.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/device.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
e/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/init_python_bindings.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/init_python_bindings.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
ackages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/debug_util.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/debug_util.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/aten_xla_type_default.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/aten_xla_type_default.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
Apr 22 23:32:45 torch_xla/csrc/aten_xla_type_default.cpp:5:10: fatal error: 'torch/library.h' file not found 
Apr 22 23:32:45 #include <torch/library.h> 
Apr 22 23:32:45          ^~~~~~~~~~~~~~~~~ 
Apr 22 23:32:48 1 error generated. 
ackages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/batch_norm.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/batch_norm.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/layout_manager.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/layout_manager.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
on3.6/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/view.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/view.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
te-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/nll_loss.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/nll_loss.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/reduction.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/reduction.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/tensor_methods.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/tensor_methods.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 
torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.6/site-packages/torch/include/TH -I/opt/conda/lib/python3.6/site-packages/torch/include/THC -I/opt/conda/include/python3.6m -c torch_xla/csrc/xla_lower_util.cpp -o build/temp.linux-x86_64-3.6/torch_xla/csrc/xla_lower_util.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1 

See CircleCI build pytorch_macos_10_13_py3_test (2/2)

Step: "Test" (full log | pattern match details | 🔁 rerun) <confirmed not flaky by 2 failures>

Apr 22 16:36:16 RuntimeError: test_autograd failed!
Apr 22 16:36:15 Generating XML reports... 
Apr 22 16:36:15 Generated XML report: test-reports/python-unittest/TEST-TestAutograd-20200422163233.xml 
Apr 22 16:36:15 Generated XML report: test-reports/python-unittest/TEST-TestAutogradDeviceTypeCPU-20200422163233.xml 
Apr 22 16:36:15 Generated XML report: test-reports/python-unittest/TEST-TestAutogradFunctional-20200422163233.xml 
Apr 22 16:36:15 Generated XML report: test-reports/python-unittest/TEST-TestMultithreadAutograd-20200422163233.xml 
Apr 22 16:36:16 Traceback (most recent call last): 
Apr 22 16:36:16   File "test/run_test.py", line 699, in <module> 
Apr 22 16:36:16     main() 
Apr 22 16:36:16   File "test/run_test.py", line 692, in main 
Apr 22 16:36:16     raise RuntimeError(message) 
Apr 22 16:36:16 RuntimeError: test_autograd failed! 
Apr 22 16:36:16 + cleanup 
Apr 22 16:36:16 + retcode=1 
Apr 22 16:36:16 + set +x 

Extra GitHub checks: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 12 times.

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There probably wasn't any reason why _batched routines with magma_queue_t argument were not used. Likely when these bindings were written magma did not yet support queue argument. cc @vishwakftw in case he knows.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the error returned by magma kernels was non-sticky, it won't be checked (because cudaStreamSynchronize will likely succeed). Also, you can use getDefaultCUDAStream here instead of 0 here and above?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is functionally correct, and in most cases magma performance is pretty bad anyway so probably we should not micro-optimize, but if pytorch is currently on default stream you don't need a synchronization here, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to complete before the host memory is freed immediately afterwords.

Copy link
Collaborator Author

@peterbell10 peterbell10 Apr 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See its use on line 296.

@vishwakftw
Copy link
Contributor

Re: _batched variants, there is a MAGMAQueue RAII wrapper which implicitly creates a magma_ queue_t object with the appropriate device index, and relevant streams. This was introduced in #9949, and can be found at aten/src/ATen/native/cuda/MiscUtils.h.

@peterbell10
Copy link
Collaborator Author

There probably wasn't any reason why _batched routines with magma_queue_t argument were not used. Likely when these bindings were written magma did not yet support queue argument.

If you look at apply_solve for example, both variants are used in the same function depending on if the input is batched.

@ezyang ezyang removed their request for review April 15, 2020 14:20
@ezyang
Copy link
Contributor

ezyang commented Apr 15, 2020

I'm going to let @ngimel field this one, lmk if there is anything specific I should look at

@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 18, 2020
@peterbell10
Copy link
Collaborator Author

Okay, I've tried out substituting single matrix calls with batched calls. The results are not good. Many tests now fail because of illegal memory accesses (#26996). That issue suggests fixing this would require batch sizes to be rounded up, which would mean doing a lot of unnecessary work.

I've also had a look into the magma source code and it strongly warns against using the _batched versions with large matrices:
https://github.com/maxhutch/magma/blob/e63a6c0fde9170896125de3b6437f7ce0f96837a/src/dpotrf_batched.cpp#L97-L100

So, I don't think it makes sense to use the batched versions here.

@vishwakftw
Copy link
Contributor

@peterbell10 just FYI, the latest magma source is available on bitbucket (icl/magma)

@ngimel
Copy link
Collaborator

ngimel commented Apr 22, 2020

Ok, let's not use batched versions. Recently we had another issue with batched versions with large matrices being buggy. Can you please rebase so that windows tests have a chance to run? Once CI is green we can merge.

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@peterbell10
Copy link
Collaborator Author

Rebased on viable/strict. Windows tests have run and passed, remaining failures are all unrelated to MAGMA.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel
Copy link
Collaborator

ngimel commented Apr 23, 2020

Thank you!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in a51f047.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Stream safety of MAGMA functions

6 participants