Skip to content

Commit 8f143f0

Browse files
committed
Add meta impl for topk
ghstack-source-id: 4044cf2 Pull Request resolved: #88694
1 parent cf0b5c7 commit 8f143f0

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

test/test_proxy_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1352,7 +1352,6 @@ def f(a, b, c, d, e):
13521352
xfail('take_along_dim', ''), # dtype of indices should be Long but got Float
13531353
xfail('take', ''), # aten.take.default - couldn't find symbolic meta function/decomposition
13541354
xfail('tensordot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
1355-
xfail('topk', ''), # aten.topk.default - couldn't find symbolic meta function/decomposition
13561355
xfail('trapz', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
13571356
xfail('trapezoid', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
13581357
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition

torch/_meta_registrations.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,6 +1685,21 @@ def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
16851685
(), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
16861686
)
16871687

1688+
@register_meta(aten.topk.default)
1689+
def topk_meta(self, k, dim=-1, largest=True, sorted=True):
1690+
# From aten/src/ATen/native/Sorting.cpp
1691+
dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
1692+
check(
1693+
k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1),
1694+
lambda: "selected index k out of range",
1695+
)
1696+
sliceSize = 1 if self.dim() == 0 else self.size(dim)
1697+
check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
1698+
1699+
topKSize = list(self.shape)
1700+
if len(topKSize) > 0:
1701+
topKSize[dim] = k
1702+
return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
16881703

16891704
# We must also trigger meta registrations from PrimTorch ref
16901705
# decompositions

0 commit comments

Comments
 (0)