Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions test/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,34 @@
import unittest
from torch import multiprocessing
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
from torch.utils.data.dataset import random_split
from torch.utils.data.dataloader import default_collate
from common import TestCase, run_tests, TEST_NUMPY, IS_WINDOWS
from common_nn import TEST_CUDA


class TestDatasetRandomSplit(TestCase):
def test_lengths_must_equal_datset_size(self):
with self.assertRaises(ValueError):
random_split([1, 2, 3, 4], [1, 2])

def test_splits_have_correct_size(self):
splits = random_split([1, 2, 3, 4, 5, 6], [2, 4])
self.assertEqual(len(splits), 2)
self.assertEqual(len(splits[0]), 2)
self.assertEqual(len(splits[1]), 4)

def test_splits_are_mutually_exclusive(self):
data = [5, 2, 3, 4, 1, 6]
splits = random_split(data, [2, 4])
all_values = []
all_values.extend(list(splits[0]))
all_values.extend(list(splits[1]))
data.sort()
all_values.sort()
self.assertListEqual(data, all_values)


class TestTensorDataset(TestCase):

def test_len(self):
Expand Down
31 changes: 31 additions & 0 deletions torch/utils/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import bisect
import warnings

from torch._utils import _accumulate
from torch import randperm


class Dataset(object):
"""An abstract class representing a Dataset.
Expand Down Expand Up @@ -85,3 +88,31 @@ def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes


class Subset(Dataset):
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices

def __getitem__(self, idx):
return self.dataset[self.indices[idx]]

def __len__(self):
return len(self.indices)


def random_split(dataset, lengths):
"""
Randomly split a dataset into non-overlapping new datasets of given lengths
ds

Arguments:
dataset (Dataset): Dataset to be split
lengths (iterable): lengths of splits to be produced
"""
if sum(lengths) != len(dataset):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

indices = randperm(sum(lengths))
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]