Skip to content

Commit 31b8dc7

Browse files
Revert "[JIT] Frozen Graph Linear-BatchNormNd Folding (#86706)"
This reverts commit e585156. Reverted #86706 on behalf of https://github.com/davidberard98 due to possibly causing internal build failures, will revert and investigate later
1 parent 535b0e3 commit 31b8dc7

File tree

9 files changed

+0
-306
lines changed

9 files changed

+0
-306
lines changed

build_variables.bzl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,10 @@ core_sources_full_mobile_no_backend_interface_xplat = [
296296
"torch/csrc/jit/passes/remove_mutation.cpp",
297297
"torch/csrc/jit/passes/prepack_folding.cpp",
298298
"torch/csrc/jit/passes/fold_conv_bn.cpp",
299-
"torch/csrc/jit/passes/fold_linear_bn.cpp",
300299
"torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.cpp",
301300
"torch/csrc/jit/passes/frozen_concat_linear.cpp",
302301
"torch/csrc/jit/passes/frozen_conv_add_relu_fusion.cpp",
303302
"torch/csrc/jit/passes/frozen_conv_folding.cpp",
304-
"torch/csrc/jit/passes/frozen_linear_folding.cpp",
305303
"torch/csrc/jit/passes/frozen_linear_transpose.cpp",
306304
"torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp",
307305
"torch/csrc/jit/passes/frozen_graph_optimizations.cpp",

test/jit/test_freezing.py

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2223,107 +2223,6 @@ def forward(self, x):
22232223
FileCheck().check("conv").check_not("aten::batch_norm").run(traced_model.graph)
22242224
FileCheck().check("conv").check_not("aten::add").run(traced_model.graph)
22252225

2226-
def test_linear_bn_folding(self):
2227-
module_pairs = [(nn.Linear, nn.BatchNorm1d), (nn.Linear, nn.BatchNorm2d), (nn.Linear, nn.BatchNorm3d)]
2228-
use_tracing = [True, False]
2229-
bn_running_stats = [True, False]
2230-
2231-
for modules, tracing, track_stats in product(module_pairs, use_tracing, bn_running_stats):
2232-
class LinearBN(torch.nn.Module):
2233-
def __init__(self, in_features, out_features):
2234-
super(LinearBN, self).__init__()
2235-
self.linear = modules[0](in_features, out_features)
2236-
self.bn = modules[1](out_features, eps=0.001, track_running_stats=track_stats)
2237-
2238-
def forward(self, x):
2239-
x = self.linear(x)
2240-
return self.bn(x)
2241-
2242-
mod_eager = LinearBN(32, 32).eval()
2243-
2244-
inps = [3, 32]
2245-
if modules[1] == nn.BatchNorm2d:
2246-
inps.append(inps[-1])
2247-
inps.append(inps[-1])
2248-
if modules[1] == nn.BatchNorm3d:
2249-
inps.append(inps[-1])
2250-
inps.append(inps[-1])
2251-
inps.append(inps[-1])
2252-
2253-
inp = torch.rand(inps)
2254-
2255-
if tracing:
2256-
scripted_mod = torch.jit.trace(mod_eager, (inp))
2257-
else:
2258-
scripted_mod = torch.jit.script(mod_eager)
2259-
2260-
self.run_pass("inline", scripted_mod.graph)
2261-
self.run_pass("peephole", scripted_mod.graph)
2262-
self.run_pass("constant_propagation", scripted_mod.graph)
2263-
2264-
FileCheck().check("linear").check("batch").run(scripted_mod.graph)
2265-
# successfully no-ops with non-const inputs
2266-
self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
2267-
FileCheck().check("linear").check("aten::batch_norm").run(scripted_mod.graph)
2268-
2269-
scripted_mod = torch.jit.freeze(scripted_mod)
2270-
self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
2271-
if track_stats:
2272-
FileCheck().check("linear").check_not("aten::batch_norm").run(scripted_mod.graph)
2273-
else:
2274-
FileCheck().check("linear").check("aten::batch_norm").run(scripted_mod.graph)
2275-
2276-
self.assertEqual(mod_eager(inp), scripted_mod(inp))
2277-
self.assertEqual(mod_eager(inp), scripted_mod(inp))
2278-
2279-
@skipCUDAMemoryLeakCheckIf(True)
2280-
@unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
2281-
def test_linear_bn_folding_autocast_scenario_cuda(self):
2282-
module_pairs = [(nn.Linear, nn.BatchNorm1d), (nn.Linear, nn.BatchNorm2d), (nn.Linear, nn.BatchNorm3d)]
2283-
use_tracing = [True, False]
2284-
bn_running_stats = [True, False]
2285-
2286-
for modules, tracing, track_stats in product(module_pairs, use_tracing, bn_running_stats):
2287-
class LinearBN(torch.nn.Module):
2288-
def __init__(self, in_features, out_features):
2289-
super(LinearBN, self).__init__()
2290-
self.linear = modules[0](in_features, out_features, bias=False, dtype=torch.half)
2291-
self.bn = modules[1](out_features, eps=0.001, dtype=torch.float)
2292-
2293-
def forward(self, x):
2294-
x = self.linear(x)
2295-
return self.bn(x)
2296-
2297-
mod_eager = LinearBN(32, 32).cuda().eval()
2298-
2299-
inps = [3, 32]
2300-
if modules[1] == nn.BatchNorm2d:
2301-
inps.append(inps[-1])
2302-
inps.append(inps[-1])
2303-
if modules[1] == nn.BatchNorm3d:
2304-
inps.append(inps[-1])
2305-
inps.append(inps[-1])
2306-
inps.append(inps[-1])
2307-
2308-
x = torch.rand(inps, dtype=torch.half).cuda()
2309-
2310-
if tracing:
2311-
scripted_mod = torch.jit.trace(mod_eager, (x))
2312-
else:
2313-
scripted_mod = torch.jit.script(mod_eager)
2314-
scripted_mod = torch.jit.freeze(scripted_mod)
2315-
FileCheck().check("linear").check_not("aten::batch_norm").run(scripted_mod.graph)
2316-
lin_node = scripted_mod.graph.findNode("aten::linear", True)
2317-
self.assertTrue(lin_node is not None)
2318-
weight_input = lin_node.namedInput("weight")
2319-
bias_input = lin_node.namedInput("bias")
2320-
self.assertTrue(bias_input is not None)
2321-
self.assertTrue(weight_input.type().dtype() == torch.half)
2322-
self.assertTrue(bias_input.type().dtype() == torch.half)
2323-
2324-
self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
2325-
self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
2326-
23272226
@unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
23282227
def test_linear_concat(self):
23292228
out_dimms = [[5, 10], [1, 5]]

torch/csrc/jit/passes/fold_linear_bn.cpp

Lines changed: 0 additions & 28 deletions
This file was deleted.

torch/csrc/jit/passes/fold_linear_bn.h

Lines changed: 0 additions & 29 deletions
This file was deleted.

torch/csrc/jit/passes/frozen_graph_optimizations.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <torch/csrc/jit/passes/frozen_concat_linear.h>
55
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
66
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
7-
#include <torch/csrc/jit/passes/frozen_linear_folding.h>
87
#include <torch/csrc/jit/passes/remove_dropout.h>
98
#include <torch/csrc/jit/runtime/graph_executor.h>
109
#include <torch/csrc/utils/memory.h>
@@ -25,7 +24,6 @@ void OptimizeFrozenGraph(
2524
changed |= FoldFrozenConvBatchnorm(graph);
2625
changed |= FoldFrozenConvAddOrSub(graph);
2726
changed |= FoldFrozenConvMulOrDiv(graph);
28-
changed |= FoldFrozenLinearBatchnorm(graph);
2927
} while (changed);
3028
}
3129
}

torch/csrc/jit/passes/frozen_graph_optimizations.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
* - FoldFrozenConvBatchnorm
99
* - FoldFrozenConvAddOrSub
1010
* - FoldFrozenConvMulOrDiv
11-
* - FoldFrozenLinearBatchnorm
1211
*/
1312

1413
namespace torch {

torch/csrc/jit/passes/frozen_linear_folding.cpp

Lines changed: 0 additions & 127 deletions
This file was deleted.

torch/csrc/jit/passes/frozen_linear_folding.h

Lines changed: 0 additions & 14 deletions
This file was deleted.

torch/csrc/jit/python/init.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
4141
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
4242
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
43-
#include <torch/csrc/jit/passes/frozen_linear_folding.h>
4443
#include <torch/csrc/jit/passes/frozen_linear_transpose.h>
4544
#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
4645
#include <torch/csrc/jit/passes/fuse_linear.h>
@@ -400,7 +399,6 @@ void initJITBindings(PyObject* module) {
400399
.def("_jit_pass_fold_frozen_conv_bn", &FoldFrozenConvBatchnorm)
401400
.def("_jit_pass_fold_frozen_conv_add_or_sub", &FoldFrozenConvAddOrSub)
402401
.def("_jit_pass_fold_frozen_conv_mul_or_div", &FoldFrozenConvMulOrDiv)
403-
.def("_jit_pass_fold_frozen_linear_bn", &FoldFrozenLinearBatchnorm)
404402
.def("_jit_pass_convert_frozen_ops_to_mkldnn", &ConvertFrozenOpsToMKLDNN)
405403
.def("_jit_pass_fuse_frozen_conv_add_relu", &FuseFrozenConvAddRelu)
406404
.def("_jit_pass_transpose_frozen_linear", &FrozenLinearTranspose)

0 commit comments

Comments
 (0)