88import threading
99import itertools
1010import warnings
11+ from typing import Any , Callable , TypeVar , Generic , Sequence , List , Optional
1112
1213import multiprocessing as python_multiprocessing
1314import torch
1415import torch .multiprocessing as multiprocessing
1516from torch ._utils import ExceptionWrapper
1617from torch ._six import queue , string_classes
1718
18- from . import IterableDataset , Sampler , SequentialSampler , RandomSampler , BatchSampler
19+ from . import IterableDataset , Sampler , SequentialSampler , RandomSampler , BatchSampler , Dataset
1920from . import _utils
2021
22+ T_co = TypeVar ('T_co' , covariant = True )
23+ T = TypeVar ('T' )
24+ _worker_init_fn_t = Callable [[int ], None ]
25+
26+ # Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
27+ # type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
28+ # See https://github.com/python/mypy/issues/3737.
29+ _collate_fn_t = Callable [[List [T ]], Any ]
2130
22- get_worker_info = _utils .worker .get_worker_info
2331
2432# This function used to be defined in this file. However, it was moved to
2533# _utils/collate.py. Although it is rather hard to access this from user land
2634# (one has to explicitly directly `import torch.utils.data.dataloader`), there
2735# probably is user code out there using it. This aliasing maintains BC in this
2836# aspect.
29- default_collate = _utils .collate .default_collate
37+ default_collate : _collate_fn_t = _utils .collate .default_collate
3038
39+ get_worker_info = _utils .worker .get_worker_info
3140
3241class _DatasetKind (object ):
3342 Map = 0
@@ -57,7 +66,7 @@ def __iter__(self):
5766 yield None
5867
5968
60- class DataLoader (object ):
69+ class DataLoader (Generic [ T_co ] ):
6170 r"""
6271 Data loader. Combines a dataset and a sampler, and provides an iterable over
6372 the given dataset.
@@ -116,15 +125,24 @@ class DataLoader(object):
116125 details on these two types of datasets and how
117126 :class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_.
118127 """
128+ dataset : Dataset [T_co ]
129+ batch_size : Optional [int ]
130+ num_workers : int
131+ pin_memory : bool
132+ drop_last : bool
133+ timeout : float
134+ sampler : Sampler
119135
120136 __initialized = False
121137
122- def __init__ (self , dataset , batch_size = 1 , shuffle = False , sampler = None ,
123- batch_sampler = None , num_workers = 0 , collate_fn = None ,
124- pin_memory = False , drop_last = False , timeout = 0 ,
125- worker_init_fn = None , multiprocessing_context = None ,
126- generator = None ):
127- torch ._C ._log_api_usage_once ("python.data_loader" )
138+ def __init__ (self , dataset : Dataset [T_co ], batch_size : Optional [int ] = 1 ,
139+ shuffle : bool = False , sampler : Optional [Sampler [int ]] = None ,
140+ batch_sampler : Optional [Sampler [Sequence [int ]]] = None ,
141+ num_workers : int = 0 , collate_fn : _collate_fn_t = None ,
142+ pin_memory : bool = False , drop_last : bool = False ,
143+ timeout : float = 0 , worker_init_fn : _worker_init_fn_t = None ,
144+ multiprocessing_context = None , generator = None ):
145+ torch ._C ._log_api_usage_once ("python.data_loader" ) # type: ignore
128146
129147 if num_workers < 0 :
130148 raise ValueError ('num_workers option should be non-negative; '
@@ -146,7 +164,7 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
146164 # after spending time fixing the custom sampler errors.
147165 if isinstance (dataset , IterableDataset ):
148166 self ._dataset_kind = _DatasetKind .Iterable
149- # NOTE [ Custom Samplers and ` IterableDataset` ]
167+ # NOTE [ Custom Samplers and IterableDataset ]
150168 #
151169 # `IterableDataset` does not support custom `batch_sampler` or
152170 # `sampler` since the key is irrelevant (unless we support
@@ -212,7 +230,9 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
212230 sampler = _InfiniteConstantSampler ()
213231 else : # map-style
214232 if shuffle :
215- sampler = RandomSampler (dataset , generator = generator )
233+ # Cannot statically verify that dataset is Sized
234+ # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
235+ sampler = RandomSampler (dataset , generator = generator ) # type: ignore
216236 else :
217237 sampler = SequentialSampler (dataset )
218238
@@ -253,9 +273,10 @@ def multiprocessing_context(self, multiprocessing_context):
253273 if multiprocessing_context not in valid_start_methods :
254274 raise ValueError (
255275 ('multiprocessing_context option '
256- 'should specify a valid start method in {}, but got '
257- 'multiprocessing_context={}' ).format (valid_start_methods , multiprocessing_context ))
258- multiprocessing_context = multiprocessing .get_context (multiprocessing_context )
276+ 'should specify a valid start method in {!r}, but got '
277+ 'multiprocessing_context={!r}' ).format (valid_start_methods , multiprocessing_context ))
278+ # error: Argument 1 to "get_context" has incompatible type "Union[str, bytes]"; expected "str" [arg-type]
279+ multiprocessing_context = multiprocessing .get_context (multiprocessing_context ) # type: ignore
259280
260281 if not isinstance (multiprocessing_context , python_multiprocessing .context .BaseContext ):
261282 raise TypeError (('multiprocessing_context option should be a valid context '
@@ -275,7 +296,9 @@ def __setattr__(self, attr, val):
275296
276297 super (DataLoader , self ).__setattr__ (attr , val )
277298
278- def __iter__ (self ):
299+ # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
300+ # since '_BaseDataLoaderIter' references 'DataLoader'.
301+ def __iter__ (self ) -> '_BaseDataLoaderIter' :
279302 if self .num_workers == 0 :
280303 return _SingleProcessDataLoaderIter (self )
281304 else :
@@ -297,7 +320,7 @@ def _index_sampler(self):
297320 else :
298321 return self .sampler
299322
300- def __len__ (self ):
323+ def __len__ (self ) -> int :
301324 if self ._dataset_kind == _DatasetKind .Iterable :
302325 # NOTE [ IterableDataset and __len__ ]
303326 #
@@ -313,7 +336,9 @@ def __len__(self):
313336 # To provide a further warning, we track if `__len__` was called on the
314337 # `DataLoader`, save the returned value in `self._len_called`, and warn
315338 # if the iterator ends up yielding more than this number of samples.
316- length = self ._IterableDataset_len_called = len (self .dataset )
339+
340+ # Cannot statically verify that dataset is Sized
341+ length = self ._IterableDataset_len_called = len (self .dataset ) # type: ignore
317342 if self .batch_size is not None :
318343 from math import ceil
319344 if self .drop_last :
@@ -326,7 +351,7 @@ def __len__(self):
326351
327352
328353class _BaseDataLoaderIter (object ):
329- def __init__ (self , loader ) :
354+ def __init__ (self , loader : DataLoader ) -> None :
330355 self ._dataset = loader .dataset
331356 self ._dataset_kind = loader ._dataset_kind
332357 self ._IterableDataset_len_called = loader ._IterableDataset_len_called
@@ -341,7 +366,7 @@ def __init__(self, loader):
341366 self ._base_seed = torch .empty ((), dtype = torch .int64 ).random_ (generator = loader .generator ).item ()
342367 self ._num_yielded = 0
343368
344- def __iter__ (self ):
369+ def __iter__ (self ) -> '_BaseDataLoaderIter' :
345370 return self
346371
347372 def _next_index (self ):
@@ -350,7 +375,7 @@ def _next_index(self):
350375 def _next_data (self ):
351376 raise NotImplementedError
352377
353- def __next__ (self ):
378+ def __next__ (self ) -> Any :
354379 data = self ._next_data ()
355380 self ._num_yielded += 1
356381 if self ._dataset_kind == _DatasetKind .Iterable and \
@@ -368,7 +393,7 @@ def __next__(self):
368393
369394 next = __next__ # Python 2 compatibility
370395
371- def __len__ (self ):
396+ def __len__ (self ) -> int :
372397 return len (self ._index_sampler )
373398
374399 def __getstate__ (self ):
@@ -690,7 +715,8 @@ def __init__(self, loader):
690715
691716 self ._worker_init_fn = loader .worker_init_fn
692717 self ._worker_queue_idx_cycle = itertools .cycle (range (self ._num_workers ))
693- self ._worker_result_queue = multiprocessing_context .Queue ()
718+ # No certainty which module multiprocessing_context is
719+ self ._worker_result_queue = multiprocessing_context .Queue () # type: ignore
694720 self ._worker_pids_set = False
695721 self ._shutdown = False
696722 self ._send_idx = 0 # idx of the next task to be sent to workers
@@ -710,7 +736,8 @@ def __init__(self, loader):
710736 # (i.e., if kind != Iterable).
711737 self ._workers_status = []
712738 for i in range (self ._num_workers ):
713- index_queue = multiprocessing_context .Queue ()
739+ # No certainty which module multiprocessing_context is
740+ index_queue = multiprocessing_context .Queue () # type: ignore
714741 # index_queue.cancel_join_thread()
715742 w = multiprocessing_context .Process (
716743 target = _utils .worker ._worker_loop ,
@@ -732,7 +759,9 @@ def __init__(self, loader):
732759
733760 if self ._pin_memory :
734761 self ._pin_memory_thread_done_event = threading .Event ()
735- self ._data_queue = queue .Queue ()
762+
763+ # Queue is not type-annotated
764+ self ._data_queue = queue .Queue () # type: ignore
736765 pin_memory_thread = threading .Thread (
737766 target = _utils .pin_memory ._pin_memory_loop ,
738767 args = (self ._worker_result_queue , self ._data_queue ,
0 commit comments