Skip to content

Commit efc91d8

Browse files
ssnlezyang
authored andcommitted
Add arg checks in torch.utils.data.Sampler classes (#6249)
Fixes #6168 * add arg checks in torch.utils.data.Sampler * add check for positive-ness
1 parent 0016dad commit efc91d8

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

torch/utils/data/sampler.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from torch._six import int_classes as _int_classes
23

34

45
class Sampler(object):
@@ -82,7 +83,14 @@ class WeightedRandomSampler(Sampler):
8283
"""
8384

8485
def __init__(self, weights, num_samples, replacement=True):
85-
self.weights = torch.DoubleTensor(weights)
86+
if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
87+
num_samples <= 0:
88+
raise ValueError("num_samples should be a positive integeral "
89+
"value, but got num_samples={}".format(num_samples))
90+
if not isinstance(replacement, bool):
91+
raise ValueError("replacement should be a boolean value, but got"
92+
"got replacement={}".format(replacement))
93+
self.weights = torch.tensor(weights, dtype=torch.double)
8694
self.num_samples = num_samples
8795
self.replacement = replacement
8896

@@ -110,6 +118,17 @@ class BatchSampler(object):
110118
"""
111119

112120
def __init__(self, sampler, batch_size, drop_last):
121+
if not isinstance(sampler, Sampler):
122+
raise ValueError("sampler should be an instance of "
123+
"torch.utils.data.Sampler, but got sampler={}"
124+
.format(sampler))
125+
if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
126+
batch_size <= 0:
127+
raise ValueError("batch_size should be a positive integeral value, "
128+
"but got batch_size={}".format(batch_size))
129+
if not isinstance(drop_last, bool):
130+
raise ValueError("drop_last should be a boolean value, but got"
131+
"got drop_last={}".format(drop_last))
113132
self.sampler = sampler
114133
self.batch_size = batch_size
115134
self.drop_last = drop_last

0 commit comments

Comments
 (0)