Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,12 @@ core_sources_full_mobile_no_backend_interface_xplat = [
"torch/csrc/jit/passes/remove_mutation.cpp",
"torch/csrc/jit/passes/prepack_folding.cpp",
"torch/csrc/jit/passes/fold_conv_bn.cpp",
"torch/csrc/jit/passes/fold_linear_bn.cpp",
"torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.cpp",
"torch/csrc/jit/passes/frozen_concat_linear.cpp",
"torch/csrc/jit/passes/frozen_conv_add_relu_fusion.cpp",
"torch/csrc/jit/passes/frozen_conv_folding.cpp",
"torch/csrc/jit/passes/frozen_linear_folding.cpp",
"torch/csrc/jit/passes/frozen_linear_transpose.cpp",
"torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp",
"torch/csrc/jit/passes/frozen_graph_optimizations.cpp",
Expand Down
101 changes: 101 additions & 0 deletions test/jit/test_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2223,6 +2223,107 @@ def forward(self, x):
FileCheck().check("conv").check_not("aten::batch_norm").run(traced_model.graph)
FileCheck().check("conv").check_not("aten::add").run(traced_model.graph)

def test_linear_bn_folding(self):
module_pairs = [(nn.Linear, nn.BatchNorm1d), (nn.Linear, nn.BatchNorm2d), (nn.Linear, nn.BatchNorm3d)]
use_tracing = [True, False]
bn_running_stats = [True, False]

for modules, tracing, track_stats in product(module_pairs, use_tracing, bn_running_stats):
class LinearBN(torch.nn.Module):
def __init__(self, in_features, out_features):
super(LinearBN, self).__init__()
self.linear = modules[0](in_features, out_features)
self.bn = modules[1](out_features, eps=0.001, track_running_stats=track_stats)

def forward(self, x):
x = self.linear(x)
return self.bn(x)

mod_eager = LinearBN(32, 32).eval()

inps = [3, 32]
if modules[1] == nn.BatchNorm2d:
inps.append(inps[-1])
inps.append(inps[-1])
if modules[1] == nn.BatchNorm3d:
inps.append(inps[-1])
inps.append(inps[-1])
inps.append(inps[-1])

inp = torch.rand(inps)

if tracing:
scripted_mod = torch.jit.trace(mod_eager, (inp))
else:
scripted_mod = torch.jit.script(mod_eager)

self.run_pass("inline", scripted_mod.graph)
self.run_pass("peephole", scripted_mod.graph)
self.run_pass("constant_propagation", scripted_mod.graph)

FileCheck().check("linear").check("batch").run(scripted_mod.graph)
# successfully no-ops with non-const inputs
self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
FileCheck().check("linear").check("aten::batch_norm").run(scripted_mod.graph)

scripted_mod = torch.jit.freeze(scripted_mod)
self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
if track_stats:
FileCheck().check("linear").check_not("aten::batch_norm").run(scripted_mod.graph)
else:
FileCheck().check("linear").check("aten::batch_norm").run(scripted_mod.graph)

self.assertEqual(mod_eager(inp), scripted_mod(inp))
self.assertEqual(mod_eager(inp), scripted_mod(inp))

@skipCUDAMemoryLeakCheckIf(True)
@unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
def test_linear_bn_folding_autocast_scenario_cuda(self):
module_pairs = [(nn.Linear, nn.BatchNorm1d), (nn.Linear, nn.BatchNorm2d), (nn.Linear, nn.BatchNorm3d)]
use_tracing = [True, False]
bn_running_stats = [True, False]

for modules, tracing, track_stats in product(module_pairs, use_tracing, bn_running_stats):
class LinearBN(torch.nn.Module):
def __init__(self, in_features, out_features):
super(LinearBN, self).__init__()
self.linear = modules[0](in_features, out_features, bias=False, dtype=torch.half)
self.bn = modules[1](out_features, eps=0.001, dtype=torch.float)

def forward(self, x):
x = self.linear(x)
return self.bn(x)

mod_eager = LinearBN(32, 32).cuda().eval()

inps = [3, 32]
if modules[1] == nn.BatchNorm2d:
inps.append(inps[-1])
inps.append(inps[-1])
if modules[1] == nn.BatchNorm3d:
inps.append(inps[-1])
inps.append(inps[-1])
inps.append(inps[-1])

x = torch.rand(inps, dtype=torch.half).cuda()

if tracing:
scripted_mod = torch.jit.trace(mod_eager, (x))
else:
scripted_mod = torch.jit.script(mod_eager)
scripted_mod = torch.jit.freeze(scripted_mod)
FileCheck().check("linear").check_not("aten::batch_norm").run(scripted_mod.graph)
lin_node = scripted_mod.graph.findNode("aten::linear", True)
self.assertTrue(lin_node is not None)
weight_input = lin_node.namedInput("weight")
bias_input = lin_node.namedInput("bias")
self.assertTrue(bias_input is not None)
self.assertTrue(weight_input.type().dtype() == torch.half)
self.assertTrue(bias_input.type().dtype() == torch.half)

self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)

@unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
def test_linear_concat(self):
out_dimms = [[5, 10], [1, 5]]
Expand Down
28 changes: 28 additions & 0 deletions torch/csrc/jit/passes/fold_linear_bn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include <torch/csrc/jit/passes/fold_linear_bn.h>

#include <ATen/TensorOperators.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/rsqrt.h>
#endif

namespace torch {
namespace jit {

std::tuple<at::Tensor, at::Tensor> computeUpdatedLinearWeightAndBias(
const LinearBNParameters& p) {
at::Tensor bn_scale = p.bn_w * at::rsqrt(p.bn_rv + p.bn_eps);
at::Tensor fused_w = p.linear_w * bn_scale.unsqueeze(-1);
at::Tensor fused_b = (p.linear_b - p.bn_rm) * bn_scale + p.bn_b;

auto linear_w_dtype = p.linear_w.dtype();
auto linear_b_dtype = p.linear_b.dtype();

return std::make_tuple(
fused_w.to(linear_w_dtype), fused_b.to(linear_b_dtype));
}

} // namespace jit
} // namespace torch
29 changes: 29 additions & 0 deletions torch/csrc/jit/passes/fold_linear_bn.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include <torch/csrc/jit/api/module.h>

namespace torch {
namespace jit {

struct TORCH_API LinearBNParameters {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a super minor nit - but from what I can tell the contents of fold_linear_bn.(h|cpp) are essentially implementation details of frozen_linear_folding.cpp, is that correct? if so we can probably move this into the anonymous namespace of frozen_linear_folding.cpp right?

(this is mostly just a matter of preference tbh)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. Actually, fold_linear_bn.h is mostly copied from fold_conv_bn.h, so I thought I would follow it's style. I'm fine with either options -- what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's fine as it is

at::Tensor linear_w;
at::Tensor linear_b;
at::Tensor bn_rm;
at::Tensor bn_rv;
double bn_eps = 0.0;
at::Tensor bn_w;
at::Tensor bn_b;
};

/**
* Given the current weight and bias tensors of a Linear module and parameters
* of the BatchNorm module we're folding with, compute the updated values
* for the weight and bias.
*
* The function is basically copied from torch/nn/utils/fusion.py
*/
TORCH_API std::tuple<at::Tensor, at::Tensor> computeUpdatedLinearWeightAndBias(
const LinearBNParameters& p);

} // namespace jit
} // namespace torch
2 changes: 2 additions & 0 deletions torch/csrc/jit/passes/frozen_graph_optimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/csrc/jit/passes/frozen_concat_linear.h>
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/passes/frozen_linear_folding.h>
#include <torch/csrc/jit/passes/remove_dropout.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/utils/memory.h>
Expand All @@ -24,6 +25,7 @@ void OptimizeFrozenGraph(
changed |= FoldFrozenConvBatchnorm(graph);
changed |= FoldFrozenConvAddOrSub(graph);
changed |= FoldFrozenConvMulOrDiv(graph);
changed |= FoldFrozenLinearBatchnorm(graph);
} while (changed);
}
}
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/passes/frozen_graph_optimizations.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* - FoldFrozenConvBatchnorm
* - FoldFrozenConvAddOrSub
* - FoldFrozenConvMulOrDiv
* - FoldFrozenLinearBatchnorm
*/

namespace torch {
Expand Down
127 changes: 127 additions & 0 deletions torch/csrc/jit/passes/frozen_linear_folding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/fold_linear_bn.h>
#include <torch/csrc/jit/passes/frozen_linear_folding.h>
#include <torch/csrc/jit/passes/utils/optimization_utils.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/ones_like.h>
#include <ATen/ops/zeros_like.h>
#endif

namespace torch {
namespace jit {

namespace {

using Tensor = at::Tensor;

bool supportedLinearNode(Node* n) {
if (n->kind() == aten::linear) {
return true;
} else {
return false;
}
}

bool FoldFrozenLinearBatchnorm(Block* b) {
bool graph_modified = false;
for (Node* n : b->nodes()) {
for (Block* block : n->blocks()) {
graph_modified |= FoldFrozenLinearBatchnorm(block);
}

if (n->kind() == aten::batch_norm &&
supportedLinearNode(n->inputs().at(0)->node())) {
auto linear = n->inputs().at(0)->node();
auto bn = n;

if (nonConstantParameters(linear) || nonConstantParameters(bn)) {
continue;
}

auto bn_rm_ivalue = bn->namedInput("running_mean");
auto bn_rv_ivalue = bn->namedInput("running_var");

// check running_mean and running_var has value, if they are
// None(track_running_stats=False), skiping the folding path.
if (bn_rm_ivalue->type() == NoneType::get() &&
bn_rv_ivalue->type() == NoneType::get()) {
continue;
}

auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean")).value();
auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var")).value();
auto bn_eps = constant_as<double>(bn->namedInput("eps")).value();
auto linear_w = constant_as<Tensor>(linear->namedInput("weight")).value();

// implementation taken from torch/nn/utils/fusion.py
Tensor linear_b;
if (linear->namedInput("bias")->type() == NoneType::get()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like it's mostly copied from frozen_conv_folding.cpp, is that accurate? recently we had some issues with autocasting, see #77617. Can you add some tests like this to confirm that it's not an issue in this case?

Copy link
Collaborator Author

@min-jean-cho min-jean-cho Nov 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like it's mostly copied from frozen_conv_folding.cpp, is that accurate?

That's correct, this is mostly copied from frozen_conv_folding.cpp.

recently we had some issues with autocasting, see #77617

Thanks for the notice. I verified that similar issue occurs for jit autocasting for lin-bn1d only and on cuda only (inputs are only casted to half on cuda).

The dtype mistmatch is checked in addmm with {2,3}d input tensor, which is the correct usage of bn1d; however, the issue will also occur for scripting (not tracing) with an incorrect usage of bn{2,3}d with {2,3}d input tensor (incorrect usage, expected input dims are 4d, 5d for bn2d, and bn3d respectively). Below, replace nn.BatchNorm1d(out_features) with nn.BatchNorm{2,3}d(out_features).

To reproduce RuntimeError: self and mat2 must have the same dtype:

import torch
import torch.nn as nn

class LinearBN(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super(LinearBN, self).__init__()
        self.linear = nn.Linear(in_features, out_features, bias=False)
        self.bn = nn.BatchNorm1d(out_features)
    
    def forward(self, x):
        x = self.linear(x)
        return self.bn(x)

net = LinearBN(32, 32).cuda().eval()  
print(net.linear.weight.dtype) #torch.float32
img = torch.rand((1, 32, 32)).cuda()
with torch.cuda.amp.autocast(True):
    net = torch.jit.script(net)
    net = torch.jit.freeze(net)
print(net.graph)
lin_node = net.graph.findNode("aten::linear", True)
print(lin_node.namedInput("weight").type().dtype()) #torch.float16
print(lin_node.namedInput("bias").type().dtype()) #torch.float32
net(img)

at::ScalarType bias_dtype = bn_rm.scalar_type();
at::ScalarType weight_dtype = linear_w.scalar_type();
at::DeviceType weight_device = linear_w.device().type();
if (weight_device == at::kCUDA &&
(weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
bias_dtype == at::kFloat) {
bias_dtype = weight_dtype;
}
linear_b = at::zeros_like(bn_rm, at::TensorOptions().dtype(bias_dtype));
} else {
linear_b = constant_as<Tensor>(linear->namedInput("bias")).value();
}
Tensor bn_w;
if (bn->namedInput("weight")->type() == NoneType::get()) {
bn_w = at::ones_like(bn_rm);
} else {
bn_w = constant_as<Tensor>(bn->namedInput("weight")).value();
}
Tensor bn_b;
if (n->namedInput("bias")->type() == NoneType::get()) {
bn_b = at::zeros_like(bn_rm);
} else {
bn_b = constant_as<Tensor>(bn->namedInput("bias")).value();
}

LinearBNParameters params;
params.linear_w = linear_w;
params.linear_b = linear_b;
params.bn_rm = bn_rm;
params.bn_rv = bn_rv;
params.bn_eps = bn_eps;
params.bn_w = bn_w;
params.bn_b = bn_b;
std::tuple<Tensor, Tensor> out =
computeUpdatedLinearWeightAndBias(params);
WithInsertPoint guard(linear);
auto fused_linear_w = b->owningGraph()->insertConstant(std::get<0>(out));
auto fused_linear_b = b->owningGraph()->insertConstant(std::get<1>(out));
auto linear_w_value = linear->namedInput("weight");
auto linear_b_value = linear->namedInput("bias");

fused_linear_w->setDebugName(linear_w_value->debugName() + "_fused_bn");
fused_linear_b->setDebugName(linear_b_value->debugName() + "_fused_bn");

linear->replaceInputWith(linear_w_value, fused_linear_w);
linear->replaceInputWith(linear_b_value, fused_linear_b);

bn->output()->replaceAllUsesWith(linear->output());
graph_modified = true;
}
}
return graph_modified;
}

} // namespace

bool FoldFrozenLinearBatchnorm(std::shared_ptr<Graph>& graph) {
bool graph_modified = FoldFrozenLinearBatchnorm(graph->block());
EliminateDeadCode(graph);
return graph_modified;
}

} // namespace jit
} // namespace torch
14 changes: 14 additions & 0 deletions torch/csrc/jit/passes/frozen_linear_folding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once

#include <torch/csrc/jit/ir/ir.h>

namespace torch {
namespace jit {

// Fuses Linear -> BatchNormNd into a single Linear by
// folding batchnorm weights into linear weights.
// This pass only works on Frozen Graphs; otherwise it is a No-Op.
TORCH_API bool FoldFrozenLinearBatchnorm(std::shared_ptr<Graph>& graph);

} // namespace jit
} // namespace torch
2 changes: 2 additions & 0 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/passes/frozen_linear_folding.h>
#include <torch/csrc/jit/passes/frozen_linear_transpose.h>
#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
#include <torch/csrc/jit/passes/fuse_linear.h>
Expand Down Expand Up @@ -399,6 +400,7 @@ void initJITBindings(PyObject* module) {
.def("_jit_pass_fold_frozen_conv_bn", &FoldFrozenConvBatchnorm)
.def("_jit_pass_fold_frozen_conv_add_or_sub", &FoldFrozenConvAddOrSub)
.def("_jit_pass_fold_frozen_conv_mul_or_div", &FoldFrozenConvMulOrDiv)
.def("_jit_pass_fold_frozen_linear_bn", &FoldFrozenLinearBatchnorm)
.def("_jit_pass_convert_frozen_ops_to_mkldnn", &ConvertFrozenOpsToMKLDNN)
.def("_jit_pass_fuse_frozen_conv_add_relu", &FuseFrozenConvAddRelu)
.def("_jit_pass_transpose_frozen_linear", &FrozenLinearTranspose)
Expand Down