3636from torch ._guards import GuardSource , TracingContext
3737from torch ._higher_order_ops .torchbind import call_torchbind
3838from torch ._ops import HigherOrderOperator
39- from torch ._streambase import _EventBase , _StreamBase
4039from torch ._subclasses .fake_tensor import FakeTensor , is_fake , maybe_get_fake_mode
4140from torch ._subclasses .meta_utils import is_sparse_any , safe_grad
4241from torch ._utils_internal import justknobs_check
@@ -822,7 +821,7 @@ def build_key_value(i, k, v):
822821 stream_source = AttrSource (self .source , "stream" )
823822 stream_var = VariableBuilder (self .tx , stream_source )(value .stream )
824823 return StreamContextVariable .create (self .tx , stream_var )
825- elif isinstance (value , _StreamBase ):
824+ elif isinstance (value , torch . Stream ):
826825 self .install_guards (GuardBuilder .ID_MATCH )
827826 stream_proxy = self .tx .output .create_proxy (
828827 "call_function" ,
@@ -847,7 +846,7 @@ def build_key_value(i, k, v):
847846 elif isinstance (value , torch ._C ._SDPBackend ):
848847 self .install_guards (GuardBuilder .ID_MATCH )
849848 return ConstantVariable (value )
850- elif isinstance (value , _EventBase ):
849+ elif isinstance (value , torch . Event ):
851850 self .install_guards (GuardBuilder .ID_MATCH )
852851 torch ._dynamo .utils .store_user_object_weakref (value )
853852 event_proxy = self .tx .output .create_proxy (
@@ -2265,15 +2264,16 @@ def _clone_input(value):
22652264 return SymNodeVariable (proxy , example_value , ** options )
22662265 elif (
22672266 inspect .isclass (proxy .node .target )
2268- and issubclass (proxy .node .target , _StreamBase )
2267+ and issubclass (proxy .node .target , torch . Stream )
22692268 ) or proxy .node .target in [
22702269 device_interface .current_stream
22712270 for _ , device_interface in get_registered_device_interfaces ()
22722271 ]:
22732272 set_example_value (proxy .node , example_value )
22742273 return StreamVariable (proxy , example_value , example_value .device , ** options )
22752274 elif (
2276- inspect .isclass (proxy .node .target ) and issubclass (proxy .node .target , _EventBase )
2275+ inspect .isclass (proxy .node .target )
2276+ and issubclass (proxy .node .target , torch .Event )
22772277 ) or proxy .node .target in [
22782278 device_interface .Event
22792279 for _ , device_interface in get_registered_device_interfaces ()
@@ -2285,7 +2285,7 @@ def _clone_input(value):
22852285 return ConstantVariable (example_value , ** options )
22862286 elif (
22872287 example_value is not None
2288- and isinstance (example_value , _EventBase )
2288+ and isinstance (example_value , torch . Event )
22892289 and proxy .node .target == "record_event"
22902290 and proxy .node .op == "call_method"
22912291 ):
0 commit comments