Skip to content

Commit d290948

Browse files
guangyeypytorchmergebot
authored andcommitted
Use torch.Stream&torch.Event for Dynamo capature (#134850)
# Motivation This PR aims to solve the multiple Inheritance problem. Pull Request resolved: #134850 Approved by: https://github.com/yf225, https://github.com/EikanWang
1 parent bf73af4 commit d290948

File tree

7 files changed

+45
-80
lines changed

7 files changed

+45
-80
lines changed

test/inductor/extension_backends/triton/device_interface.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import time
44

5+
import torch
56
from torch._dynamo import device_interface # noqa: PLC2701 import-private-name
67

78

@@ -13,9 +14,7 @@ def __init__(self) -> None:
1314

1415

1516
class DeviceInterface(device_interface.DeviceInterface):
16-
class Event(
17-
device_interface._EventBase
18-
): # pyright: ignore [reportPrivateImportUsage]
17+
class Event(torch.Event):
1918
def __init__(
2019
self,
2120
enable_timing: bool = False,

torch/_dynamo/device_interface.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
# mypy: allow-untyped-defs
2-
import inspect
32
import time
43
from dataclasses import dataclass
54
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
65

76
import torch
8-
from torch._streambase import _EventBase, _StreamBase
97

108

119
get_cuda_stream: Optional[Callable[[int], int]]
@@ -21,21 +19,7 @@
2119
caching_worker_current_devices: Dict[str, int] = {}
2220

2321

24-
class DeviceInterfaceMeta(type):
25-
def __new__(metacls, *args, **kwargs):
26-
class_member = args[2]
27-
if "Event" in class_member:
28-
assert inspect.isclass(class_member["Event"]) and issubclass(
29-
class_member["Event"], _EventBase
30-
), "DeviceInterface member Event should be inherit from _EventBase"
31-
if "Stream" in class_member:
32-
assert inspect.isclass(class_member["Stream"]) and issubclass(
33-
class_member["Stream"], _StreamBase
34-
), "DeviceInterface member Stream should be inherit from _StreamBase"
35-
return super().__new__(metacls, *args, **kwargs)
36-
37-
38-
class DeviceInterface(metaclass=DeviceInterfaceMeta):
22+
class DeviceInterface:
3923
"""
4024
This is a simple device runtime interface for Inductor. It enables custom
4125
backends to be integrated with Inductor in a device-agnostic semantic.
@@ -45,6 +29,18 @@ class device:
4529
def __new__(cls, device: _device_t):
4630
raise NotImplementedError
4731

32+
class Event:
33+
def __new__(cls, *args, **kwargs):
34+
raise NotImplementedError(
35+
"Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo."
36+
)
37+
38+
class Stream:
39+
def __new__(cls, *args, **kwargs):
40+
raise NotImplementedError(
41+
"Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo."
42+
)
43+
4844
class Worker:
4945
"""
5046
Worker API to query device properties that will work in multi processing
@@ -161,7 +157,7 @@ class CudaInterface(DeviceInterface):
161157
device = torch.cuda.device
162158

163159
# register Event and Stream class into the backend interface
164-
# make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase
160+
# make sure Event and Stream are implemented and inherited from the torch.Event and torch.Stream
165161
Event = torch.cuda.Event
166162
Stream = torch.cuda.Stream
167163

@@ -303,14 +299,14 @@ class CpuDeviceProperties:
303299

304300

305301
class CpuInterface(DeviceInterface):
306-
class Event(_EventBase):
302+
class Event(torch.Event):
307303
def __init__(self, enable_timing=True):
308304
self.time = 0.0
309305

310306
def elapsed_time(self, end_event) -> float:
311307
return (end_event.time - self.time) * 1000
312308

313-
def record(self):
309+
def record(self, stream=None):
314310
self.time = time.perf_counter()
315311

316312
@staticmethod

torch/_dynamo/variables/builder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from torch._guards import GuardSource, TracingContext
3737
from torch._higher_order_ops.torchbind import call_torchbind
3838
from torch._ops import HigherOrderOperator
39-
from torch._streambase import _EventBase, _StreamBase
4039
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
4140
from torch._subclasses.meta_utils import is_sparse_any, safe_grad
4241
from torch._utils_internal import justknobs_check
@@ -822,7 +821,7 @@ def build_key_value(i, k, v):
822821
stream_source = AttrSource(self.source, "stream")
823822
stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
824823
return StreamContextVariable.create(self.tx, stream_var)
825-
elif isinstance(value, _StreamBase):
824+
elif isinstance(value, torch.Stream):
826825
self.install_guards(GuardBuilder.ID_MATCH)
827826
stream_proxy = self.tx.output.create_proxy(
828827
"call_function",
@@ -847,7 +846,7 @@ def build_key_value(i, k, v):
847846
elif isinstance(value, torch._C._SDPBackend):
848847
self.install_guards(GuardBuilder.ID_MATCH)
849848
return ConstantVariable(value)
850-
elif isinstance(value, _EventBase):
849+
elif isinstance(value, torch.Event):
851850
self.install_guards(GuardBuilder.ID_MATCH)
852851
torch._dynamo.utils.store_user_object_weakref(value)
853852
event_proxy = self.tx.output.create_proxy(
@@ -2265,15 +2264,16 @@ def _clone_input(value):
22652264
return SymNodeVariable(proxy, example_value, **options)
22662265
elif (
22672266
inspect.isclass(proxy.node.target)
2268-
and issubclass(proxy.node.target, _StreamBase)
2267+
and issubclass(proxy.node.target, torch.Stream)
22692268
) or proxy.node.target in [
22702269
device_interface.current_stream
22712270
for _, device_interface in get_registered_device_interfaces()
22722271
]:
22732272
set_example_value(proxy.node, example_value)
22742273
return StreamVariable(proxy, example_value, example_value.device, **options)
22752274
elif (
2276-
inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase)
2275+
inspect.isclass(proxy.node.target)
2276+
and issubclass(proxy.node.target, torch.Event)
22772277
) or proxy.node.target in [
22782278
device_interface.Event
22792279
for _, device_interface in get_registered_device_interfaces()
@@ -2285,7 +2285,7 @@ def _clone_input(value):
22852285
return ConstantVariable(example_value, **options)
22862286
elif (
22872287
example_value is not None
2288-
and isinstance(example_value, _EventBase)
2288+
and isinstance(example_value, torch.Event)
22892289
and proxy.node.target == "record_event"
22902290
and proxy.node.op == "call_method"
22912291
):

torch/_dynamo/variables/torch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torch.nn
1414
from torch._guards import TracingContext
1515
from torch._logging import warning_once
16-
from torch._streambase import _StreamBase
1716
from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
1817

1918
from .. import config, polyfills, variables
@@ -267,7 +266,7 @@ def call_function(
267266
assert len(args) <= 1 and len(kwargs) == 0
268267
inf_mode = args[0].as_python_constant() if len(args) == 1 else True
269268
return InferenceModeVariable.create(tx, inf_mode)
270-
elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase):
269+
elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream):
271270
from torch._dynamo.variables.builder import wrap_fx_proxy_cls
272271

273272
return wrap_fx_proxy_cls(

torch/_streambase.py

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,20 @@
1-
# mypy: allow-untyped-defs
2-
from abc import ABC, abstractmethod
1+
from typing_extensions import deprecated
32

3+
import torch
44

5-
class _StreamBase(ABC):
6-
r"""Base stream class abstraction for multi backends Stream to herit from"""
75

8-
@abstractmethod
9-
def wait_event(self, event) -> None:
10-
raise NotImplementedError
6+
# Preserved only for BC reasons
7+
@deprecated(
8+
"`torch._streambase._StreamBase` is deprecated. Please use `torch.Stream` instead.",
9+
category=FutureWarning,
10+
)
11+
class _StreamBase(torch.Stream):
12+
pass
1113

12-
@abstractmethod
13-
def wait_stream(self, stream) -> None:
14-
raise NotImplementedError
1514

16-
@abstractmethod
17-
def record_event(self, event=None) -> None:
18-
raise NotImplementedError
19-
20-
@abstractmethod
21-
def query(self) -> bool:
22-
raise NotImplementedError
23-
24-
@abstractmethod
25-
def synchronize(self) -> None:
26-
raise NotImplementedError
27-
28-
@abstractmethod
29-
def __eq__(self, stream) -> bool:
30-
raise NotImplementedError
31-
32-
33-
class _EventBase(ABC):
34-
r"""Base Event class abstraction for multi backends Event to herit from"""
35-
36-
@abstractmethod
37-
def wait(self, stream=None) -> None:
38-
raise NotImplementedError
39-
40-
@abstractmethod
41-
def query(self) -> bool:
42-
raise NotImplementedError
43-
44-
@abstractmethod
45-
def synchronize(self) -> None:
46-
raise NotImplementedError
15+
@deprecated(
16+
"`torch._streambase._EventBase` is deprecated. Please use `torch.Event` instead.",
17+
category=FutureWarning,
18+
)
19+
class _EventBase(torch.Event):
20+
pass

torch/cuda/streams.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import ctypes
33

44
import torch
5-
from torch._streambase import _EventBase, _StreamBase
65
from torch._utils import _dummy_type
76

87

@@ -12,7 +11,7 @@
1211
torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")
1312

1413

15-
class Stream(torch._C._CudaStreamBase, _StreamBase):
14+
class Stream(torch._C._CudaStreamBase):
1615
r"""Wrapper around a CUDA stream.
1716
1817
A CUDA stream is a linear sequence of execution that belongs to a specific
@@ -138,7 +137,7 @@ def __new__(cls, stream_ptr, device=None, **kwargs):
138137
return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)
139138

140139

141-
class Event(torch._C._CudaEventBase, _EventBase):
140+
class Event(torch._C._CudaEventBase):
142141
r"""Wrapper around a CUDA event.
143142
144143
CUDA events are synchronization markers that can be used to monitor the

torch/xpu/streams.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
import ctypes
33

44
import torch
5-
from torch._streambase import _EventBase, _StreamBase
6-
7-
from .._utils import _dummy_type
5+
from torch._utils import _dummy_type
86

97

108
if not hasattr(torch._C, "_XpuStreamBase"):
@@ -13,7 +11,7 @@
1311
torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase")
1412

1513

16-
class Stream(torch._C._XpuStreamBase, _StreamBase):
14+
class Stream(torch._C._XpuStreamBase):
1715
r"""Wrapper around a XPU stream.
1816
1917
A XPU stream is a linear sequence of execution that belongs to a specific
@@ -98,7 +96,7 @@ def __repr__(self):
9896
return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"
9997

10098

101-
class Event(torch._C._XpuEventBase, _EventBase):
99+
class Event(torch._C._XpuEventBase):
102100
r"""Wrapper around a XPU event.
103101
104102
XPU events are synchronization markers that can be used to monitor the

0 commit comments

Comments
 (0)