Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8f06492
[ao][fx] fixing public v private qconfig_mapping_utils.py
HDCharles Nov 3, 2022
958f96b
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Nov 4, 2022
1c34569
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Nov 4, 2022
ad52cd1
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Nov 5, 2022
321fcea
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Nov 7, 2022
5b37282
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Nov 8, 2022
6ebbceb
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Nov 11, 2022
c4f8180
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Nov 11, 2022
f19125c
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Nov 11, 2022
f72f4d8
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Nov 11, 2022
2ee8f78
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Nov 15, 2022
72d2c0a
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Nov 15, 2022
90db66f
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Nov 15, 2022
6297b08
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Nov 16, 2022
611a23c
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Nov 17, 2022
b242bc9
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Dec 7, 2022
74d8cf9
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Dec 7, 2022
f614d61
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Dec 7, 2022
ee02bf9
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Dec 7, 2022
1de4288
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Dec 8, 2022
5b588b8
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Dec 9, 2022
637cb1f
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Dec 9, 2022
8709942
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Dec 12, 2022
1afe31c
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Dec 13, 2022
96adafb
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Dec 13, 2022
efd2bab
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Dec 13, 2022
f57ae08
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Dec 13, 2022
42bfa11
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Dec 14, 2022
a1b9897
Update on "[ao][fx] fixing public v private qconfig_mapping_utils.py"
HDCharles Dec 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
_get_object_type_qconfig,
_get_module_name_qconfig,
_get_module_name_regex_qconfig,
maybe_adjust_qconfig_for_module_name_object_type_order,
_maybe_adjust_qconfig_for_module_name_object_type_order,
)

from torch.ao.quantization.fx.pattern_utils import (
Expand Down Expand Up @@ -1963,9 +1963,9 @@ def test_qconfig_mapping_set_module_name_object_type_order(self):
self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2)
self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig1)
self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2)
self.assertEqual(maybe_adjust_qconfig_for_module_name_object_type_order(
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig1)
self.assertEqual(maybe_adjust_qconfig_for_module_name_object_type_order(
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2)
# Override existing key
qconfig_mapping.set_module_name_object_type_order("mod1", torch.nn.Linear, 0, qconfig3)
Expand All @@ -1974,16 +1974,16 @@ def test_qconfig_mapping_set_module_name_object_type_order(self):
self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2)
self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig3)
self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2)
self.assertEqual(maybe_adjust_qconfig_for_module_name_object_type_order(
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig3)
self.assertEqual(maybe_adjust_qconfig_for_module_name_object_type_order(
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2)
# No match
self.assertEqual(maybe_adjust_qconfig_for_module_name_object_type_order(
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig_mapping, "mod123", torch.nn.Linear, 0, None), None)
self.assertEqual(maybe_adjust_qconfig_for_module_name_object_type_order(
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig_mapping, "mod1", torch.nn.Linear, 35, None), None)
self.assertEqual(maybe_adjust_qconfig_for_module_name_object_type_order(
self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig_mapping, "mod2", torch.nn.Conv2d, 1, None), None)

def _get_qconfig_dict_for_qconfig_mapping_test(self, global_qconfig, qconfig1, qconfig2):
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/ns/_numeric_suite_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
from torch.ao.quantization.backend_config import BackendConfig
from torch.ao.quantization.fx.match_utils import _find_matches
from torch.ao.quantization.fx.qconfig_mapping_utils import generate_node_name_to_qconfig
from torch.ao.quantization.fx.qconfig_mapping_utils import _generate_node_name_to_qconfig
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
from torch.ao.quantization.qconfig import QConfigAny
from torch.ao.ns.fx.n_shadows_utils import (
Expand Down Expand Up @@ -825,7 +825,7 @@ def prepare_n_shadows_model(
# TODO(future PR): deduplicate repeating entries
list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]] = []
for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list:
node_name_to_qconfig = generate_node_name_to_qconfig(
node_name_to_qconfig = _generate_node_name_to_qconfig(
mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope)
list_of_node_name_to_qconfig.append(node_name_to_qconfig)

Expand Down
16 changes: 8 additions & 8 deletions torch/ao/quantization/fx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
)
from ..qconfig_mapping import QConfigMapping
from .qconfig_mapping_utils import (
generate_node_name_to_qconfig,
compare_prepare_convert_qconfig_mappings,
update_qconfig_for_fusion,
is_qconfig_supported_by_dtype_configs,
_generate_node_name_to_qconfig,
_compare_prepare_convert_qconfig_mappings,
_update_qconfig_for_fusion,
_is_qconfig_supported_by_dtype_configs,
_update_qconfig_for_qat,
)
from torch.ao.quantization.backend_config.utils import (
Expand Down Expand Up @@ -683,7 +683,7 @@ def convert_weighted_module(
# skip converting to reference quantized module if the qconfig is not supported
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
dtype_configs = pattern_to_dtype_configs.get(type(original_module), [])
if not is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs):
if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs):
return

# TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
Expand Down Expand Up @@ -920,10 +920,10 @@ def convert(

if model._is_qat:
_update_qconfig_for_qat(qconfig_mapping, {})
update_qconfig_for_fusion(model, qconfig_mapping)
_update_qconfig_for_fusion(model, qconfig_mapping)

compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) # type: ignore[arg-type]
convert_node_name_to_qconfig = generate_node_name_to_qconfig(
_compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) # type: ignore[arg-type]
convert_node_name_to_qconfig = _generate_node_name_to_qconfig(
model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope)
# check the convert_node_name_to_qconfig generated and ensure that
# all the values either match what was set in prepare node_name_to_qconfig
Expand Down
14 changes: 7 additions & 7 deletions torch/ao/quantization/fx/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
QConfigMapping,
)
from .qconfig_mapping_utils import (
generate_node_name_to_qconfig,
update_qconfig_for_fusion,
_generate_node_name_to_qconfig,
_update_qconfig_for_fusion,
_get_flattened_qconfig_dict,
_update_qconfig_for_qat,
)
Expand Down Expand Up @@ -1497,8 +1497,8 @@ def prepare(
root_node_getter_mapping = \
get_fusion_pattern_to_root_node_getter(backend_config)

update_qconfig_for_fusion(model, qconfig_mapping)
update_qconfig_for_fusion(model, _equalization_config)
_update_qconfig_for_fusion(model, qconfig_mapping)
_update_qconfig_for_fusion(model, _equalization_config)
flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
# TODO: support regex as well
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
Expand All @@ -1517,10 +1517,10 @@ def prepare(
# }
modules = dict(model.named_modules(remove_duplicate=False))

# fill node_name_to_qconfig, a map from node name to qconfig, used in find_matches
equalization_node_name_to_qconfig = generate_node_name_to_qconfig(
# fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
model, modules, model.graph, _equalization_config, node_name_to_scope)
node_name_to_qconfig = generate_node_name_to_qconfig(model, modules, model.graph, qconfig_mapping, node_name_to_scope)
node_name_to_qconfig = _generate_node_name_to_qconfig(model, modules, model.graph, qconfig_mapping, node_name_to_scope)

# match the patterns that will get quantized
standalone_module_names = list(prepare_custom_config.standalone_module_names.keys())
Expand Down
27 changes: 10 additions & 17 deletions torch/ao/quantization/fx/qconfig_mapping_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,12 @@
get_default_qat_module_mappings,
)

# TODO: revisit this list. Many helper methods shouldn't be public
__all__ = [
"check_is_valid_config_dict",
"compare_prepare_convert_qconfig_mappings",
"generate_node_name_to_qconfig",
"is_qconfig_supported_by_dtype_configs",
"maybe_adjust_qconfig_for_module_name_object_type_order",
"update_qconfig_for_fusion",
]

__all__: List[str] = []


def maybe_adjust_qconfig_for_module_name_object_type_order(

def _maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig_mapping: QConfigMapping,
cur_module_path: str,
cur_object_type: Callable,
Expand All @@ -63,7 +56,7 @@ def maybe_adjust_qconfig_for_module_name_object_type_order(
return fallback_qconfig


def update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping):
def _update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping):
"""
Update the QConfigMapping to account for fused modules such as LinearReLU.
This assumes the QConfigMapping's attributes have already been converted to OrderedDicts.
Expand Down Expand Up @@ -100,7 +93,7 @@ def update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMappin
if fused_qconfig is not None:
object_type_dict[type(maybe_fused_module)] = fused_qconfig

def generate_node_name_to_qconfig(
def _generate_node_name_to_qconfig(
root: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
input_graph: Graph,
Expand Down Expand Up @@ -137,7 +130,7 @@ def generate_node_name_to_qconfig(
cur_object_type_idx = \
submodule_to_object_type_to_cur_idx[module_path][node.target]
submodule_to_object_type_to_cur_idx[module_path][node.target] += 1
qconfig = maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig)
qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))

Expand Down Expand Up @@ -171,7 +164,7 @@ def generate_node_name_to_qconfig(
cur_object_type_idx = \
submodule_to_object_type_to_cur_idx[parent_name][module_type]
submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1
qconfig = maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig_mapping, parent_name, module_type, cur_object_type_idx,
qconfig)
qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
Expand All @@ -187,7 +180,7 @@ def generate_node_name_to_qconfig(
return node_name_to_qconfig


def check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], dict_name: str) -> None:
def _check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], dict_name: str) -> None:
r""" Checks if the given config_dict has the correct keys

Args:
Expand All @@ -202,7 +195,7 @@ def check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], dict_na
'\' instead.')


def compare_prepare_convert_qconfig_mappings(
def _compare_prepare_convert_qconfig_mappings(
prepare_qconfig_mapping: QConfigMapping,
convert_qconfig_mapping: QConfigMapping):
r""" Compare the qconfig_mapping passed in convert to the one from prepare and check the values
Expand Down Expand Up @@ -233,7 +226,7 @@ def compare_prepare_convert_qconfig_mappings(
"Expected convert QConfigMapping to have the same qconfig as prepare for key {} {}; \
prepare: {}; convert: {}".format(dict_names[i], name, prepare_dicts[i][name], convert_dicts[i][name])

def is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[DTypeConfig]):
def _is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[DTypeConfig]):
for dtype_config in dtype_configs:
is_dynamic = dtype_config.is_dynamic
if is_dynamic is None:
Expand Down