Skip to content

Commit df5bbc0

Browse files
guangyeypytorchmergebot
authored andcommitted
Make device-specific event inherits from torch.Event (#134845)
# Motivation This PR intends to make device-specific Event inherit from the generic torch.Event. The benefit is providing a generic abstract class `torch.Event` for different devices, like `torch.Stream`. This make it easier for Dynamo to capture the Event of different devices, like torch.cuda.Event and torch.xpu.Event. And the next PR would like to remove previous useless base class `_StreamBase` and `_EventBase` to avoid multiple Inheritance. Pull Request resolved: #134845 Approved by: https://github.com/albanD, https://github.com/EikanWang
1 parent 47a78da commit df5bbc0

File tree

8 files changed

+30
-8
lines changed

8 files changed

+30
-8
lines changed

test/test_cuda.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,9 @@ def test_generic_stream_event(self):
661661
device_index=stream.device_index,
662662
device_type=stream.device_type,
663663
)
664+
self.assertIsInstance(cuda_stream, torch.Stream)
665+
self.assertTrue(issubclass(type(cuda_stream), torch.Stream))
666+
self.assertTrue(torch.Stream in type(cuda_stream).mro())
664667
self.assertEqual(stream.stream_id, cuda_stream.stream_id)
665668
self.assertNotEqual(stream.stream_id, torch.cuda.current_stream().stream_id)
666669

@@ -683,6 +686,10 @@ def test_generic_stream_event(self):
683686
self.assertNotEqual(event1.event_id, event2.event_id)
684687
self.assertEqual(c_cuda.cpu(), a + b)
685688
self.assertTrue(event1.elapsed_time(event2) > 0)
689+
cuda_event = torch.cuda.Event()
690+
self.assertIsInstance(cuda_event, torch.Event)
691+
self.assertTrue(issubclass(type(cuda_event), torch.Event))
692+
self.assertTrue(torch.Event in type(cuda_event).mro())
686693

687694
def test_record_stream(self):
688695
cycles_per_ms = get_cycles_per_ms()

test/test_xpu.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,9 @@ def test_generic_stream_event(self):
237237
device_index=stream.device_index,
238238
device_type=stream.device_type,
239239
)
240+
self.assertIsInstance(xpu_stream, torch.Stream)
241+
self.assertTrue(issubclass(type(xpu_stream), torch.Stream))
242+
self.assertTrue(torch.Stream in type(xpu_stream).mro())
240243
self.assertEqual(stream.stream_id, xpu_stream.stream_id)
241244
self.assertNotEqual(stream.stream_id, torch.xpu.current_stream().stream_id)
242245

@@ -262,6 +265,10 @@ def test_generic_stream_event(self):
262265
NotImplementedError, "elapsedTime is not supported by XPU backend."
263266
):
264267
event1.elapsed_time(event2)
268+
xpu_event = torch.xpu.Event()
269+
self.assertIsInstance(xpu_event, torch.Event)
270+
self.assertTrue(issubclass(type(xpu_event), torch.Event))
271+
self.assertTrue(torch.Event in type(xpu_event).mro())
265272

266273
def test_generator(self):
267274
torch.manual_seed(2024)

torch/csrc/Event.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include <structmember.h>
1616
#include <string>
1717

18-
PyObject* THPEventClass = nullptr;
18+
PyTypeObject* THPEventClass = nullptr;
1919

2020
static PyObject* THPEvent_pynew(
2121
PyTypeObject* type,
@@ -316,7 +316,7 @@ PyTypeObject THPEventType = {
316316
};
317317

318318
void THPEvent_init(PyObject* module) {
319-
THPEventClass = (PyObject*)&THPEventType;
319+
THPEventClass = &THPEventType;
320320
if (PyType_Ready(&THPEventType) < 0) {
321321
throw python_error();
322322
}

torch/csrc/Event.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
struct TORCH_API THPEvent {
88
PyObject_HEAD c10::Event event;
99
};
10-
extern PyObject* THPEventClass;
10+
TORCH_API extern PyTypeObject* THPEventClass;
1111
TORCH_API extern PyTypeObject THPEventType;
1212

1313
TORCH_API void THPEvent_init(PyObject* module);
1414
TORCH_API PyObject* THPEvent_new(
1515
c10::DeviceType device_type,
1616
c10::EventFlag flag);
1717
inline bool THPEvent_Check(PyObject* obj) {
18-
return THPEventClass && PyObject_IsInstance(obj, THPEventClass);
18+
return THPEventClass && PyObject_IsInstance(obj, (PyObject*)THPEventClass);
1919
}
2020

2121
#endif // THP_EVENT_INC

torch/csrc/cuda/Event.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ PyTypeObject THCPEventType = {
240240
};
241241

242242
void THCPEvent_init(PyObject* module) {
243+
TORCH_CHECK(THPEventClass, "THPEvent has not been initialized yet.");
244+
Py_INCREF(THPEventClass);
245+
THCPEventType.tp_base = THPEventClass;
243246
THCPEventClass = (PyObject*)&THCPEventType;
244247
if (PyType_Ready(&THCPEventType) < 0) {
245248
throw python_error();

torch/csrc/cuda/Event.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
#define THCP_EVENT_INC
33

44
#include <ATen/cuda/CUDAEvent.h>
5+
#include <torch/csrc/Event.h>
56
#include <torch/csrc/python_headers.h>
67

7-
struct THCPEvent {
8-
PyObject_HEAD at::cuda::CUDAEvent cuda_event;
8+
struct THCPEvent : THPEvent {
9+
at::cuda::CUDAEvent cuda_event;
910
};
1011
extern PyObject* THCPEventClass;
1112

torch/csrc/xpu/Event.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ PyTypeObject THXPEventType = {
167167
};
168168

169169
void THXPEvent_init(PyObject* module) {
170+
TORCH_CHECK(THPEventClass, "THPEvent has not been initialized yet.");
171+
Py_INCREF(THPEventClass);
172+
THXPEventType.tp_base = THPEventClass;
170173
THXPEventClass = (PyObject*)&THXPEventType;
171174
if (PyType_Ready(&THXPEventType) < 0) {
172175
throw python_error();

torch/csrc/xpu/Event.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#pragma once
22

33
#include <ATen/xpu/XPUEvent.h>
4+
#include <torch/csrc/Event.h>
45
#include <torch/csrc/python_headers.h>
56

6-
struct THXPEvent {
7-
PyObject_HEAD at::xpu::XPUEvent xpu_event;
7+
struct THXPEvent : THPEvent {
8+
at::xpu::XPUEvent xpu_event;
89
};
910
extern PyObject* THXPEventClass;
1011

0 commit comments

Comments
 (0)