Skip to content

Commit c263bd4

Browse files
atalmanNatalia Gimelshein
andauthored
[inductor] use triu ref instead of lowering (#96040) (#96462)
Fixes #95958 Generated code is functionally identical with ref and lowering, only minor differences Pull Request resolved: #96040 Approved by: https://github.com/jansel Co-authored-by: Natalia Gimelshein <ngimel@fb.com>
1 parent c9913cf commit c263bd4

File tree

3 files changed

+2
-24
lines changed

3 files changed

+2
-24
lines changed

test/inductor/test_torchinductor_opinfo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ def wrapper_set_seed(op, *args, **kwargs):
448448
"mT",
449449
"mH",
450450
"rsub",
451+
"triu",
451452
}
452453

453454

torch/_decomp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]:
308308
aten.trace,
309309
aten.transpose.int,
310310
aten.tril.default,
311+
aten.triu.default,
311312
aten.unfold,
312313
aten.unfold_backward,
313314
aten.upsample_bilinear2d,

torch/_inductor/lowering.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,30 +1505,6 @@ def fn(index):
15051505
)
15061506

15071507

1508-
@register_lowering(aten.triu)
1509-
def triu(x, diagonal=0):
1510-
x_loader = x.make_loader()
1511-
dtype = x.get_dtype()
1512-
1513-
def inner_fn(index):
1514-
*_, i, j = index
1515-
return ops.where(
1516-
ops.ge(
1517-
ops.index_expr(j - i - diagonal, torch.int32),
1518-
ops.constant(0, torch.int32),
1519-
),
1520-
x_loader(index),
1521-
ops.constant(0, dtype),
1522-
)
1523-
1524-
return Pointwise.create(
1525-
device=x.get_device(),
1526-
dtype=dtype,
1527-
inner_fn=inner_fn,
1528-
ranges=list(x.get_size()),
1529-
)
1530-
1531-
15321508
@register_lowering(aten.select_scatter, type_promotion_kind=None)
15331509
def select_scatter(x, src, dim: int, index: int):
15341510
assert x.get_dtype() == src.get_dtype()

0 commit comments

Comments
 (0)