Skip to content

Commit cf0b5c7

Browse files
committed
Add meta support for scalar_tensor and argmax
ghstack-source-id: 784b69e Pull Request resolved: #88590
1 parent ee28b86 commit cf0b5c7

File tree

4 files changed

+82
-4
lines changed

4 files changed

+82
-4
lines changed

test/functorch/test_vmap.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3229,6 +3229,7 @@ def test():
32293229
xfail('linspace', ''), # test runner can't handle factory functions
32303230
xfail('arange', ''), # test runner can't handle factory functions
32313231
xfail('logspace', ''), # test runner can't handle factory functions
3232+
xfail('scalar_tensor') # test runner can't handle factory functions
32323233
xfail('empty', ''), # test runner can't handle factory functions
32333234
xfail('ones', ''), # test runner can't handle factory functions
32343235
xfail('zeros', ''), # test runner can't handle factory functions

test/test_proxy_tensor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,8 +1117,8 @@ def f(a, b, c, d, e):
11171117
skip('masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel
11181118
xfail('masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
11191119
xfail('masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
1120-
xfail('masked.argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition
1121-
xfail('masked.argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition
1120+
xfail('masked.argmax', ''), # broadcast_to(): argument 'size' (position 2) must be tuple of ints, but found ...
1121+
xfail('masked.argmin', ''), # broadcast_to(): argument 'size' (position 2) must be tuple of ints, but found ...
11221122
xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
11231123
xfail('masked.cumsum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
11241124
xfail('masked.log_softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
@@ -1135,8 +1135,6 @@ def f(a, b, c, d, e):
11351135
xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition
11361136
xfail('addr', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
11371137
xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition
1138-
xfail('argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition
1139-
xfail('argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition
11401138
xfail('argsort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition
11411139
xfail('argwhere', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition
11421140
xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition

torch/_meta_registrations.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,6 +1636,56 @@ def upsample_nearest2d_vec(input, output_size, scale_factors):
16361636
)
16371637

16381638

1639+
def zero_numel_check_dims(self, dim, fn_name):
1640+
if self.ndim == 0:
1641+
check(
1642+
dim == 0 or dim == -1,
1643+
lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
1644+
IndexError,
1645+
)
1646+
else:
1647+
check(
1648+
self.size(dim) != 0,
1649+
lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
1650+
IndexError,
1651+
)
1652+
1653+
1654+
# From aten/src/ATen/native/ReduceOps.cpp
1655+
def check_argmax_argmin(name, self, dim):
1656+
if dim is not None:
1657+
dim = maybe_wrap_dim(dim, self.dim())
1658+
zero_numel_check_dims(self, dim, name)
1659+
else:
1660+
check(
1661+
self.numel() != 0,
1662+
lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
1663+
)
1664+
1665+
1666+
@register_meta(aten.argmax.default)
1667+
def argmax_meta(self, dim=None, keepdim=False):
1668+
check_argmax_argmin("argmax", self, dim)
1669+
dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None)
1670+
shape = _compute_reduction_shape(self, dims, keepdim)
1671+
return self.new_empty(shape, dtype=torch.int64)
1672+
1673+
1674+
@register_meta(aten.argmin.default)
1675+
def argmin_min(self, dim=None, keepdim=False):
1676+
check_argmax_argmin("argmin", self, dim)
1677+
dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None)
1678+
shape = _compute_reduction_shape(self, dims, keepdim)
1679+
return self.new_empty(shape, dtype=torch.int64)
1680+
1681+
1682+
@register_meta(aten.scalar_tensor.default)
1683+
def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
1684+
return torch.empty(
1685+
(), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
1686+
)
1687+
1688+
16391689
# We must also trigger meta registrations from PrimTorch ref
16401690
# decompositions
16411691
import torch._refs

torch/testing/_internal/common_methods_invocations.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,6 +1357,14 @@ def sample_inputs_empty(op, device, dtype, requires_grad, **kwargs):
13571357
for case in cases:
13581358
yield SampleInput(case, device=device, dtype=dtype, requires_grad=requires_grad)
13591359

1360+
def sample_inputs_scalar_tensor(op, device, dtype, requires_grad, **kwargs):
1361+
# only ints >= 0 are allowed for both arguments, unless m is omitted
1362+
vals = (-5, 0, 1, torch.tensor(2))
1363+
1364+
for item in vals:
1365+
_kwargs = {'device': device, 'dtype': dtype, 'requires_grad': requires_grad}
1366+
yield SampleInput(item, args=(), kwargs=_kwargs)
1367+
13601368
def sample_inputs_eye(op, device, dtype, requires_grad, **kwargs):
13611369
# only ints >= 0 are allowed for both arguments, unless m is omitted
13621370
sizes = (None, 0, 1, 2, 3, 4, 7, L, M, S)
@@ -14440,6 +14448,27 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1444014448
# UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
1444114449
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
1444214450
)),
14451+
OpInfo('scalar_tensor',
14452+
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
14453+
sample_inputs_func=sample_inputs_scalar_tensor,
14454+
supports_autograd=False,
14455+
supports_out=False,
14456+
skips=(
14457+
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
14458+
# TODO: same as this?
14459+
# https://github.com/pytorch/pytorch/issues/81774
14460+
# also see: arange, new_full
14461+
# fails to match any schemas despite working in the interpreter
14462+
DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
14463+
# fails to match any schemas despite working in the interpreter
14464+
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
14465+
# skip these tests since we have non tensor input
14466+
DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"),
14467+
DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
14468+
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
14469+
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
14470+
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
14471+
)),
1444314472
OpInfo('new_full',
1444414473
op=lambda x, *args, **kwargs: x.new_full(*args, **kwargs),
1444514474
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),

0 commit comments

Comments
 (0)