Skip to content

Commit 184b78c

Browse files
jfix71pytorchmergebot
authored andcommitted
[acc_ops] Move slice_tensor to consider single dim at a time (#5906)
Summary: Pull Request resolved: pytorch/glow#5906 Pull Request resolved: #71883 Fixes slice_tensor retracing. Include fix for retrace coverage. Missed in D33760455 (66939e3). Test Plan: CI Reviewed By: wushirong Differential Revision: D33802222 fbshipit-source-id: 4e0e44ae4a4eb70b99d79f0cd582182031b87e25 (cherry picked from commit 98fd23c)
1 parent 082ff25 commit 184b78c

File tree

3 files changed

+50
-36
lines changed

3 files changed

+50
-36
lines changed

test/fx_acc/test_acc_tracer.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
9494

9595
ref_outputs = m(a)
9696
outputs = traced(a)
97-
traced_again = acc_tracer.trace(m, [a])
97+
traced_again = acc_tracer.trace(traced, [a])
9898
outputs_again = traced_again(a)
9999
if isinstance(ref_outputs, torch.Tensor):
100100
ref_outputs = [ref_outputs]
@@ -1881,6 +1881,27 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
18811881
for i, j in zip(ref_output, output):
18821882
self.assertTrue(torch.equal(i, j))
18831883

1884+
@parameterized.expand(
1885+
[
1886+
("neg_1", -1, 1, 3),
1887+
("neg_2", -2, 1, 3),
1888+
("neg_4", -4, 1, 1),
1889+
]
1890+
)
1891+
def test_negative_slicing(self, _, dim, start, length):
1892+
"""
1893+
Test that slicing with negative dims works.
1894+
"""
1895+
self._make_acc_op_function_test(
1896+
acc_ops.slice_tensor,
1897+
torch.narrow,
1898+
input_shape=(2, 3, 4, 5),
1899+
validate_same_kwargs=False,
1900+
dim=dim,
1901+
start=start,
1902+
length=length,
1903+
)
1904+
18841905
def test_list_input(self):
18851906
"""
18861907
Test that list inputs are traced correctly.

torch/fx/experimental/fx2trt/converters/acc_ops_converters.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,30 +1422,26 @@ def acc_ops_slice_tensor(
14221422
"of the TensorRT region!")
14231423

14241424
ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
1425-
dims = [get_positive_dim(dim, ranks) for dim in cast(Sequence[int], kwargs["dims"])]
1425+
dim = get_positive_dim(cast(int, kwargs["dim"]), ranks)
14261426

14271427
if network.has_implicit_batch_dimension:
1428-
if not len(dims):
1429-
raise RuntimeError("dim argument cannot be empty!")
1430-
if any([dim == 0 for dim in dims]):
1428+
if dim == 0:
14311429
raise RuntimeError(
1432-
f"We do not support slice_tensor at batch dim when it's implicit, got {dims}!"
1430+
f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
14331431
)
1434-
dims = [d - 1 for d in dims]
1432+
dim = dim - 1
14351433
else:
14361434
raise RuntimeError("We don't support slice_tensor with explicit batch dimension yet!")
14371435

1436+
start_int = cast(int, kwargs["start"])
1437+
stop_int = cast(int, kwargs["stop"])
1438+
step_int = cast(int, kwargs["step"])
14381439
start = [0] * len(input_val.shape)
1440+
start[dim] = start_int
14391441
stride = [1] * len(start)
1442+
stride[dim] = step_int
14401443
output_shape = list(input_val.shape)
1441-
starts = cast(Sequence[int], kwargs["starts"])
1442-
stops = cast(Sequence[int], kwargs["stops"])
1443-
steps = cast(Sequence[int], kwargs["steps"])
1444-
1445-
for i, dim in enumerate(dims):
1446-
start[dim] = starts[i]
1447-
stride[dim] = steps[i]
1448-
output_shape[dim] = (stops[i] - starts[i]) // steps[i]
1444+
output_shape[dim] = (stop_int - start_int) // step_int
14491445

14501446
layer = network.add_slice(input_val, start=start, shape=output_shape, stride=stride)
14511447
set_layer_name(layer, target, name)

torch/fx/experimental/fx_acc/acc_ops.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44

55
import torch # isort:skip
6-
from typing import Sequence, Optional, List, cast
6+
from typing import Sequence, List, cast
77

88
import torch.fx.experimental.fx_acc.acc_utils as acc_utils
99
import torch.nn as nn
@@ -1310,10 +1310,10 @@ def torch_split_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
13101310
assert isinstance(i, int)
13111311
new_kwargs = {
13121312
"input": node.kwargs["input"],
1313-
"dims": (node.kwargs["dim"],),
1314-
"starts": (start,),
1315-
"stops": (start + i,),
1316-
"steps": (1,),
1313+
"dim": node.kwargs["dim"],
1314+
"start": start,
1315+
"stop": start + i,
1316+
"step": 1,
13171317
}
13181318
new_node = node.graph.call_function(slice_tensor, kwargs=new_kwargs)
13191319
new_node.meta["type"] = torch.Tensor
@@ -1504,19 +1504,16 @@ def getitem(*, input, idx):
15041504

15051505
@register_acc_op_properties(AccOpProperty.unary)
15061506
@register_acc_op
1507-
def slice_tensor(*, input, dims, starts, stops, steps):
1508-
slices: List[Optional[slice]] = [None for _ in range(input.dim())]
1509-
1510-
# For all provided dims, extract out a slice for starts/stops/steps.
1511-
for idx, dim in enumerate(dims):
1512-
slices[dim] = slice(starts[idx], stops[idx], steps[idx])
1513-
1514-
# For all unspecified dims, default to the full slice.
1515-
for idx, s in enumerate(slices):
1516-
if s is None:
1517-
slices[idx] = slice(None, None, None)
1507+
def slice_tensor(*, input, dim, start, stop, step):
1508+
slc = slice(start, stop, step)
1509+
if dim >= 0:
1510+
slices: List[slice] = [slice(None, None, None) for _ in range(dim)]
1511+
slices.append(slc)
1512+
else:
1513+
slices = [Ellipsis, slc] # type: ignore[list-item]
1514+
slices.extend([slice(None, None, None) for _ in range(-dim - 1)])
15181515

1519-
return input[slices]
1516+
return input[tuple(slices)]
15201517

15211518

15221519
@register_custom_acc_mapper_fn(
@@ -1543,10 +1540,10 @@ def custom_narrow_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node:
15431540
)
15441541
kwargs = {
15451542
"input": node.kwargs["input"],
1546-
"dims": (node.kwargs["dim"],),
1547-
"starts": (node.kwargs["start"],),
1548-
"stops": (node.kwargs["start"] + node.kwargs["length"],),
1549-
"steps": (1,),
1543+
"dim": node.kwargs["dim"],
1544+
"start": node.kwargs["start"],
1545+
"stop": node.kwargs["start"] + node.kwargs["length"],
1546+
"step": 1,
15501547
}
15511548
with node.graph.inserting_before(node):
15521549
new_node = node.graph.call_function(slice_tensor, kwargs=kwargs)

0 commit comments

Comments
 (0)