Skip to content

Commit f522bde

Browse files
ngoldbaumfacebook-github-bot
authored andcommitted
Replace references to _DataLoaderIter with _BaseDataLoaderIter (#27105)
Summary: Back in April, malmaud added type annotations for `dataloader.py`. However, at about the same time, SsnL in #19228 replaced `_DataLoaderIter` with `_BaseDataLoaderIter` and two subclasses, `_SingleProcessDataLoaderIter`, and `_MultiProcessingDataLoaderIter`. However - probably because these changes happened in parallel at roughly the same time, the type stubs and several other references in the codebase were never updated to match this refactoring. I've gone ahead and done the updates to reflect the refactoring in #19228, which fixes the specific type stub/impelementation mismatch pointed out in #26673, although not the broader problem that pytorch doesn't have a test to make sure that the `.pyi` type stub files match the real API defined in `.py` files. Pull Request resolved: #27105 Differential Revision: D17813641 Pulled By: ezyang fbshipit-source-id: ed7ac025c8d6ad3f298dd073347ec83bb4b6600c
1 parent d571248 commit f522bde

File tree

8 files changed

+16
-14
lines changed

8 files changed

+16
-14
lines changed

torch/csrc/DataLoader.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ static PyObject *THPModule_setWorkerPIDs(PyObject *module, PyObject *args) {
155155
}
156156
int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
157157
if (worker_pids.find(key) != worker_pids.end()) {
158-
throw ValueError("_set_worker_pids should be called only once for each _DataLoaderIter.");
158+
throw ValueError("_set_worker_pids should be called only once for each _BaseDataLoaderIter.");
159159
}
160160
PyObject *child_pids = PyTuple_GET_ITEM(args, 1);
161161
if (!PyTuple_Check(child_pids)) {
@@ -182,7 +182,7 @@ static PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *loader_i
182182
int64_t key = THPUtils_unpackLong(loader_id);
183183
auto it = worker_pids.find(key);
184184
if (it == worker_pids.end()) {
185-
throw ValueError("Cannot find worker information for _DataLoaderIter with id %ld.", key);
185+
throw ValueError("Cannot find worker information for _BaseDataLoaderIter with id %ld.", key);
186186
}
187187
worker_pids.erase(it);
188188

torch/utils/data/_utils/collate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
r""""Contains definitions of the methods used by the _DataLoaderIter workers to
1+
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
22
collate samples fetched from dataset into Tensor(s).
33
44
These **needs** to be in global scope since Py2 doesn't support serializing

torch/utils/data/_utils/fetch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
r""""Contains definitions of the methods used by the _DataLoaderIter to fetch
1+
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch
22
data from an iterable-style or map-style dataset. This logic is shared in both
33
single- and multi-processing data loading.
44
"""

torch/utils/data/_utils/pin_memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
r""""Contains definitions of the methods used by the _DataLoaderIter to put
1+
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to put
22
fetched tensors into pinned memory.
33
44
These **needs** to be in global scope since Py2 doesn't support serializing

torch/utils/data/_utils/signal_handling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
our best effort to provide some error message to users when such unfortunate
1010
events happen.
1111
12-
When a _DataLoaderIter starts worker processes, their pids are registered in a
13-
defined in `DataLoader.cpp`: id(_DataLoaderIter) => Collection[ Worker pids ]
12+
When a _BaseDataLoaderIter starts worker processes, their pids are registered in a
13+
defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ]
1414
via `_set_worker_pids`.
1515
1616
When an error happens in a worker process, the main process received a SIGCHLD,

torch/utils/data/_utils/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
r""""Contains definitions of the methods used by the _DataLoaderIter workers.
1+
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
22
33
These **needs** to be in global scope since Py2 doesn't support serializing
44
static methods.

torch/utils/data/dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes.
1+
r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
22
33
To support these two classes, in `./_utils` we define many utility methods and
44
functions to be run in multiprocessing. E.g., the data loading worker loop is

torch/utils/data/dataloader.pyi

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ class DataLoader(Generic[T_co]):
2828
worker_init_fn: _worker_init_fn_t=...) -> None: ...
2929

3030
def __len__(self) -> int: ...
31-
# We quote '_DataLoaderIter' since it isn't defined yet and the definition can't be moved up since
32-
# '_DataLoaderIter' references 'DataLoader'. Pending updates of PEP 484 will fix this.
33-
def __iter__(self) -> '_DataLoaderIter':...
31+
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
32+
# since '_BaseDataLoaderIter' references 'DataLoader'. In mypy 0.720 and newer a new semantic
33+
# analyzer is used that obviates the need for this but we leave the quoting in to support older
34+
# versions of mypy
35+
def __iter__(self) -> '_BaseDataLoaderIter':...
3436

35-
class _DataLoaderIter:
37+
class _BaseDataLoaderIter:
3638
def __init__(self, loader: DataLoader) -> None:...
3739
def __len__(self) -> int: ...
38-
def __iter__(self) -> _DataLoaderIter: ...
40+
def __iter__(self) -> _BaseDataLoaderIter: ...
3941
def __next__(self) -> Any: ...

0 commit comments

Comments
 (0)