Skip to content

Commit 43a2a23

Browse files
jaewoosongpytorchmergebot
authored andcommitted
Support linear/BN fusion and follow the API guideline (#141585)
Current `fuse` function supports conv/BN fusions only. This commit is to support linear/BN fusion as well. Changes to follow the API guidelines are also applied. (This will close the PR #141352 which I created for the same topic and got approval but had lint and API guideline problems.) Pull Request resolved: #141585 Approved by: https://github.com/ezyang
1 parent 9e299b8 commit 43a2a23

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

torch/fx/experimental/optimization.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,23 @@
1414
import torch.utils.mkldnn as th_mkldnn
1515
from torch.fx.node import Argument, Target
1616
from torch.fx.passes.shape_prop import ShapeProp
17-
from torch.nn.utils.fusion import fuse_conv_bn_eval
17+
from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_linear_bn_eval
18+
19+
20+
__all__ = [
21+
"matches_module_pattern",
22+
"replace_node_module",
23+
"fuse",
24+
"remove_dropout",
25+
"extract_subgraph",
26+
"modules_to_mkldnn",
27+
"reset_modules",
28+
"MklSubgraph",
29+
"gen_mkl_autotuner",
30+
"use_mkl_length",
31+
"UnionFind",
32+
"optimize_for_inference",
33+
]
1834

1935

2036
def _parent_name(target: str) -> Tuple[str, str]:
@@ -58,13 +74,14 @@ def replace_node_module(
5874

5975
def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module:
6076
"""
61-
Fuses convolution/BN layers for inference purposes. Will deepcopy your
62-
model by default, but can modify the model inplace as well.
77+
Fuses convolution/BN and linear/BN layers for inference purposes.
78+
Will deepcopy your model by default, but can modify the model inplace as well.
6379
"""
6480
patterns = [
6581
(nn.Conv1d, nn.BatchNorm1d),
6682
(nn.Conv2d, nn.BatchNorm2d),
6783
(nn.Conv3d, nn.BatchNorm3d),
84+
(nn.Linear, nn.BatchNorm1d),
6885
]
6986
if not inplace:
7087
model = copy.deepcopy(model)
@@ -78,14 +95,18 @@ def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Modu
7895
for pattern in patterns:
7996
for node in new_graph.nodes:
8097
if matches_module_pattern(pattern, node, modules):
81-
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
98+
if len(node.args[0].users) > 1:
99+
# Output of conv/linear is used by other nodes
82100
continue
83-
conv = modules[node.args[0].target]
101+
first_layer = modules[node.args[0].target]
84102
bn = modules[node.target]
85103
if not bn.track_running_stats:
86104
continue
87-
fused_conv = fuse_conv_bn_eval(conv, bn)
88-
replace_node_module(node.args[0], modules, fused_conv)
105+
if pattern[0] in [nn.Conv1d, nn.Conv2d, nn.Conv3d]:
106+
fused_layer = fuse_conv_bn_eval(first_layer, bn)
107+
else: # nn.Linear
108+
fused_layer = fuse_linear_bn_eval(first_layer, bn)
109+
replace_node_module(node.args[0], modules, fused_layer)
89110
node.replace_all_uses_with(node.args[0])
90111
new_graph.erase_node(node)
91112
return fx.GraphModule(fx_model, new_graph)

0 commit comments

Comments
 (0)