Skip to content

Commit 1af8ff4

Browse files
committed
Enable torch.is_complex in Dynamo tracing (#103154)
Pull Request resolved: #103154 Approved by: https://github.com/yanboliang ghstack-source-id: 6b16fc3
1 parent 2e8d2a2 commit 1af8ff4

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

test/dynamo/test_functions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,23 @@ def test_get_default_dtype(x):
436436
else:
437437
return x - 1
438438

439+
@make_test
440+
def test_get_autocast_gpu_dtype(x):
441+
dtype = torch.get_autocast_gpu_dtype()
442+
return x.type(dtype)
443+
444+
@make_test
445+
def test_get_calculate_correct_fan(x):
446+
fan_in = torch.nn.init._calculate_correct_fan(x, "fan_in")
447+
return x + fan_in
448+
449+
@make_test
450+
def test_is_complex(x):
451+
if torch.is_complex(x):
452+
return x + 1
453+
else:
454+
return x - 1
455+
439456
@make_test
440457
def test_device(x):
441458
if not x.is_cuda:

torch/_dynamo/variables/torch.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@
1111
import torch.fx
1212
import torch.nn
1313
import torch.onnx.operators
14+
<<<<<<< HEAD
1415
from torch._dynamo.utils import get_fake_value, get_real_value
1516
from torch._dynamo.variables import SymNodeVariable
1617
from torch._dynamo.variables.user_defined import ProcessGroupVariable
18+
=======
19+
from torch._dynamo.utils import get_fake_value, get_real_value, torch_np
20+
from torch._dynamo.variables import SymNodeVariable, UserFunctionVariable
21+
>>>>>>> fecd5f75277... Enable torch.nn.init._calculate_correct_fan in dynamo tracing
1722
from torch._guards import GuardsCheckpointState, Source
1823
from torch.utils import _pytree as pytree
1924

@@ -73,6 +78,7 @@
7378
torch.is_autocast_cache_enabled,
7479
torch.is_autocast_cpu_enabled,
7580
torch.is_autocast_enabled,
81+
torch.is_complex,
7682
torch.is_floating_point,
7783
torch.nn.functional._Reduction.get_enum,
7884
]
@@ -528,6 +534,10 @@ def get_state_from_generator():
528534
assert len(args) == 1, "Expected one arg (pg)"
529535
assert isinstance(args[0], ProcessGroupVariable)
530536
return ConstantVariable(self.value(args[0].as_python_constant()))
537+
elif self.value == torch.nn.init._calculate_correct_fan:
538+
return UserFunctionVariable(
539+
torch.nn.init._calculate_correct_fan, **options
540+
).call_function(tx, args, {})
531541
else:
532542
any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
533543
all_ints_or_floats = all(

0 commit comments

Comments
 (0)