Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions test/inductor/extension_backends/triton/device_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import time

import torch
from torch._dynamo import device_interface # noqa: PLC2701 import-private-name


Expand All @@ -13,9 +14,7 @@ def __init__(self) -> None:


class DeviceInterface(device_interface.DeviceInterface):
class Event(
device_interface._EventBase
): # pyright: ignore [reportPrivateImportUsage]
class Event(torch.Event):
def __init__(
self,
enable_timing: bool = False,
Expand Down
36 changes: 16 additions & 20 deletions torch/_dynamo/device_interface.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# mypy: allow-untyped-defs
import inspect
import time
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union

import torch
from torch._streambase import _EventBase, _StreamBase


get_cuda_stream: Optional[Callable[[int], int]]
Expand All @@ -21,21 +19,7 @@
caching_worker_current_devices: Dict[str, int] = {}


class DeviceInterfaceMeta(type):
def __new__(metacls, *args, **kwargs):
class_member = args[2]
if "Event" in class_member:
assert inspect.isclass(class_member["Event"]) and issubclass(
class_member["Event"], _EventBase
), "DeviceInterface member Event should be inherit from _EventBase"
if "Stream" in class_member:
assert inspect.isclass(class_member["Stream"]) and issubclass(
class_member["Stream"], _StreamBase
), "DeviceInterface member Stream should be inherit from _StreamBase"
return super().__new__(metacls, *args, **kwargs)


class DeviceInterface(metaclass=DeviceInterfaceMeta):
class DeviceInterface:
"""
This is a simple device runtime interface for Inductor. It enables custom
backends to be integrated with Inductor in a device-agnostic semantic.
Expand All @@ -45,6 +29,18 @@ class device:
def __new__(cls, device: _device_t):
raise NotImplementedError

class Event:
def __new__(cls, *args, **kwargs):
raise NotImplementedError(
"Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo."
)

class Stream:
def __new__(cls, *args, **kwargs):
raise NotImplementedError(
"Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo."
)

class Worker:
"""
Worker API to query device properties that will work in multi processing
Expand Down Expand Up @@ -161,7 +157,7 @@ class CudaInterface(DeviceInterface):
device = torch.cuda.device

# register Event and Stream class into the backend interface
# make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase
# make sure Event and Stream are implemented and inherited from the torch.Event and torch.Stream
Event = torch.cuda.Event
Stream = torch.cuda.Stream

Expand Down Expand Up @@ -303,14 +299,14 @@ class CpuDeviceProperties:


class CpuInterface(DeviceInterface):
class Event(_EventBase):
class Event(torch.Event):
def __init__(self, enable_timing=True):
self.time = 0.0

def elapsed_time(self, end_event) -> float:
return (end_event.time - self.time) * 1000

def record(self):
def record(self, stream=None):
self.time = time.perf_counter()

@staticmethod
Expand Down
12 changes: 6 additions & 6 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from torch._guards import GuardSource, TracingContext
from torch._higher_order_ops.torchbind import call_torchbind
from torch._ops import HigherOrderOperator
from torch._streambase import _EventBase, _StreamBase
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
from torch._subclasses.meta_utils import is_sparse_any, safe_grad
from torch._utils_internal import justknobs_check
Expand Down Expand Up @@ -822,7 +821,7 @@ def build_key_value(i, k, v):
stream_source = AttrSource(self.source, "stream")
stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
return StreamContextVariable.create(self.tx, stream_var)
elif isinstance(value, _StreamBase):
elif isinstance(value, torch.Stream):
self.install_guards(GuardBuilder.ID_MATCH)
stream_proxy = self.tx.output.create_proxy(
"call_function",
Expand All @@ -847,7 +846,7 @@ def build_key_value(i, k, v):
elif isinstance(value, torch._C._SDPBackend):
self.install_guards(GuardBuilder.ID_MATCH)
return ConstantVariable(value)
elif isinstance(value, _EventBase):
elif isinstance(value, torch.Event):
self.install_guards(GuardBuilder.ID_MATCH)
torch._dynamo.utils.store_user_object_weakref(value)
event_proxy = self.tx.output.create_proxy(
Expand Down Expand Up @@ -2265,15 +2264,16 @@ def _clone_input(value):
return SymNodeVariable(proxy, example_value, **options)
elif (
inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, _StreamBase)
and issubclass(proxy.node.target, torch.Stream)
) or proxy.node.target in [
device_interface.current_stream
for _, device_interface in get_registered_device_interfaces()
]:
set_example_value(proxy.node, example_value)
return StreamVariable(proxy, example_value, example_value.device, **options)
elif (
inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase)
inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, torch.Event)
) or proxy.node.target in [
device_interface.Event
for _, device_interface in get_registered_device_interfaces()
Expand All @@ -2285,7 +2285,7 @@ def _clone_input(value):
return ConstantVariable(example_value, **options)
elif (
example_value is not None
and isinstance(example_value, _EventBase)
and isinstance(example_value, torch.Event)
and proxy.node.target == "record_event"
and proxy.node.op == "call_method"
):
Expand Down
3 changes: 1 addition & 2 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch.nn
from torch._guards import TracingContext
from torch._logging import warning_once
from torch._streambase import _StreamBase
from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type

from .. import config, polyfills, variables
Expand Down Expand Up @@ -267,7 +266,7 @@ def call_function(
assert len(args) <= 1 and len(kwargs) == 0
inf_mode = args[0].as_python_constant() if len(args) == 1 else True
return InferenceModeVariable.create(tx, inf_mode)
elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase):
elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream):
from torch._dynamo.variables.builder import wrap_fx_proxy_cls

return wrap_fx_proxy_cls(
Expand Down
56 changes: 15 additions & 41 deletions torch/_streambase.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,20 @@
# mypy: allow-untyped-defs
from abc import ABC, abstractmethod
from typing_extensions import deprecated

import torch

class _StreamBase(ABC):
r"""Base stream class abstraction for multi backends Stream to herit from"""

@abstractmethod
def wait_event(self, event) -> None:
raise NotImplementedError
# Preserved only for BC reasons
@deprecated(
"`torch._streambase._StreamBase` is deprecated. Please use `torch.Stream` instead.",
category=FutureWarning,
)
class _StreamBase(torch.Stream):
pass

@abstractmethod
def wait_stream(self, stream) -> None:
raise NotImplementedError

@abstractmethod
def record_event(self, event=None) -> None:
raise NotImplementedError

@abstractmethod
def query(self) -> bool:
raise NotImplementedError

@abstractmethod
def synchronize(self) -> None:
raise NotImplementedError

@abstractmethod
def __eq__(self, stream) -> bool:
raise NotImplementedError


class _EventBase(ABC):
r"""Base Event class abstraction for multi backends Event to herit from"""

@abstractmethod
def wait(self, stream=None) -> None:
raise NotImplementedError

@abstractmethod
def query(self) -> bool:
raise NotImplementedError

@abstractmethod
def synchronize(self) -> None:
raise NotImplementedError
@deprecated(
"`torch._streambase._EventBase` is deprecated. Please use `torch.Event` instead.",
category=FutureWarning,
)
class _EventBase(torch.Event):
pass
5 changes: 2 additions & 3 deletions torch/cuda/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import ctypes

import torch
from torch._streambase import _EventBase, _StreamBase
from torch._utils import _dummy_type


Expand All @@ -12,7 +11,7 @@
torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")


class Stream(torch._C._CudaStreamBase, _StreamBase):
class Stream(torch._C._CudaStreamBase):
r"""Wrapper around a CUDA stream.

A CUDA stream is a linear sequence of execution that belongs to a specific
Expand Down Expand Up @@ -138,7 +137,7 @@ def __new__(cls, stream_ptr, device=None, **kwargs):
return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)


class Event(torch._C._CudaEventBase, _EventBase):
class Event(torch._C._CudaEventBase):
r"""Wrapper around a CUDA event.

CUDA events are synchronization markers that can be used to monitor the
Expand Down
8 changes: 3 additions & 5 deletions torch/xpu/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import ctypes

import torch
from torch._streambase import _EventBase, _StreamBase

from .._utils import _dummy_type
from torch._utils import _dummy_type


if not hasattr(torch._C, "_XpuStreamBase"):
Expand All @@ -13,7 +11,7 @@
torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase")


class Stream(torch._C._XpuStreamBase, _StreamBase):
class Stream(torch._C._XpuStreamBase):
r"""Wrapper around a XPU stream.

A XPU stream is a linear sequence of execution that belongs to a specific
Expand Down Expand Up @@ -98,7 +96,7 @@ def __repr__(self):
return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"


class Event(torch._C._XpuEventBase, _EventBase):
class Event(torch._C._XpuEventBase):
r"""Wrapper around a XPU event.

XPU events are synchronization markers that can be used to monitor the
Expand Down