Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions aten/src/ATen/native/cudnn/BatchNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm(
mode = CUDNN_BATCHNORM_PER_ACTIVATION;
} else {
mode = CUDNN_BATCHNORM_SPATIAL;
#if CUDNN_VERSION >= 7003
if(training)
mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#endif
// TODO: The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was
// introduced in CuDNN 7 for performance optimization, but it results in
// accuracy losses in convolution models such as ResNeXt-101 and
// video R(2+1)D. We will fall back to the normal CUDNN_BATCHNORM_SPATIAL
}

auto output_t = at::empty(input->sizes(), input->options());
Expand Down Expand Up @@ -183,11 +183,11 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
if (input->dim() == 2) {
mode = CUDNN_BATCHNORM_PER_ACTIVATION;
} else {
#if CUDNN_VERSION >= 7003
mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
// TODO: The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was
// introduced in CuDNN 7 for performance optimization, but it results in
// accuracy losses in convolution models such as ResNeXt-101 and
// video R(2+1)D. We will fall back to the normal CUDNN_BATCHNORM_SPATIAL
mode = CUDNN_BATCHNORM_SPATIAL;
#endif
}

auto grad_input_t = at::empty(input->sizes(), input->options());
Expand Down