-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[JIT] Frozen Graph Linear-BatchNormNd Folding #86706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4699b58
76472af
823f5c4
d3753c4
ef4d634
879df37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| #include <torch/csrc/jit/passes/fold_linear_bn.h> | ||
|
|
||
| #include <ATen/TensorOperators.h> | ||
|
|
||
| #ifndef AT_PER_OPERATOR_HEADERS | ||
| #include <ATen/Functions.h> | ||
| #else | ||
| #include <ATen/ops/rsqrt.h> | ||
| #endif | ||
|
|
||
| namespace torch { | ||
| namespace jit { | ||
|
|
||
| std::tuple<at::Tensor, at::Tensor> computeUpdatedLinearWeightAndBias( | ||
| const LinearBNParameters& p) { | ||
| at::Tensor bn_scale = p.bn_w * at::rsqrt(p.bn_rv + p.bn_eps); | ||
| at::Tensor fused_w = p.linear_w * bn_scale.unsqueeze(-1); | ||
| at::Tensor fused_b = (p.linear_b - p.bn_rm) * bn_scale + p.bn_b; | ||
|
|
||
| auto linear_w_dtype = p.linear_w.dtype(); | ||
| auto linear_b_dtype = p.linear_b.dtype(); | ||
|
|
||
| return std::make_tuple( | ||
| fused_w.to(linear_w_dtype), fused_b.to(linear_b_dtype)); | ||
| } | ||
|
|
||
| } // namespace jit | ||
| } // namespace torch |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| #pragma once | ||
|
|
||
| #include <torch/csrc/jit/api/module.h> | ||
|
|
||
| namespace torch { | ||
| namespace jit { | ||
|
|
||
| struct TORCH_API LinearBNParameters { | ||
| at::Tensor linear_w; | ||
| at::Tensor linear_b; | ||
| at::Tensor bn_rm; | ||
| at::Tensor bn_rv; | ||
| double bn_eps = 0.0; | ||
| at::Tensor bn_w; | ||
| at::Tensor bn_b; | ||
| }; | ||
|
|
||
| /** | ||
| * Given the current weight and bias tensors of a Linear module and parameters | ||
| * of the BatchNorm module we're folding with, compute the updated values | ||
| * for the weight and bias. | ||
| * | ||
| * The function is basically copied from torch/nn/utils/fusion.py | ||
| */ | ||
| TORCH_API std::tuple<at::Tensor, at::Tensor> computeUpdatedLinearWeightAndBias( | ||
| const LinearBNParameters& p); | ||
|
|
||
| } // namespace jit | ||
| } // namespace torch | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| #include <torch/csrc/jit/ir/constants.h> | ||
| #include <torch/csrc/jit/ir/ir.h> | ||
| #include <torch/csrc/jit/passes/dead_code_elimination.h> | ||
| #include <torch/csrc/jit/passes/fold_linear_bn.h> | ||
| #include <torch/csrc/jit/passes/frozen_linear_folding.h> | ||
| #include <torch/csrc/jit/passes/utils/optimization_utils.h> | ||
|
|
||
| #ifndef AT_PER_OPERATOR_HEADERS | ||
| #include <ATen/Functions.h> | ||
| #else | ||
| #include <ATen/ops/ones_like.h> | ||
| #include <ATen/ops/zeros_like.h> | ||
| #endif | ||
|
|
||
| namespace torch { | ||
| namespace jit { | ||
|
|
||
| namespace { | ||
|
|
||
| using Tensor = at::Tensor; | ||
|
|
||
| bool supportedLinearNode(Node* n) { | ||
| if (n->kind() == aten::linear) { | ||
| return true; | ||
| } else { | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| bool FoldFrozenLinearBatchnorm(Block* b) { | ||
| bool graph_modified = false; | ||
| for (Node* n : b->nodes()) { | ||
| for (Block* block : n->blocks()) { | ||
| graph_modified |= FoldFrozenLinearBatchnorm(block); | ||
| } | ||
|
|
||
| if (n->kind() == aten::batch_norm && | ||
| supportedLinearNode(n->inputs().at(0)->node())) { | ||
| auto linear = n->inputs().at(0)->node(); | ||
| auto bn = n; | ||
|
|
||
| if (nonConstantParameters(linear) || nonConstantParameters(bn)) { | ||
| continue; | ||
| } | ||
|
|
||
| auto bn_rm_ivalue = bn->namedInput("running_mean"); | ||
| auto bn_rv_ivalue = bn->namedInput("running_var"); | ||
|
|
||
| // check running_mean and running_var has value, if they are | ||
| // None(track_running_stats=False), skiping the folding path. | ||
| if (bn_rm_ivalue->type() == NoneType::get() && | ||
| bn_rv_ivalue->type() == NoneType::get()) { | ||
| continue; | ||
| } | ||
|
|
||
| auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean")).value(); | ||
| auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var")).value(); | ||
| auto bn_eps = constant_as<double>(bn->namedInput("eps")).value(); | ||
| auto linear_w = constant_as<Tensor>(linear->namedInput("weight")).value(); | ||
|
|
||
| // implementation taken from torch/nn/utils/fusion.py | ||
| Tensor linear_b; | ||
| if (linear->namedInput("bias")->type() == NoneType::get()) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this looks like it's mostly copied from frozen_conv_folding.cpp, is that accurate? recently we had some issues with autocasting, see #77617. Can you add some tests like this to confirm that it's not an issue in this case?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That's correct, this is mostly copied from frozen_conv_folding.cpp.
Thanks for the notice. I verified that similar issue occurs for jit autocasting for lin-bn1d only and on cuda only (inputs are only casted to half on cuda). The dtype mistmatch is checked in addmm with {2,3}d input tensor, which is the correct usage of bn1d; however, the issue will also occur for scripting (not tracing) with an incorrect usage of bn{2,3}d with {2,3}d input tensor (incorrect usage, expected input dims are 4d, 5d for bn2d, and bn3d respectively). Below, replace To reproduce |
||
| at::ScalarType bias_dtype = bn_rm.scalar_type(); | ||
| at::ScalarType weight_dtype = linear_w.scalar_type(); | ||
| at::DeviceType weight_device = linear_w.device().type(); | ||
| if (weight_device == at::kCUDA && | ||
| (weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) && | ||
| bias_dtype == at::kFloat) { | ||
| bias_dtype = weight_dtype; | ||
| } | ||
| linear_b = at::zeros_like(bn_rm, at::TensorOptions().dtype(bias_dtype)); | ||
| } else { | ||
| linear_b = constant_as<Tensor>(linear->namedInput("bias")).value(); | ||
| } | ||
| Tensor bn_w; | ||
| if (bn->namedInput("weight")->type() == NoneType::get()) { | ||
| bn_w = at::ones_like(bn_rm); | ||
| } else { | ||
| bn_w = constant_as<Tensor>(bn->namedInput("weight")).value(); | ||
| } | ||
| Tensor bn_b; | ||
| if (n->namedInput("bias")->type() == NoneType::get()) { | ||
| bn_b = at::zeros_like(bn_rm); | ||
| } else { | ||
| bn_b = constant_as<Tensor>(bn->namedInput("bias")).value(); | ||
| } | ||
|
|
||
| LinearBNParameters params; | ||
| params.linear_w = linear_w; | ||
| params.linear_b = linear_b; | ||
| params.bn_rm = bn_rm; | ||
| params.bn_rv = bn_rv; | ||
| params.bn_eps = bn_eps; | ||
| params.bn_w = bn_w; | ||
| params.bn_b = bn_b; | ||
| std::tuple<Tensor, Tensor> out = | ||
| computeUpdatedLinearWeightAndBias(params); | ||
| WithInsertPoint guard(linear); | ||
| auto fused_linear_w = b->owningGraph()->insertConstant(std::get<0>(out)); | ||
| auto fused_linear_b = b->owningGraph()->insertConstant(std::get<1>(out)); | ||
| auto linear_w_value = linear->namedInput("weight"); | ||
| auto linear_b_value = linear->namedInput("bias"); | ||
|
|
||
| fused_linear_w->setDebugName(linear_w_value->debugName() + "_fused_bn"); | ||
| fused_linear_b->setDebugName(linear_b_value->debugName() + "_fused_bn"); | ||
|
|
||
| linear->replaceInputWith(linear_w_value, fused_linear_w); | ||
| linear->replaceInputWith(linear_b_value, fused_linear_b); | ||
|
|
||
| bn->output()->replaceAllUsesWith(linear->output()); | ||
| graph_modified = true; | ||
| } | ||
| } | ||
| return graph_modified; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| bool FoldFrozenLinearBatchnorm(std::shared_ptr<Graph>& graph) { | ||
| bool graph_modified = FoldFrozenLinearBatchnorm(graph->block()); | ||
| EliminateDeadCode(graph); | ||
| return graph_modified; | ||
| } | ||
|
|
||
| } // namespace jit | ||
| } // namespace torch | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| #pragma once | ||
|
|
||
| #include <torch/csrc/jit/ir/ir.h> | ||
|
|
||
| namespace torch { | ||
| namespace jit { | ||
|
|
||
| // Fuses Linear -> BatchNormNd into a single Linear by | ||
| // folding batchnorm weights into linear weights. | ||
| // This pass only works on Frozen Graphs; otherwise it is a No-Op. | ||
| TORCH_API bool FoldFrozenLinearBatchnorm(std::shared_ptr<Graph>& graph); | ||
|
|
||
| } // namespace jit | ||
| } // namespace torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a super minor nit - but from what I can tell the contents of fold_linear_bn.(h|cpp) are essentially implementation details of frozen_linear_folding.cpp, is that correct? if so we can probably move this into the anonymous namespace of frozen_linear_folding.cpp right?
(this is mostly just a matter of preference tbh)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct. Actually,
fold_linear_bn.his mostly copied fromfold_conv_bn.h, so I thought I would follow it's style. I'm fine with either options -- what do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's fine as it is