Skip to content

Commit 561bc09

Browse files
xw285cornellfacebook-github-bot
authored andcommitted
Remove CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode for accuracy (#13844)
Summary: Pull Request resolved: #13844 In S163230, we've found CuDNN 7 upgrade causes accuracy drop in training convolution network such as ResNeXt-101 (~0% accuracy), and video R(2+1)D (65 --> 63%). We've fixed this in Caffe2 D9601217, and we should do the same to ATen as well. Reviewed By: ezyang Differential Revision: D13025486 fbshipit-source-id: 04f4f0d9af6287b0400ca1842fb2cdac1f8cdb70
1 parent 0d2762e commit 561bc09

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

aten/src/ATen/native/cudnn/BatchNorm.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,10 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm(
8888
mode = CUDNN_BATCHNORM_PER_ACTIVATION;
8989
} else {
9090
mode = CUDNN_BATCHNORM_SPATIAL;
91-
#if CUDNN_VERSION >= 7003
92-
if(training)
93-
mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
94-
#endif
91+
// TODO: The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was
92+
// introduced in CuDNN 7 for performance optimization, but it results in
93+
// accuracy losses in convolution models such as ResNeXt-101 and
94+
// video R(2+1)D. We will fall back to the normal CUDNN_BATCHNORM_SPATIAL
9595
}
9696

9797
auto output_t = at::empty(input->sizes(), input->options());
@@ -183,11 +183,11 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
183183
if (input->dim() == 2) {
184184
mode = CUDNN_BATCHNORM_PER_ACTIVATION;
185185
} else {
186-
#if CUDNN_VERSION >= 7003
187-
mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
188-
#else
186+
// TODO: The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was
187+
// introduced in CuDNN 7 for performance optimization, but it results in
188+
// accuracy losses in convolution models such as ResNeXt-101 and
189+
// video R(2+1)D. We will fall back to the normal CUDNN_BATCHNORM_SPATIAL
189190
mode = CUDNN_BATCHNORM_SPATIAL;
190-
#endif
191191
}
192192

193193
auto grad_input_t = at::empty(input->sizes(), input->options());

0 commit comments

Comments
 (0)