Skip to content

Commit fda0340

Browse files
Ailing Zhangfacebook-github-bot
authored andcommitted
add device to CUDAEvent (#9415)
Summary: This PR add a device_ member to CUDAEvent. This is necessary since if we create a cudaEvent on one device but destroy it from another, it also creates an additional context on that device. So this device information is needed to guard the cudaEventDestroy. (cc: ngimel is this expected behavior? I can provide a simple cu script to repro this). c10d tests are probably not in CI yet, please let me know how the test are run and I could double check. Thanks pietern apaszke for help debugging! Pull Request resolved: #9415 Reviewed By: apaszke Differential Revision: D8839688 Pulled By: ailzhang fbshipit-source-id: b950ba37d57b9e3c5fe71726ec92f6a9601c4d0e
1 parent a4f6357 commit fda0340

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

torch/lib/c10d/CUDAUtils.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,20 @@
55
namespace c10d {
66

77
CUDAEvent CUDAEvent::create(unsigned int flags) {
8-
CUDAEvent event;
8+
int current_device;
9+
C10D_CUDA_CHECK(cudaGetDevice(&current_device));
10+
CUDAEvent event(nullptr, current_device);
11+
912
C10D_CUDA_CHECK(cudaEventCreateWithFlags(&event.event_, flags));
1013
return event;
1114
}
1215

1316
CUDAEvent::~CUDAEvent() {
1417
if (event_ != nullptr) {
18+
// cudaEventDestroy must run on the same device of the event,
19+
// otherwise it creates a context on default device as well.
20+
at::DeviceGuard guard(device_);
21+
1522
C10D_CUDA_CHECK(cudaEventDestroy(event_));
1623
}
1724
}

torch/lib/c10d/CUDAUtils.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ namespace c10d {
1212
// RAII wrapper for CUDA events.
1313
class CUDAEvent {
1414
public:
15-
CUDAEvent(cudaEvent_t event) : event_(event) {}
15+
CUDAEvent(cudaEvent_t event, int device) : device_(device), event_(event) {}
1616

17-
CUDAEvent() : CUDAEvent(nullptr) {}
17+
CUDAEvent() : CUDAEvent(nullptr, 0) {}
1818

1919
~CUDAEvent();
2020

@@ -27,19 +27,26 @@ class CUDAEvent {
2727
// Must be move constructable.
2828
CUDAEvent(CUDAEvent&& other) {
2929
std::swap(event_, other.event_);
30+
std::swap(device_, other.device_);
3031
}
3132

3233
// Must be move assignable.
3334
CUDAEvent& operator=(CUDAEvent&& other) {
3435
std::swap(event_, other.event_);
36+
std::swap(device_, other.device_);
3537
return *this;
3638
}
3739

3840
cudaEvent_t getEvent() const {
3941
return event_;
4042
}
4143

44+
int getDevice() const {
45+
return device_;
46+
}
47+
4248
protected:
49+
int device_;
4350
cudaEvent_t event_;
4451
};
4552

0 commit comments

Comments
 (0)