Skip to content

Commit f496ea3

Browse files
ssnlfacebook-github-bot
authored andcommitted
DataLoader: add error detection for worker_init_fn (#20150)
Summary: This is an attempt to isolate unrelated changes from #19228 for easier review. Pull Request resolved: #20150 Differential Revision: D15314891 Pulled By: ezyang fbshipit-source-id: 8c429747ba83ad5aca4cdd8f8086bcf65a326921
1 parent 163f0e1 commit f496ea3

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

test/test_dataloader.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,10 @@ def kill_pid(pid):
454454
def init_fn(worker_id):
455455
torch.manual_seed(12345)
456456

457+
# used with test_error_in_init
458+
def error_worker_init_fn(_):
459+
raise RuntimeError("Error in worker_init_fn")
460+
457461

458462
class TestDataLoader(TestCase):
459463

@@ -509,6 +513,11 @@ def fn():
509513

510514
self.assertRaises(ValueError, fn)
511515

516+
def test_error_in_init(self):
517+
loader = DataLoader(self.dataset, num_workers=2, worker_init_fn=error_worker_init_fn)
518+
with self.assertRaisesRegex(RuntimeError, 'Error in worker_init_fn'):
519+
list(iter(loader))
520+
512521
def test_sequential(self):
513522
self._test_sequential(DataLoader(self.dataset))
514523

torch/utils/data/_utils/worker.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,13 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
7575

7676
data_queue.cancel_join_thread()
7777

78+
init_exception = None
79+
7880
if init_fn is not None:
79-
init_fn(worker_id)
81+
try:
82+
init_fn(worker_id)
83+
except Exception:
84+
init_exception = ExceptionWrapper(sys.exc_info())
8085

8186
watchdog = ManagerWatchdog()
8287

@@ -96,7 +101,11 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
96101
continue
97102
idx, batch_indices = r
98103
try:
99-
samples = collate_fn([dataset[i] for i in batch_indices])
104+
if init_exception is not None:
105+
samples = init_exception
106+
init_exception = None
107+
else:
108+
samples = collate_fn([dataset[i] for i in batch_indices])
100109
except Exception:
101110
# It is important that we don't store exc_info in a variable,
102111
# see NOTE [ Python Traceback Reference Cycle Problem ]

0 commit comments

Comments
 (0)