Skip to content

Commit e98af60

Browse files
alykhantejanisoumith
authored andcommitted
Accept longs in default_collate for dataloader in python 2 (#4001)
1 parent 6338da9 commit e98af60

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torch/utils/data/dataloader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sys
77
import traceback
88
import threading
9-
from torch._six import string_classes
9+
from torch._six import string_classes, int_classes
1010

1111

1212
if sys.version_info[0] == 2:
@@ -106,7 +106,7 @@ def default_collate(batch):
106106
if elem.shape == (): # scalars
107107
py_type = float if elem.dtype.name.startswith('float') else int
108108
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
109-
elif isinstance(batch[0], int):
109+
elif isinstance(batch[0], int_classes):
110110
return torch.LongTensor(batch)
111111
elif isinstance(batch[0], float):
112112
return torch.DoubleTensor(batch)

0 commit comments

Comments
 (0)