Skip to content

Commit d7bd3b9

Browse files
ngimelsoumith
authored andcommitted
allow cudnn for fp16 batch norm (#4021)
1 parent 7763c6f commit d7bd3b9

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torch/csrc/autograd/functions/batch_normalization.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ auto BatchNormForward::apply(const variable_list& inputs) -> variable_list {
5555
bool use_cudnn = false;
5656
#ifdef WITH_CUDNN
5757
use_cudnn = (input.type().isCuda()
58-
&& input.type().scalarType() != at::kHalf
58+
&& (input.type().scalarType() != at::kHalf
59+
|| weight.type().scalarType() == at::kFloat)
5960
&& weight.defined() && bias.defined()
6061
&& input.size(0) <= 131070
6162
&& cudnn_enabled && CUDNN_VERSION >= 5110L);
@@ -115,7 +116,8 @@ auto BatchNormBackward::apply(const variable_list& grad_outputs) -> variable_lis
115116
bool use_cudnn = false;
116117
#ifdef WITH_CUDNN
117118
use_cudnn = (input.type().backend() == at::kCUDA
118-
&& input.type().scalarType() != at::kHalf
119+
&& (input.type().scalarType() != at::kHalf
120+
|| weight.type().scalarType() == at::kFloat)
119121
&& weight.defined() && bias.defined() && training
120122
&& input.size(0) <= 131070
121123
&& cudnn_enabled && CUDNN_VERSION >= 5110L);

0 commit comments

Comments
 (0)