3636from torch ._guards import GuardSource , TracingContext
3737from torch ._higher_order_ops .torchbind import call_torchbind
3838from torch ._ops import HigherOrderOperator
39- from torch ._streambase import _EventBase , _StreamBase
4039from torch ._subclasses .fake_tensor import FakeTensor , is_fake , maybe_get_fake_mode
4140from torch ._subclasses .meta_utils import is_sparse_any , safe_grad
4241from torch ._utils_internal import justknobs_check
@@ -823,7 +822,7 @@ def build_key_value(i, k, v):
823822 stream_source = AttrSource (self .source , "stream" )
824823 stream_var = VariableBuilder (self .tx , stream_source )(value .stream )
825824 return StreamContextVariable .create (self .tx , stream_var )
826- elif isinstance (value , _StreamBase ):
825+ elif isinstance (value , torch . Stream ):
827826 self .install_guards (GuardBuilder .ID_MATCH )
828827 stream_proxy = self .tx .output .create_proxy (
829828 "call_function" ,
@@ -848,7 +847,7 @@ def build_key_value(i, k, v):
848847 elif isinstance (value , torch ._C ._SDPBackend ):
849848 self .install_guards (GuardBuilder .ID_MATCH )
850849 return ConstantVariable (value )
851- elif isinstance (value , _EventBase ):
850+ elif isinstance (value , torch . Event ):
852851 self .install_guards (GuardBuilder .ID_MATCH )
853852 torch ._dynamo .utils .store_user_object_weakref (value )
854853 event_proxy = self .tx .output .create_proxy (
@@ -2267,15 +2266,16 @@ def _clone_input(value):
22672266 return SymNodeVariable (proxy , example_value , ** options )
22682267 elif (
22692268 inspect .isclass (proxy .node .target )
2270- and issubclass (proxy .node .target , _StreamBase )
2269+ and issubclass (proxy .node .target , torch . Stream )
22712270 ) or proxy .node .target in [
22722271 device_interface .current_stream
22732272 for _ , device_interface in get_registered_device_interfaces ()
22742273 ]:
22752274 set_example_value (proxy .node , example_value )
22762275 return StreamVariable (proxy , example_value , example_value .device , ** options )
22772276 elif (
2278- inspect .isclass (proxy .node .target ) and issubclass (proxy .node .target , _EventBase )
2277+ inspect .isclass (proxy .node .target )
2278+ and issubclass (proxy .node .target , torch .Event )
22792279 ) or proxy .node .target in [
22802280 device_interface .Event
22812281 for _ , device_interface in get_registered_device_interfaces ()
@@ -2287,7 +2287,7 @@ def _clone_input(value):
22872287 return ConstantVariable (example_value , ** options )
22882288 elif (
22892289 example_value is not None
2290- and isinstance (example_value , _EventBase )
2290+ and isinstance (example_value , torch . Event )
22912291 and proxy .node .target == "record_event"
22922292 and proxy .node .op == "call_method"
22932293 ):
0 commit comments