Skip to content

Commit f6488d8

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][user-defined] Remove __getattribute__ checks and add getsetdescriptor (#144173)
Pull Request resolved: #144173 Approved by: https://github.com/jansel
1 parent b01556b commit f6488d8

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

test/dynamo/test_misc.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11837,6 +11837,21 @@ def fn(x, f):
1183711837
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1183811838
self.assertEqual(fn(x, get_foo()), opt_fn(x, get_foo()))
1183911839

11840+
def test_dunder_weakref(self):
11841+
class Foo:
11842+
pass
11843+
11844+
def fn(x):
11845+
foo = Foo()
11846+
# tests isgetsetdescriptor
11847+
if foo.__weakref__:
11848+
return torch.cos(x)
11849+
return torch.sin(x)
11850+
11851+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
11852+
x = torch.randn(4)
11853+
self.assertEqual(fn(x), opt_fn(x))
11854+
1184011855

1184111856
class TestTracer(JitTestCase):
1184211857
def test_jit_save(self):

torch/_dynamo/variables/user_defined.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -936,10 +936,6 @@ def call_function(
936936

937937
return super().call_function(tx, args, kwargs)
938938

939-
def _check_for_getattribute(self):
940-
if object_has_getattribute(self.value):
941-
unimplemented("UserDefinedObjectVariable with custom __getattribute__")
942-
943939
def _check_for_getattr(self):
944940
return get_custom_getattr(self.value)
945941

@@ -961,7 +957,7 @@ def _getattr_static(self, name):
961957

962958
# In some cases, we have to do dynamic lookup because getattr_static is not enough. For example, threading.local
963959
# has side-effect free __getattribute__ and the attribute is not visible without a dynamic lookup.
964-
if (
960+
if not object_has_getattribute(self.value) and (
965961
subobj is NO_SUCH_SUBOBJ # e.g., threading.local
966962
or isinstance(
967963
subobj, _collections._tuplegetter
@@ -970,15 +966,24 @@ def _getattr_static(self, name):
970966
inspect.ismemberdescriptor(subobj) and name in self.value.__slots__
971967
) # handle memberdecriptor and slots
972968
or self._is_c_defined_property(subobj)
969+
or inspect.isgetsetdescriptor(
970+
subobj
971+
) # handle getsetdescriptor like __dict__
973972
):
974973
# Call __getattribute__, we have already checked that this is not overridden and side-effect free. We don't
975974
# want to call getattr because it can be user-overridden.
976975
subobj = self.value.__getattribute__(name)
976+
elif object_has_getattribute(self.value) and subobj is NO_SUCH_SUBOBJ:
977+
# If the object has an overridden getattribute method, Dynamo has
978+
# already tried tracing it, and encountered an AttributeError. We
979+
# call getattr_static only when the __getattribute__ tracing fails
980+
# (check var_getattr impl). So, it is safe here to raise the
981+
# AttributeError.
982+
raise AttributeError
977983

978984
return subobj
979985

980986
def has_key_in_generic_dict(self, tx: "InstructionTranslator", key):
981-
self._check_for_getattribute()
982987
if tx.output.side_effects.has_pending_mutation_of_attr(self, key):
983988
mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True)
984989
return not isinstance(mutated_attr, variables.DeletedVariable)

0 commit comments

Comments
 (0)