File tree Expand file tree Collapse file tree 2 files changed +24
-1
lines changed
Expand file tree Collapse file tree 2 files changed +24
-1
lines changed Original file line number Diff line number Diff line change @@ -116,6 +116,15 @@ def __len__(self):
116116 return 10
117117
118118
119+ class RandomDatasetMock (object ):
120+
121+ def __getitem__ (self , index ):
122+ return torch .tensor ([torch .rand (1 ).item (), random .uniform (0 , 1 )])
123+
124+ def __len__ (self ):
125+ return 1000
126+
127+
119128class TestCheckpoint (TestCase ):
120129
121130 # Test whether checkpoint is being triggered or not. For this, we check
@@ -233,6 +242,20 @@ def setUp(self):
233242 self .dataset = torch .randn (5 , 3 , 3 , 2 )
234243 self .batch_size = 3
235244
245+ def test_random_seed (self ):
246+ def run ():
247+ dataloader = torch .utils .data .DataLoader (RandomDatasetMock (),
248+ batch_size = 2 ,
249+ num_workers = 4 ,
250+ shuffle = True )
251+ return next (iter (dataloader ))
252+
253+ torch .manual_seed (2018 )
254+ x1 = run ()
255+ torch .manual_seed (2018 )
256+ x2 = run ()
257+ self .assertEqual (x1 , x2 )
258+
236259 def test_single_keep (self ):
237260 dataloader = torch .utils .data .DataLoader (self .dataset ,
238261 batch_size = self .batch_size ,
Original file line number Diff line number Diff line change @@ -246,7 +246,7 @@ def __init__(self, loader):
246246
247247 self .sample_iter = iter (self .batch_sampler )
248248
249- base_seed = torch .LongTensor (1 ).random_ ()[ 0 ]
249+ base_seed = torch .LongTensor (1 ).random_ (). item ()
250250
251251 if self .num_workers > 0 :
252252 self .worker_init_fn = loader .worker_init_fn
You can’t perform that action at this time.
0 commit comments