-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add aten mkldnn batch_norm operator #19206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cbe728c
d9ce6d8
8de6797
c294bd3
f50385c
0afce4a
1f9ddaf
ca1e90e
28d059f
670eba5
2b93218
4bb28e4
40542eb
0909c90
d606a52
8ea288c
64717b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| #include <ATen/ATen.h> | ||
| #include <ATen/Config.h> | ||
| #include <ATen/NativeFunctions.h> | ||
| #include <tuple> | ||
|
|
||
| #if !AT_MKLDNN_ENABLED() | ||
|
|
||
| namespace at { | ||
| namespace native { | ||
|
|
||
| std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm( | ||
| const Tensor& self, | ||
| const Tensor& weight, | ||
| const Tensor& bias, | ||
| const Tensor& running_mean, | ||
| const Tensor& running_var, | ||
| bool train, | ||
| double momentum, | ||
| double eps) { | ||
| AT_ERROR("mkldnn_batch_norm: ATen not compiled with MKLDNN support"); | ||
| } | ||
|
|
||
| } // namespace native | ||
| } // namespace at | ||
|
|
||
| #else // AT_MKLDNN_EBABLED | ||
|
|
||
| #include <ATen/native/mkldnn/MKLDNNCommon.h> | ||
|
|
||
| namespace at { | ||
| namespace native { | ||
|
|
||
| std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm( | ||
| const Tensor& input, | ||
| const Tensor& weight, | ||
| const Tensor& bias, | ||
| const Tensor& running_mean, | ||
| const Tensor& running_var, | ||
| bool train, | ||
| double momentum, | ||
| double eps) { | ||
| ideep::tensor& x = itensor_from_mkldnn(input); | ||
| ideep::tensor& w = itensor_from_mkldnn(weight); | ||
| ideep::tensor& b = itensor_from_mkldnn(bias); | ||
| ideep::tensor& m = itensor_from_mkldnn(running_mean); | ||
| ideep::tensor& v = itensor_from_mkldnn(running_var); | ||
|
|
||
| ideep::tensor y; | ||
|
|
||
| if (train) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This training support is incomplete. Derivatives.yaml calls native_batch_norm_backward and that one doesn't know what to do with mkldnn tensors. Thus I'd say either assert that train=false or implement native_batch_norm_backward for consistency
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes senses will add an assert |
||
| // TODO: support training | ||
| AT_ERROR("mkldnn_batch_norm: mkldnn training is not supported in yet."); | ||
|
|
||
| // ideep::tensor saved_mean; | ||
| // ideep::tensor saved_var; | ||
| // ideep::batch_normalization_forward_training::compute<AllocForMKLDNN>( | ||
| // x, w, b, y, saved_mean, saved_var, m, v, momentum, eps); | ||
| // return std::make_tuple( | ||
| // new_with_itensor_mkldnn(std::move(y), input.options()), | ||
| // new_with_itensor_mkldnn(std::move(saved_mean), input.options()), | ||
| // new_with_itensor_mkldnn(std::move(saved_var), input.options())); | ||
| } else { | ||
| AT_ASSERTM(input.dim() == 4 || input.dim() == 5, | ||
| "mkldnn_batch_norm: currently mkldnn only support 2d and 3d batchnorm"); | ||
| ideep::batch_normalization_forward_inference::compute<AllocForMKLDNN>( | ||
| x, m, v, w, b, y, eps); | ||
| return std::make_tuple( | ||
| new_with_itensor_mkldnn(std::move(y), input.options()), | ||
| new_with_itensor_mkldnn(ideep::tensor{}, input.options()), | ||
| new_with_itensor_mkldnn(ideep::tensor{}, input.options())); | ||
| } | ||
| } | ||
|
|
||
| } // namespace native | ||
| } // namespace at | ||
|
|
||
| #endif // AT_MKLDNN_EBABLED | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bddppq , I think there should check whether running status is used for training and infernce, i.e. running_mean and running_var are defined or not. Another suggestion is that mkldnn only support batchnorm2d and batchnorm3d, so you shoud make some checks when call mkldnn, perhaps you can see the code as reference. thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would mkldnn throw with nice message that it's not supported? in this case the input is already mkldnn tensor so it's better to fail if something is not supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dzhulgakov I'm afraid not. :-( How about we add an assertion here to guarantee 2d or 3d batchnorm here?