Skip to content

Commit 8c69eac

Browse files
zou3519apaszke
authored andcommitted
Initialize cuda before setting cuda tensor types as default
1 parent 154038e commit 8c69eac

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

torch/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ def set_default_tensor_type(t):
129129
global Storage
130130
Tensor = _import_dotted_name(t)
131131
Storage = _import_dotted_name(t.replace('Tensor', 'Storage'))
132+
133+
if 'cuda' in t:
134+
import torch.cuda
135+
torch.cuda.init()
136+
132137
_C._set_default_tensor_type(Tensor)
133138

134139

0 commit comments

Comments
 (0)