Skip to content

Commit 0e09511

Browse files
Baranowskifacebook-github-bot
authored andcommitted
type annotations for dataloader, dataset, sampler (#39392)
Summary: Fixes #38913 Pull Request resolved: #39392 Reviewed By: anjali411 Differential Revision: D22102489 Pulled By: zou3519 fbshipit-source-id: acb68d9521145f0b047214d62b5bdc5a0d1b9be4
1 parent a6b703c commit 0e09511

File tree

8 files changed

+133
-182
lines changed

8 files changed

+133
-182
lines changed

mypy.ini

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,6 @@ ignore_errors = True
302302
[mypy-torch.utils.data._utils.worker]
303303
ignore_errors = True
304304

305-
[mypy-torch.utils.data.dataset]
306-
ignore_errors = True
307-
308305
[mypy-torch.utils.data.distributed]
309306
ignore_errors = True
310307

test/test_dataloader.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,16 @@ def test_error_in_init(self):
821821
with self.assertRaisesRegex(RuntimeError, 'Error in worker_init_fn'):
822822
list(iter(loader))
823823

824+
def test_typing(self):
825+
from typing import List
826+
# Make sure there is no TypeError
827+
828+
class SomeDatasetClass(Dataset[List[torch.Tensor]]):
829+
pass
830+
831+
def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]:
832+
pass
833+
824834
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
825835
@unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows")
826836
def test_fd_limit_exceeded(self):
@@ -2019,5 +2029,6 @@ def test_set_affinity_in_worker_init(self):
20192029
self.assertEqual(sample, [2])
20202030

20212031

2032+
20222033
if __name__ == '__main__':
20232034
run_tests()

torch/utils/data/dataloader.py

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,35 @@
88
import threading
99
import itertools
1010
import warnings
11+
from typing import Any, Callable, TypeVar, Generic, Sequence, List, Optional
1112

1213
import multiprocessing as python_multiprocessing
1314
import torch
1415
import torch.multiprocessing as multiprocessing
1516
from torch._utils import ExceptionWrapper
1617
from torch._six import queue, string_classes
1718

18-
from . import IterableDataset, Sampler, SequentialSampler, RandomSampler, BatchSampler
19+
from . import IterableDataset, Sampler, SequentialSampler, RandomSampler, BatchSampler, Dataset
1920
from . import _utils
2021

22+
T_co = TypeVar('T_co', covariant=True)
23+
T = TypeVar('T')
24+
_worker_init_fn_t = Callable[[int], None]
25+
26+
# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
27+
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
28+
# See https://github.com/python/mypy/issues/3737.
29+
_collate_fn_t = Callable[[List[T]], Any]
2130

22-
get_worker_info = _utils.worker.get_worker_info
2331

2432
# This function used to be defined in this file. However, it was moved to
2533
# _utils/collate.py. Although it is rather hard to access this from user land
2634
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
2735
# probably is user code out there using it. This aliasing maintains BC in this
2836
# aspect.
29-
default_collate = _utils.collate.default_collate
37+
default_collate: _collate_fn_t = _utils.collate.default_collate
3038

39+
get_worker_info = _utils.worker.get_worker_info
3140

3241
class _DatasetKind(object):
3342
Map = 0
@@ -57,7 +66,7 @@ def __iter__(self):
5766
yield None
5867

5968

60-
class DataLoader(object):
69+
class DataLoader(Generic[T_co]):
6170
r"""
6271
Data loader. Combines a dataset and a sampler, and provides an iterable over
6372
the given dataset.
@@ -116,15 +125,24 @@ class DataLoader(object):
116125
details on these two types of datasets and how
117126
:class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_.
118127
"""
128+
dataset: Dataset[T_co]
129+
batch_size: Optional[int]
130+
num_workers: int
131+
pin_memory: bool
132+
drop_last: bool
133+
timeout: float
134+
sampler: Sampler
119135

120136
__initialized = False
121137

122-
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
123-
batch_sampler=None, num_workers=0, collate_fn=None,
124-
pin_memory=False, drop_last=False, timeout=0,
125-
worker_init_fn=None, multiprocessing_context=None,
126-
generator=None):
127-
torch._C._log_api_usage_once("python.data_loader")
138+
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
139+
shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
140+
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
141+
num_workers: int = 0, collate_fn: _collate_fn_t = None,
142+
pin_memory: bool = False, drop_last: bool = False,
143+
timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
144+
multiprocessing_context=None, generator=None):
145+
torch._C._log_api_usage_once("python.data_loader") # type: ignore
128146

129147
if num_workers < 0:
130148
raise ValueError('num_workers option should be non-negative; '
@@ -146,7 +164,7 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
146164
# after spending time fixing the custom sampler errors.
147165
if isinstance(dataset, IterableDataset):
148166
self._dataset_kind = _DatasetKind.Iterable
149-
# NOTE [ Custom Samplers and `IterableDataset` ]
167+
# NOTE [ Custom Samplers and IterableDataset ]
150168
#
151169
# `IterableDataset` does not support custom `batch_sampler` or
152170
# `sampler` since the key is irrelevant (unless we support
@@ -212,7 +230,9 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
212230
sampler = _InfiniteConstantSampler()
213231
else: # map-style
214232
if shuffle:
215-
sampler = RandomSampler(dataset, generator=generator)
233+
# Cannot statically verify that dataset is Sized
234+
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
235+
sampler = RandomSampler(dataset, generator=generator) # type: ignore
216236
else:
217237
sampler = SequentialSampler(dataset)
218238

@@ -253,9 +273,10 @@ def multiprocessing_context(self, multiprocessing_context):
253273
if multiprocessing_context not in valid_start_methods:
254274
raise ValueError(
255275
('multiprocessing_context option '
256-
'should specify a valid start method in {}, but got '
257-
'multiprocessing_context={}').format(valid_start_methods, multiprocessing_context))
258-
multiprocessing_context = multiprocessing.get_context(multiprocessing_context)
276+
'should specify a valid start method in {!r}, but got '
277+
'multiprocessing_context={!r}').format(valid_start_methods, multiprocessing_context))
278+
# error: Argument 1 to "get_context" has incompatible type "Union[str, bytes]"; expected "str" [arg-type]
279+
multiprocessing_context = multiprocessing.get_context(multiprocessing_context) # type: ignore
259280

260281
if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
261282
raise TypeError(('multiprocessing_context option should be a valid context '
@@ -275,7 +296,9 @@ def __setattr__(self, attr, val):
275296

276297
super(DataLoader, self).__setattr__(attr, val)
277298

278-
def __iter__(self):
299+
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
300+
# since '_BaseDataLoaderIter' references 'DataLoader'.
301+
def __iter__(self) -> '_BaseDataLoaderIter':
279302
if self.num_workers == 0:
280303
return _SingleProcessDataLoaderIter(self)
281304
else:
@@ -297,7 +320,7 @@ def _index_sampler(self):
297320
else:
298321
return self.sampler
299322

300-
def __len__(self):
323+
def __len__(self) -> int:
301324
if self._dataset_kind == _DatasetKind.Iterable:
302325
# NOTE [ IterableDataset and __len__ ]
303326
#
@@ -313,7 +336,9 @@ def __len__(self):
313336
# To provide a further warning, we track if `__len__` was called on the
314337
# `DataLoader`, save the returned value in `self._len_called`, and warn
315338
# if the iterator ends up yielding more than this number of samples.
316-
length = self._IterableDataset_len_called = len(self.dataset)
339+
340+
# Cannot statically verify that dataset is Sized
341+
length = self._IterableDataset_len_called = len(self.dataset) # type: ignore
317342
if self.batch_size is not None:
318343
from math import ceil
319344
if self.drop_last:
@@ -326,7 +351,7 @@ def __len__(self):
326351

327352

328353
class _BaseDataLoaderIter(object):
329-
def __init__(self, loader):
354+
def __init__(self, loader: DataLoader) -> None:
330355
self._dataset = loader.dataset
331356
self._dataset_kind = loader._dataset_kind
332357
self._IterableDataset_len_called = loader._IterableDataset_len_called
@@ -341,7 +366,7 @@ def __init__(self, loader):
341366
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
342367
self._num_yielded = 0
343368

344-
def __iter__(self):
369+
def __iter__(self) -> '_BaseDataLoaderIter':
345370
return self
346371

347372
def _next_index(self):
@@ -350,7 +375,7 @@ def _next_index(self):
350375
def _next_data(self):
351376
raise NotImplementedError
352377

353-
def __next__(self):
378+
def __next__(self) -> Any:
354379
data = self._next_data()
355380
self._num_yielded += 1
356381
if self._dataset_kind == _DatasetKind.Iterable and \
@@ -368,7 +393,7 @@ def __next__(self):
368393

369394
next = __next__ # Python 2 compatibility
370395

371-
def __len__(self):
396+
def __len__(self) -> int:
372397
return len(self._index_sampler)
373398

374399
def __getstate__(self):
@@ -690,7 +715,8 @@ def __init__(self, loader):
690715

691716
self._worker_init_fn = loader.worker_init_fn
692717
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
693-
self._worker_result_queue = multiprocessing_context.Queue()
718+
# No certainty which module multiprocessing_context is
719+
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore
694720
self._worker_pids_set = False
695721
self._shutdown = False
696722
self._send_idx = 0 # idx of the next task to be sent to workers
@@ -710,7 +736,8 @@ def __init__(self, loader):
710736
# (i.e., if kind != Iterable).
711737
self._workers_status = []
712738
for i in range(self._num_workers):
713-
index_queue = multiprocessing_context.Queue()
739+
# No certainty which module multiprocessing_context is
740+
index_queue = multiprocessing_context.Queue() # type: ignore
714741
# index_queue.cancel_join_thread()
715742
w = multiprocessing_context.Process(
716743
target=_utils.worker._worker_loop,
@@ -732,7 +759,9 @@ def __init__(self, loader):
732759

733760
if self._pin_memory:
734761
self._pin_memory_thread_done_event = threading.Event()
735-
self._data_queue = queue.Queue()
762+
763+
# Queue is not type-annotated
764+
self._data_queue = queue.Queue() # type: ignore
736765
pin_memory_thread = threading.Thread(
737766
target=_utils.pin_memory._pin_memory_loop,
738767
args=(self._worker_result_queue, self._data_queue,

torch/utils/data/dataloader.pyi

Lines changed: 0 additions & 46 deletions
This file was deleted.

0 commit comments

Comments
 (0)