Skip to content

Commit 146b951

Browse files
thuyensoumith
authored andcommitted
Fix seeding random module in DataLoader (#7886)
* fix seeding random module * make base seed int * follow 0.4 idiom * add a test for random seeding
1 parent 65f8465 commit 146b951

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

test/test_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,15 @@ def __len__(self):
116116
return 10
117117

118118

119+
class RandomDatasetMock(object):
120+
121+
def __getitem__(self, index):
122+
return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)])
123+
124+
def __len__(self):
125+
return 1000
126+
127+
119128
class TestCheckpoint(TestCase):
120129

121130
# Test whether checkpoint is being triggered or not. For this, we check
@@ -233,6 +242,20 @@ def setUp(self):
233242
self.dataset = torch.randn(5, 3, 3, 2)
234243
self.batch_size = 3
235244

245+
def test_random_seed(self):
246+
def run():
247+
dataloader = torch.utils.data.DataLoader(RandomDatasetMock(),
248+
batch_size=2,
249+
num_workers=4,
250+
shuffle=True)
251+
return next(iter(dataloader))
252+
253+
torch.manual_seed(2018)
254+
x1 = run()
255+
torch.manual_seed(2018)
256+
x2 = run()
257+
self.assertEqual(x1, x2)
258+
236259
def test_single_keep(self):
237260
dataloader = torch.utils.data.DataLoader(self.dataset,
238261
batch_size=self.batch_size,

torch/utils/data/dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(self, loader):
246246

247247
self.sample_iter = iter(self.batch_sampler)
248248

249-
base_seed = torch.LongTensor(1).random_()[0]
249+
base_seed = torch.LongTensor(1).random_().item()
250250

251251
if self.num_workers > 0:
252252
self.worker_init_fn = loader.worker_init_fn

0 commit comments

Comments
 (0)