Skip to content

Commit da32bf4

Browse files
rgommersfacebook-github-bot
authored andcommitted
Move type annotations for remaining torch.utils stub files inline (#43406)
Summary: Pull Request resolved: #43406 Reviewed By: mruberry Differential Revision: D23319736 Pulled By: malfet fbshipit-source-id: e25fbb49f27aa4893590b022441303d6d98263a9
1 parent 6022097 commit da32bf4

File tree

7 files changed

+32
-41
lines changed

7 files changed

+32
-41
lines changed

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ ignore_errors = True
227227
[mypy-torch.utils.data._utils.worker]
228228
ignore_errors = True
229229

230+
[mypy-torch.utils.data.distributed]
231+
ignore_errors = True
232+
230233
[mypy-torch.nn.utils.prune]
231234
ignore_errors = True
232235

torch/utils/data/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
from .sampler import Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler
2-
from .distributed import DistributedSampler
32
from .dataset import Dataset, IterableDataset, TensorDataset, ConcatDataset, ChainDataset, Subset, random_split
3+
from .distributed import DistributedSampler
44
from .dataloader import DataLoader, _DatasetKind, get_worker_info
5+
6+
7+
__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler',
8+
'SubsetRandomSampler', 'WeightedRandomSampler', 'BatchSampler'
9+
'DistributedSampler' 'Dataset', 'IterableDataset', 'TensorDataset',
10+
'ConcatDataset', 'ChainDataset', 'Subset', 'random_split'
11+
'DataLoader', '_DatasetKind', 'get_worker_info']

torch/utils/data/__init__.pyi

Lines changed: 0 additions & 7 deletions
This file was deleted.

torch/utils/data/distributed.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import math
2+
from typing import TypeVar, Optional, Iterator
3+
24
import torch
3-
from . import Sampler
5+
from . import Sampler, Dataset
46
import torch.distributed as dist
57

68

7-
class DistributedSampler(Sampler):
9+
T_co = TypeVar('T_co', covariant=True)
10+
11+
12+
class DistributedSampler(Sampler[T_co]):
813
r"""Sampler that restricts data loading to a subset of the dataset.
914
1015
It is especially useful in conjunction with
@@ -51,7 +56,9 @@ class DistributedSampler(Sampler):
5156
... train(loader)
5257
"""
5358

54-
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False):
59+
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
60+
rank: Optional[int] = None, shuffle: bool = True,
61+
seed: int = 0, drop_last: bool = False) -> None:
5562
if num_replicas is None:
5663
if not dist.is_available():
5764
raise RuntimeError("Requires distributed package to be available")
@@ -80,7 +87,7 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0,
8087
self.shuffle = shuffle
8188
self.seed = seed
8289

83-
def __iter__(self):
90+
def __iter__(self) -> Iterator[T_co]:
8491
if self.shuffle:
8592
# deterministically shuffle based on epoch and seed
8693
g = torch.Generator()
@@ -89,7 +96,6 @@ def __iter__(self):
8996
else:
9097
indices = list(range(len(self.dataset)))
9198

92-
9399
if not self.drop_last:
94100
# add extra samples to make it evenly divisible
95101
indices += indices[:(self.total_size - len(indices))]
@@ -104,10 +110,10 @@ def __iter__(self):
104110

105111
return iter(indices)
106112

107-
def __len__(self):
113+
def __len__(self) -> int:
108114
return self.num_samples
109115

110-
def set_epoch(self, epoch):
116+
def set_epoch(self, epoch: int) -> None:
111117
r"""
112118
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
113119
use a different random ordering for each epoch. Otherwise, the next iteration of this

torch/utils/data/distributed.pyi

Lines changed: 0 additions & 9 deletions
This file was deleted.

torch/utils/hooks.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,29 @@
22
from collections import OrderedDict
33
import weakref
44
import warnings
5+
from typing import Any
56

67

78
class RemovableHandle(object):
89
"""A handle which provides the capability to remove a hook."""
910

10-
next_id = 0
11+
id: int
12+
next_id: int = 0
1113

12-
def __init__(self, hooks_dict):
14+
def __init__(self, hooks_dict: Any) -> None:
1315
self.hooks_dict_ref = weakref.ref(hooks_dict)
1416
self.id = RemovableHandle.next_id
1517
RemovableHandle.next_id += 1
1618

17-
def remove(self):
19+
def remove(self) -> None:
1820
hooks_dict = self.hooks_dict_ref()
1921
if hooks_dict is not None and self.id in hooks_dict:
2022
del hooks_dict[self.id]
2123

2224
def __getstate__(self):
2325
return (self.hooks_dict_ref(), self.id)
2426

25-
def __setstate__(self, state):
27+
def __setstate__(self, state) -> None:
2628
if state[0] is None:
2729
# create a dead reference
2830
self.hooks_dict_ref = weakref.ref(OrderedDict())
@@ -31,10 +33,10 @@ def __setstate__(self, state):
3133
self.id = state[1]
3234
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
3335

34-
def __enter__(self):
36+
def __enter__(self) -> 'RemovableHandle':
3537
return self
3638

37-
def __exit__(self, type, value, tb):
39+
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
3840
self.remove()
3941

4042

torch/utils/hooks.pyi

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)