Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion torch/lib/c10d/CUDAUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@
namespace c10d {

CUDAEvent CUDAEvent::create(unsigned int flags) {
CUDAEvent event;
int current_device;
C10D_CUDA_CHECK(cudaGetDevice(&current_device));
CUDAEvent event(nullptr, current_device);

C10D_CUDA_CHECK(cudaEventCreateWithFlags(&event.event_, flags));
return event;
}

CUDAEvent::~CUDAEvent() {
if (event_ != nullptr) {
// cudaEventDestroy must run on the same device of the event,
// otherwise it creates a context on default device as well.
at::DeviceGuard guard(device_);

C10D_CUDA_CHECK(cudaEventDestroy(event_));
}
}
Expand Down
11 changes: 9 additions & 2 deletions torch/lib/c10d/CUDAUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ namespace c10d {
// RAII wrapper for CUDA events.
class CUDAEvent {
public:
CUDAEvent(cudaEvent_t event) : event_(event) {}
CUDAEvent(cudaEvent_t event, int device) : device_(device), event_(event) {}

CUDAEvent() : CUDAEvent(nullptr) {}
CUDAEvent() : CUDAEvent(nullptr, 0) {}

~CUDAEvent();

Expand All @@ -27,19 +27,26 @@ class CUDAEvent {
// Must be move constructable.
CUDAEvent(CUDAEvent&& other) {
std::swap(event_, other.event_);
std::swap(device_, other.device_);
}

// Must be move assignable.
CUDAEvent& operator=(CUDAEvent&& other) {
std::swap(event_, other.event_);
std::swap(device_, other.device_);
return *this;
}

cudaEvent_t getEvent() const {
return event_;
}

int getDevice() const {
return device_;
}

protected:
int device_;
cudaEvent_t event_;
};

Expand Down