Skip to content

Commit a2afad2

Browse files
mruberryfacebook-github-bot
authored andcommitted
Improves ATen CUDAEvent (#11293)
Summary: After submitting PR #9726, PR #10581 created a different CUDAEvent class. The CUDAEvent proposed in #9726 was similar to the c10d::CUDAEvent class with additional testing and functionality. In particular, it was movable but not copyable. The CUDAEvent created by #10581 is refcounted and copyable. This PR retains the refcounting of the latter PR while fixing several bugs, adding tests, and extending the functionality to support testing and usage like in PR #8354. In particular, this PR: - Adds set_device() to CUDAContext - Adds three CUDAEvent tests to stream_test.cpp - Fixes three bugs: - Refcounting was broken. Destroying an of the RAIIs holding a particular CUDAEvent would destroy the event UNLESS it was the last RAII (the check was backwards). - Moving an event would cause a segfault. - Events were not destroyed on the device they were created on. See PR #9415 (pietern) - Adds the happened() and recordOnce() functions - Changes the record() functions to not be const - Adds additional assertions to verify correctness This PR does not: - Make c10d use the ATen CUDAEvent (this is appropriate for a separate PR) Whether events should be refcounted is an interesting question. It adds some atomic operations and makes event creation eager. Making events movable but not copyable (like the c10d events) avoids these costs and allows events to be lazily constructed. Lazy construction is preferable when working with containers (like std::array or std::vector) and because the event's device can be set automatically to the first stream it's recorded on. With eager construction the user is required to understand that events have a device and acquire the device of the stream the event will be recorded on upfront. This can be seen here: https://github.com/pytorch/pytorch/blob/542aadd9a7609892e207c1e15de08a975b697752/aten/src/ATen/native/cudnn/RNN.cpp#L1130-L1132 and that file is the only one which currently uses the ATen CUDAEvent. Refcounting does allow single writer multi-reader scenarios, although these scenarios can be also be supported by providing indirect access to the underlying CUDAEvent. I believe all current and planned usage scenarios do not require refcounting, and if desired I can update this PR to remove refcounting and make the ATen event movable but not copyable like the c10d event. I think not refcounting is preferable because it can improve performance, ease usability, and simplify the code (as seen with two of the above bugs). I have decided to separate this from PR #8354 since while it's required for PR #8354 the changes are, clearly, of independent interest. PR #8354 has a new dependency on this one, however. I am closing PR #9726 in favor of this PR. apaszke ezyang pietern Pull Request resolved: #11293 Differential Revision: D9665836 Pulled By: soumith fbshipit-source-id: a1513fa4f9761e2f304d126e402f6b6950e1c1d2
1 parent b3b1e76 commit a2afad2

File tree

7 files changed

+137
-117
lines changed

7 files changed

+137
-117
lines changed

aten/src/ATen/cuda/CUDAContext.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ int64_t current_device() {
1616
return cur_device;
1717
}
1818

19+
void set_device(int64_t device) {
20+
AT_CUDA_CHECK(cudaSetDevice((int)device));
21+
}
22+
1923
cudaDeviceProp* getCurrentDeviceProperties() {
2024
return THCState_getCurrentDeviceProperties(at::globalContext().getTHCState());
2125
}

aten/src/ATen/cuda/CUDAContext.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ AT_API int64_t getNumGPUs();
3939

4040
AT_API int64_t current_device();
4141

42+
AT_API void set_device(int64_t device);
43+
4244
AT_API cudaDeviceProp* getCurrentDeviceProperties();
4345

4446
AT_API cudaDeviceProp* getDeviceProperties(int64_t device);

aten/src/ATen/cuda/CUDAEvent.cpp

Lines changed: 0 additions & 66 deletions
This file was deleted.

aten/src/ATen/cuda/CUDAEvent.h

Lines changed: 86 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,116 @@
11
#pragma once
22

3-
#include <cstdint>
4-
#include <utility>
3+
#include "ATen/cuda/ATenCUDAGeneral.h"
4+
#include "ATen/cuda/CUDAStream.h"
5+
#include "ATen/cuda/CUDAContext.h"
6+
#include "ATen/cuda/Exceptions.h"
7+
#include "ATen/core/Error.h"
8+
#include "ATen/DeviceGuard.h"
59

610
#include "cuda_runtime_api.h"
711

8-
#include <ATen/core/ATenGeneral.h>
9-
#include <ATen/Error.h>
12+
#include <cstdint>
13+
#include <utility>
14+
15+
namespace at { namespace cuda {
1016

1117
/*
12-
* A CUDA event interface with no CUDA build dependency.
18+
* CUDAEvents are movable not copyable wrappers around CUDA's events.
1319
*
14-
* Includes the CUDAEvent RAII class and a pointer-based event API.
20+
* CUDAEvents are constructed lazily when recorded on streams. The events
21+
* have a device, and this device is acquired from the first recording stream.
22+
* Later streams that record to the event must share this device, but streams
23+
* on any device can wait on the event.
1524
*/
16-
17-
struct CUDAEventInternals;
18-
19-
namespace at {
20-
namespace cuda {
21-
22-
struct CUDAStream;
23-
24-
namespace detail {
25-
26-
// Pointer-based API (for internal use)
27-
// Note: ATen/Context is preferred to work with streams safely
28-
AT_API CUDAEventInternals* CUDAEvent_create(unsigned int flags);
29-
AT_API void CUDAEvent_retain(CUDAEventInternals* internals);
30-
AT_API void CUDAEvent_uncheckedFree(CUDAEventInternals* internals);
31-
AT_API cudaEvent_t CUDAEvent_event(CUDAEventInternals* internals);
32-
AT_API int64_t CUDAEvent_device(CUDAEventInternals* internals);
33-
34-
} // namespace detail
35-
36-
struct CUDAEvent {
25+
struct AT_CUDA_API CUDAEvent {
3726
// Constants
3827
static constexpr unsigned int DEFAULT_FLAGS = cudaEventDisableTiming;
3928

4029
// Constructors
41-
CUDAEvent(unsigned int flags = DEFAULT_FLAGS)
42-
: internals_(detail::CUDAEvent_create(flags)) {}
43-
44-
~CUDAEvent() { detail::CUDAEvent_uncheckedFree(internals_); }
45-
46-
CUDAEvent(const CUDAEvent& other) {
47-
detail::CUDAEvent_retain(other.internals_);
48-
internals_ = other.internals_;
30+
CUDAEvent(unsigned int flags = DEFAULT_FLAGS)
31+
: flags_{flags} { }
32+
33+
// Note: event destruction done on creating device to avoid creating a
34+
// CUDA context on other devices.
35+
~CUDAEvent() {
36+
try {
37+
if (is_created_) {
38+
at::DeviceGuard device_guard{(int)device_};
39+
cudaEventDestroy(event_);
40+
}
41+
} catch (...) { /* No throw */ }
4942
}
5043

51-
CUDAEvent(CUDAEvent&& other) {
52-
std::swap(internals_, other.internals_);
53-
}
44+
CUDAEvent(const CUDAEvent&) = delete;
45+
CUDAEvent& operator=(const CUDAEvent&) = delete;
5446

55-
CUDAEvent& operator=(CUDAEvent other) noexcept {
56-
std::swap(internals_, other.internals_);
47+
CUDAEvent(CUDAEvent&& other) { moveHelper(std::move(other)); }
48+
CUDAEvent& operator=(CUDAEvent&& other) {
49+
moveHelper(std::move(other));
5750
return *this;
5851
}
5952

60-
operator cudaEvent_t() const { return detail::CUDAEvent_event(internals_); }
53+
operator cudaEvent_t() const { return event(); }
6154

6255
// Less than operator (to allow use in sets)
6356
friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) {
64-
return left.internals_ < right.internals_;
57+
return left.event_ < right.event_;
6558
}
6659

67-
int64_t device() const { return detail::CUDAEvent_device(internals_); }
68-
cudaEvent_t event() const { return detail::CUDAEvent_event(internals_); }
69-
CUDAEventInternals* internals() const { return internals_; }
60+
bool isCreated() const { return is_created_; }
61+
int64_t device() const { return device_; }
62+
cudaEvent_t event() const { return event_; }
63+
64+
bool happened() const {
65+
return (was_recorded_ && cudaEventQuery(event_) == cudaSuccess);
66+
}
67+
68+
void record() { record(getCurrentCUDAStream()); }
69+
70+
void recordOnce(const CUDAStream& stream) {
71+
if (!was_recorded_) record(stream);
72+
}
73+
74+
void record(const CUDAStream& stream) {
75+
if (is_created_) {
76+
AT_ASSERT(device_ == stream.device());
77+
} else {
78+
create(stream.device());
79+
}
80+
81+
AT_CUDA_CHECK(cudaEventRecord(event_, stream));
82+
was_recorded_ = true;
83+
}
7084

71-
void record() const; // Record on the current stream
72-
void record(const CUDAStream& stream) const;
85+
void block (const CUDAStream& stream) {
86+
if (is_created_) {
87+
AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0));
88+
}
89+
}
90+
7391

7492
private:
75-
CUDAEventInternals* internals_;
93+
unsigned int flags_ = DEFAULT_FLAGS;
94+
bool is_created_ = false;
95+
bool was_recorded_ = false;
96+
int64_t device_ = -1;
97+
cudaEvent_t event_;
98+
99+
void moveHelper(CUDAEvent&& other) {
100+
std::swap(flags_, other.flags_);
101+
std::swap(is_created_, other.is_created_);
102+
std::swap(was_recorded_, other.was_recorded_);
103+
std::swap(device_, other.device_);
104+
std::swap(event_, other.event_);
105+
}
106+
107+
void create(const int64_t device) {
108+
at::DeviceGuard device_guard{(int)device};
109+
AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
110+
111+
is_created_ = true;
112+
device_ = device;
113+
}
76114
};
77115

78116
} // namespace cuda

aten/src/ATen/cuda/CUDAStream.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ int64_t CUDAStream_device(CUDAStreamInternals* ptr) {
209209
}
210210

211211
void CUDAStream_synchronize_with(CUDAStreamInternals* ptr, const CUDAEvent& event) {
212-
AT_CUDA_CHECK(cudaStreamWaitEvent(ptr->stream, event, 0));
212+
if (event.isCreated())
213+
AT_CUDA_CHECK(cudaStreamWaitEvent(ptr->stream, event, 0));
213214
}
214215

215216
} // namespace detail

aten/src/ATen/cuda/CUDAStream.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ AT_CUDA_API int64_t CUDAStream_device(CUDAStreamInternals*);
7777

7878
// RAII for a CUDA stream
7979
// Allows use as a cudaStream_t, copying, moving, and metadata access.
80-
struct CUDAStream {
80+
struct AT_CUDA_API CUDAStream {
8181

8282
// Constructors
8383
CUDAStream() = default;

aten/src/ATen/test/stream_test.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "ATen/cuda/CUDAContext.h"
55
#include "ATen/cuda/CUDAGuard.h"
6+
#include "ATen/cuda/CUDAEvent.h"
67

78
#include "cuda_runtime.h"
89

@@ -211,7 +212,6 @@ TEST_CASE("Streampool Round Robin") {
211212
REQUIRE(hasDuplicates);
212213
}
213214

214-
// Note: to be expanded once CUDAEvent PR is accepted
215215
TEST_CASE("Multi-GPU") {
216216
if (at::cuda::getNumGPUs() < 2) return;
217217

@@ -226,3 +226,44 @@ TEST_CASE("Multi-GPU") {
226226
at::DeviceGuard device_guard{1};
227227
REQUIRE(s1 == at::cuda::getCurrentCUDAStream());
228228
}
229+
230+
TEST_CASE("CUDAEvent Syncs") {
231+
const auto stream = at::cuda::createCUDAStream();
232+
at::cuda::CUDAEvent event;
233+
234+
REQUIRE(!event.happened());
235+
236+
event.recordOnce(stream);
237+
238+
const auto wait_stream0 = at::cuda::createCUDAStream();
239+
const auto wait_stream1 = at::cuda::createCUDAStream();
240+
241+
wait_stream0.synchronize_with(event);
242+
wait_stream1.synchronize_with(event);
243+
244+
cudaStreamSynchronize(wait_stream0);
245+
REQUIRE(event.happened());
246+
}
247+
248+
TEST_CASE("Cross-Device Events") {
249+
if (at::cuda::getNumGPUs() < 2) return;
250+
251+
const auto stream0 = at::cuda::createCUDAStream();
252+
at::cuda::CUDAEvent event0;
253+
254+
at::cuda::set_device(1);
255+
const auto stream1 = at::cuda::createCUDAStream();
256+
at::cuda::CUDAEvent event1;
257+
258+
event0.record(stream0);
259+
event1.record(stream1);
260+
261+
event0 = std::move(event1);
262+
263+
REQUIRE(event0.device() == 1);
264+
265+
stream0.synchronize_with(event0);
266+
267+
cudaStreamSynchronize(stream0);
268+
REQUIRE(event0.happened());
269+
}

0 commit comments

Comments
 (0)