Skip to content

Commit cc9dc3f

Browse files
ssnlsoumith
authored andcommitted
add lock for SynchronizedSeedDataset; add additional os level close stderr for tests that launch failing process (#4463)
1 parent cc70a33 commit cc9dc3f

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

test/test_dataloader.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22
import sys
3+
import os
34
import ctypes
45
import torch
56
import time
@@ -153,15 +154,16 @@ class SynchronizedSeedDataset(Dataset):
153154

154155
def __init__(self, size, num_workers):
155156
assert size >= num_workers
156-
self.count = multiprocessing.Value('i', 0)
157+
self.count = multiprocessing.Value('i', 0, lock=True)
157158
self.barrier = multiprocessing.Semaphore(0)
158159
self.num_workers = num_workers
159160
self.size = size
160161

161162
def __getitem__(self, idx):
162-
self.count.value += 1
163-
if self.count.value == self.num_workers:
164-
self.barrier.release()
163+
with self.count.get_lock():
164+
self.count.value += 1
165+
if self.count.value == self.num_workers:
166+
self.barrier.release()
165167
self.barrier.acquire()
166168
self.barrier.release()
167169
return torch.initial_seed()
@@ -249,6 +251,7 @@ def test_multiple_dataloaders(self):
249251
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows")
250252
def test_segfault(self):
251253
def _test_segfault():
254+
os.close(sys.stderr.fileno())
252255
sys.stderr.close()
253256
dataset = SegfaultDataset(10)
254257
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
@@ -266,6 +269,7 @@ def _test_segfault():
266269
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows")
267270
def test_timeout(self):
268271
def _test_timeout():
272+
os.close(sys.stderr.fileno())
269273
sys.stderr.close()
270274
dataset = SleepDataset(10, 10)
271275
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)

0 commit comments

Comments
 (0)