File tree Expand file tree Collapse file tree 8 files changed +30
-8
lines changed
Expand file tree Collapse file tree 8 files changed +30
-8
lines changed Original file line number Diff line number Diff line change @@ -661,6 +661,9 @@ def test_generic_stream_event(self):
661661 device_index = stream .device_index ,
662662 device_type = stream .device_type ,
663663 )
664+ self .assertIsInstance (cuda_stream , torch .Stream )
665+ self .assertTrue (issubclass (type (cuda_stream ), torch .Stream ))
666+ self .assertTrue (torch .Stream in type (cuda_stream ).mro ())
664667 self .assertEqual (stream .stream_id , cuda_stream .stream_id )
665668 self .assertNotEqual (stream .stream_id , torch .cuda .current_stream ().stream_id )
666669
@@ -683,6 +686,10 @@ def test_generic_stream_event(self):
683686 self .assertNotEqual (event1 .event_id , event2 .event_id )
684687 self .assertEqual (c_cuda .cpu (), a + b )
685688 self .assertTrue (event1 .elapsed_time (event2 ) > 0 )
689+ cuda_event = torch .cuda .Event ()
690+ self .assertIsInstance (cuda_event , torch .Event )
691+ self .assertTrue (issubclass (type (cuda_event ), torch .Event ))
692+ self .assertTrue (torch .Event in type (cuda_event ).mro ())
686693
687694 def test_record_stream (self ):
688695 cycles_per_ms = get_cycles_per_ms ()
Original file line number Diff line number Diff line change @@ -237,6 +237,9 @@ def test_generic_stream_event(self):
237237 device_index = stream .device_index ,
238238 device_type = stream .device_type ,
239239 )
240+ self .assertIsInstance (xpu_stream , torch .Stream )
241+ self .assertTrue (issubclass (type (xpu_stream ), torch .Stream ))
242+ self .assertTrue (torch .Stream in type (xpu_stream ).mro ())
240243 self .assertEqual (stream .stream_id , xpu_stream .stream_id )
241244 self .assertNotEqual (stream .stream_id , torch .xpu .current_stream ().stream_id )
242245
@@ -262,6 +265,10 @@ def test_generic_stream_event(self):
262265 NotImplementedError , "elapsedTime is not supported by XPU backend."
263266 ):
264267 event1 .elapsed_time (event2 )
268+ xpu_event = torch .xpu .Event ()
269+ self .assertIsInstance (xpu_event , torch .Event )
270+ self .assertTrue (issubclass (type (xpu_event ), torch .Event ))
271+ self .assertTrue (torch .Event in type (xpu_event ).mro ())
265272
266273 def test_generator (self ):
267274 torch .manual_seed (2024 )
Original file line number Diff line number Diff line change 1515#include < structmember.h>
1616#include < string>
1717
18- PyObject * THPEventClass = nullptr ;
18+ PyTypeObject * THPEventClass = nullptr ;
1919
2020static PyObject* THPEvent_pynew (
2121 PyTypeObject* type,
@@ -316,7 +316,7 @@ PyTypeObject THPEventType = {
316316};
317317
318318void THPEvent_init (PyObject* module ) {
319- THPEventClass = (PyObject*) &THPEventType;
319+ THPEventClass = &THPEventType;
320320 if (PyType_Ready (&THPEventType) < 0 ) {
321321 throw python_error ();
322322 }
Original file line number Diff line number Diff line change 77struct TORCH_API THPEvent {
88 PyObject_HEAD c10 ::Event event ;
99};
10- extern PyObject * THPEventClass ;
10+ TORCH_API extern PyTypeObject * THPEventClass ;
1111TORCH_API extern PyTypeObject THPEventType ;
1212
1313TORCH_API void THPEvent_init (PyObject * module );
1414TORCH_API PyObject * THPEvent_new (
1515 c10 ::DeviceType device_type ,
1616 c10 ::EventFlag flag );
1717inline bool THPEvent_Check (PyObject * obj ) {
18- return THPEventClass && PyObject_IsInstance (obj , THPEventClass );
18+ return THPEventClass && PyObject_IsInstance (obj , ( PyObject * ) THPEventClass );
1919}
2020
2121#endif // THP_EVENT_INC
Original file line number Diff line number Diff line change @@ -240,6 +240,9 @@ PyTypeObject THCPEventType = {
240240};
241241
242242void THCPEvent_init (PyObject* module ) {
243+ TORCH_CHECK (THPEventClass, " THPEvent has not been initialized yet." );
244+ Py_INCREF (THPEventClass);
245+ THCPEventType.tp_base = THPEventClass;
243246 THCPEventClass = (PyObject*)&THCPEventType;
244247 if (PyType_Ready (&THCPEventType) < 0 ) {
245248 throw python_error ();
Original file line number Diff line number Diff line change 22#define THCP_EVENT_INC
33
44#include <ATen/cuda/CUDAEvent.h>
5+ #include <torch/csrc/Event.h>
56#include <torch/csrc/python_headers.h>
67
7- struct THCPEvent {
8- PyObject_HEAD at ::cuda ::CUDAEvent cuda_event ;
8+ struct THCPEvent : THPEvent {
9+ at ::cuda ::CUDAEvent cuda_event ;
910};
1011extern PyObject * THCPEventClass ;
1112
Original file line number Diff line number Diff line change @@ -167,6 +167,9 @@ PyTypeObject THXPEventType = {
167167};
168168
169169void THXPEvent_init (PyObject* module ) {
170+ TORCH_CHECK (THPEventClass, " THPEvent has not been initialized yet." );
171+ Py_INCREF (THPEventClass);
172+ THXPEventType.tp_base = THPEventClass;
170173 THXPEventClass = (PyObject*)&THXPEventType;
171174 if (PyType_Ready (&THXPEventType) < 0 ) {
172175 throw python_error ();
Original file line number Diff line number Diff line change 11#pragma once
22
33#include <ATen/xpu/XPUEvent.h>
4+ #include <torch/csrc/Event.h>
45#include <torch/csrc/python_headers.h>
56
6- struct THXPEvent {
7- PyObject_HEAD at ::xpu ::XPUEvent xpu_event ;
7+ struct THXPEvent : THPEvent {
8+ at ::xpu ::XPUEvent xpu_event ;
89};
910extern PyObject * THXPEventClass ;
1011
You can’t perform that action at this time.
0 commit comments