Skip to content

Commit 6cf9ed4

Browse files
jerryzh168facebook-github-bot
authored andcommitted
ConvBn2d/ConvBnReLU2d (#23357)
Summary: Added _intrinsic.qat.ConvBn2d/_intrinsic.qat.ConvBnReLU2d. Pull Request resolved: #23357 ghstack-source-id: 87519573 Differential Revision: D16295500 fbshipit-source-id: 81e6d1d10d05bf6e343721fc5701d3d6bd7e07e6
1 parent 029c8e7 commit 6cf9ed4

File tree

9 files changed

+412
-60
lines changed

9 files changed

+412
-60
lines changed

test/common_quantization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def forward(self, x):
228228
class SubModForFusion(torch.nn.Module):
229229
def __init__(self):
230230
super(SubModForFusion, self).__init__()
231-
self.conv = torch.nn.Conv2d(20, 20, 1)
231+
self.conv = torch.nn.Conv2d(20, 20, 1, bias=None)
232232
self.bn = torch.nn.BatchNorm2d(20)
233233

234234
def forward(self, x):
@@ -239,9 +239,9 @@ def forward(self, x):
239239
class ModForFusion(torch.nn.Module):
240240
def __init__(self):
241241
super(ModForFusion, self).__init__()
242-
self.conv1 = torch.nn.Conv2d(10, 20, 5)
242+
self.conv1 = torch.nn.Conv2d(10, 20, 5, bias=None)
243243
self.bn1 = torch.nn.BatchNorm2d(20)
244-
self.relu1 = torch.nn.ReLU()
244+
self.relu1 = torch.nn.ReLU(inplace=False)
245245
self.sub1 = SubModForFusion()
246246
self.sub2 = SubModForFusion()
247247

test/test_qat.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
from __future__ import unicode_literals
5+
6+
import torch
7+
from torch.nn import Conv2d, BatchNorm2d, ReLU
8+
from torch.nn._intrinsic.qat import ConvBn2d, ConvBnReLU2d
9+
from torch.quantization.QConfig import default_qat_qconfig
10+
from torch.nn import Parameter
11+
from common_utils import TestCase, run_tests
12+
from hypothesis import given
13+
from hypothesis import strategies as st
14+
from functools import reduce
15+
16+
17+
class IntrinsicQATModuleTest(TestCase):
18+
19+
@given(batch_size=st.integers(1, 3),
20+
input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
21+
height=st.integers(10, 16),
22+
width=st.integers(7, 14),
23+
output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
24+
groups=st.integers(1, 3),
25+
kernel_h=st.integers(1, 7),
26+
kernel_w=st.integers(1, 7),
27+
stride_h=st.integers(1, 2),
28+
stride_w=st.integers(1, 2),
29+
pad_h=st.integers(0, 2),
30+
pad_w=st.integers(0, 2),
31+
dilation=st.integers(1, 1),
32+
padding_mode=st.sampled_from(['zeros', 'circular']),
33+
use_relu=st.booleans(),
34+
eps=st.sampled_from([1e-5, 1e-4, 1e-3, 0.01, 0.1]),
35+
momentum=st.sampled_from([0.1, 0.2, 0.3]),
36+
freeze_bn=st.booleans())
37+
def test_conv_bn_relu(
38+
self,
39+
batch_size,
40+
input_channels_per_group,
41+
height,
42+
width,
43+
output_channels_per_group,
44+
groups,
45+
kernel_h,
46+
kernel_w,
47+
stride_h,
48+
stride_w,
49+
pad_h,
50+
pad_w,
51+
dilation,
52+
padding_mode,
53+
use_relu,
54+
eps,
55+
momentum,
56+
freeze_bn
57+
):
58+
input_channels = input_channels_per_group * groups
59+
output_channels = output_channels_per_group * groups
60+
dilation_h = dilation_w = dilation
61+
62+
conv_op = Conv2d(
63+
input_channels,
64+
output_channels,
65+
(kernel_h, kernel_w),
66+
(stride_h, stride_w),
67+
(pad_h, pad_w),
68+
(dilation_h, dilation_w),
69+
groups,
70+
False, # No bias
71+
padding_mode
72+
).to(dtype=torch.float)
73+
bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.float)
74+
relu_op = ReLU()
75+
76+
cls = ConvBnReLU2d if use_relu else ConvBn2d
77+
qat_op = cls(
78+
input_channels,
79+
output_channels,
80+
(kernel_h, kernel_w),
81+
(stride_h, stride_w),
82+
(pad_h, pad_w),
83+
(dilation_h, dilation_w),
84+
groups,
85+
padding_mode,
86+
eps,
87+
momentum,
88+
freeze_bn,
89+
default_qat_qconfig.activation,
90+
default_qat_qconfig.weight
91+
).to(dtype=torch.float).disable_fake_quant()
92+
93+
# align inputs and internal parameters
94+
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.float)
95+
input.requires_grad_()
96+
conv_op.weight = Parameter(qat_op.weight)
97+
bn_op.running_mean = qat_op.running_mean
98+
bn_op.running_var = qat_op.running_var
99+
bn_op.weight = qat_op.gamma
100+
bn_op.bias = qat_op.beta
101+
102+
def compose(functions):
103+
# functions are reversed for natural reading order
104+
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
105+
106+
if not use_relu:
107+
def relu_op(x):
108+
return x
109+
110+
if freeze_bn:
111+
def ref_op(x):
112+
x = conv_op(x)
113+
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \
114+
(bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \
115+
.reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
116+
x = relu_op(x)
117+
return x
118+
else:
119+
ref_op = compose([conv_op, bn_op, relu_op])
120+
121+
result_ref = ref_op(input)
122+
result_actual = qat_op(input)
123+
self.assertEqual(result_ref, result_actual)
124+
125+
# backward
126+
dout = torch.randn(result_ref.size(), dtype=torch.float)
127+
result_actual.backward(dout, retain_graph=True)
128+
grad_ref = input.grad.cpu()
129+
result_actual.backward(dout)
130+
grad_actual = input.grad.cpu()
131+
self.assertEqual(grad_ref, grad_actual)
132+
133+
if __name__ == '__main__':
134+
run_tests()
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import absolute_import, division, print_function, unicode_literals
22

33
from .linear_relu import LinearReLU
4-
from .conv_relu import ConvReLU2d
4+
from .conv_fused import ConvBn2d, ConvBnReLU2d, ConvReLU2d
55

66
__all__ = [
77
'LinearReLU',
88
'ConvReLU2d',
9+
'ConvBn2d',
10+
'ConvBnReLU2d'
911
]

0 commit comments

Comments
 (0)