We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6338da9 commit e98af60Copy full SHA for e98af60
torch/utils/data/dataloader.py
@@ -6,7 +6,7 @@
6
import sys
7
import traceback
8
import threading
9
-from torch._six import string_classes
+from torch._six import string_classes, int_classes
10
11
12
if sys.version_info[0] == 2:
@@ -106,7 +106,7 @@ def default_collate(batch):
106
if elem.shape == (): # scalars
107
py_type = float if elem.dtype.name.startswith('float') else int
108
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
109
- elif isinstance(batch[0], int):
+ elif isinstance(batch[0], int_classes):
110
return torch.LongTensor(batch)
111
elif isinstance(batch[0], float):
112
return torch.DoubleTensor(batch)
0 commit comments