Skip to content

Commit 57e05e8

Browse files
blankstaticpytorchmergebot
authored andcommitted
Issue 68576 prefetch factor (#88972)
Fixes #68576 This PR allows set the `prefetch_factor=None` making it really optional according to the documentation Pull Request resolved: #88972 Approved by: https://github.com/kit1980
1 parent 2b3ac87 commit 57e05e8

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

torch/utils/data/dataloader.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ class DataLoader(Generic[T_co]):
217217
timeout: float
218218
sampler: Union[Sampler, Iterable]
219219
pin_memory_device: str
220-
prefetch_factor: int
220+
prefetch_factor: Optional[int]
221221
_iterator : Optional['_BaseDataLoaderIter']
222222
__initialized = False
223223

@@ -228,7 +228,7 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
228228
pin_memory: bool = False, drop_last: bool = False,
229229
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
230230
multiprocessing_context=None, generator=None,
231-
*, prefetch_factor: int = 2,
231+
*, prefetch_factor: Optional[int] = None,
232232
persistent_workers: bool = False,
233233
pin_memory_device: str = ""):
234234
torch._C._log_api_usage_once("python.data_loader")
@@ -240,10 +240,13 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
240240
if timeout < 0:
241241
raise ValueError('timeout option should be non-negative')
242242

243-
if num_workers == 0 and prefetch_factor != 2:
243+
if num_workers == 0 and prefetch_factor is not None:
244244
raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
245-
'let num_workers > 0 to enable multiprocessing.')
246-
assert prefetch_factor > 0
245+
'let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.')
246+
elif num_workers > 0 and prefetch_factor is None:
247+
prefetch_factor = 2
248+
elif prefetch_factor is not None and prefetch_factor < 0:
249+
raise ValueError('prefetch_factor option should be non-negative')
247250

248251
if persistent_workers and num_workers == 0:
249252
raise ValueError('persistent_workers option needs num_workers > 0')
@@ -581,7 +584,6 @@ def __init__(self, loader: DataLoader) -> None:
581584
ws, rank = _get_distributed_settings()
582585
self._world_size = ws
583586
self._rank = rank
584-
self._prefetch_factor = loader.prefetch_factor
585587
# for other backends, pin_memory_device need to set. if not set
586588
# default behaviour is CUDA device. if pin_memory_device is selected
587589
# and pin_memory is not set, the default behaviour false.
@@ -991,6 +993,8 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
991993
def __init__(self, loader):
992994
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
993995

996+
self._prefetch_factor = loader.prefetch_factor
997+
994998
assert self._num_workers > 0
995999
assert self._prefetch_factor > 0
9961000

0 commit comments

Comments
 (0)