Skip to content

Commit d6b83a7

Browse files
blankstaticpytorchmergebot
authored andcommitted
set _prefetch_factor at _MultiProcessingDataLoaderIter and mypy fixes
1 parent 5273e4f commit d6b83a7

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torch/utils/data/dataloader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
245245
'let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.')
246246
elif num_workers > 0 and prefetch_factor is None:
247247
prefetch_factor = 2
248-
elif prefetch_factor < 0:
248+
elif prefetch_factor is not None and prefetch_factor < 0:
249249
raise ValueError('prefetch_factor option should be non-negative')
250250

251251
if persistent_workers and num_workers == 0:
@@ -584,7 +584,6 @@ def __init__(self, loader: DataLoader) -> None:
584584
ws, rank = _get_distributed_settings()
585585
self._world_size = ws
586586
self._rank = rank
587-
self._prefetch_factor = loader.prefetch_factor
588587
# for other backends, pin_memory_device need to set. if not set
589588
# default behaviour is CUDA device. if pin_memory_device is selected
590589
# and pin_memory is not set, the default behaviour false.
@@ -994,6 +993,8 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
994993
def __init__(self, loader):
995994
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
996995

996+
self._prefetch_factor = loader.prefetch_factor
997+
997998
assert self._num_workers > 0
998999
assert self._prefetch_factor > 0
9991000

0 commit comments

Comments
 (0)