Skip to content

Commit c64ae60

Browse files
janselpytorchmergebot
authored andcommitted
[dynamo] Fix support for classmethod(property(...)) (#134968)
Fixes #134451 Pull Request resolved: #134968 Approved by: https://github.com/yanboliang
1 parent 7f5abb4 commit c64ae60

File tree

3 files changed

+84
-1
lines changed

3 files changed

+84
-1
lines changed

test/dynamo/test_repros.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
import itertools
1515
import os
1616
import random
17+
import sys
1718
import unittest
1819
import warnings
1920
import weakref
2021
from abc import ABC
2122
from collections import namedtuple
2223
from copy import deepcopy
23-
from enum import Enum
24+
from enum import Enum, IntEnum
2425
from functools import wraps
2526
from typing import Any, Dict, Iterator, List, Tuple
2627
from unittest import mock
@@ -4546,6 +4547,82 @@ def f(*args):
45464547
f(*args)
45474548
self.assertEqual(num_compiles, 1)
45484549

4550+
@unittest.skipIf(sys.version_info < (3, 9), "requires python 3.9+")
4551+
def test_issue134451(self):
4552+
class BoundingBox2DIndex(IntEnum):
4553+
_X = 0
4554+
_Y = 1
4555+
_HEADING = 2
4556+
_LENGTH = 3
4557+
_WIDTH = 4
4558+
4559+
@classmethod
4560+
def size(cls):
4561+
return 5
4562+
4563+
@classmethod
4564+
@property
4565+
def X(cls):
4566+
return cls._X
4567+
4568+
@classmethod
4569+
@property
4570+
def Y(cls):
4571+
return cls._Y
4572+
4573+
@classmethod
4574+
@property
4575+
def HEADING(cls):
4576+
return cls._HEADING
4577+
4578+
@classmethod
4579+
@property
4580+
def LENGTH(cls):
4581+
return cls._LENGTH
4582+
4583+
@classmethod
4584+
@property
4585+
def WIDTH(cls):
4586+
return cls._WIDTH
4587+
4588+
@classmethod
4589+
@property
4590+
def POINT(cls):
4591+
# assumes X, Y have subsequent indices
4592+
return slice(cls._X, cls._Y + 1)
4593+
4594+
@classmethod
4595+
@property
4596+
def STATE_SE2(cls):
4597+
# assumes X, Y, HEADING have subsequent indices
4598+
return slice(cls._X, cls._HEADING + 1)
4599+
4600+
class SimpleModel(nn.Module):
4601+
def __init__(self):
4602+
super().__init__()
4603+
self._mlp_states = nn.Sequential(
4604+
nn.Linear(10, 20),
4605+
nn.ReLU(),
4606+
nn.Linear(20, BoundingBox2DIndex.size()),
4607+
)
4608+
4609+
def forward(self, x):
4610+
agent_states = self._mlp_states(x)
4611+
agent_states[..., BoundingBox2DIndex.POINT] = (
4612+
agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32
4613+
)
4614+
agent_states[..., BoundingBox2DIndex.HEADING] = (
4615+
agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi
4616+
)
4617+
return agent_states
4618+
4619+
model = SimpleModel().eval()
4620+
input_tensor = torch.randn(1, 10, dtype=torch.float32)
4621+
opt = torch.compile(model.eval(), backend="eager", fullgraph=True)
4622+
actual = opt(input_tensor)
4623+
expected = model(input_tensor)
4624+
self.assertEqual(actual, expected)
4625+
45494626
def test_invalid_seq_unpack(self):
45504627
def myfn(arg):
45514628
(a, b) = arg

torch/_dynamo/variables/constant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ def create(cls, cls_type, value_vt, options):
222222
unimplemented("Enum variable is constructed with non constant values")
223223

224224
def as_proxy(self):
225+
if isinstance(self.value, int):
226+
return int(self.value) # convert IntEnum to a normal int
225227
return self.value
226228

227229
def __str__(self) -> str:

torch/_dynamo/variables/user_defined.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke
193193
else:
194194
return SourcelessBuilder.create(tx, func)
195195
elif isinstance(obj, classmethod):
196+
if isinstance(obj.__func__, property):
197+
return variables.UserFunctionVariable(obj.__func__.fget).call_function(
198+
tx, [self], {}
199+
)
196200
return variables.UserMethodVariable(obj.__func__, self, source=source)
197201
elif isinstance(obj, types.ClassMethodDescriptorType):
198202
# e.g.: inspect.getattr_static(dict, "fromkeys")

0 commit comments

Comments
 (0)