Skip to content

Commit 2e8d2a2

Browse files
andrewor14pytorchmergebot
authored andcommitted
[quant][pt2] Add test for inplace add (#102867)
Summary: This was broken after the recent partitioner refactors. Test Plan: python test/test_quantization.py TestQuantizePT2E.test_qat_inplace_add_relu Differential Revision: D46402378 Pull Request resolved: #102867 Approved by: https://github.com/jerryzh168
1 parent 28f43c7 commit 2e8d2a2

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,28 @@ def test_prepare_qat_conv_bn_relu_fusion(self):
12081208
m1, example_inputs, is_per_channel=True, has_relu=True
12091209
)
12101210

1211+
def test_qat_inplace_add_relu(self):
1212+
class M(torch.nn.Module):
1213+
def __init__(self):
1214+
super().__init__()
1215+
self.conv = torch.nn.Conv2d(1, 1, 1)
1216+
self.relu = torch.nn.ReLU(inplace=True)
1217+
1218+
def forward(self, x):
1219+
x0 = x
1220+
x = self.conv(x)
1221+
x += x0
1222+
x = self.relu(x)
1223+
return x
1224+
1225+
example_inputs = (torch.randn(1, 1, 3, 3),)
1226+
self._verify_symmetric_qnnpack_qat_numerics(
1227+
M(), example_inputs, is_per_channel=False, verify_convert=True,
1228+
)
1229+
self._verify_symmetric_qnnpack_qat_numerics(
1230+
M(), example_inputs, is_per_channel=True, verify_convert=True,
1231+
)
1232+
12111233
def test_prepare_qat_conv_bn_fusion_getitem_placeholder(self):
12121234
"""
12131235
Test this special case seen in resnet18:

0 commit comments

Comments
 (0)