@@ -217,7 +217,7 @@ class DataLoader(Generic[T_co]):
217217 timeout : float
218218 sampler : Union [Sampler , Iterable ]
219219 pin_memory_device : str
220- prefetch_factor : int
220+ prefetch_factor : Optional [ int ]
221221 _iterator : Optional ['_BaseDataLoaderIter' ]
222222 __initialized = False
223223
@@ -228,7 +228,7 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
228228 pin_memory : bool = False , drop_last : bool = False ,
229229 timeout : float = 0 , worker_init_fn : Optional [_worker_init_fn_t ] = None ,
230230 multiprocessing_context = None , generator = None ,
231- * , prefetch_factor : int = 2 ,
231+ * , prefetch_factor : Optional [ int ] = None ,
232232 persistent_workers : bool = False ,
233233 pin_memory_device : str = "" ):
234234 torch ._C ._log_api_usage_once ("python.data_loader" )
@@ -240,10 +240,13 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
240240 if timeout < 0 :
241241 raise ValueError ('timeout option should be non-negative' )
242242
243- if num_workers == 0 and prefetch_factor != 2 :
243+ if num_workers == 0 and prefetch_factor is not None :
244244 raise ValueError ('prefetch_factor option could only be specified in multiprocessing.'
245- 'let num_workers > 0 to enable multiprocessing.' )
246- assert prefetch_factor > 0
245+ 'let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.' )
246+ elif num_workers > 0 and prefetch_factor is None :
247+ prefetch_factor = 2
248+ elif prefetch_factor is not None and prefetch_factor < 0 :
249+ raise ValueError ('prefetch_factor option should be non-negative' )
247250
248251 if persistent_workers and num_workers == 0 :
249252 raise ValueError ('persistent_workers option needs num_workers > 0' )
@@ -581,7 +584,6 @@ def __init__(self, loader: DataLoader) -> None:
581584 ws , rank = _get_distributed_settings ()
582585 self ._world_size = ws
583586 self ._rank = rank
584- self ._prefetch_factor = loader .prefetch_factor
585587 # for other backends, pin_memory_device need to set. if not set
586588 # default behaviour is CUDA device. if pin_memory_device is selected
587589 # and pin_memory is not set, the default behaviour false.
@@ -991,6 +993,8 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
991993 def __init__ (self , loader ):
992994 super (_MultiProcessingDataLoaderIter , self ).__init__ (loader )
993995
996+ self ._prefetch_factor = loader .prefetch_factor
997+
994998 assert self ._num_workers > 0
995999 assert self ._prefetch_factor > 0
9961000
0 commit comments