@@ -243,8 +243,6 @@ def __init__(self, overloadpacket, op, op_dk, schema, tags):
243243 op .__module__ = overloadpacket .__module__
244244 self .__qualname__ = self ._name
245245 self .__annotations__ = {}
246- # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp
247- self ._dispatch_cache = {}
248246
249247 # it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
250248 def __deepcopy__ (self , memo = None ):
@@ -291,7 +289,6 @@ def inner(fn):
291289 assert mode not in self .python_key_mode_table
292290 # TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
293291 self .python_key_mode_table [mode ] = fn
294- self ._dispatch_cache .clear ()
295292 return fn
296293
297294 assert isinstance (dispatch_key_or_mode , torch ._C .DispatchKey )
@@ -304,19 +301,23 @@ def inner(fn):
304301 f"Trying to override a python impl for { dispatch_key_or_mode } on operator { self ._name } "
305302 )
306303 self .py_kernels [dispatch_key_or_mode ] = fn
307- self ._dispatch_cache .clear ()
308304 return fn
309305
310306 return inner
311307
312308 # This implements the pre-computation logic for the Python dispatcher.
313- def _get_dispatch (self , key ):
314- # This is only called upon a cache miss
315- assert key not in self ._dispatch_cache
309+ def __getattr__ (self , attr ):
310+ if len (attr ) == 0 or not attr [0 ].isupper ():
311+ raise AttributeError ()
312+
313+ try :
314+ key = torch ._C ._dispatch_key_parse (attr )
315+ except Exception as e :
316+ raise AttributeError ()
316317
317318 if key == torch ._C .DispatchKey .Python :
318319 if not self .python_key_mode_table :
319- self . _dispatch_cache [ key ] = key
320+ setattr ( self , attr , key )
320321 return key
321322
322323 def handler (* args , ** kwargs ):
@@ -335,12 +336,12 @@ def handler(*args, **kwargs):
335336 # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
336337 return self .python_key_mode_table [curr_mode ](* args , ** kwargs )
337338
338- self . _dispatch_cache [ key ] = handler
339+ setattr ( self , attr , handler )
339340 return handler
340341
341342 key = resolve_key (self , key )
342343 r = self .py_kernels .get (key , key )
343- self . _dispatch_cache [ key ] = r
344+ setattr ( self , attr , r )
344345 return r
345346
346347 def name (self ):
0 commit comments