Skip to content

Commit 66939e3

Browse files
jfix71pytorchmergebot
authored andcommitted
[acc_tracer] Add test coverage for retracing (#71752)
Summary: Pull Request resolved: #71752 Added coverage for reshape specifically which required a fix. The problem for `acc_ops.reshape` as best as I understand: - `torch.reshape` requires the `shape` arg to be a `tuple` of `ints` - If `torch.reshape` is passed a `tuple` where the first element is not an `int`, it throws a TypeError e.g. `TypeError: reshape(): argument 'shape' (position 2) must be tuple of ints, not tuple` - If the `shape` we're reshaping to is an FX Proxy then this type error will be thrown. This happens when the first element of the `shape` tuple is a Proxy because it's input-dependent. - As a workaround we use `tensor.reshape` instead of `torch.reshape`, which doesn't do equivalent type checking for a `tuple` of `ints`. Also remove unnecessary `acc_utils.get_field_from_acc_out_ty()` with cast to `TensorMetadata`. Test Plan: Added test coverage Reviewed By: yinghai Differential Revision: D33760455 fbshipit-source-id: bff5563bf9e3d9e9318901b56211151d2c0e4eb2 (cherry picked from commit d5c1b97)
1 parent b36b11c commit 66939e3

File tree

4 files changed

+83
-48
lines changed

4 files changed

+83
-48
lines changed

test/fx_acc/test_acc_tracer.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ def _make_model_unit_test(
3737
torch.testing.assert_allclose(model(input), traced(input))
3838
else:
3939
self.assertTrue(torch.equal(model(input), traced(input)))
40+
traced_again = acc_tracer.trace(traced, [input])
41+
if enable_allclose:
42+
torch.testing.assert_allclose(model(input), traced_again(input))
43+
else:
44+
self.assertTrue(torch.equal(model(input), traced_again(input)))
4045

4146
def _make_acc_op_function_test(
4247
self,
@@ -89,30 +94,47 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
8994

9095
ref_outputs = m(a)
9196
outputs = traced(a)
97+
traced_again = acc_tracer.trace(m, [a])
98+
outputs_again = traced_again(a)
9299
if isinstance(ref_outputs, torch.Tensor):
93100
ref_outputs = [ref_outputs]
94101
outputs = [outputs]
102+
outputs_again = [outputs_again]
95103

96-
for ref_output, output in zip(ref_outputs, outputs):
104+
for ref_output, output, output_again in zip(
105+
ref_outputs, outputs, outputs_again
106+
):
97107
if enable_allclose:
98108
torch.testing.assert_allclose(
99109
torch.nan_to_num(ref_output), torch.nan_to_num(output)
100110
)
111+
torch.testing.assert_allclose(
112+
torch.nan_to_num(ref_output), torch.nan_to_num(output_again)
113+
)
101114
else:
102115
self.assertTrue(
103116
torch.equal(torch.nan_to_num(ref_output), torch.nan_to_num(output))
104117
)
118+
self.assertTrue(
119+
torch.equal(
120+
torch.nan_to_num(ref_output), torch.nan_to_num(output_again)
121+
)
122+
)
105123

106124
def test_sum(self):
107125
self._make_acc_op_function_test(acc_ops.sum, torch.sum)
108126
self._make_acc_op_function_test(acc_ops.sum, torch.sum, dim=(1,), keepdim=True)
109127

110128
def test_mean(self):
111129
self._make_acc_op_function_test(acc_ops.mean, torch.mean)
112-
self._make_acc_op_function_test(acc_ops.mean, torch.mean, dim=(1,), keepdim=True)
130+
self._make_acc_op_function_test(
131+
acc_ops.mean, torch.mean, dim=(1,), keepdim=True
132+
)
113133

114134
def test_pad(self):
115-
self._make_acc_op_function_test(acc_ops.pad, torch.nn.functional.pad, pad=(2, 0))
135+
self._make_acc_op_function_test(
136+
acc_ops.pad, torch.nn.functional.pad, pad=(2, 0)
137+
)
116138

117139
def test_max(self):
118140
def torch_max(x, *args, **kwargs):
@@ -504,7 +526,9 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
504526
self.assertEqual(node.kwargs["running_mean"], bn_mean)
505527
self.assertEqual(node.kwargs["running_var"], bn_var)
506528
self.assertEqual(node.kwargs["acc_out_ty"][6]["scale"], bn_scale)
507-
self.assertEqual(node.kwargs["acc_out_ty"][6]["zero_point"], bn_zero_point)
529+
self.assertEqual(
530+
node.kwargs["acc_out_ty"][6]["zero_point"], bn_zero_point
531+
)
508532
bn = node
509533
elif node.op == "output":
510534
self.assertEqual(bn, node.args[0])
@@ -1230,7 +1254,9 @@ def test_dropout(self):
12301254
def test_stochastic_depth(self):
12311255
self._make_acc_op_function_test(
12321256
None,
1233-
lambda x, p, mode, training: torchvision.ops.stochastic_depth(x, p=p, mode=mode, training=training),
1257+
lambda x, p, mode, training: torchvision.ops.stochastic_depth(
1258+
x, p=p, mode=mode, training=training
1259+
),
12341260
input_shape=(1, 2, 3),
12351261
p=0.5,
12361262
mode="row",
@@ -1387,7 +1413,9 @@ def test_relu(self):
13871413
self._make_acc_op_function_test(acc_ops.relu, torch.relu)
13881414

13891415
def test_leaky_relu(self):
1390-
self._make_acc_op_function_test(acc_ops.leaky_relu, torch.nn.functional.leaky_relu)
1416+
self._make_acc_op_function_test(
1417+
acc_ops.leaky_relu, torch.nn.functional.leaky_relu
1418+
)
13911419

13921420
def test_elu(self):
13931421
self._make_acc_op_function_test(acc_ops.elu, torch.nn.functional.elu)
@@ -1472,11 +1500,17 @@ def test_div(self):
14721500
self._make_acc_op_function_test(acc_ops.div, lambda x: x / 2)
14731501

14741502
def test_floor_div(self):
1475-
self._make_acc_op_function_test(acc_ops.floor_div, lambda x: torch.div(x, 2, rounding_mode="floor"))
1503+
self._make_acc_op_function_test(
1504+
acc_ops.floor_div, lambda x: torch.div(x, 2, rounding_mode="floor")
1505+
)
14761506

14771507
def test_trunc_div(self):
1478-
self._make_acc_op_function_test(acc_ops.trunc_div, lambda x: torch.div(x, 2, rounding_mode="trunc"))
1479-
self._make_acc_op_function_test(acc_ops.trunc_div, lambda x: torch.floor_divide(x, 2))
1508+
self._make_acc_op_function_test(
1509+
acc_ops.trunc_div, lambda x: torch.div(x, 2, rounding_mode="trunc")
1510+
)
1511+
self._make_acc_op_function_test(
1512+
acc_ops.trunc_div, lambda x: torch.floor_divide(x, 2)
1513+
)
14801514

14811515
def test_view(self):
14821516
"""
@@ -1907,6 +1941,22 @@ def test_cumsum(self):
19071941
def test_chunk(self):
19081942
self._make_acc_op_function_test(acc_ops.chunk, torch.chunk, chunks=2, dim=0)
19091943

1944+
def test_retrace_reshape(self):
1945+
"""
1946+
Retrace reshape to verify it's retraceable.
1947+
"""
1948+
1949+
class TestModule(torch.nn.Module):
1950+
def forward(self, a: torch.Tensor) -> torch.Tensor:
1951+
return a.reshape(a.size()[0], 1, 2)
1952+
1953+
m = TestModule()
1954+
a = torch.randn(2, 2)
1955+
gm = acc_tracer.trace(m, [a])
1956+
self.assertTrue(torch.equal(m(a), gm(a)))
1957+
gm_retrace = acc_tracer.trace(gm, [a])
1958+
self.assertTrue(torch.equal(m(a), gm_retrace(a)))
1959+
19101960
def test_all_acc_ops_registered(self):
19111961
self.assertEqual(
19121962
acc_normalizer._acc_ops,

torch/fx/experimental/fx2trt/converters/acc_ops_converters.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import tensorrt as trt
88
import torch
99
import torch.fx.experimental.fx_acc.acc_ops as acc_ops
10-
import torch.fx.experimental.fx_acc.acc_utils as acc_utils
1110
from torch.fx.experimental.fx2trt.converter_registry import tensorrt_converter
1211
from torch.fx.experimental.fx2trt.types import * # noqa: F403
1312
from torch.fx.experimental.fx2trt.utils import (
@@ -16,6 +15,7 @@
1615
)
1716
from torch.fx.immutable_collections import immutable_list
1817
from torch.fx.node import Target, Argument
18+
from torch.fx.passes.shape_prop import TensorMetadata
1919

2020
from .converter_utils import * # noqa: F403
2121

@@ -1376,7 +1376,7 @@ def acc_ops_reshape(
13761376
"of the TensorRT region!"
13771377
)
13781378

1379-
shape = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "shape") # type: ignore[arg-type]
1379+
shape = TensorMetadata(*kwargs["acc_out_ty"]).shape # type: ignore[misc]
13801380
if network.has_implicit_batch_dimension:
13811381
shape = shape[1:]
13821382

@@ -1887,10 +1887,10 @@ def acc_ops_quantize_per_tensor(
18871887
raise RuntimeError(f"{name} received input {input_val} that is not part "
18881888
"of the TensorRT region!")
18891889

1890-
qparams = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "qparams") # type: ignore[arg-type]
1890+
qparams = TensorMetadata(*kwargs["acc_out_ty"]).qparams # type: ignore[misc]
18911891
q_scale = qparams["scale"]
18921892
q_zero_point = qparams["zero_point"]
1893-
dtype = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "dtype") # type: ignore[arg-type]
1893+
dtype = TensorMetadata(*kwargs["acc_out_ty"]).dtype # type: ignore[misc]
18941894
if dtype not in (torch.quint8, torch.qint8, torch.qint32):
18951895
raise RuntimeError("Only support (torch.quint8, torch.qint8, torch.qint32) "
18961896
f"quantized type in quantize_per_tensor, get {dtype}.")
@@ -1923,11 +1923,11 @@ def acc_ops_quantize_per_channel(
19231923
raise RuntimeError(f"{name} received input {input_val} that is not part "
19241924
"of the TensorRT region!")
19251925

1926-
qparams = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "qparams") # type: ignore[arg-type]
1926+
qparams = TensorMetadata(*kwargs["acc_out_ty"]).qparams # type: ignore[misc]
19271927
q_per_channel_scales = qparams["scale"]
19281928
q_per_channel_zero_points = qparams["zero_point"]
19291929
q_per_channel_axis = qparams["axis"]
1930-
dtype = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "dtype") # type: ignore[arg-type]
1930+
dtype = TensorMetadata(*kwargs["acc_out_ty"]).dtype # type: ignore[misc]
19311931
if dtype not in (torch.quint8, torch.qint8, torch.qint32):
19321932
raise RuntimeError("Only support (torch.quint8, torch.qint8, torch.qint32) "
19331933
f"quantized type in quantize_per_tensor, get {dtype}.")
@@ -1970,7 +1970,7 @@ def acc_ops_dequantize(
19701970
raise RuntimeError(f"{name} received input {input_val} that is not part "
19711971
"of the TensorRT region!")
19721972

1973-
qparams = acc_utils.get_field_from_acc_out_ty(input_val_tensor_meta, "qparams") # type: ignore[arg-type]
1973+
qparams = TensorMetadata(*input_val_tensor_meta).qparams # type: ignore[misc]
19741974
qscheme = qparams["qscheme"]
19751975
if qscheme == torch.per_tensor_affine:
19761976
q_scale = qparams["scale"]
@@ -1990,7 +1990,7 @@ def acc_ops_dequantize(
19901990
else:
19911991
raise RuntimeError("Unsupported qscheme in dequantize: {qscheme}")
19921992

1993-
dtype = acc_utils.get_field_from_acc_out_ty(input_val_tensor_meta, "dtype") # type: ignore[arg-type]
1993+
dtype = TensorMetadata(*input_val_tensor_meta).dtype # type: ignore[misc]
19941994

19951995
if dtype not in (torch.quint8, torch.qint8, torch.qint32):
19961996
raise RuntimeError("Only support (torch.quint8, torch.qint8, torch.qint32) "

torch/fx/experimental/fx_acc/acc_ops.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
AccOpProperty,
1717
register_acc_op_properties,
1818
)
19-
from torch.fx.passes.shape_prop import _extract_tensor_metadata
19+
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
2020

2121
this_arg_is_optional = True
2222
move_to_qparams = True
@@ -33,7 +33,7 @@ def linear(*, input, weight, bias):
3333
@register_acc_op
3434
def quantized_linear(*, input, weight, bias, acc_out_ty=None):
3535
assert acc_out_ty is not None
36-
qparams = acc_utils.get_field_from_acc_out_ty(acc_out_ty, "qparams")
36+
qparams = TensorMetadata(*acc_out_ty).qparams
3737
return nn.quantized.functional.linear(
3838
input,
3939
weight,
@@ -462,11 +462,13 @@ def dropout_mapper(node: torch.fx.Node, mod: nn.Module):
462462
"""
463463
return node.kwargs["input"]
464464

465+
465466
try:
466467
from torchvision.ops import stochastic_depth
467468
except Exception as e:
468469
warnings.warn(f"Unable to import torchvision related libraries.: {e}")
469470
else:
471+
470472
@register_custom_acc_mapper_fn(
471473
op_and_target=("call_function", stochastic_depth),
472474
arg_replacement_tuples=[("input", "input")],
@@ -477,6 +479,7 @@ def stochastic_depth_mapper(node: torch.fx.Node, mod: nn.Module):
477479
"""
478480
return node.kwargs["input"]
479481

482+
480483
@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary)
481484
@register_acc_op_mapping(
482485
op_and_target=("call_function", nn.functional.hardtanh),
@@ -502,9 +505,7 @@ def hardsigmoid(*, input):
502505
def silu(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
503506
input_node = node.kwargs["input"]
504507
with node.graph.inserting_before(node):
505-
sigmoid_node = node.graph.call_function(
506-
sigmoid, kwargs={"input": input_node}
507-
)
508+
sigmoid_node = node.graph.call_function(sigmoid, kwargs={"input": input_node})
508509
sigmoid_node.meta = node.meta.copy()
509510
new_node = node.graph.call_function(
510511
mul, kwargs={"input": sigmoid_node, "other": input_node}
@@ -550,7 +551,7 @@ def hardswish_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
550551
@register_acc_op
551552
def quantized_add(*, input, other, acc_out_ty=None):
552553
assert acc_out_ty is not None
553-
qparams = acc_utils.get_field_from_acc_out_ty(acc_out_ty, "qparams")
554+
qparams = TensorMetadata(*acc_out_ty).qparams
554555
return torch.ops.quantized.add(
555556
input,
556557
other,
@@ -576,7 +577,7 @@ def quantized_add(*, input, other, acc_out_ty=None):
576577
@register_acc_op
577578
def quantized_mul(*, input, other, acc_out_ty=None):
578579
assert acc_out_ty is not None
579-
qparams = acc_utils.get_field_from_acc_out_ty(acc_out_ty, "qparams")
580+
qparams = TensorMetadata(*acc_out_ty).qparams
580581
return torch.ops.quantized.mul(
581582
input,
582583
other,
@@ -604,8 +605,8 @@ def quantized_mul(*, input, other, acc_out_ty=None):
604605
@register_acc_op
605606
def quantize_per_tensor(*, input, acc_out_ty=None):
606607
assert acc_out_ty is not None
607-
qparams = acc_utils.get_field_from_acc_out_ty(acc_out_ty, "qparams")
608-
dtype = acc_utils.get_field_from_acc_out_ty(acc_out_ty, "dtype")
608+
qparams = TensorMetadata(*acc_out_ty).qparams
609+
dtype = TensorMetadata(*acc_out_ty).dtype
609610
return torch.quantize_per_tensor(
610611
input, qparams["scale"], qparams["zero_point"], dtype
611612
)
@@ -631,8 +632,8 @@ def quantize_per_tensor(*, input, acc_out_ty=None):
631632
@register_acc_op
632633
def quantize_per_channel(*, input, acc_out_ty=None):
633634
assert acc_out_ty is not None
634-
qparams = acc_utils.get_field_from_acc_out_ty(acc_out_ty, "qparams")
635-
dtype = acc_utils.get_field_from_acc_out_ty(acc_out_ty, "dtype")
635+
qparams = TensorMetadata(*acc_out_ty).qparams
636+
dtype = TensorMetadata(*acc_out_ty).dtype
636637
return torch.quantize_per_channel(
637638
input,
638639
torch.tensor(qparams["scale"]),
@@ -1152,7 +1153,7 @@ def quantized_conv2d(
11521153
padding_mode,
11531154
acc_out_ty,
11541155
):
1155-
qparams = acc_utils.get_field_from_acc_out_ty(acc_out_ty, "qparams")
1156+
qparams = TensorMetadata(*acc_out_ty).qparams
11561157
return torch.nn.quantized.functional.conv2d(
11571158
input=input,
11581159
weight=weight,
@@ -1359,7 +1360,7 @@ def tuple_construct(*, tensors):
13591360
def quantized_batch_norm2d(
13601361
*, input, running_mean, running_var, weight, bias, eps, acc_out_ty
13611362
):
1362-
qparams = acc_utils.get_field_from_acc_out_ty(acc_out_ty, "qparams")
1363+
qparams = TensorMetadata(*acc_out_ty).qparams
13631364
return torch.ops.quantized.batch_norm2d(
13641365
input,
13651366
weight,
@@ -1573,9 +1574,7 @@ def custom_narrow_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
15731574
@register_acc_op
15741575
def reshape(*, input, acc_out_ty=None):
15751576
assert acc_out_ty is not None
1576-
return torch.reshape(
1577-
input, tuple(acc_utils.get_field_from_acc_out_ty(acc_out_ty, "shape"))
1578-
)
1577+
return input.reshape(TensorMetadata(*acc_out_ty).shape)
15791578

15801579

15811580
@register_custom_acc_mapper_fn(
@@ -1615,7 +1614,7 @@ def custom_tensor_reshape_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.
16151614
@register_acc_op
16161615
def to_dtype(input, acc_out_ty=None):
16171616
assert acc_out_ty is not None
1618-
return input.to(dtype=acc_utils.get_field_from_acc_out_ty(acc_out_ty, "dtype"))
1617+
return input.to(dtype=TensorMetadata(*acc_out_ty).dtype)
16191618

16201619

16211620
@register_custom_acc_mapper_fn(

torch/fx/experimental/fx_acc/acc_utils.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,20 +78,6 @@ def is_acc_op_with_kwarg(
7878
return kwarg in inspect.signature(inspect.unwrap(target)).parameters
7979

8080

81-
def get_field_from_acc_out_ty(
82-
acc_out_ty_or_dict: Union[Tuple, Dict[str, Any]], field: str
83-
):
84-
"""
85-
After tracing NamedTuple inputs are converted to standard tuples, so we cannot
86-
access them by name directly. Use this helper instead.
87-
"""
88-
if isinstance(acc_out_ty_or_dict, dict):
89-
acc_out_ty = acc_out_ty_or_dict["acc_out_ty"]
90-
else:
91-
acc_out_ty = acc_out_ty_or_dict
92-
return acc_out_ty[TensorMetadata._fields.index(field)]
93-
94-
9581
def serialize_module_json_to_file(fx_module: GraphModule, fname: str):
9682
weights: Dict = {}
9783
serialized_json = json.dumps(serialize_module(fx_module, weights), indent=2)

0 commit comments

Comments
 (0)