Skip to content

Commit 12c257c

Browse files
jerryzh168pytorchmergebot
authored andcommitted
[qunat][pt2e] Support allow_implicit_sharing flag (#112929)
Summary: For a Node: node1 and edge: (node1, node2), since they are observing the same Tensor, we may want to implicitly share observers, this flag allows people to turn off this behavior for the output of the node See the test_allow_implicit_sharing test for use case Test Plan: python test/test_quantization.py TestQuantizePT2E.test_allow_implicit_sharing Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: #112929 Approved by: https://github.com/kimishpatel
1 parent 625958d commit 12c257c

File tree

5 files changed

+108
-16
lines changed

5 files changed

+108
-16
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,83 @@ def validate(self, model: torch.fx.GraphModule) -> None:
10601060

10611061
self._test_transitive_sharing_with_cat_helper(BackendAQuantizer())
10621062

1063+
def test_allow_implicit_sharing(self):
1064+
"""This tests the allow_transitive_sharing flag of QuantizationAnnotation, that is
1065+
if a node is configured with allow_implicit_sharing=False, we will not have implicit sharing
1066+
for node and (node, consumer) even they refer to the same Tensor
1067+
1068+
x1 -> add1 -----> add3
1069+
x2 -/ /
1070+
x3 -> add2 /
1071+
x4 -/
1072+
1073+
all add has shared input and output, and second input is using shared quantization spec pointing
1074+
to first input, but we set allow_implicit_sharing to False for all add nodes so input and output of add1,
1075+
add2 and add3 will each belong to one sharing group, so we'll have:
1076+
1077+
x1 -> obs1 -> add1 -> obs1 -> obs3--> add3 -> obs3
1078+
x2 -> obs1 -/ /
1079+
x3 -> obs2 -> add2 -> obs2 -> obs3
1080+
x4 -> obs2 -/
1081+
"""
1082+
# TODO: refactor this to a common util
1083+
class BackendAQuantizer(Quantizer):
1084+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1085+
for node in model.graph.nodes:
1086+
if node.target is torch.ops.aten.add.Tensor:
1087+
add_node = node
1088+
first_input_node = add_node.args[0]
1089+
second_input_node = add_node.args[1]
1090+
input_qspec_map = {}
1091+
act_qspec = QuantizationSpec(
1092+
dtype=torch.uint8,
1093+
quant_min=0,
1094+
quant_max=255,
1095+
qscheme=torch.per_tensor_affine,
1096+
is_dynamic=False,
1097+
observer_or_fake_quant_ctr=observer.default_observer,
1098+
)
1099+
input_qspec_map[second_input_node] = act_qspec
1100+
share_qparams_with_input_act1_qspec = SharedQuantizationSpec((second_input_node, add_node))
1101+
input_qspec_map[first_input_node] = share_qparams_with_input_act1_qspec
1102+
1103+
add_node.meta[
1104+
"quantization_annotation"
1105+
] = QuantizationAnnotation(
1106+
input_qspec_map=input_qspec_map,
1107+
output_qspec=share_qparams_with_input_act1_qspec,
1108+
allow_implicit_sharing=False,
1109+
_annotated=True,
1110+
)
1111+
1112+
def validate(self, model: torch.fx.GraphModule) -> None:
1113+
pass
1114+
1115+
m = TestHelperModules.ThreeAdd().eval()
1116+
example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5))
1117+
1118+
# program capture
1119+
m = capture_pre_autograd_graph(
1120+
m,
1121+
example_inputs,
1122+
)
1123+
quantizer = BackendAQuantizer()
1124+
m = prepare_pt2e(m, quantizer)
1125+
m(*example_inputs)
1126+
observers = []
1127+
for n in m.graph.nodes:
1128+
if n.target == torch.ops.aten.add.Tensor:
1129+
input_obs1 = getattr(m, n.args[0].target)
1130+
input_obs2 = getattr(m, n.args[1].target)
1131+
output_obs = getattr(m, list(n.users)[0].target)
1132+
self.assertIs(input_obs1, input_obs2)
1133+
self.assertIs(input_obs1, output_obs)
1134+
observers.append(input_obs1)
1135+
assert len(observers) == 3
1136+
self.assertIsNot(observers[0], observers[1])
1137+
self.assertIsNot(observers[0], observers[2])
1138+
self.assertIsNot(observers[1], observers[2])
1139+
10631140
def test_int16(self):
10641141
class Int16ActQuantizer(Quantizer):
10651142
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:

test/quantization/pt2e/test_xnnpack_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def test_propagate_annotation(self):
373373
]:
374374
input_act = getattr(m, n.args[0].target)
375375
output_act = getattr(m, list(n.users)[0].target)
376-
self.assertTrue(input_act is output_act)
376+
self.assertIs(input_act, output_act)
377377

378378
m = convert_pt2e(m, fold_quantize=True)
379379
node_occurrence = {

torch/ao/quantization/pt2e/prepare.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"prepare",
3030
]
3131

32+
3233
def _find_root(edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> EdgeOrNode:
3334
"""Find the root node for the sharing tree
3435
Args:
@@ -177,21 +178,22 @@ def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, Quanti
177178
# find root_qspec for `arg` Node (the output of previous node)
178179
assert isinstance(input_edge, tuple)
179180
arg, n = input_edge
180-
arg_as_output_root_qspec = None
181-
if arg in edge_or_node_to_qspec:
182-
arg_as_output_qspec = edge_or_node_to_qspec[arg]
183-
arg_as_output_root_qspec = _find_root_qspec(arg_as_output_qspec, edge_or_node_to_qspec, shared_with_map)
184-
# TODO: add assertions for types of root qspecs
185-
if (
186-
arg_as_output_root_qspec is not None and
187-
_has_same_dtype(arg_as_output_root_qspec, input_edge_root_qspec) and
188-
_has_same_is_dynamic(arg_as_output_root_qspec, input_edge_root_qspec)
189-
):
190-
# the input arg to the node should reuse the existing output observer for arg
191-
# since dtype is the same (we may want to extend this to be a more strict check
192-
# in the future)
193-
# so we point from `input_edge` to `arg` (output of the argument)
194-
_union(arg, input_edge, shared_with_map)
181+
if n.meta["quantization_annotation"].allow_implicit_sharing:
182+
arg_as_output_root_qspec = None
183+
if arg in edge_or_node_to_qspec:
184+
arg_as_output_qspec = edge_or_node_to_qspec[arg]
185+
arg_as_output_root_qspec = _find_root_qspec(arg_as_output_qspec, edge_or_node_to_qspec, shared_with_map)
186+
# TODO: add assertions for types of root qspecs
187+
if (
188+
arg_as_output_root_qspec is not None and
189+
_has_same_dtype(arg_as_output_root_qspec, input_edge_root_qspec) and
190+
_has_same_is_dynamic(arg_as_output_root_qspec, input_edge_root_qspec)
191+
):
192+
# the input arg to the node should reuse the existing output observer for arg
193+
# since dtype is the same (we may want to extend this to be a more strict check
194+
# in the future)
195+
# so we point from `input_edge` to `arg` (output of the argument)
196+
_union(arg, input_edge, shared_with_map)
195197
_update_shared_with(input_edge, qspec, shared_with_map)
196198

197199
# now that we get the sharing relations between all edges and nodes, we can assingn group ids

torch/ao/quantization/quantizer/quantizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class SharedQuantizationSpec(QuantizationSpecBase):
114114
Quantization spec for the Tensors whose quantization parameters are shared with other Tensors
115115
"""
116116

117+
# the edge or node to share observer or fake quant instances with
117118
edge_or_node: EdgeOrNode
118119

119120

@@ -146,6 +147,11 @@ class QuantizationAnnotation:
146147
# TODO: change the value to QuantizationSpec in a separate PR
147148
output_qspec: Optional[QuantizationSpecBase] = None
148149

150+
# For a Node: node1 and edge: (node1, node2), since they are observing the same
151+
# Tensor, we may want to implicitly share observers, this flag allows people to
152+
# turn off this behavior for the output of the node
153+
allow_implicit_sharing: bool = True
154+
149155
# whether the node is annotated or not
150156
_annotated: bool = False
151157

torch/testing/_internal/common_quantization.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2585,6 +2585,13 @@ def forward(self, x1, x2, x3, x4):
25852585
w = torch.cat([z, y])
25862586
return w
25872587

2588+
class ThreeAdd(torch.nn.Module):
2589+
def forward(self, x1, x2, x3, x4):
2590+
y = x1 + x2
2591+
z = x3 + x4
2592+
w = y + z
2593+
return w
2594+
25882595
class EmbeddingModule(torch.nn.Module):
25892596
def __init__(self):
25902597
super().__init__()

0 commit comments

Comments
 (0)