Skip to content

Commit 18a866a

Browse files
alykhantejaniapaszke
authored andcommitted
Add random_split to torch.utils.data.dataset (#4435)
1 parent 57f9db9 commit 18a866a

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

test/test_dataloader.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,34 @@
77
import unittest
88
from torch import multiprocessing
99
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
10+
from torch.utils.data.dataset import random_split
1011
from torch.utils.data.dataloader import default_collate
1112
from common import TestCase, run_tests, TEST_NUMPY, IS_WINDOWS
1213
from common_nn import TEST_CUDA
1314

1415

16+
class TestDatasetRandomSplit(TestCase):
17+
def test_lengths_must_equal_datset_size(self):
18+
with self.assertRaises(ValueError):
19+
random_split([1, 2, 3, 4], [1, 2])
20+
21+
def test_splits_have_correct_size(self):
22+
splits = random_split([1, 2, 3, 4, 5, 6], [2, 4])
23+
self.assertEqual(len(splits), 2)
24+
self.assertEqual(len(splits[0]), 2)
25+
self.assertEqual(len(splits[1]), 4)
26+
27+
def test_splits_are_mutually_exclusive(self):
28+
data = [5, 2, 3, 4, 1, 6]
29+
splits = random_split(data, [2, 4])
30+
all_values = []
31+
all_values.extend(list(splits[0]))
32+
all_values.extend(list(splits[1]))
33+
data.sort()
34+
all_values.sort()
35+
self.assertListEqual(data, all_values)
36+
37+
1538
class TestTensorDataset(TestCase):
1639

1740
def test_len(self):

torch/utils/data/dataset.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import bisect
22
import warnings
33

4+
from torch._utils import _accumulate
5+
from torch import randperm
6+
47

58
class Dataset(object):
69
"""An abstract class representing a Dataset.
@@ -85,3 +88,31 @@ def cummulative_sizes(self):
8588
warnings.warn("cummulative_sizes attribute is renamed to "
8689
"cumulative_sizes", DeprecationWarning, stacklevel=2)
8790
return self.cumulative_sizes
91+
92+
93+
class Subset(Dataset):
94+
def __init__(self, dataset, indices):
95+
self.dataset = dataset
96+
self.indices = indices
97+
98+
def __getitem__(self, idx):
99+
return self.dataset[self.indices[idx]]
100+
101+
def __len__(self):
102+
return len(self.indices)
103+
104+
105+
def random_split(dataset, lengths):
106+
"""
107+
Randomly split a dataset into non-overlapping new datasets of given lengths
108+
ds
109+
110+
Arguments:
111+
dataset (Dataset): Dataset to be split
112+
lengths (iterable): lengths of splits to be produced
113+
"""
114+
if sum(lengths) != len(dataset):
115+
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
116+
117+
indices = randperm(sum(lengths))
118+
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]

0 commit comments

Comments
 (0)