Skip to content

Commit 700b7db

Browse files
committed
[ao][fx] fixing public v private qconfig_mapping_utils.py
Pull Request resolved: #88399 made _check_is_valid_config_dict, _compare_prepare_convert_qconfig_mappings, _generate_node_name_to_qconfig, _is_qconfig_supported_by_dtype_configs, _maybe_adjust_qconfig_for_module_name_object_type_order, _update_qconfig_for_fusion private ghstack-source-id: 176062423 Differential Revision: [D41015544](https://our.internmc.facebook.com/intern/diff/D41015544/)
1 parent b8f35ec commit 700b7db

File tree

5 files changed

+35
-42
lines changed

5 files changed

+35
-42
lines changed

test/quantization/fx/test_quantize_fx.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@
103103
_get_object_type_qconfig,
104104
_get_module_name_qconfig,
105105
_get_module_name_regex_qconfig,
106-
maybe_adjust_qconfig_for_module_name_object_type_order,
106+
_maybe_adjust_qconfig_for_module_name_object_type_order,
107107
)
108108

109109
from torch.ao.quantization.fx.pattern_utils import (
@@ -1959,9 +1959,9 @@ def test_qconfig_mapping_set_module_name_object_type_order(self):
19591959
self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2)
19601960
self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig1)
19611961
self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2)
1962-
self.assertEqual(maybe_adjust_qconfig_for_module_name_object_type_order(
1962+
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
19631963
qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig1)
1964-
self.assertEqual(maybe_adjust_qconfig_for_module_name_object_type_order(
1964+
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
19651965
qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2)
19661966
# Override existing key
19671967
qconfig_mapping.set_module_name_object_type_order("mod1", torch.nn.Linear, 0, qconfig3)
@@ -1970,16 +1970,16 @@ def test_qconfig_mapping_set_module_name_object_type_order(self):
19701970
self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2)
19711971
self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig3)
19721972
self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2)
1973-
self.assertEqual(maybe_adjust_qconfig_for_module_name_object_type_order(
1973+
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
19741974
qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig3)
1975-
self.assertEqual(maybe_adjust_qconfig_for_module_name_object_type_order(
1975+
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
19761976
qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2)
19771977
# No match
1978-
self.assertEqual(maybe_adjust_qconfig_for_module_name_object_type_order(
1978+
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
19791979
qconfig_mapping, "mod123", torch.nn.Linear, 0, None), None)
1980-
self.assertEqual(maybe_adjust_qconfig_for_module_name_object_type_order(
1980+
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
19811981
qconfig_mapping, "mod1", torch.nn.Linear, 35, None), None)
1982-
self.assertEqual(maybe_adjust_qconfig_for_module_name_object_type_order(
1982+
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
19831983
qconfig_mapping, "mod2", torch.nn.Conv2d, 1, None), None)
19841984

19851985
def _get_qconfig_dict_for_qconfig_mapping_test(self, global_qconfig, qconfig1, qconfig2):

torch/ao/ns/_numeric_suite_fx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@
122122
from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
123123
from torch.ao.quantization.backend_config import BackendConfig
124124
from torch.ao.quantization.fx.match_utils import find_matches
125-
from torch.ao.quantization.fx.qconfig_mapping_utils import generate_node_name_to_qconfig
125+
from torch.ao.quantization.fx.qconfig_mapping_utils import _generate_node_name_to_qconfig
126126
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
127127
from torch.ao.quantization.qconfig import QConfigAny
128128
from torch.ao.ns.fx.n_shadows_utils import (
@@ -825,7 +825,7 @@ def prepare_n_shadows_model(
825825
# TODO(future PR): deduplicate repeating entries
826826
list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]] = []
827827
for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list:
828-
node_name_to_qconfig = generate_node_name_to_qconfig(
828+
node_name_to_qconfig = _generate_node_name_to_qconfig(
829829
mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope)
830830
list_of_node_name_to_qconfig.append(node_name_to_qconfig)
831831

torch/ao/quantization/fx/convert.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
)
2525
from ..qconfig_mapping import QConfigMapping
2626
from .qconfig_mapping_utils import (
27-
generate_node_name_to_qconfig,
28-
compare_prepare_convert_qconfig_mappings,
29-
update_qconfig_for_fusion,
30-
is_qconfig_supported_by_dtype_configs,
27+
_generate_node_name_to_qconfig,
28+
_compare_prepare_convert_qconfig_mappings,
29+
_update_qconfig_for_fusion,
30+
_is_qconfig_supported_by_dtype_configs,
3131
_update_qconfig_for_qat,
3232
)
3333
from torch.ao.quantization.backend_config.utils import (
@@ -691,7 +691,7 @@ def convert_weighted_module(
691691
# skip converting to reference quantized module if the qconfig is not supported
692692
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
693693
dtype_configs = pattern_to_dtype_configs.get(type(original_module), [])
694-
if not is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs):
694+
if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs):
695695
return
696696

697697
# TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
@@ -928,10 +928,10 @@ def convert(
928928

929929
if model._is_qat:
930930
_update_qconfig_for_qat(qconfig_mapping, {})
931-
update_qconfig_for_fusion(model, qconfig_mapping)
931+
_update_qconfig_for_fusion(model, qconfig_mapping)
932932

933-
compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) # type: ignore[arg-type]
934-
convert_node_name_to_qconfig = generate_node_name_to_qconfig(
933+
_compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) # type: ignore[arg-type]
934+
convert_node_name_to_qconfig = _generate_node_name_to_qconfig(
935935
model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope)
936936
# check the convert_node_name_to_qconfig generated and ensure that
937937
# all the values either match what was set in prepare node_name_to_qconfig

torch/ao/quantization/fx/prepare.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
QConfigMapping,
2525
)
2626
from .qconfig_mapping_utils import (
27-
generate_node_name_to_qconfig,
28-
update_qconfig_for_fusion,
27+
_generate_node_name_to_qconfig,
28+
_update_qconfig_for_fusion,
2929
_get_flattened_qconfig_dict,
3030
_update_qconfig_for_qat,
3131
)
@@ -1497,8 +1497,8 @@ def prepare(
14971497
root_node_getter_mapping = \
14981498
get_fusion_pattern_to_root_node_getter(backend_config)
14991499

1500-
update_qconfig_for_fusion(model, qconfig_mapping)
1501-
update_qconfig_for_fusion(model, _equalization_config)
1500+
_update_qconfig_for_fusion(model, qconfig_mapping)
1501+
_update_qconfig_for_fusion(model, _equalization_config)
15021502
flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
15031503
# TODO: support regex as well
15041504
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
@@ -1517,10 +1517,10 @@ def prepare(
15171517
# }
15181518
modules = dict(model.named_modules(remove_duplicate=False))
15191519

1520-
# fill node_name_to_qconfig, a map from node name to qconfig, used in find_matches
1521-
equalization_node_name_to_qconfig = generate_node_name_to_qconfig(
1520+
# fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
1521+
equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
15221522
model, modules, model.graph, _equalization_config, node_name_to_scope)
1523-
node_name_to_qconfig = generate_node_name_to_qconfig(model, modules, model.graph, qconfig_mapping, node_name_to_scope)
1523+
node_name_to_qconfig = _generate_node_name_to_qconfig(model, modules, model.graph, qconfig_mapping, node_name_to_scope)
15241524

15251525
# match the patterns that will get quantized
15261526
standalone_module_names = list(prepare_custom_config.standalone_module_names.keys())

torch/ao/quantization/fx/qconfig_mapping_utils.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,12 @@
3434
get_default_qat_module_mappings,
3535
)
3636

37-
# TODO: revisit this list. Many helper methods shouldn't be public
38-
__all__ = [
39-
"check_is_valid_config_dict",
40-
"compare_prepare_convert_qconfig_mappings",
41-
"generate_node_name_to_qconfig",
42-
"is_qconfig_supported_by_dtype_configs",
43-
"maybe_adjust_qconfig_for_module_name_object_type_order",
44-
"update_qconfig_for_fusion",
45-
]
4637

38+
__all__: List[str] = []
4739

4840

49-
def maybe_adjust_qconfig_for_module_name_object_type_order(
41+
42+
def _maybe_adjust_qconfig_for_module_name_object_type_order(
5043
qconfig_mapping: QConfigMapping,
5144
cur_module_path: str,
5245
cur_object_type: Callable,
@@ -63,7 +56,7 @@ def maybe_adjust_qconfig_for_module_name_object_type_order(
6356
return fallback_qconfig
6457

6558

66-
def update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping):
59+
def _update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping):
6760
"""
6861
Update the QConfigMapping to account for fused modules such as LinearReLU.
6962
This assumes the QConfigMapping's attributes have already been converted to OrderedDicts.
@@ -100,7 +93,7 @@ def update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMappin
10093
if fused_qconfig is not None:
10194
object_type_dict[type(maybe_fused_module)] = fused_qconfig
10295

103-
def generate_node_name_to_qconfig(
96+
def _generate_node_name_to_qconfig(
10497
root: torch.nn.Module,
10598
modules: Dict[str, torch.nn.Module],
10699
input_graph: Graph,
@@ -137,7 +130,7 @@ def generate_node_name_to_qconfig(
137130
cur_object_type_idx = \
138131
submodule_to_object_type_to_cur_idx[module_path][node.target]
139132
submodule_to_object_type_to_cur_idx[module_path][node.target] += 1
140-
qconfig = maybe_adjust_qconfig_for_module_name_object_type_order(
133+
qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
141134
qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig)
142135
qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
143136

@@ -171,7 +164,7 @@ def generate_node_name_to_qconfig(
171164
cur_object_type_idx = \
172165
submodule_to_object_type_to_cur_idx[parent_name][module_type]
173166
submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1
174-
qconfig = maybe_adjust_qconfig_for_module_name_object_type_order(
167+
qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
175168
qconfig_mapping, parent_name, module_type, cur_object_type_idx,
176169
qconfig)
177170
qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
@@ -187,7 +180,7 @@ def generate_node_name_to_qconfig(
187180
return node_name_to_qconfig
188181

189182

190-
def check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], dict_name: str) -> None:
183+
def _check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], dict_name: str) -> None:
191184
r""" Checks if the given config_dict has the correct keys
192185
193186
Args:
@@ -202,7 +195,7 @@ def check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], dict_na
202195
'\' instead.')
203196

204197

205-
def compare_prepare_convert_qconfig_mappings(
198+
def _compare_prepare_convert_qconfig_mappings(
206199
prepare_qconfig_mapping: QConfigMapping,
207200
convert_qconfig_mapping: QConfigMapping):
208201
r""" Compare the qconfig_mapping passed in convert to the one from prepare and check the values
@@ -233,7 +226,7 @@ def compare_prepare_convert_qconfig_mappings(
233226
"Expected convert QConfigMapping to have the same qconfig as prepare for key {} {}; \
234227
prepare: {}; convert: {}".format(dict_names[i], name, prepare_dicts[i][name], convert_dicts[i][name])
235228

236-
def is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[DTypeConfig]):
229+
def _is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[DTypeConfig]):
237230
for dtype_config in dtype_configs:
238231
is_dynamic = dtype_config.is_dynamic
239232
if is_dynamic is None:

0 commit comments

Comments
 (0)