Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/allowlist_for_publicAPI.json
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@
"prepare_qat",
"propagate_qconfig_",
"qconfig_equals",
"quant_type_to_str",
"_get_quant_type_to_str",
"quantize",
"quantize_dynamic",
"quantize_dynamic_jit",
Expand Down Expand Up @@ -874,7 +874,7 @@
],
"torch.quantization.quant_type": [
"QuantType",
"quant_type_to_str"
"_get_quant_type_to_str"
],
"torch.quantization.quantization_mappings": [
"get_default_compare_output_module_list",
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/ao_migration/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_package_import_quant_type(self):
def test_function_import_quant_type(self):
function_list = [
'QuantType',
'quant_type_to_str',
'_get_quant_type_to_str',
]
self._test_function_import('quant_type', function_list)

Expand Down
4 changes: 2 additions & 2 deletions test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@

from torch.ao.quantization import (
QuantType,
quant_type_to_str,
)
from torch.ao.quantization.quant_type import _get_quant_type_to_str

from torch.ao.quantization import (
QuantStub,
Expand Down Expand Up @@ -2636,7 +2636,7 @@ def forward(self, x):
}

for quant_type in [QuantType.STATIC, QuantType.DYNAMIC]:
key = quant_type_to_str(quant_type)
key = _get_quant_type_to_str(quant_type)
qconfig, quantized_module_class, num_observers = test_configs[key]
qconfig_dict = {"": qconfig}
if key == "static":
Expand Down
1 change: 0 additions & 1 deletion torch/ao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@
"prepare_qat",
"propagate_qconfig_",
"qconfig_equals",
"quant_type_to_str",
"quantize",
"quantize_dynamic",
"quantize_dynamic_jit",
Expand Down
6 changes: 3 additions & 3 deletions torch/ao/quantization/fx/custom_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.backend_config import BackendConfig
from torch.ao.quantization.quant_type import QuantType, _quant_type_from_str, quant_type_to_str
from torch.ao.quantization.quant_type import QuantType, _quant_type_from_str, _get_quant_type_to_str


__all__ = [
Expand Down Expand Up @@ -263,7 +263,7 @@ def _make_tuple(key: Any, e: StandaloneModuleConfigEntry):
for quant_type, float_to_observed_mapping in self.float_to_observed_mapping.items():
if FLOAT_TO_OBSERVED_DICT_KEY not in d:
d[FLOAT_TO_OBSERVED_DICT_KEY] = {}
d[FLOAT_TO_OBSERVED_DICT_KEY][quant_type_to_str(quant_type)] = float_to_observed_mapping
d[FLOAT_TO_OBSERVED_DICT_KEY][_get_quant_type_to_str(quant_type)] = float_to_observed_mapping
if len(self.non_traceable_module_names) > 0:
d[NON_TRACEABLE_MODULE_NAME_DICT_KEY] = self.non_traceable_module_names
if len(self.non_traceable_module_classes) > 0:
Expand Down Expand Up @@ -350,7 +350,7 @@ def to_dict(self) -> Dict[str, Any]:
for quant_type, observed_to_quantized_mapping in self.observed_to_quantized_mapping.items():
if OBSERVED_TO_QUANTIZED_DICT_KEY not in d:
d[OBSERVED_TO_QUANTIZED_DICT_KEY] = {}
d[OBSERVED_TO_QUANTIZED_DICT_KEY][quant_type_to_str(quant_type)] = observed_to_quantized_mapping
d[OBSERVED_TO_QUANTIZED_DICT_KEY][_get_quant_type_to_str(quant_type)] = observed_to_quantized_mapping
if len(self.preserved_attributes) > 0:
d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
return d
Expand Down
3 changes: 1 addition & 2 deletions torch/ao/quantization/quant_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

__all__ = [
"QuantType",
"quant_type_to_str",
]

# Quantization type (dynamic quantization, static quantization).
Expand All @@ -21,7 +20,7 @@ class QuantType(enum.IntEnum):
}

# TODO: make this private
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can remove this comment

def quant_type_to_str(quant_type: QuantType) -> str:
def _get_quant_type_to_str(quant_type: QuantType) -> str:
return _quant_type_to_str[quant_type]

def _quant_type_from_str(name: str) -> QuantType:
Expand Down
2 changes: 1 addition & 1 deletion torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def default_eval_fn(model, calib_data):
# Top level API for graph mode quantization on GraphModule(torch.fx)
# 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
# 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
'QuantType', 'quant_type_to_str', # quantization type
'QuantType', # quantization type
# custom module APIs
'get_default_static_quant_module_mappings', 'get_static_quant_module_class',
'get_default_dynamic_quant_module_mappings',
Expand Down
2 changes: 1 addition & 1 deletion torch/quantization/quant_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
"""

from torch.ao.quantization.quant_type import QuantType
from torch.ao.quantization.quant_type import quant_type_to_str
from torch.ao.quantization.quant_type import _get_quant_type_to_str