Skip to content

Commit 06b1045

Browse files
committed
remove useless _StreamBase&_EventBase
ghstack-source-id: ea24873 Pull Request resolved: #134850
1 parent 1a7e696 commit 06b1045

File tree

7 files changed

+43
-78
lines changed

7 files changed

+43
-78
lines changed

test/inductor/extension_backends/triton/device_interface.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import time
44

5+
import torch
56
from torch._dynamo import device_interface # noqa: PLC2701 import-private-name
67

78

@@ -13,9 +14,7 @@ def __init__(self) -> None:
1314

1415

1516
class DeviceInterface(device_interface.DeviceInterface):
16-
class Event(
17-
device_interface._EventBase
18-
): # pyright: ignore [reportPrivateImportUsage]
17+
class Event(torch.Event):
1918
def __init__(
2019
self,
2120
enable_timing: bool = False,

torch/_dynamo/device_interface.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# mypy: allow-untyped-defs
2-
import inspect
32
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
43

54
import torch
6-
from torch._streambase import _EventBase, _StreamBase
75

86

97
get_cuda_stream: Optional[Callable[[int], int]]
@@ -19,21 +17,7 @@
1917
caching_worker_current_devices: Dict[str, int] = {}
2018

2119

22-
class DeviceInterfaceMeta(type):
23-
def __new__(metacls, *args, **kwargs):
24-
class_member = args[2]
25-
if "Event" in class_member:
26-
assert inspect.isclass(class_member["Event"]) and issubclass(
27-
class_member["Event"], _EventBase
28-
), "DeviceInterface member Event should be inherit from _EventBase"
29-
if "Stream" in class_member:
30-
assert inspect.isclass(class_member["Stream"]) and issubclass(
31-
class_member["Stream"], _StreamBase
32-
), "DeviceInterface member Stream should be inherit from _StreamBase"
33-
return super().__new__(metacls, *args, **kwargs)
34-
35-
36-
class DeviceInterface(metaclass=DeviceInterfaceMeta):
20+
class DeviceInterface:
3721
"""
3822
This is a simple device runtime interface for Inductor. It enables custom
3923
backends to be integrated with Inductor in a device-agnostic semantic.
@@ -43,6 +27,18 @@ class device:
4327
def __new__(cls, device: _device_t):
4428
raise NotImplementedError
4529

30+
class Event:
31+
def __new__(cls, *args, **kwargs):
32+
raise NotImplementedError(
33+
"Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo."
34+
)
35+
36+
class Stream:
37+
def __new__(cls, *args, **kwargs):
38+
raise NotImplementedError(
39+
"Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo."
40+
)
41+
4642
class Worker:
4743
"""
4844
Worker API to query device properties that will work in multi processing
@@ -159,7 +155,7 @@ class CudaInterface(DeviceInterface):
159155
device = torch.cuda.device
160156

161157
# register Event and Stream class into the backend interface
162-
# make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase
158+
# make sure Event and Stream are implemented and inherited from the torch.Event and torch.Stream
163159
Event = torch.cuda.Event
164160
Stream = torch.cuda.Stream
165161

torch/_dynamo/variables/builder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from torch._guards import GuardSource, TracingContext
3737
from torch._higher_order_ops.torchbind import call_torchbind
3838
from torch._ops import HigherOrderOperator
39-
from torch._streambase import _EventBase, _StreamBase
4039
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
4140
from torch._subclasses.meta_utils import is_sparse_any, safe_grad
4241
from torch._utils_internal import justknobs_check
@@ -823,7 +822,7 @@ def build_key_value(i, k, v):
823822
stream_source = AttrSource(self.source, "stream")
824823
stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
825824
return StreamContextVariable.create(self.tx, stream_var)
826-
elif isinstance(value, _StreamBase):
825+
elif isinstance(value, torch.Stream):
827826
self.install_guards(GuardBuilder.ID_MATCH)
828827
stream_proxy = self.tx.output.create_proxy(
829828
"call_function",
@@ -848,7 +847,7 @@ def build_key_value(i, k, v):
848847
elif isinstance(value, torch._C._SDPBackend):
849848
self.install_guards(GuardBuilder.ID_MATCH)
850849
return ConstantVariable(value)
851-
elif isinstance(value, _EventBase):
850+
elif isinstance(value, torch.Event):
852851
self.install_guards(GuardBuilder.ID_MATCH)
853852
torch._dynamo.utils.store_user_object_weakref(value)
854853
event_proxy = self.tx.output.create_proxy(
@@ -2267,15 +2266,16 @@ def _clone_input(value):
22672266
return SymNodeVariable(proxy, example_value, **options)
22682267
elif (
22692268
inspect.isclass(proxy.node.target)
2270-
and issubclass(proxy.node.target, _StreamBase)
2269+
and issubclass(proxy.node.target, torch.Stream)
22712270
) or proxy.node.target in [
22722271
device_interface.current_stream
22732272
for _, device_interface in get_registered_device_interfaces()
22742273
]:
22752274
set_example_value(proxy.node, example_value)
22762275
return StreamVariable(proxy, example_value, example_value.device, **options)
22772276
elif (
2278-
inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase)
2277+
inspect.isclass(proxy.node.target)
2278+
and issubclass(proxy.node.target, torch.Event)
22792279
) or proxy.node.target in [
22802280
device_interface.Event
22812281
for _, device_interface in get_registered_device_interfaces()
@@ -2287,7 +2287,7 @@ def _clone_input(value):
22872287
return ConstantVariable(example_value, **options)
22882288
elif (
22892289
example_value is not None
2290-
and isinstance(example_value, _EventBase)
2290+
and isinstance(example_value, torch.Event)
22912291
and proxy.node.target == "record_event"
22922292
and proxy.node.op == "call_method"
22932293
):

torch/_dynamo/variables/torch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torch.nn
1414
from torch._guards import TracingContext
1515
from torch._logging import warning_once
16-
from torch._streambase import _StreamBase
1716
from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
1817

1918
from .. import config, polyfills, variables
@@ -267,7 +266,7 @@ def call_function(
267266
assert len(args) <= 1 and len(kwargs) == 0
268267
inf_mode = args[0].as_python_constant() if len(args) == 1 else True
269268
return InferenceModeVariable.create(tx, inf_mode)
270-
elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase):
269+
elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream):
271270
from torch._dynamo.variables.builder import wrap_fx_proxy_cls
272271

273272
return wrap_fx_proxy_cls(

torch/_streambase.py

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,20 @@
1-
# mypy: allow-untyped-defs
2-
from abc import ABC, abstractmethod
1+
from typing_extensions import deprecated
32

3+
import torch
44

5-
class _StreamBase(ABC):
6-
r"""Base stream class abstraction for multi backends Stream to herit from"""
75

8-
@abstractmethod
9-
def wait_event(self, event) -> None:
10-
raise NotImplementedError
6+
# Preserved only for BC reasons
7+
@deprecated(
8+
"`torch._streambase._StreamBase` is deprecated. Please use `torch.Stream` instead.",
9+
category=FutureWarning,
10+
)
11+
class _StreamBase(torch.Stream):
12+
pass
1113

12-
@abstractmethod
13-
def wait_stream(self, stream) -> None:
14-
raise NotImplementedError
1514

16-
@abstractmethod
17-
def record_event(self, event=None) -> None:
18-
raise NotImplementedError
19-
20-
@abstractmethod
21-
def query(self) -> bool:
22-
raise NotImplementedError
23-
24-
@abstractmethod
25-
def synchronize(self) -> None:
26-
raise NotImplementedError
27-
28-
@abstractmethod
29-
def __eq__(self, stream) -> bool:
30-
raise NotImplementedError
31-
32-
33-
class _EventBase(ABC):
34-
r"""Base Event class abstraction for multi backends Event to herit from"""
35-
36-
@abstractmethod
37-
def wait(self, stream=None) -> None:
38-
raise NotImplementedError
39-
40-
@abstractmethod
41-
def query(self) -> bool:
42-
raise NotImplementedError
43-
44-
@abstractmethod
45-
def synchronize(self) -> None:
46-
raise NotImplementedError
15+
@deprecated(
16+
"`torch._streambase._EventBase` is deprecated. Please use `torch.Event` instead.",
17+
category=FutureWarning,
18+
)
19+
class _EventBase(torch.Event):
20+
pass

torch/cuda/streams.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import ctypes
33

44
import torch
5-
from torch._streambase import _EventBase, _StreamBase
65
from torch._utils import _dummy_type
76

87

@@ -12,7 +11,7 @@
1211
torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")
1312

1413

15-
class Stream(torch._C._CudaStreamBase, _StreamBase):
14+
class Stream(torch._C._CudaStreamBase):
1615
r"""Wrapper around a CUDA stream.
1716
1817
A CUDA stream is a linear sequence of execution that belongs to a specific
@@ -138,7 +137,7 @@ def __new__(cls, stream_ptr, device=None, **kwargs):
138137
return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)
139138

140139

141-
class Event(torch._C._CudaEventBase, _EventBase):
140+
class Event(torch._C._CudaEventBase):
142141
r"""Wrapper around a CUDA event.
143142
144143
CUDA events are synchronization markers that can be used to monitor the

torch/xpu/streams.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
import ctypes
33

44
import torch
5-
from torch._streambase import _EventBase, _StreamBase
6-
7-
from .._utils import _dummy_type
5+
from torch._utils import _dummy_type
86

97

108
if not hasattr(torch._C, "_XpuStreamBase"):
@@ -13,7 +11,7 @@
1311
torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase")
1412

1513

16-
class Stream(torch._C._XpuStreamBase, _StreamBase):
14+
class Stream(torch._C._XpuStreamBase):
1715
r"""Wrapper around a XPU stream.
1816
1917
A XPU stream is a linear sequence of execution that belongs to a specific
@@ -98,7 +96,7 @@ def __repr__(self):
9896
return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"
9997

10098

101-
class Event(torch._C._XpuEventBase, _EventBase):
99+
class Event(torch._C._XpuEventBase):
102100
r"""Wrapper around a XPU event.
103101
104102
XPU events are synchronization markers that can be used to monitor the

0 commit comments

Comments
 (0)