|
1 | 1 | import torch |
| 2 | +from torch._six import int_classes as _int_classes |
2 | 3 |
|
3 | 4 |
|
4 | 5 | class Sampler(object): |
@@ -82,7 +83,14 @@ class WeightedRandomSampler(Sampler): |
82 | 83 | """ |
83 | 84 |
|
84 | 85 | def __init__(self, weights, num_samples, replacement=True): |
85 | | - self.weights = torch.DoubleTensor(weights) |
| 86 | + if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \ |
| 87 | + num_samples <= 0: |
| 88 | + raise ValueError("num_samples should be a positive integeral " |
| 89 | + "value, but got num_samples={}".format(num_samples)) |
| 90 | + if not isinstance(replacement, bool): |
| 91 | + raise ValueError("replacement should be a boolean value, but got" |
| 92 | + "got replacement={}".format(replacement)) |
| 93 | + self.weights = torch.tensor(weights, dtype=torch.double) |
86 | 94 | self.num_samples = num_samples |
87 | 95 | self.replacement = replacement |
88 | 96 |
|
@@ -110,6 +118,17 @@ class BatchSampler(object): |
110 | 118 | """ |
111 | 119 |
|
112 | 120 | def __init__(self, sampler, batch_size, drop_last): |
| 121 | + if not isinstance(sampler, Sampler): |
| 122 | + raise ValueError("sampler should be an instance of " |
| 123 | + "torch.utils.data.Sampler, but got sampler={}" |
| 124 | + .format(sampler)) |
| 125 | + if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ |
| 126 | + batch_size <= 0: |
| 127 | + raise ValueError("batch_size should be a positive integeral value, " |
| 128 | + "but got batch_size={}".format(batch_size)) |
| 129 | + if not isinstance(drop_last, bool): |
| 130 | + raise ValueError("drop_last should be a boolean value, but got" |
| 131 | + "got drop_last={}".format(drop_last)) |
113 | 132 | self.sampler = sampler |
114 | 133 | self.batch_size = batch_size |
115 | 134 | self.drop_last = drop_last |
|
0 commit comments