-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[quant] Add a quantized batch_norm operator #33080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
8bff502
[quant] Add a quantized batch_norm operator
supriyar bb9553d
Update on "[quant] Add a quantized batch_norm operator"
supriyar 417085a
Update on "[quant] Add a quantized batch_norm operator"
supriyar c341741
Update on "[quant] Add a quantized batch_norm operator"
supriyar 9c3a2f2
Update on "[quant] Add a quantized batch_norm operator"
supriyar 1bc646a
Update on "[quant] Add a quantized batch_norm operator"
supriyar b20ae07
Update on "[quant] Add a quantized batch_norm operator"
supriyar 4df9826
Update on "[quant] Add a quantized batch_norm operator"
supriyar ab77661
Update on "[quant] Add a quantized batch_norm operator"
supriyar 185a18d
Update on "[quant] Add a quantized batch_norm operator"
supriyar 12507a3
Update on "[quant] Add a quantized batch_norm operator"
supriyar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| #include <ATen/ATen.h> | ||
| #include <ATen/NativeFunctions.h> | ||
| #include <ATen/Parallel.h> | ||
| #include <ATen/core/op_registration/op_registration.h> | ||
| #include <ATen/native/quantized/cpu/quantized_ops.h> | ||
|
|
||
| #include <algorithm> | ||
| #include <vector> | ||
|
|
||
| namespace at { | ||
| namespace native { | ||
|
|
||
| DEFINE_DISPATCH(qbatch_norm_stub); | ||
|
|
||
| namespace { | ||
| void compute_fused_params( | ||
| const int64_t channels, | ||
| const float* weight_data, | ||
| const float* bias_data, | ||
| const float* mean_data, | ||
| const float* var_data, | ||
| double eps, | ||
| float input_scale, | ||
| float output_scale, | ||
| float* alpha_data, | ||
| float* beta_data) { | ||
| // Batch Normalization | ||
| // output(n, c, h, w) | ||
| // = (input(n, c, h, w) - mean(c)) / sqrt(var(c) + eps) * weight(c) | ||
| // + bias(c) | ||
| // We factor out inv_sigma(c) = 1 / sqrt(var(c) + eps). | ||
| for (int64_t c = 0; c < channels; c++) { | ||
| float inv_sigma = 1.0 / std::sqrt(var_data[c] + static_cast<float>(eps)); | ||
| float weight_v = weight_data ? weight_data[c] : 1; | ||
| float bias_v = bias_data ? bias_data[c] : 0; | ||
| alpha_data[c] = inv_sigma * weight_v * (input_scale / output_scale); | ||
| beta_data[c] = (bias_v - mean_data[c] * inv_sigma * weight_v) / output_scale; | ||
| } | ||
| } | ||
|
|
||
| template <bool ReluFused> | ||
| Tensor q_batch_norm_impl( | ||
| Tensor qx, | ||
| Tensor weight, | ||
| Tensor bias, | ||
| Tensor mean, | ||
| Tensor var, | ||
| double eps, | ||
| float output_scale, | ||
| int64_t output_zero_point) { | ||
|
|
||
| if (qx.numel() == 0) { | ||
| auto out = qx.clone(); | ||
| return out; | ||
| } | ||
| int64_t ndim = qx.dim(); | ||
| TORCH_CHECK(ndim == 4, "Expecting the input tensor of rank 4."); | ||
| const int64_t N = qx.size(0); | ||
| const int64_t C = qx.size(1); | ||
| const int64_t H = qx.size(2); | ||
| const int64_t W = qx.size(3); | ||
|
|
||
| TORCH_CHECK(weight.numel() == C, "Expect weight size to match C"); | ||
| TORCH_CHECK(bias.numel() == C, "Expect weight size to match C"); | ||
|
|
||
| const float* weight_data = weight.template data<float>(); | ||
| const float* bias_data = bias.template data<float>(); | ||
|
|
||
| TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension"); | ||
| TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension"); | ||
|
|
||
| Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||
| Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||
| float* alpha_data = alpha.data_ptr<float>(); | ||
| float* beta_data = beta.data_ptr<float>(); | ||
|
|
||
| const float* mean_data = mean.template data<float>(); | ||
| const float* var_data = var.template data<float>(); | ||
|
|
||
| auto oSizes = qx.sizes(); | ||
| auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast); | ||
| Tensor qy = at::_empty_affine_quantized( | ||
| oSizes, | ||
| at::device(kCPU).dtype(qx_nhwc.scalar_type()), | ||
| output_scale, | ||
| output_zero_point, | ||
| MemoryFormat::ChannelsLast); | ||
|
|
||
| compute_fused_params( | ||
| C, | ||
| weight_data, | ||
| bias_data, | ||
| mean_data, | ||
| var_data, | ||
| eps, | ||
| qx.q_scale(), | ||
| output_scale, | ||
| alpha_data, | ||
| beta_data); | ||
|
|
||
| qbatch_norm_stub( | ||
| qx.device().type(), | ||
| N, | ||
| C, | ||
| H * W, | ||
| qx.q_zero_point(), | ||
| output_zero_point, | ||
| qx_nhwc, | ||
| alpha, | ||
| beta, | ||
| qy); | ||
| return qy; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| Tensor quantized_batch_norm( | ||
| const Tensor& qx, | ||
| const Tensor& weight /* optional */, | ||
| const Tensor& bias /* optional */, | ||
| const Tensor& mean /* optional */, | ||
| const Tensor& var /* optional */, | ||
| double eps, | ||
| double output_scale, | ||
| int64_t output_zero_point) { | ||
| Tensor qy; | ||
| qy = q_batch_norm_impl<false>( | ||
| qx, weight, bias, mean, var, eps, output_scale, output_zero_point); | ||
| return qy; | ||
| } | ||
|
|
||
| // Keep the registry in the anonymous namespace. | ||
| namespace { | ||
| class QBatchNorm2d final : public torch::OperatorKernel { | ||
| public: | ||
| Tensor operator()( | ||
| Tensor qx, | ||
| Tensor weight, | ||
| Tensor bias, | ||
| Tensor mean, | ||
| Tensor var, | ||
| double eps, | ||
| double output_scale, | ||
| int64_t output_zero_point) { | ||
| return q_batch_norm_impl<false>( | ||
| qx, weight, bias, mean, var, eps, output_scale, output_zero_point); | ||
| } | ||
| }; | ||
|
|
||
| static auto registry = torch::RegisterOperators().op( | ||
| "quantized::batch_norm(Tensor qx, " | ||
| "Tensor weight, " | ||
| "Tensor bias, " | ||
| "Tensor mean, " | ||
| "Tensor var, " | ||
| "float eps, " | ||
| "float output_scale, " | ||
| "int output_zero_point) -> Tensor", | ||
| torch::RegisterOperators::options().kernel<QBatchNorm2d>( | ||
| DispatchKey::QuantizedCPUTensorId)); | ||
|
|
||
| } // namespace | ||
| } // namespace native | ||
| } // namespace at |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.