@@ -208,7 +208,7 @@ class ErrorInput(object):
208208
209209 __slots__ = ['sample_input', 'error_type', 'error_regex']
210210
211- def __init__(self, sample_input, *, error_type, error_regex):
211+ def __init__(self, sample_input, *, error_type=RuntimeError , error_regex):
212212 self.sample_input = sample_input
213213 self.error_type = error_type
214214 self.error_regex = error_regex
@@ -1474,8 +1474,8 @@ def error_inputs_hsplit(op_info, device, **kwargs):
14741474 dtype=torch.float32,
14751475 device=device),
14761476 args=(0,),)
1477- return (ErrorInput(si1, error_type=RuntimeError, error_regex=err_msg1),
1478- ErrorInput(si2, error_type=RuntimeError, error_regex=err_msg2),)
1477+ return (ErrorInput(si1, error_regex=err_msg1),
1478+ ErrorInput(si2, error_regex=err_msg2),)
14791479
14801480def error_inputs_vsplit(op_info, device, **kwargs):
14811481 err_msg1 = ("torch.vsplit requires a tensor with at least 2 dimension, "
@@ -1491,8 +1491,8 @@ def error_inputs_vsplit(op_info, device, **kwargs):
14911491 dtype=torch.float32,
14921492 device=device),
14931493 args=(0,),)
1494- return (ErrorInput(si1, error_type=RuntimeError, error_regex=err_msg1),
1495- ErrorInput(si2, error_type=RuntimeError, error_regex=err_msg2),)
1494+ return (ErrorInput(si1, error_regex=err_msg1),
1495+ ErrorInput(si2, error_regex=err_msg2),)
14961496
14971497def error_inputs_dsplit(op_info, device, **kwargs):
14981498 err_msg1 = ("torch.dsplit requires a tensor with at least 3 dimension, "
@@ -1508,8 +1508,8 @@ def error_inputs_dsplit(op_info, device, **kwargs):
15081508 dtype=torch.float32,
15091509 device=device),
15101510 args=(0,),)
1511- return (ErrorInput(si1, error_type=RuntimeError, error_regex=err_msg1),
1512- ErrorInput(si2, error_type=RuntimeError, error_regex=err_msg2),)
1511+ return (ErrorInput(si1, error_regex=err_msg1),
1512+ ErrorInput(si2, error_regex=err_msg2),)
15131513
15141514def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwargs):
15151515 # Each test case consists of the sizes in the chain of multiplications
@@ -3060,12 +3060,12 @@ def error_inputs_gather(op_info, device, **kwargs):
30603060
30613061 # Index should be smaller than self except on dimesion 1
30623062 bad_src = make_tensor((1, 1), device=device, dtype=torch.float32)
3063- yield ErrorInput(SampleInput(bad_src, args=(1, idx,)), error_type=RuntimeError,
3063+ yield ErrorInput(SampleInput(bad_src, args=(1, idx,)),
30643064 error_regex="Size does not match at dimension 0")
30653065
30663066 # Index must have long dtype
30673067 bad_idx = idx.to(torch.int32)
3068- yield ErrorInput(SampleInput(src, args=(1, bad_idx)), error_type=RuntimeError,
3068+ yield ErrorInput(SampleInput(src, args=(1, bad_idx)),
30693069 error_regex="Expected dtype int64 for index")
30703070
30713071 # TODO: FIXME
@@ -3074,28 +3074,28 @@ def error_inputs_gather(op_info, device, **kwargs):
30743074 src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
30753075 idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
30763076 out = torch.empty((2, 2), device=device, dtype=torch.float64)
3077- yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}), error_type=RuntimeError,
3077+ yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}),
30783078 error_regex="Expected out tensor to have dtype")
30793079
30803080 # src and index tensors must have the same # of dimensions
30813081 # idx too few dimensions
30823082 src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
30833083 idx = torch.tensor((0, 0), device=device, dtype=torch.long)
3084- yield ErrorInput(SampleInput(src, args=(1, idx)), error_type=RuntimeError,
3084+ yield ErrorInput(SampleInput(src, args=(1, idx)),
30853085 error_regex="Index tensor must have the same number of dimensions")
30863086
30873087 # src too few dimensions
30883088 src = torch.tensor((1, 2), device=device, dtype=torch.float32)
30893089 idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
3090- yield ErrorInput(SampleInput(src, args=(0, idx)), error_type=RuntimeError,
3090+ yield ErrorInput(SampleInput(src, args=(0, idx)),
30913091 error_regex="Index tensor must have the same number of dimensions")
30923092
30933093 # index out of bounds
30943094 # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices
30953095 if torch.device(device).type == 'cpu':
30963096 src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
30973097 idx = torch.tensor(((0, 23), (1, 0)), device=device, dtype=torch.long)
3098- yield ErrorInput(SampleInput(src, args=(1, idx,)), error_type=RuntimeError,
3098+ yield ErrorInput(SampleInput(src, args=(1, idx,)),
30993099 error_regex="index 23 is out of bounds for dimension")
31003100
31013101# Error inputs for scatter
@@ -3104,28 +3104,28 @@ def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs):
31043104 src = make_tensor((2, 5), device=device, dtype=torch.float32)
31053105 idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
31063106 dst = torch.zeros((3, 5), device=device, dtype=torch.double)
3107- yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3107+ yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
31083108 error_regex="Expected self.dtype to be equal to src.dtype")
31093109
31103110 # Index dtype must be long
31113111 src = make_tensor((2, 5), device=device, dtype=torch.float32)
31123112 idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.int32)
31133113 dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
3114- yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3114+ yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
31153115 error_regex="Expected dtype int64 for index")
31163116
31173117 # Index and destination must have the same number of dimensions
31183118 src = make_tensor((2, 5), device=device, dtype=torch.float32)
31193119 idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
31203120 dst = torch.zeros((3, 5, 3), device=device, dtype=torch.float32)
3121- yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3121+ yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
31223122 error_regex="Index tensor must have the same number of dimensions as self tensor")
31233123
31243124 # Index and src must have the same number of dimensions when src is not a scalar
31253125 src = make_tensor((2, 5, 2), device=device, dtype=torch.float32)
31263126 idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long)
31273127 dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
3128- yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3128+ yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
31293129 error_regex="Index tensor must have the same number of dimensions as src tensor")
31303130
31313131 # Index out of bounds
@@ -3134,7 +3134,7 @@ def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs):
31343134 src = make_tensor((2, 5), device=device, dtype=torch.float32)
31353135 idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long)
31363136 dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
3137- yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
3137+ yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
31383138 error_regex="index 34 is out of bounds for dimension 0 with size 3")
31393139
31403140def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs):
@@ -5532,6 +5532,39 @@ def sample_inputs_cov(op_info, device, dtype, requires_grad, **kwargs):
55325532 return inputs
55335533
55345534
5535+ def error_inputs_cov(op_info, device, **kwargs):
5536+ a = torch.rand(S, device=device)
5537+ error_inputs = []
5538+ error_inputs.append(ErrorInput(
5539+ SampleInput(torch.rand(S, S, S, device=device)),
5540+ error_regex="expected input to have two or fewer dimensions"))
5541+ error_inputs.append(ErrorInput(
5542+ SampleInput(a, kwargs={'fweights': torch.rand(S, S, device=device)}),
5543+ error_regex="expected fweights to have one or fewer dimensions"))
5544+ error_inputs.append(ErrorInput(
5545+ SampleInput(a, kwargs={'aweights': torch.rand(S, S, device=device)}),
5546+ error_regex="expected aweights to have one or fewer dimensions"))
5547+ error_inputs.append(ErrorInput(
5548+ SampleInput(a, kwargs={'fweights': torch.rand(S, device=device)}),
5549+ error_regex="expected fweights to have integral dtype"))
5550+ error_inputs.append(ErrorInput(
5551+ SampleInput(a, kwargs={'aweights': torch.tensor([1, 1], device=device)}),
5552+ error_regex="expected aweights to have floating point dtype"))
5553+ error_inputs.append(ErrorInput(
5554+ SampleInput(a, kwargs={'fweights': torch.tensor([1], device=device)}),
5555+ error_regex="expected fweights to have the same numel"))
5556+ error_inputs.append(ErrorInput(
5557+ SampleInput(a, kwargs={'aweights': torch.rand(1, device=device)}),
5558+ error_regex="expected aweights to have the same numel"))
5559+ error_inputs.append(ErrorInput(
5560+ SampleInput(a, kwargs={'fweights': torch.tensor([-1, -2, -3, -4 , -5], device=device)}),
5561+ error_regex="fweights cannot be negative"))
5562+ error_inputs.append(ErrorInput(
5563+ SampleInput(a, kwargs={'aweights': torch.tensor([-1., -2., -3., -4., -5.], device=device)}),
5564+ error_regex="aweights cannot be negative"))
5565+ return error_inputs
5566+
5567+
55355568def sample_inputs_svd(op_info, device, dtype, requires_grad=False, **kwargs):
55365569 make_fullrank = make_fullrank_matrices_with_distinct_singular_values
55375570 make_arg = partial(make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad)
@@ -5967,7 +6000,7 @@ def error_inputs_neg(op_info, device, **kwargs):
59676000 msg = ("Negation, the `\\-` operator, on a bool tensor is not supported."
59686001 " If you are trying to invert a mask, use the `\\~` or"
59696002 " `logical_not\\(\\)` operator instead.")
5970- return (ErrorInput(si, error_type=RuntimeError, error_regex=msg),)
6003+ return (ErrorInput(si, error_regex=msg),)
59716004
59726005def sample_inputs_nextafter(op_info, device, dtype, requires_grad, **kwargs):
59736006 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
@@ -7104,7 +7137,7 @@ def error_inputs_where(op_info, device, **kwargs):
71047137 si = SampleInput(make_tensor(shape, device=devices[0], dtype=torch.float32),
71057138 args=(make_tensor(shape, dtype=torch.bool, device=devices[1]),
71067139 make_tensor(shape, device=devices[2], dtype=torch.float32)))
7107- yield ErrorInput(si, error_type=RuntimeError, error_regex=err_msg)
7140+ yield ErrorInput(si, error_regex=err_msg)
71087141
71097142def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs):
71107143 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
@@ -7164,13 +7197,13 @@ def error_inputs_kthvalue(op_info, device, **kwargs):
71647197 si = SampleInput(t, args=(5,), kwargs={'out': (t, indices)})
71657198
71667199 k_out_of_range_err = "selected number k out of range for dimension"
7167- return (ErrorInput(si, error_type=RuntimeError, error_regex="unsupported operation"),
7200+ return (ErrorInput(si, error_regex="unsupported operation"),
71687201 ErrorInput(SampleInput(torch.randn(2, 2, device=device), args=(3, 0)),
7169- error_type=RuntimeError, error_regex=k_out_of_range_err),
7202+ error_regex=k_out_of_range_err),
71707203 ErrorInput(SampleInput(torch.randn(2, 2, device=device), args=(3,)),
7171- error_type=RuntimeError, error_regex=k_out_of_range_err),
7204+ error_regex=k_out_of_range_err),
71727205 ErrorInput(SampleInput(torch.tensor(2, device=device), args=(3,)),
7173- error_type=RuntimeError, error_regex=k_out_of_range_err),)
7206+ error_regex=k_out_of_range_err),)
71747207
71757208def sample_inputs_dropout(op_info, device, dtype, requires_grad, *,
71767209 train=None, valid_input_dim=None, **kwargs):
@@ -9087,6 +9120,7 @@ def ref_pairwise_distance(input1, input2):
90879120 backward_dtypesIfCUDA=all_types_and_complex_and(torch.half, *[torch.bfloat16]
90889121 if (CUDA11OrLater or TEST_WITH_ROCM) else []),
90899122 sample_inputs_func=sample_inputs_cov,
9123+ error_inputs_func=error_inputs_cov,
90909124 supports_out=False,
90919125 supports_forward_ad=True,
90929126 supports_fwgrad_bwgrad=True,
0 commit comments