Skip to content

Commit 709eee7

Browse files
committed
Enable torch.get_autocast_gpu_dtype in Dynamo tracing
ghstack-source-id: c343dcf Pull Request resolved: #103166
1 parent 84a8988 commit 709eee7

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

test/dynamo/test_functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,11 @@ def test_get_default_dtype(x):
399399
else:
400400
return x - 1
401401

402+
@make_test
403+
def test_get_autocast_gpu_dtype(x):
404+
dtype = torch.get_autocast_gpu_dtype()
405+
return x.type(dtype)
406+
402407
@make_test
403408
def test_is_complex(x):
404409
if torch.is_complex(x):

torch/_dynamo/variables/torch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
torch.device,
6565
torch.distributed.is_available,
6666
torch.finfo,
67+
torch.get_autocast_gpu_dtype,
6768
torch.get_default_dtype,
6869
torch.iinfo,
6970
torch.is_autocast_cache_enabled,

0 commit comments

Comments
 (0)