Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions torch/utils/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class DataLoader(Generic[T_co]):
timeout: float
sampler: Union[Sampler, Iterable]
pin_memory_device: str
prefetch_factor: int
prefetch_factor: Optional[int]
_iterator : Optional['_BaseDataLoaderIter']
__initialized = False

Expand All @@ -228,7 +228,7 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
*, prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
pin_memory_device: str = ""):
torch._C._log_api_usage_once("python.data_loader")
Expand All @@ -240,10 +240,13 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
if timeout < 0:
raise ValueError('timeout option should be non-negative')

if num_workers == 0 and prefetch_factor != 2:
if num_workers == 0 and prefetch_factor is not None:
raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
'let num_workers > 0 to enable multiprocessing.')
assert prefetch_factor > 0
'let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.')
elif num_workers > 0 and prefetch_factor is None:
prefetch_factor = 2
elif prefetch_factor is not None and prefetch_factor < 0:
raise ValueError('prefetch_factor option should be non-negative')

if persistent_workers and num_workers == 0:
raise ValueError('persistent_workers option needs num_workers > 0')
Expand Down Expand Up @@ -581,7 +584,6 @@ def __init__(self, loader: DataLoader) -> None:
ws, rank = _get_distributed_settings()
self._world_size = ws
self._rank = rank
self._prefetch_factor = loader.prefetch_factor
# for other backends, pin_memory_device need to set. if not set
# default behaviour is CUDA device. if pin_memory_device is selected
# and pin_memory is not set, the default behaviour false.
Expand Down Expand Up @@ -991,6 +993,8 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_MultiProcessingDataLoaderIter, self).__init__(loader)

self._prefetch_factor = loader.prefetch_factor

assert self._num_workers > 0
assert self._prefetch_factor > 0

Expand Down