File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed
torch/csrc/autograd/functions Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -55,7 +55,8 @@ auto BatchNormForward::apply(const variable_list& inputs) -> variable_list {
5555 bool use_cudnn = false ;
5656#ifdef WITH_CUDNN
5757 use_cudnn = (input.type ().isCuda ()
58- && input.type ().scalarType () != at::kHalf
58+ && (input.type ().scalarType () != at::kHalf
59+ || weight.type ().scalarType () == at::kFloat )
5960 && weight.defined () && bias.defined ()
6061 && input.size (0 ) <= 131070
6162 && cudnn_enabled && CUDNN_VERSION >= 5110L );
@@ -115,7 +116,8 @@ auto BatchNormBackward::apply(const variable_list& grad_outputs) -> variable_lis
115116 bool use_cudnn = false ;
116117#ifdef WITH_CUDNN
117118 use_cudnn = (input.type ().backend () == at::kCUDA
118- && input.type ().scalarType () != at::kHalf
119+ && (input.type ().scalarType () != at::kHalf
120+ || weight.type ().scalarType () == at::kFloat )
119121 && weight.defined () && bias.defined () && training
120122 && input.size (0 ) <= 131070
121123 && cudnn_enabled && CUDNN_VERSION >= 5110L );
You can’t perform that action at this time.
0 commit comments