|
11 | 11 | import torch.fx |
12 | 12 | import torch.nn |
13 | 13 | import torch.onnx.operators |
| 14 | +<<<<<<< HEAD |
14 | 15 | from torch._dynamo.utils import get_fake_value, get_real_value |
15 | 16 | from torch._dynamo.variables import SymNodeVariable |
16 | 17 | from torch._dynamo.variables.user_defined import ProcessGroupVariable |
| 18 | +======= |
| 19 | +from torch._dynamo.utils import get_fake_value, get_real_value, torch_np |
| 20 | +from torch._dynamo.variables import SymNodeVariable, UserFunctionVariable |
| 21 | +>>>>>>> fecd5f75277... Enable torch.nn.init._calculate_correct_fan in dynamo tracing |
17 | 22 | from torch._guards import GuardsCheckpointState, Source |
18 | 23 | from torch.utils import _pytree as pytree |
19 | 24 |
|
|
73 | 78 | torch.is_autocast_cache_enabled, |
74 | 79 | torch.is_autocast_cpu_enabled, |
75 | 80 | torch.is_autocast_enabled, |
| 81 | + torch.is_complex, |
76 | 82 | torch.is_floating_point, |
77 | 83 | torch.nn.functional._Reduction.get_enum, |
78 | 84 | ] |
@@ -528,6 +534,10 @@ def get_state_from_generator(): |
528 | 534 | assert len(args) == 1, "Expected one arg (pg)" |
529 | 535 | assert isinstance(args[0], ProcessGroupVariable) |
530 | 536 | return ConstantVariable(self.value(args[0].as_python_constant())) |
| 537 | + elif self.value == torch.nn.init._calculate_correct_fan: |
| 538 | + return UserFunctionVariable( |
| 539 | + torch.nn.init._calculate_correct_fan, **options |
| 540 | + ).call_function(tx, args, {}) |
531 | 541 | else: |
532 | 542 | any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) |
533 | 543 | all_ints_or_floats = all( |
|
0 commit comments