Skip to content

Commit 6493e0d

Browse files
committed
improve conv input shape check
1 parent 2821d96 commit 6493e0d

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

aten/src/ATen/native/Convolution.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,15 @@ auto ConvParams::is_depthwise(
141141
static void check_input_shape_forward(const at::Tensor& input,
142142
const at::Tensor& weight, const at::Tensor& bias,
143143
int64_t groups, bool transposed) {
144-
int k = input.ndimension();
144+
int64_t k = input.ndimension();
145+
int64_t weight_dim = weight.ndimension();
145146

146-
if (weight.ndimension() != k) {
147-
std::stringstream ss;
148-
ss << "Expected " << k << "-dimensional input for " << k
149-
<< "-dimensional weight " << weight.sizes() << ", but got input of size "
150-
<< input.sizes() << " instead";
151-
throw std::runtime_error(ss.str());
147+
if (weight_dim != k) {
148+
std::stringstream ss;
149+
ss << "Expected " << weight_dim << "-dimensional input for " << weight_dim
150+
<< "-dimensional weight " << weight.sizes() << ", but got input of size "
151+
<< input.sizes() << " instead";
152+
throw std::runtime_error(ss.str());
152153
}
153154
if (weight.size(0) < groups) {
154155
std::stringstream ss;
@@ -266,10 +267,10 @@ at::Tensor convolution(
266267
}
267268

268269
static inline std::vector<int64_t> convolution_expand_param_if_needed(
269-
IntList list_param, const char *param_name, size_t expected_dim) {
270+
IntList list_param, const char *param_name, int64_t expected_dim) {
270271
if (list_param.size() == 1) {
271272
return std::vector<int64_t>(expected_dim, list_param[0]);
272-
} else if (list_param.size() != expected_dim) {
273+
} else if ((int64_t) list_param.size() != expected_dim) {
273274
std::ostringstream ss;
274275
ss << "expected " << param_name << " to be a single integer value or a "
275276
<< "list of " << expected_dim << " values to match the convolution "

torch/utils/data/dataloader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,9 @@ def _shutdown_workers(self):
313313
# done_event should be sufficient to exit worker_manager_thread, but
314314
# be safe here and put another None
315315
self.worker_result_queue.put(None)
316-
if self.worker_pids_set:
317-
_remove_worker_pids(id(self))
318-
self.worker_pids_set = False
316+
if self.worker_pids_set:
317+
_remove_worker_pids(id(self))
318+
self.worker_pids_set = False
319319

320320
def __del__(self):
321321
if self.num_workers > 0:

0 commit comments

Comments
 (0)