Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/allowlist_for_publicAPI.json
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@
"get_quantized_operator",
"get_static_quant_module_class",
"get_unique_devices_",
"is_activation_post_process",
"_is_activation_post_process",
"load_observer_state_dict",
"no_observer_set",
"prepare",
Expand Down Expand Up @@ -894,7 +894,7 @@
"convert",
"get_observer_dict",
"get_unique_devices_",
"is_activation_post_process",
"_is_activation_post_process",
"prepare",
"prepare_qat",
"propagate_qconfig_",
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/ao_migration/test_ao_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_function_import_quantize(self):
'convert',
'get_observer_dict',
'get_unique_devices_',
'is_activation_post_process',
'_is_activation_post_process',
'prepare',
'prepare_qat',
'propagate_qconfig_',
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/ao_migration/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_function_import_quantize(self):
'convert',
'get_observer_dict',
'get_unique_devices_',
'is_activation_post_process',
'_is_activation_post_process',
'prepare',
'prepare_qat',
'propagate_qconfig_',
Expand Down
6 changes: 3 additions & 3 deletions test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
get_default_qat_qconfig,
get_default_qconfig_mapping,
get_default_qat_qconfig_mapping,
is_activation_post_process,
fuse_modules,
fuse_modules_qat,
prepare,
Expand Down Expand Up @@ -146,6 +145,7 @@
default_fixed_qparams_range_0to1_observer,
default_fixed_qparams_range_neg1to1_observer,
MinMaxObserver,
_is_activation_post_process,
)

# test utils
Expand Down Expand Up @@ -3292,7 +3292,7 @@ def _check_node_not_observed(model, arg_node, node):
_check_node_not_observed(model, new_node, node)
elif arg_node.op == "call_module":
self.assertTrue(
not is_activation_post_process(getattr(model, arg_node.target)),
not _is_activation_post_process(getattr(model, arg_node.target)),
"Arg: {0} of node: {1} is observed but is not a float tensor".format(
arg_node, node
),
Expand Down Expand Up @@ -5051,7 +5051,7 @@ def forward(self, x):
qconfig_dict = func(backend)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1)))
for name, mod in m.named_modules():
if is_activation_post_process(mod) and mod.dtype == torch.quint8:
if _is_activation_post_process(mod) and mod.dtype == torch.quint8:
if backend == "fbgemm":
lower_bnd = 0
upper_bnd = 127
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/ns/fx/graph_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch.ao.ns.fx.mappings import (
get_node_type_to_io_type_map,
)
from torch.ao.quantization.quantize import is_activation_post_process
from torch.ao.quantization.observer import _is_activation_post_process

from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set

Expand All @@ -38,7 +38,7 @@ def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
if node.op == 'call_module':
assert isinstance(node.target, str)
module = getattr_from_fqn(gm, node.target)
if is_activation_post_process(module):
if _is_activation_post_process(module):
node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index]
return fqn # type: ignore[return-value]
Expand Down
6 changes: 3 additions & 3 deletions torch/ao/ns/fx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
FakeQuantizeBase,
)
from torch.ao.quantization.utils import getattr_from_fqn
from torch.ao.quantization.quantize import is_activation_post_process
from torch.ao.quantization.observer import _is_activation_post_process

from .ns_types import NSNodeTargetType, NSResultsType

Expand Down Expand Up @@ -256,14 +256,14 @@ def return_first_non_observer_node(
"""
if node.op == "call_module":
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
if is_activation_post_process(node_obj):
if _is_activation_post_process(node_obj):
assert len(node.args) == 1
assert isinstance(node.args[0], Node)
node = node.args[0]
# code duplication intended, not worth refactoring
assert isinstance(node.target, str)
node_obj = getattr_from_fqn(gm, node.target)
if is_activation_post_process(node_obj):
if _is_activation_post_process(node_obj):
assert len(node.args) == 1
assert isinstance(node.args[0], Node)
node = node.args[0]
Expand Down
1 change: 0 additions & 1 deletion torch/ao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@
"get_quantized_operator",
"get_static_quant_module_class",
"get_unique_devices_",
"is_activation_post_process",
"load_observer_state_dict",
"no_observer_set",
"per_channel_weight_observer_range_neg_127_to_127",
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/fx/_model_report/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
default_equalization_qconfig,
EqualizationQConfig,
)
from torch.ao.quantization.quantize import is_activation_post_process
from torch.ao.quantization.observer import _is_activation_post_process

# Names for observer insert keys
DETECTOR_TARGET_NODE_KEY = "target_node"
Expand Down Expand Up @@ -1273,7 +1273,7 @@ def _supports_insertion(self, module: nn.Module) -> bool:
# case for insertion of module
# check if the module has any children and isn't observer
num_children = len(list(module.children()))
return num_children == 0 and not is_activation_post_process(module)
return num_children == 0 and not _is_activation_post_process(module)

def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
r""" Returns the DetectorQConfigInfo for each module_fqn relavent
Expand Down
6 changes: 3 additions & 3 deletions torch/ao/quantization/fx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
BackendConfig,
get_native_backend_config,
)
from torch.ao.quantization.observer import _is_activation_post_process
from .graph_module import (
QuantizedGraphModule,
is_observed_module,
Expand All @@ -62,7 +63,6 @@
)
from torch.ao.quantization.quantize import (
_remove_qconfig,
is_activation_post_process,
)
from torch.ao.quantization.stubs import DeQuantStub
from .custom_config import (
Expand Down Expand Up @@ -582,7 +582,7 @@ def maybe_get_observer_for_node(
for maybe_obs_node, _ in node.users.items():
if maybe_obs_node.op == 'call_module':
maybe_obs = modules[str(maybe_obs_node.target)]
if is_activation_post_process(maybe_obs):
if _is_activation_post_process(maybe_obs):
return maybe_obs
return None

Expand Down Expand Up @@ -1010,7 +1010,7 @@ def convert(
elif node.op == "call_module":
mod = _get_module(node, modules)
assert mod is not None
if is_activation_post_process(mod):
if _is_activation_post_process(mod):
observed_node = node.args[0]
if observed_node in statically_quantized_custom_module_nodes:
_replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/fx/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from ..observer import (
ObserverBase,
_is_activation_post_process
)
from ..qconfig import (
_is_reuse_input_qconfig,
Expand Down Expand Up @@ -78,7 +79,6 @@
)

from torch.ao.quantization.quantize import (
is_activation_post_process,
convert
)

Expand Down Expand Up @@ -144,7 +144,7 @@

def is_activation_post_process_node(node: Node, modules: Dict[str, torch.nn.Module]) -> bool:
return isinstance(node, torch.fx.Node) and node.op == "call_module" and \
is_activation_post_process(modules[str(node.target)])
_is_activation_post_process(modules[str(node.target)])

def is_input_arg_dtype_supported_by_backend(
arg: Argument,
Expand Down
6 changes: 3 additions & 3 deletions torch/ao/quantization/fx/qconfig_mapping_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Callable, Any, Dict, Tuple, Set, List, Union
from torch.ao.quantization import QConfig
from torch.ao.quantization.qconfig import _add_module_to_qconfig_obs_ctr, QConfigAny, qconfig_equals
from torch.ao.quantization.quantize import (
is_activation_post_process,
from torch.ao.quantization.observer import (
_is_activation_post_process,
)
from torch.ao.quantization.backend_config import (
DTypeConfig,
Expand Down Expand Up @@ -158,7 +158,7 @@ def generate_node_name_to_qconfig(

elif node.op == 'call_module':
# if the node is an observer, just continue - don't add it to the qconfig_map
if is_activation_post_process(modules[node.target]):
if _is_activation_post_process(modules[node.target]):
continue
qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
qconfig_mapping, type(modules[node.target]), node.target, global_qconfig)
Expand Down
10 changes: 6 additions & 4 deletions torch/ao/quantization/fx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
qconfig_equals,
)
from torch.ao.quantization.stubs import DeQuantStub
from torch.ao.quantization.utils import activation_is_statically_quantized
from torch.ao.quantization.quantize import is_activation_post_process
from torch.ao.quantization.utils import (
activation_is_statically_quantized,
)
from torch.ao.quantization.observer import _is_activation_post_process

from torch.fx import GraphModule, map_arg

Expand Down Expand Up @@ -248,7 +250,7 @@ def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module
result = False
elif node.op == 'call_module':
assert isinstance(node.target, str)
if is_activation_post_process(modules[node.target]):
if _is_activation_post_process(modules[node.target]):
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
elif node.op == 'call_module':
result = False
Expand Down Expand Up @@ -825,7 +827,7 @@ def _activation_post_process_satisfies_dtype_config_constraints(
satisfies_constraints = True
if activation_post_process_ctr is not None:
activation_post_process = activation_post_process_ctr()
assert is_activation_post_process(activation_post_process)
assert _is_activation_post_process(activation_post_process)
# If dtypes don't match, don't check the activation_post_process and return True early
if activation_post_process.dtype != dtype_with_constraints.dtype:
return True
Expand Down
2 changes: 1 addition & 1 deletion torch/ao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,7 +1449,7 @@ def _is_observer_script_module(mod, obs_type_name):
def _is_activation_post_process(module):
return (
isinstance(module, torch.ao.quantization.ObserverBase)
or isinstance(module, torch.ao.quantization.FakeQuantize)
or isinstance(module, torch.ao.quantization.FakeQuantizeBase)
or _is_observer_script_module(module, "quantization.observer")
)

Expand Down
14 changes: 7 additions & 7 deletions torch/ao/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@
float_qparams_weight_only_qconfig_4bit,
_activation_is_memoryless)
from torch.nn.utils.parametrize import type_before_parametrizations
from torch.ao.quantization.observer import _is_activation_post_process

# TODO remove this once BC is no longer required to avoid a SEV
from torch.ao.quantization.observer import ( # noqa: F401
_is_activation_post_process as is_activation_post_process
)

__all__ = [
"get_default_custom_config_dict",
"is_activation_post_process",
"propagate_qconfig_",
"register_activation_post_process_hook",
"add_observer_",
Expand Down Expand Up @@ -62,11 +67,6 @@ def get_default_custom_config_dict():
"""
return _DEFAULT_CUSTOM_CONFIG_DICT

def is_activation_post_process(module):
return (isinstance(module, torch.ao.quantization.ObserverBase) or
isinstance(module, torch.ao.quantization.FakeQuantizeBase))


def _propagate_qconfig_helper(module, qconfig_dict,
qconfig_parent=None, prefix='', prepare_custom_config_dict=None):
r"""This is a helper function for `propagate_qconfig_`
Expand Down Expand Up @@ -324,7 +324,7 @@ def _remove_activation_post_process(module):
# TODO: maybe we should change activation_post_process to _activation_post_process
# to prevent it from being used by user
if hasattr(module, 'activation_post_process') and \
is_activation_post_process(module.activation_post_process):
_is_activation_post_process(module.activation_post_process):
delattr(module, 'activation_post_process')

# remove activation_post_proceess pre and post hooks
Expand Down
2 changes: 1 addition & 1 deletion torch/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.ao.quantization.quantize import convert
from torch.ao.quantization.quantize import get_observer_dict
from torch.ao.quantization.quantize import get_unique_devices_
from torch.ao.quantization.quantize import is_activation_post_process
from torch.ao.quantization.quantize import _is_activation_post_process
from torch.ao.quantization.quantize import prepare
from torch.ao.quantization.quantize import prepare_qat
from torch.ao.quantization.quantize import propagate_qconfig_
Expand Down