Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,6 @@ ignore_errors = True
[mypy-torch.utils.data._utils.worker]
ignore_errors = True

[mypy-torch.utils.data.dataset]
ignore_errors = True

[mypy-torch.utils.data.distributed]
ignore_errors = True

Expand Down
11 changes: 11 additions & 0 deletions test/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,16 @@ def test_error_in_init(self):
with self.assertRaisesRegex(RuntimeError, 'Error in worker_init_fn'):
list(iter(loader))

def test_typing(self):
from typing import List
# Make sure there is no TypeError

class SomeDatasetClass(Dataset[List[torch.Tensor]]):
pass

def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]:
pass

@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
@unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows")
def test_fd_limit_exceeded(self):
Expand Down Expand Up @@ -2019,5 +2029,6 @@ def test_set_affinity_in_worker_init(self):
self.assertEqual(sample, [2])



if __name__ == '__main__':
run_tests()
79 changes: 54 additions & 25 deletions torch/utils/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,35 @@
import threading
import itertools
import warnings
from typing import Any, Callable, TypeVar, Generic, Sequence, List, Optional

import multiprocessing as python_multiprocessing
import torch
import torch.multiprocessing as multiprocessing
from torch._utils import ExceptionWrapper
from torch._six import queue, string_classes

from . import IterableDataset, Sampler, SequentialSampler, RandomSampler, BatchSampler
from . import IterableDataset, Sampler, SequentialSampler, RandomSampler, BatchSampler, Dataset
from . import _utils

T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
_worker_init_fn_t = Callable[[int], None]

# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
# See https://github.com/python/mypy/issues/3737.
_collate_fn_t = Callable[[List[T]], Any]

get_worker_info = _utils.worker.get_worker_info

# This function used to be defined in this file. However, it was moved to
# _utils/collate.py. Although it is rather hard to access this from user land
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
# probably is user code out there using it. This aliasing maintains BC in this
# aspect.
default_collate = _utils.collate.default_collate
default_collate: _collate_fn_t = _utils.collate.default_collate

get_worker_info = _utils.worker.get_worker_info

class _DatasetKind(object):
Map = 0
Expand Down Expand Up @@ -57,7 +66,7 @@ def __iter__(self):
yield None


class DataLoader(object):
class DataLoader(Generic[T_co]):
r"""
Data loader. Combines a dataset and a sampler, and provides an iterable over
the given dataset.
Expand Down Expand Up @@ -116,15 +125,24 @@ class DataLoader(object):
details on these two types of datasets and how
:class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_.
"""
dataset: Dataset[T_co]
batch_size: Optional[int]
num_workers: int
pin_memory: bool
drop_last: bool
timeout: float
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to list sampler: Sampler here? Some internal code attempts to access dataloader.sampler and the type appears was inferred to be Optional[Sampler]. However, after __init__, we're sure that sampler is a Sampler

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I verified that adding sampler: Sampler makes the problem go away, but I'm not sure if there was a reason why it wasn't here)

sampler: Sampler

__initialized = False

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None,
generator=None):
torch._C._log_api_usage_once("python.data_loader")
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0, collate_fn: _collate_fn_t = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
multiprocessing_context=None, generator=None):
torch._C._log_api_usage_once("python.data_loader") # type: ignore

if num_workers < 0:
raise ValueError('num_workers option should be non-negative; '
Expand All @@ -146,7 +164,7 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
# after spending time fixing the custom sampler errors.
if isinstance(dataset, IterableDataset):
self._dataset_kind = _DatasetKind.Iterable
# NOTE [ Custom Samplers and `IterableDataset` ]
# NOTE [ Custom Samplers and IterableDataset ]
#
# `IterableDataset` does not support custom `batch_sampler` or
# `sampler` since the key is irrelevant (unless we support
Expand Down Expand Up @@ -212,7 +230,9 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset, generator=generator)
# Cannot statically verify that dataset is Sized
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
sampler = RandomSampler(dataset, generator=generator) # type: ignore
else:
sampler = SequentialSampler(dataset)

Expand Down Expand Up @@ -253,9 +273,10 @@ def multiprocessing_context(self, multiprocessing_context):
if multiprocessing_context not in valid_start_methods:
raise ValueError(
('multiprocessing_context option '
'should specify a valid start method in {}, but got '
'multiprocessing_context={}').format(valid_start_methods, multiprocessing_context))
multiprocessing_context = multiprocessing.get_context(multiprocessing_context)
'should specify a valid start method in {!r}, but got '
'multiprocessing_context={!r}').format(valid_start_methods, multiprocessing_context))
# error: Argument 1 to "get_context" has incompatible type "Union[str, bytes]"; expected "str" [arg-type]
multiprocessing_context = multiprocessing.get_context(multiprocessing_context) # type: ignore

if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
raise TypeError(('multiprocessing_context option should be a valid context '
Expand All @@ -275,7 +296,9 @@ def __setattr__(self, attr, val):

super(DataLoader, self).__setattr__(attr, val)

def __iter__(self):
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
# since '_BaseDataLoaderIter' references 'DataLoader'.
def __iter__(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
Expand All @@ -297,7 +320,7 @@ def _index_sampler(self):
else:
return self.sampler

def __len__(self):
def __len__(self) -> int:
if self._dataset_kind == _DatasetKind.Iterable:
# NOTE [ IterableDataset and __len__ ]
#
Expand All @@ -313,7 +336,9 @@ def __len__(self):
# To provide a further warning, we track if `__len__` was called on the
# `DataLoader`, save the returned value in `self._len_called`, and warn
# if the iterator ends up yielding more than this number of samples.
length = self._IterableDataset_len_called = len(self.dataset)

# Cannot statically verify that dataset is Sized
length = self._IterableDataset_len_called = len(self.dataset) # type: ignore
if self.batch_size is not None:
from math import ceil
if self.drop_last:
Expand All @@ -326,7 +351,7 @@ def __len__(self):


class _BaseDataLoaderIter(object):
def __init__(self, loader):
def __init__(self, loader: DataLoader) -> None:
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._IterableDataset_len_called = loader._IterableDataset_len_called
Expand All @@ -341,7 +366,7 @@ def __init__(self, loader):
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
self._num_yielded = 0

def __iter__(self):
def __iter__(self) -> '_BaseDataLoaderIter':
return self

def _next_index(self):
Expand All @@ -350,7 +375,7 @@ def _next_index(self):
def _next_data(self):
raise NotImplementedError

def __next__(self):
def __next__(self) -> Any:
data = self._next_data()
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
Expand All @@ -368,7 +393,7 @@ def __next__(self):

next = __next__ # Python 2 compatibility

def __len__(self):
def __len__(self) -> int:
return len(self._index_sampler)

def __getstate__(self):
Expand Down Expand Up @@ -690,7 +715,8 @@ def __init__(self, loader):

self._worker_init_fn = loader.worker_init_fn
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
self._worker_result_queue = multiprocessing_context.Queue()
# No certainty which module multiprocessing_context is
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore
self._worker_pids_set = False
self._shutdown = False
self._send_idx = 0 # idx of the next task to be sent to workers
Expand All @@ -710,7 +736,8 @@ def __init__(self, loader):
# (i.e., if kind != Iterable).
self._workers_status = []
for i in range(self._num_workers):
index_queue = multiprocessing_context.Queue()
# No certainty which module multiprocessing_context is
index_queue = multiprocessing_context.Queue() # type: ignore
# index_queue.cancel_join_thread()
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop,
Expand All @@ -732,7 +759,9 @@ def __init__(self, loader):

if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event()
self._data_queue = queue.Queue()

# Queue is not type-annotated
self._data_queue = queue.Queue() # type: ignore
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(self._worker_result_queue, self._data_queue,
Expand Down
46 changes: 0 additions & 46 deletions torch/utils/data/dataloader.pyi

This file was deleted.

Loading