Skip to content

Commit ad637a4

Browse files
ani300pytorchmergebot
authored andcommitted
Add support for index_put_ in NT (#135722)
Pull Request resolved: #135722 Approved by: https://github.com/jbschlosser
1 parent f14f245 commit ad637a4

File tree

3 files changed

+221
-1
lines changed

3 files changed

+221
-1
lines changed

test/test_nestedtensor.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6194,6 +6194,38 @@ def test_copy_(self, device):
61946194
):
61956195
a.copy_(b)
61966196

6197+
# This can't happen in the opinfo tests due to subprocess creation
6198+
@unittest.skipIf(
6199+
TEST_WITH_ROCM,
6200+
"In ROCm, kernel asserts are disabled due to performance overhead",
6201+
)
6202+
def test_index_put_error(self, device):
6203+
import subprocess
6204+
6205+
with self.subTest():
6206+
r = subprocess.call(
6207+
[
6208+
sys.executable,
6209+
"-c",
6210+
"""\
6211+
import torch
6212+
offsets = torch.tensor([0, 2, 5, 7], device='cuda')
6213+
lengths = torch.tensor([2, 2, 2], device='cuda')
6214+
indices = [
6215+
torch.tensor([0, 1, 2], device='cuda'),
6216+
torch.tensor([0, 2, 1], device='cuda'),
6217+
torch.tensor([0, 0, 0], device='cuda'),
6218+
]
6219+
a = torch.nested.nested_tensor_from_jagged(
6220+
torch.zeros(7, 3, device='cuda'), offsets, lengths
6221+
)
6222+
a[indices] = 1.0
6223+
torch.cuda.synchronize()
6224+
""",
6225+
]
6226+
)
6227+
self.assertTrue(r != 0)
6228+
61976229
@skipIfTorchDynamo("Dynamo doesn't know how to trace prof.events()")
61986230
def test_profiler_sequence_nr(self):
61996231
with torch.profiler.profile() as prof:
@@ -7915,6 +7947,12 @@ def test_forward(self, device, dtype, op):
79157947
out_ref = op.ref(op, sample)
79167948
self.assertEqualIgnoringNestedInts(out, out_ref)
79177949

7950+
# TODO: Revisit once https://github.com/pytorch/pytorch/pull/138369 lands
7951+
# TODO: Add xfails for other inplace ops instead of hardcoding
7952+
if op.inplace_variant and "index_put" in op.full_name:
7953+
op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
7954+
self.assertEqualIgnoringNestedInts(sample.input, out_ref)
7955+
79187956
@withXFails(BACKWARD_FAILURES)
79197957
@ops(
79207958
[op for op in njt_op_db if op.supports_njt and op.supports_autograd],
@@ -7970,6 +8008,32 @@ def f(*args, **kwargs):
79708008
else:
79718009
self.assertEqual(out_compile, out_ref)
79728010

8011+
# TODO: Revisit once https://github.com/pytorch/pytorch/pull/138369 lands
8012+
# TODO: Add xfails for other inplace ops instead of hardcoding
8013+
if op.inplace_variant and "index_put" in op.full_name:
8014+
op_fn = op.inplace_variant
8015+
8016+
def in_f(*args, **kwargs):
8017+
return op_fn(*args, **kwargs)
8018+
8019+
compiled_in_f = torch.compile(
8020+
in_f, fullgraph=True, backend="aot_eager_decomp_partition"
8021+
)
8022+
8023+
if sample.input.is_contiguous():
8024+
compiled_in_f(sample.input, *sample.args, **sample.kwargs)
8025+
if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY:
8026+
self.assertEqualIgnoringNestedInts(sample.input, out_ref)
8027+
else:
8028+
self.assertEqual(sample.input, out_ref)
8029+
else:
8030+
# see https://github.com/pytorch/pytorch/issues/106456
8031+
with self.assertRaisesRegex(
8032+
RuntimeError,
8033+
"Mutations on non-contiguous inputs are currently not allowed on tensor subclasses",
8034+
):
8035+
compiled_in_f(sample.input, *sample.args, **sample.kwargs)
8036+
79738037
@withXFails(COMPILE_BACKWARD_FAILURES)
79748038
@ops(
79758039
[op for op in njt_op_db if op.supports_njt and op.supports_autograd],

torch/nested/_internal/ops.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,6 +1558,99 @@ def slice_tensor(func, *args, **kwargs):
15581558
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
15591559

15601560

1561+
@register_jagged_func(
1562+
torch.ops.aten.index_put.default,
1563+
"input: jt_all, indices: any, values: t, accumulate: any?",
1564+
)
1565+
@register_jagged_func(
1566+
torch.ops.aten.index_put_.default,
1567+
"input: jt_all, indices: any, values: t, accumulate: any?",
1568+
)
1569+
def index_put_(func, *args, **kwargs):
1570+
_, new_kwargs = normalize_function( # type: ignore[misc]
1571+
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
1572+
)
1573+
1574+
inp: NestedTensor = new_kwargs.pop("input")
1575+
1576+
# For index_put_ to work, we add together the indices of the ragged dimension
1577+
# and the batch dimension, adding the offsets of each ragged dimension to its
1578+
# indices
1579+
1580+
indices = new_kwargs.pop("indices")
1581+
1582+
assert len(indices) <= inp.dim()
1583+
1584+
if len(indices) < inp._ragged_idx + 1:
1585+
if not inp.is_contiguous():
1586+
raise RuntimeError(
1587+
"index_put(): If ragged dimension is not part of indices, this only works on contiguous NJTs"
1588+
)
1589+
# Ragged dim is NOT part of indices, we need to pad the nested tensor to apply func
1590+
from .nested_tensor import nested_from_padded
1591+
1592+
min_seqlen = inp._maybe_min_seqlen
1593+
max_seqlen = inp._maybe_max_seqlen
1594+
padded_max_S = max_seqlen
1595+
total_L = inp._values.shape[inp._ragged_idx - 1]
1596+
if padded_max_S is None:
1597+
# use upper bound on max seqlen if it's not present
1598+
padded_max_S = total_L
1599+
1600+
padded_shape = (
1601+
*inp.shape[: inp._ragged_idx],
1602+
padded_max_S,
1603+
*inp.shape[inp._ragged_idx + 1 :],
1604+
)
1605+
padded_inp = inp.to_padded_tensor(0.0, output_size=padded_shape)
1606+
new_njt = nested_from_padded(
1607+
func(padded_inp, indices, **new_kwargs),
1608+
offsets=inp._offsets,
1609+
ragged_idx=inp._ragged_idx,
1610+
sum_S=total_L,
1611+
min_seqlen=min_seqlen,
1612+
max_seqlen=max_seqlen,
1613+
)
1614+
1615+
if func == torch.ops.aten.index_put_.default:
1616+
inp._values.copy_(new_njt.values())
1617+
return inp
1618+
return new_njt
1619+
1620+
# We can run on the underlying values directly
1621+
1622+
# Validate indices
1623+
if inp.lengths() is None:
1624+
lengths = inp.offsets().diff()
1625+
else:
1626+
lengths = inp.lengths()
1627+
torch._assert_async(
1628+
torch.all(indices[inp._ragged_idx] < lengths),
1629+
"Some indices in the ragged dimension are out of bounds!",
1630+
)
1631+
1632+
# Recompute indices for _values
1633+
ragged_indices = inp.offsets()[indices[0]] + indices[inp._ragged_idx]
1634+
func_indices = (
1635+
# before ragged dim
1636+
indices[1 : inp._ragged_idx]
1637+
# ragged dim (combined with batch)
1638+
+ [ragged_indices]
1639+
# after ragged dim
1640+
+ indices[inp._ragged_idx + 1 :]
1641+
)
1642+
1643+
if func == torch.ops.aten.index_put_.default:
1644+
inp._values = func(inp._values, func_indices, **new_kwargs)
1645+
return inp
1646+
1647+
return NestedTensor(
1648+
func(inp._values, func_indices, **new_kwargs),
1649+
**extract_kwargs(inp),
1650+
lengths=inp.lengths(),
1651+
)
1652+
1653+
15611654
@register_jagged_func(
15621655
torch.ops.aten.convolution.default,
15631656
"input: jt, weight: t, bias: t?, stride: any, padding: any, "

torch/testing/_internal/opinfo/definitions/nested.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,29 @@ def _slice_input(t, i=i, inp=nt_inp):
106106
args = tree_map(_slice_input, sample.args)
107107
kwargs = tree_map(_slice_input, sample.kwargs)
108108

109+
# Handle indices in index_put
110+
if "index_put" in op.full_name and "indices" in kwargs:
111+
if len(kwargs["indices"]) > 1:
112+
# If after unrolling we still have indices left, use them
113+
kwargs["indices"] = [t[i] for t in kwargs["indices"][1:]]
114+
else:
115+
# If no indices are left, create them so they match the NJT implementation
116+
sequence_put = kwargs["indices"][0].tolist()
117+
if i in sequence_put:
118+
kwargs["indices"] = [
119+
torch.tensor(
120+
list(range(inp.shape[0])),
121+
dtype=torch.int32,
122+
device=kwargs["indices"][0].device,
123+
)
124+
]
125+
else:
126+
kwargs["indices"] = [
127+
torch.tensor(
128+
[], dtype=torch.int32, device=kwargs["indices"][0].device
129+
)
130+
]
131+
109132
from torch._prims_common import canonicalize_dims
110133

111134
# Need to adjust dim to apply on NJT component
@@ -115,7 +138,6 @@ def _slice_input(t, i=i, inp=nt_inp):
115138

116139
# TODO: handle this
117140
assert "dims" not in kwargs
118-
119141
out_ref_component = op.op(inp, *args, **kwargs)
120142

121143
# TODO: handle list / tuple / non-NJT outputs
@@ -449,6 +471,46 @@ def sample_inputs_nn_functional_embedding(
449471
)
450472

451473

474+
def sample_inputs_index_put(
475+
op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
476+
):
477+
for njt in _sample_njts(
478+
device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
479+
):
480+
for dim in range(njt.dim()):
481+
indices = [
482+
torch.tensor(list(range(njt.size(0))), device=njt.device),
483+
*[
484+
torch.tensor([0] * njt.size(0), device=njt.device)
485+
for _ in range(dim - 1)
486+
],
487+
]
488+
yield SampleInput(
489+
njt.clone().detach(),
490+
kwargs={
491+
"indices": indices,
492+
"values": torch.tensor(1.0, device=njt.device),
493+
},
494+
)
495+
496+
# Non-cont NJT for completeness
497+
offsets = torch.tensor([0, 2, 5, 7], device=device)
498+
lengths = torch.tensor([2, 2, 2], device=device)
499+
indices = [
500+
torch.tensor([0, 1, 2], device=device),
501+
torch.tensor([0, 1, 1], device=device),
502+
torch.tensor([0, 0, 0], device=device),
503+
]
504+
a = torch.nested.nested_tensor_from_jagged(
505+
torch.zeros(7, 3, device=device), offsets, lengths
506+
)
507+
508+
yield SampleInput(
509+
a.clone().detach(),
510+
kwargs={"indices": indices, "values": torch.tensor(1.0, device=a.device)},
511+
)
512+
513+
452514
def sample_inputs_nn_functional_embedding_bag(
453515
op_info, device, dtype, requires_grad, **kwargs
454516
):
@@ -591,6 +653,7 @@ def sample_inputs_nn_functional_rms_norm(
591653
"to": sample_inputs_to,
592654
"matmul": sample_inputs_matmul,
593655
"masked_select": sample_inputs_masked_select,
656+
"index_put": sample_inputs_index_put,
594657
}
595658

596659
njt_references = {

0 commit comments

Comments
 (0)