Skip to content

Commit d7c32df

Browse files
zasdfgbnmapaszke
authored andcommitted
move Subset, random_split to data, use sequence at some places. (#7816)
1 parent ce1a65b commit d7c32df

File tree

4 files changed

+15
-8
lines changed

4 files changed

+15
-8
lines changed

docs/source/data.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@ torch.utils.data
55
.. autoclass:: Dataset
66
.. autoclass:: TensorDataset
77
.. autoclass:: ConcatDataset
8+
.. autoclass:: Subset
89
.. autoclass:: DataLoader
10+
.. autofunction:: torch.utils.data.random_split
911
.. autoclass:: torch.utils.data.Sampler
1012
.. autoclass:: torch.utils.data.SequentialSampler
1113
.. autoclass:: torch.utils.data.RandomSampler
1214
.. autoclass:: torch.utils.data.SubsetRandomSampler
1315
.. autoclass:: torch.utils.data.WeightedRandomSampler
1416
.. autoclass:: torch.utils.data.distributed.DistributedSampler
1517
.. autoclass:: torch.utils.data.distributed.BatchSampler
16-
.. autofunction:: torch.utils.data.dataset.random_split

torch/utils/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11

22
from .sampler import Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler
3-
from .dataset import Dataset, TensorDataset, ConcatDataset
3+
from .dataset import Dataset, TensorDataset, ConcatDataset, Subset, random_split
44
from .dataloader import DataLoader

torch/utils/data/dataset.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class ConcatDataset(Dataset):
5151
on-the-fly manner.
5252
5353
Arguments:
54-
datasets (iterable): List of datasets to be concatenated
54+
datasets (sequence): List of datasets to be concatenated
5555
"""
5656

5757
@staticmethod
@@ -88,6 +88,13 @@ def cummulative_sizes(self):
8888

8989

9090
class Subset(Dataset):
91+
"""
92+
Subset of a dataset at specified indices.
93+
94+
Arguments:
95+
dataset (Dataset): The whole Dataset
96+
indices (sequence): Indices in the whole set selected for subset
97+
"""
9198
def __init__(self, dataset, indices):
9299
self.dataset = dataset
93100
self.indices = indices
@@ -101,12 +108,11 @@ def __len__(self):
101108

102109
def random_split(dataset, lengths):
103110
"""
104-
Randomly split a dataset into non-overlapping new datasets of given lengths
105-
ds
111+
Randomly split a dataset into non-overlapping new datasets of given lengths.
106112
107113
Arguments:
108114
dataset (Dataset): Dataset to be split
109-
lengths (iterable): lengths of splits to be produced
115+
lengths (sequence): lengths of splits to be produced
110116
"""
111117
if sum(lengths) != len(dataset):
112118
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

torch/utils/data/sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class SubsetRandomSampler(Sampler):
5858
r"""Samples elements randomly from a given list of indices, without replacement.
5959
6060
Arguments:
61-
indices (list): a list of indices
61+
indices (sequence): a sequence of indices
6262
"""
6363

6464
def __init__(self, indices):
@@ -75,7 +75,7 @@ class WeightedRandomSampler(Sampler):
7575
r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
7676
7777
Arguments:
78-
weights (list) : a list of weights, not necessary summing up to one
78+
weights (sequence) : a sequence of weights, not necessary summing up to one
7979
num_samples (int): number of samples to draw
8080
replacement (bool): if ``True``, samples are drawn with replacement.
8181
If not, they are drawn without replacement, which means that when a

0 commit comments

Comments
 (0)