Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/quantization/fx/test_numeric_suite_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
)
from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
from torch.ao.quantization.backend_config import get_native_backend_config
from torch.ao.quantization.fx.backend_config_utils import get_pattern_to_quantize_handlers
from torch.ao.quantization.fx.quantization_patterns import _get_pattern_to_quantize_handlers


# Note: these models are not for use outside of this file. While it's good
Expand Down Expand Up @@ -299,7 +299,7 @@ def get_all_quant_patterns():
all_quant_patterns = get_default_quant_patterns()
# some of the patterns are moved to (native) backend_config_dict so we need to
# add them back here
for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config()).items():
for pattern, quantize_handler in _get_pattern_to_quantize_handlers(get_native_backend_config()).items():
all_quant_patterns[pattern] = quantize_handler
return all_quant_patterns

Expand Down
4 changes: 2 additions & 2 deletions torch/ao/ns/_numeric_suite_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@
)
from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
from torch.ao.quantization.backend_config import BackendConfig
from torch.ao.quantization.fx.backend_config_utils import get_pattern_to_quantize_handlers
from torch.ao.quantization.fx.match_utils import find_matches
from torch.ao.quantization.fx.qconfig_mapping_utils import generate_node_name_to_qconfig
from torch.ao.quantization.fx.quantization_patterns import _get_pattern_to_quantize_handlers
from torch.ao.quantization.qconfig import QConfigAny
from torch.ao.ns.fx.n_shadows_utils import (
OutputProp,
Expand Down Expand Up @@ -803,7 +803,7 @@ def prepare_n_shadows_model(
# Find the set of subgraphs in the original graph which we need to
# consider.
modules = dict(mt.named_modules(remove_duplicate=False))
patterns = get_pattern_to_quantize_handlers(backend_config)
patterns = _get_pattern_to_quantize_handlers(backend_config)
root_node_getter_mapping = \
get_fusion_pattern_to_root_node_getter(backend_config)
standalone_module_names: List[str] = []
Expand Down
5 changes: 3 additions & 2 deletions torch/ao/ns/fx/pattern_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from torch.fx import GraphModule
from torch.fx.graph import Node

from torch.ao.quantization.backend_config import get_native_backend_config
from torch.ao.quantization.fx.quantization_patterns import _get_pattern_to_quantize_handlers
from torch.ao.quantization.utils import getattr_from_fqn
from .ns_types import NSNodeTargetType
from torch.ao.quantization.fx.backend_config_utils import get_native_quant_patterns
from torch.ao.quantization import (
ObserverBase,
FakeQuantizeBase,
Expand Down Expand Up @@ -66,7 +67,7 @@ def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
# * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
# For fusions, we only care about patterns composed of multiple ops.
# TODO(future PR): allow customizations from default patterns.
all_quant_patterns = get_native_quant_patterns()
all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config())

default_base_op_idx = 0
for quant_pattern, _quant_handler in all_quant_patterns.items():
Expand Down
113 changes: 0 additions & 113 deletions torch/ao/quantization/fx/backend_config_utils.py

This file was deleted.

8 changes: 5 additions & 3 deletions torch/ao/quantization/fx/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
get_fusion_pattern_to_root_node_getter,
get_fusion_pattern_to_extra_inputs_getter,
)
from .backend_config_utils import get_fusion_pattern_to_fuse_handler_cls

from .custom_config import FuseCustomConfig

from .fusion_patterns import * # noqa: F401,F403
from .fusion_patterns import (
_get_fusion_pattern_to_fuse_handler_cls,
FuseHandler,
)

from typing import Any, Callable, Dict, List, Tuple, Union
import warnings
Expand Down Expand Up @@ -69,7 +71,7 @@ def fuse(
if backend_config is None:
backend_config = get_native_backend_config()

fusion_pattern_to_fuse_handler_cls = sorted_patterns_dict(get_fusion_pattern_to_fuse_handler_cls(backend_config))
fusion_pattern_to_fuse_handler_cls = sorted_patterns_dict(_get_fusion_pattern_to_fuse_handler_cls(backend_config))
fuser_method_mapping = get_fuser_method_mapping(backend_config)
fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(backend_config)
Expand Down
10 changes: 10 additions & 0 deletions torch/ao/quantization/fx/fusion_patterns.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch.ao.quantization.backend_config import BackendConfig
from torch.fx.graph import Node, Graph
from ..utils import _parent_name, NodePattern, Pattern
from ..fuser_method_mappings import get_fuser_method_new
Expand Down Expand Up @@ -108,3 +109,12 @@ def get_matched_types(m):
args.extend(extra_args)
node.args = tuple(args)
return node

def _get_fusion_pattern_to_fuse_handler_cls(
backend_config: BackendConfig) -> Dict[Pattern, Callable]:
fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {}
for pattern, config in backend_config.configs.items():
if config.fuser_method is not None:
# TODO: is this logic right?
fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
return fusion_pattern_to_fuse_handlers
7 changes: 2 additions & 5 deletions torch/ao/quantization/fx/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)

from .quantization_patterns import (
_get_pattern_to_quantize_handlers,
QuantizeHandler,
)

Expand Down Expand Up @@ -98,10 +99,6 @@
DTypeConfig,
get_native_backend_config,
)
from .backend_config_utils import (
get_pattern_to_quantize_handlers,
)

from .custom_config import (
PrepareCustomConfig,
StandaloneModuleConfigEntry,
Expand Down Expand Up @@ -1520,7 +1517,7 @@ def prepare(
pattern_to_quantize_handler: Dict[Pattern, QuantizeHandler] = {}
if backend_config is None:
backend_config = get_native_backend_config()
pattern_to_quantize_handler = get_pattern_to_quantize_handlers(backend_config)
pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config)
pattern_to_quantize_handler = sorted_patterns_dict(pattern_to_quantize_handler)

root_node_getter_mapping = \
Expand Down
71 changes: 69 additions & 2 deletions torch/ao/quantization/fx/quantization_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,19 @@
from .utils import (
all_node_args_have_no_tensors,
)
from torch.ao.quantization.utils import NodePattern
from torch.ao.quantization.backend_config import (
BackendConfig,
DTypeConfig,
ObservationType,
)
from torch.ao.quantization.utils import (
NodePattern,
Pattern,
QuantizerCls,
)

from abc import ABC
from typing import Callable, Dict
from typing import Callable, Dict, List, Type

__all__ = [
"QuantizeHandler",
Expand Down Expand Up @@ -101,6 +110,64 @@ def is_custom_module(self):
def is_standalone_module(self):
return self.is_standalone_module_

def _get_quantize_handler_cls(
observation_type: ObservationType,
dtype_configs: List[DTypeConfig],
num_tensor_args_to_observation_type: Dict[int, ObservationType],
input_output_observed: bool) -> Type[QuantizeHandler]:
"""
Return a configurable QuantizeHandler that matches the given specifications from the backend.
"""

class ConfigurableQuantizeHandler(QuantizeHandler):
def __init__(
self,
node_pattern: NodePattern,
modules: Dict[str, torch.nn.Module],
root_node_getter: Callable = None):
super().__init__(node_pattern, modules, root_node_getter)
if num_tensor_args_to_observation_type:
assert self.num_tensor_args in num_tensor_args_to_observation_type, \
f"Must provide observation_type config for tensor number {self.num_tensor_args}" \
f" in num_tensor_args_to_observation_type for {node_pattern}"
self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args]
else:
self.observation_type = observation_type
self.dtype_configs = dtype_configs
self.input_output_observed_ = input_output_observed

def is_general_tensor_value_op(self) -> bool:
return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT

# This is temporary, and will be removed soon
def input_output_observed(self):
return self.input_output_observed_

return ConfigurableQuantizeHandler

def _get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Pattern, QuantizerCls]:
"""
Note: Quantize handler is just a holder for some check methods like
(should_insert_observer_for_output), maybe this can be a enum as well,
we can refactor this after we convert the path for fbgemm/qnnpack fully to the
new path, this is not exposed to backend developers
"""
pattern_to_quantize_handlers = {}
for pattern, config in backend_config.configs.items():
observation_type = config.observation_type
dtype_configs = config.dtype_configs
num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type
input_output_observed = config._input_output_observed
if input_output_observed is None:
input_output_observed = True
pattern_to_quantize_handlers[pattern] = \
_get_quantize_handler_cls(
observation_type,
dtype_configs,
num_tensor_args_to_observation_type,
input_output_observed)
return pattern_to_quantize_handlers

# TODO: remove this class, this is still exposed in torch.quantization
# but we should be able to break bc
class BinaryOpQuantizeHandler(QuantizeHandler):
Expand Down