Skip to content

Commit 8ccff2e

Browse files
committed
[ao][fx] fixing public v private graph_module.py
Pull Request resolved: #88395 made _is_observed_module, _is_observed_standalone_module private ghstack-source-id: 175622130 Differential Revision: [D41015545](https://our.internmc.facebook.com/intern/diff/D41015545/)
1 parent 9efc9c6 commit 8ccff2e

File tree

5 files changed

+12
-14
lines changed

5 files changed

+12
-14
lines changed

test/quantization/ao_migration/test_quantization_fx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ def test_function_import_fx_graph_module(self):
4646
function_list = [
4747
'FusedGraphModule',
4848
'ObservedGraphModule',
49-
'is_observed_module',
49+
'_is_observed_module',
5050
'ObservedStandaloneGraphModule',
51-
'is_observed_standalone_module',
51+
'_is_observed_standalone_module',
5252
'QuantizedGraphModule'
5353
]
5454
self._test_function_import('fx.graph_module', function_list)

torch/ao/quantization/fx/convert.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@
4242
)
4343
from .graph_module import (
4444
QuantizedGraphModule,
45-
is_observed_module,
46-
is_observed_standalone_module,
45+
_is_observed_module,
46+
_is_observed_standalone_module,
4747
)
4848
from ._equalize import update_obs_for_equalization, convert_eq_obs
4949
from torch.nn.utils.parametrize import type_before_parametrizations
@@ -450,7 +450,7 @@ def _restore_state(
450450
) -> Tuple[Dict[str, Tuple[str, type]],
451451
PrepareCustomConfig,
452452
Set[str]]:
453-
assert is_observed_module(observed), \
453+
assert _is_observed_module(observed), \
454454
'incoming model must be produced by prepare_fx'
455455
prepare_custom_config: PrepareCustomConfig = observed._prepare_custom_config # type: ignore[assignment]
456456
node_name_to_scope: Dict[str, Tuple[str, type]] = observed._node_name_to_scope # type: ignore[assignment]
@@ -1017,7 +1017,7 @@ def convert(
10171017
node_name_to_qconfig)
10181018
elif isinstance(mod, DeQuantStub):
10191019
_replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
1020-
elif is_observed_standalone_module(mod):
1020+
elif _is_observed_standalone_module(mod):
10211021
convert_standalone_module(
10221022
node, modules, model, is_reference, backend_config)
10231023
# below this point `type_before_parametrizations` is used

torch/ao/quantization/fx/graph_module.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
__all__ = [
88
"FusedGraphModule",
99
"ObservedGraphModule",
10-
"is_observed_module",
1110
"ObservedStandaloneGraphModule",
12-
"is_observed_standalone_module",
1311
"QuantizedGraphModule",
1412
]
1513

@@ -56,7 +54,7 @@ def __deepcopy__(self, memo):
5654
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
5755
return ObservedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
5856

59-
def is_observed_module(module: Any) -> bool:
57+
def _is_observed_module(module: Any) -> bool:
6058
return isinstance(module, ObservedGraphModule)
6159

6260
class ObservedStandaloneGraphModule(ObservedGraphModule):
@@ -71,7 +69,7 @@ def __deepcopy__(self, memo):
7169
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
7270
return ObservedStandaloneGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
7371

74-
def is_observed_standalone_module(module: Any) -> bool:
72+
def _is_observed_standalone_module(module: Any) -> bool:
7573
return isinstance(module, ObservedStandaloneGraphModule)
7674

7775
def _save_packed_weight(self, destination, prefix, keep_vars):

torch/ao/quantization/fx/match_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
MatchAllNode
1616
)
1717
from .graph_module import (
18-
is_observed_standalone_module,
18+
_is_observed_standalone_module,
1919
)
2020
from torch.nn.utils.parametrize import type_before_parametrizations
2121
from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable
@@ -232,7 +232,7 @@ def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
232232
for node in graph.nodes:
233233
if node.op == 'call_module' and \
234234
(is_standalone_module(node.target, modules) or
235-
is_observed_standalone_module(modules[node.target])):
235+
_is_observed_standalone_module(modules[node.target])):
236236
# add node to matched nodes
237237
match_map[node.name] = (
238238
node, node, None,

torch/quantization/fx/graph_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
GraphModule,
1111
FusedGraphModule,
1212
ObservedGraphModule,
13-
is_observed_module,
13+
_is_observed_module,
1414
ObservedStandaloneGraphModule,
15-
is_observed_standalone_module,
15+
_is_observed_standalone_module,
1616
QuantizedGraphModule
1717
)

0 commit comments

Comments
 (0)