Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ torch.utils.data
.. autoclass:: Dataset
.. autoclass:: TensorDataset
.. autoclass:: ConcatDataset
.. autoclass:: Subset
.. autoclass:: DataLoader
.. autofunction:: torch.utils.data.random_split
.. autoclass:: torch.utils.data.Sampler
.. autoclass:: torch.utils.data.SequentialSampler
.. autoclass:: torch.utils.data.RandomSampler
.. autoclass:: torch.utils.data.SubsetRandomSampler
.. autoclass:: torch.utils.data.WeightedRandomSampler
.. autoclass:: torch.utils.data.distributed.DistributedSampler
.. autoclass:: torch.utils.data.distributed.BatchSampler
.. autofunction:: torch.utils.data.dataset.random_split
2 changes: 1 addition & 1 deletion torch/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

from .sampler import Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler
from .dataset import Dataset, TensorDataset, ConcatDataset
from .dataset import Dataset, TensorDataset, ConcatDataset, Subset, random_split
from .dataloader import DataLoader
14 changes: 10 additions & 4 deletions torch/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ConcatDataset(Dataset):
on-the-fly manner.

Arguments:
datasets (iterable): List of datasets to be concatenated
datasets (sequence): List of datasets to be concatenated
"""

@staticmethod
Expand Down Expand Up @@ -88,6 +88,13 @@ def cummulative_sizes(self):


class Subset(Dataset):
"""
Subset of a dataset at specified indices.

Arguments:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices
Expand All @@ -101,12 +108,11 @@ def __len__(self):

def random_split(dataset, lengths):
"""
Randomly split a dataset into non-overlapping new datasets of given lengths
ds
Randomly split a dataset into non-overlapping new datasets of given lengths.

Arguments:
dataset (Dataset): Dataset to be split
lengths (iterable): lengths of splits to be produced
lengths (sequence): lengths of splits to be produced
"""
if sum(lengths) != len(dataset):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
Expand Down
4 changes: 2 additions & 2 deletions torch/utils/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class SubsetRandomSampler(Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.

Arguments:
indices (list): a list of indices
indices (sequence): a sequence of indices
"""

def __init__(self, indices):
Expand All @@ -75,7 +75,7 @@ class WeightedRandomSampler(Sampler):
r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).

Arguments:
weights (list) : a list of weights, not necessary summing up to one
weights (sequence) : a sequence of weights, not necessary summing up to one
num_samples (int): number of samples to draw
replacement (bool): if ``True``, samples are drawn with replacement.
If not, they are drawn without replacement, which means that when a
Expand Down