Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions aten/src/ATen/native/mkldnn/Normalization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include <ATen/ATen.h>
#include <ATen/Config.h>
#include <ATen/NativeFunctions.h>
#include <tuple>

#if !AT_MKLDNN_ENABLED()

namespace at {
namespace native {

std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(
const Tensor& self,
const Tensor& weight,
const Tensor& bias,
const Tensor& running_mean,
const Tensor& running_var,
bool train,
double momentum,
double eps) {
AT_ERROR("mkldnn_batch_norm: ATen not compiled with MKLDNN support");
}

} // namespace native
} // namespace at

#else // AT_MKLDNN_EBABLED

#include <ATen/native/mkldnn/MKLDNNCommon.h>

namespace at {
namespace native {

std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const Tensor& running_mean,
const Tensor& running_var,
bool train,
double momentum,
double eps) {
ideep::tensor& x = itensor_from_mkldnn(input);
ideep::tensor& w = itensor_from_mkldnn(weight);
ideep::tensor& b = itensor_from_mkldnn(bias);
ideep::tensor& m = itensor_from_mkldnn(running_mean);
ideep::tensor& v = itensor_from_mkldnn(running_var);

ideep::tensor y;

if (train) {
Copy link
Collaborator

@XiaobingSuper XiaobingSuper Apr 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bddppq , I think there should check whether running status is used for training and infernce, i.e. running_mean and running_var are defined or not. Another suggestion is that mkldnn only support batchnorm2d and batchnorm3d, so you shoud make some checks when call mkldnn, perhaps you can see the code as reference. thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would mkldnn throw with nice message that it's not supported? in this case the input is already mkldnn tensor so it's better to fail if something is not supported

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dzhulgakov I'm afraid not. :-( How about we add an assertion here to guarantee 2d or 3d batchnorm here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This training support is incomplete. Derivatives.yaml calls native_batch_norm_backward and that one doesn't know what to do with mkldnn tensors. Thus I'd say either assert that train=false or implement native_batch_norm_backward for consistency

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes senses will add an assert

// TODO: support training
AT_ERROR("mkldnn_batch_norm: mkldnn training is not supported in yet.");

// ideep::tensor saved_mean;
// ideep::tensor saved_var;
// ideep::batch_normalization_forward_training::compute<AllocForMKLDNN>(
// x, w, b, y, saved_mean, saved_var, m, v, momentum, eps);
// return std::make_tuple(
// new_with_itensor_mkldnn(std::move(y), input.options()),
// new_with_itensor_mkldnn(std::move(saved_mean), input.options()),
// new_with_itensor_mkldnn(std::move(saved_var), input.options()));
} else {
AT_ASSERTM(input.dim() == 4 || input.dim() == 5,
"mkldnn_batch_norm: currently mkldnn only support 2d and 3d batchnorm");
ideep::batch_normalization_forward_inference::compute<AllocForMKLDNN>(
x, m, v, w, b, y, eps);
return std::make_tuple(
new_with_itensor_mkldnn(std::move(y), input.options()),
new_with_itensor_mkldnn(ideep::tensor{}, input.options()),
new_with_itensor_mkldnn(ideep::tensor{}, input.options()));
}
}

} // namespace native
} // namespace at

#endif // AT_MKLDNN_EBABLED
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,7 @@
dispatch:
CPU: batch_norm_cpu
CUDA: batch_norm_cuda
MkldnnCPU: mkldnn_batch_norm

- func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)
dispatch:
Expand Down
13 changes: 13 additions & 0 deletions test/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,19 @@ def test_avg_pool2d(self):
avg_pool2d(x),
avg_pool2d(x.to_mkldnn()).to_dense())

def test_batch_norm2d(self):
N = torch.randint(3, 10, (1,)).item()
C = torch.randint(3, 100, (1,)).item()
x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10

# TODO: support training
for train in [False]:
bn = torch.nn.BatchNorm2d(C).float().train(train)
mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn))
self.assertEqual(
bn(x),
mkldnn_bn(x.to_mkldnn()).to_dense())


if __name__ == '__main__':
run_tests()