Skip to content
Closed
7 changes: 7 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,9 @@ def test_generic_stream_event(self):
device_index=stream.device_index,
device_type=stream.device_type,
)
self.assertIsInstance(cuda_stream, torch.Stream)
self.assertTrue(issubclass(type(cuda_stream), torch.Stream))
self.assertTrue(torch.Stream in type(cuda_stream).mro())
self.assertEqual(stream.stream_id, cuda_stream.stream_id)
self.assertNotEqual(stream.stream_id, torch.cuda.current_stream().stream_id)

Expand All @@ -683,6 +686,10 @@ def test_generic_stream_event(self):
self.assertNotEqual(event1.event_id, event2.event_id)
self.assertEqual(c_cuda.cpu(), a + b)
self.assertTrue(event1.elapsed_time(event2) > 0)
cuda_event = torch.cuda.Event()
self.assertIsInstance(cuda_event, torch.Event)
self.assertTrue(issubclass(type(cuda_event), torch.Event))
self.assertTrue(torch.Event in type(cuda_event).mro())

def test_record_stream(self):
cycles_per_ms = get_cycles_per_ms()
Expand Down
7 changes: 7 additions & 0 deletions test/test_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ def test_generic_stream_event(self):
device_index=stream.device_index,
device_type=stream.device_type,
)
self.assertIsInstance(xpu_stream, torch.Stream)
self.assertTrue(issubclass(type(xpu_stream), torch.Stream))
self.assertTrue(torch.Stream in type(xpu_stream).mro())
self.assertEqual(stream.stream_id, xpu_stream.stream_id)
self.assertNotEqual(stream.stream_id, torch.xpu.current_stream().stream_id)

Expand All @@ -262,6 +265,10 @@ def test_generic_stream_event(self):
NotImplementedError, "elapsedTime is not supported by XPU backend."
):
event1.elapsed_time(event2)
xpu_event = torch.xpu.Event()
self.assertIsInstance(xpu_event, torch.Event)
self.assertTrue(issubclass(type(xpu_event), torch.Event))
self.assertTrue(torch.Event in type(xpu_event).mro())

def test_generator(self):
torch.manual_seed(2024)
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/Event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <structmember.h>
#include <string>

PyObject* THPEventClass = nullptr;
PyTypeObject* THPEventClass = nullptr;

static PyObject* THPEvent_pynew(
PyTypeObject* type,
Expand Down Expand Up @@ -316,7 +316,7 @@ PyTypeObject THPEventType = {
};

void THPEvent_init(PyObject* module) {
THPEventClass = (PyObject*)&THPEventType;
THPEventClass = &THPEventType;
if (PyType_Ready(&THPEventType) < 0) {
throw python_error();
}
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/Event.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
struct TORCH_API THPEvent {
PyObject_HEAD c10::Event event;
};
extern PyObject* THPEventClass;
TORCH_API extern PyTypeObject* THPEventClass;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You add this to our public API because you need it on a third party c++ code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think the out-of-tree backend needs this to support linear inheritance.

TORCH_API extern PyTypeObject THPEventType;

TORCH_API void THPEvent_init(PyObject* module);
TORCH_API PyObject* THPEvent_new(
c10::DeviceType device_type,
c10::EventFlag flag);
inline bool THPEvent_Check(PyObject* obj) {
return THPEventClass && PyObject_IsInstance(obj, THPEventClass);
return THPEventClass && PyObject_IsInstance(obj, (PyObject*)THPEventClass);
}

#endif // THP_EVENT_INC
3 changes: 3 additions & 0 deletions torch/csrc/cuda/Event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ PyTypeObject THCPEventType = {
};

void THCPEvent_init(PyObject* module) {
TORCH_CHECK(THPEventClass, "THPEvent has not been initialized yet.");
Py_INCREF(THPEventClass);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please assert that it is non-null here. In case someone changes the init order and this has not been set yet.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!
Updated.

THCPEventType.tp_base = THPEventClass;
THCPEventClass = (PyObject*)&THCPEventType;
if (PyType_Ready(&THCPEventType) < 0) {
throw python_error();
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/cuda/Event.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
#define THCP_EVENT_INC

#include <ATen/cuda/CUDAEvent.h>
#include <torch/csrc/Event.h>
#include <torch/csrc/python_headers.h>

struct THCPEvent {
PyObject_HEAD at::cuda::CUDAEvent cuda_event;
struct THCPEvent : THPEvent {
at::cuda::CUDAEvent cuda_event;
};
extern PyObject* THCPEventClass;

Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/xpu/Event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ PyTypeObject THXPEventType = {
};

void THXPEvent_init(PyObject* module) {
TORCH_CHECK(THPEventClass, "THPEvent has not been initialized yet.");
Py_INCREF(THPEventClass);
THXPEventType.tp_base = THPEventClass;
THXPEventClass = (PyObject*)&THXPEventType;
if (PyType_Ready(&THXPEventType) < 0) {
throw python_error();
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/xpu/Event.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#pragma once

#include <ATen/xpu/XPUEvent.h>
#include <torch/csrc/Event.h>
#include <torch/csrc/python_headers.h>

struct THXPEvent {
PyObject_HEAD at::xpu::XPUEvent xpu_event;
struct THXPEvent : THPEvent {
at::xpu::XPUEvent xpu_event;
};
extern PyObject* THXPEventClass;

Expand Down