Skip to content

Commit abd6cff

Browse files
ifedanfacebook-github-bot
authored andcommitted
Added some extra tests for std_mean and var_mean for multiple dims. (#20650)
Summary: Added some extra tests for std_mean and var_mean for multiple dims. Some refactoring of previously created tests based on PR comments: #18731 Pull Request resolved: #20650 Differential Revision: D15396101 Pulled By: ifedan fbshipit-source-id: d15c3c2c7084a24d6cfea4018173552fcc9c03a9
1 parent fa5263a commit abd6cff

File tree

4 files changed

+66
-31
lines changed

4 files changed

+66
-31
lines changed

test/common_methods_invocations.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -399,16 +399,16 @@ def method_tests():
399399
('std', (S, S, S), (1, True, True), 'keepdim_dim', (True,), [0]),
400400
('std', (S,), (0,), 'dim_1d', (True,), [0]),
401401
('std', (S,), (0, True, True), 'keepdim_dim_1d', (True,), [0]),
402-
('__var_mean__', (S, S, S), NO_ARGS, ''),
403-
('__var_mean__', (S, S, S), (1,), 'dim', [0]),
404-
('__var_mean__', (S, S, S), (1, True, True), 'keepdim_dim', [0]),
405-
('__var_mean__', (S,), (0,), 'dim_1d', [0]),
406-
('__var_mean__', (S,), (0, True, True), 'keepdim_dim_1d', [0]),
407-
('__std_mean__', (S, S, S), NO_ARGS, ''),
408-
('__std_mean__', (S, S, S), (1,), 'dim', [0]),
409-
('__std_mean__', (S, S, S), (1, True, True), 'keepdim_dim', [0]),
410-
('__std_mean__', (S,), (0,), 'dim_1d', [0]),
411-
('__std_mean__', (S,), (0, True, True), 'keepdim_dim_1d', [0]),
402+
('var_mean', (S, S, S), NO_ARGS, ''),
403+
('var_mean', (S, S, S), (1,), 'dim', [0]),
404+
('var_mean', (S, S, S), (1, True, True), 'keepdim_dim', [0]),
405+
('var_mean', (S,), (0,), 'dim_1d', [0]),
406+
('var_mean', (S,), (0, True, True), 'keepdim_dim_1d', [0]),
407+
('std_mean', (S, S, S), NO_ARGS, ''),
408+
('std_mean', (S, S, S), (1,), 'dim', [0]),
409+
('std_mean', (S, S, S), (1, True, True), 'keepdim_dim', [0]),
410+
('std_mean', (S,), (0,), 'dim_1d', [0]),
411+
('std_mean', (S,), (0, True, True), 'keepdim_dim_1d', [0]),
412412
('renorm', (S, S, S), (2, 1, 0.5), 'dim', (), [1]),
413413
('renorm', (S, S, S), (1, 2, 3), 'norm_1'),
414414
('renorm', (S, S, S), (inf, 2, 0.5), 'norm_inf'),
@@ -1086,6 +1086,16 @@ def exclude_tensor_method(name, test_name):
10861086
'test_where_scalar',
10871087
'test_where_scalar_broadcast_mask',
10881088
'test_where_scalar_broadcast_non_mask',
1089+
'test_var_mean_keepdim_dim_1d',
1090+
'test_var_mean_keepdim_dim',
1091+
'test_var_mean_dim_1d',
1092+
'test_var_mean_dim',
1093+
'test_var_mean',
1094+
'test_std_mean_keepdim_dim_1d',
1095+
'test_std_mean_keepdim_dim',
1096+
'test_std_mean_dim_1d',
1097+
'test_std_mean_dim',
1098+
'test_std_mean',
10891099
}
10901100
# there are no out-of-place tensor equivalents for these
10911101
exclude_outplace_tensor_method = {

test/test_autograd.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3314,21 +3314,21 @@ def check(name):
33143314
args_variable, kwargs_variable = create_input(args, requires_grad=not is_inplace, call_kwargs=kwargs)
33153315
self_tensor = deepcopy(self_variable.data)
33163316
args_tensor = deepcopy(unpack_variables(args_variable))
3317-
output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
33183317
if not exclude_tensor_method(name, test_name):
3318+
output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
33193319
output_tensor = getattr(self_tensor, name)(*args_tensor, **kwargs_variable)
33203320
if not isinstance(output_tensor, torch.Tensor) and not istuple(output_tensor):
33213321
output_tensor = torch.DoubleTensor((output_tensor,))
33223322
self.assertEqual(unpack_variables(output_variable), output_tensor)
33233323
# TODO: check that both have changed after adding all inplace ops
33243324

3325-
def fn(*inputs):
3326-
output = getattr(inputs[0], name)(*inputs[1:], **kwargs)
3327-
return output_process_fn(output)
3325+
def fn(*inputs):
3326+
output = getattr(inputs[0], name)(*inputs[1:], **kwargs)
3327+
return output_process_fn(output)
33283328

3329-
if not is_inplace and name not in EXCLUDE_GRADCHECK:
3330-
run_grad_and_gradgrad_checks(self, name, test_name, fn,
3331-
output_variable, (self_variable,) + args_variable)
3329+
if not is_inplace and name not in EXCLUDE_GRADCHECK:
3330+
run_grad_and_gradgrad_checks(self, name, test_name, fn,
3331+
output_variable, (self_variable,) + args_variable)
33323332

33333333
# functional interface tests
33343334
if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL:
@@ -3339,14 +3339,19 @@ def fn(*inputs):
33393339
f_args_variable = (self_variable,) + args_variable
33403340
f_args_tensor = (self_tensor,) + args_tensor
33413341
# could run the gradchecks again, but skip since we did it for the methods above.
3342+
run_gradcheck = exclude_tensor_method(name, test_name) and not is_inplace and name not in EXCLUDE_GRADCHECK
33423343
run_functional_checks(self, test_name, name, fn,
3343-
False, f_args_variable, f_args_tensor)
3344+
run_gradcheck, f_args_variable, f_args_tensor)
33443345

33453346
# check for correct type of input.data and input.grad.data
33463347
if not is_inplace:
33473348
self_variable = create_input((self_size,), requires_grad=True)[0][0]
33483349
args_variable, kwargs_variable = create_input(args, requires_grad=False, call_kwargs=kwargs)
3349-
output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
3350+
if hasattr(self_variable, name):
3351+
output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
3352+
else:
3353+
self_and_args_variable = (self_variable,) + args_variable
3354+
output_variable = getattr(torch, name)(*self_and_args_variable, **kwargs_variable)
33503355
if isinstance(output_variable, torch.autograd.Variable):
33513356
if output_variable.is_sparse:
33523357
rand = randn_like(output_variable.to_dense()).to_sparse()

test/test_torch.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2656,6 +2656,38 @@ def test_var_mean_all_dims(self):
26562656
self.assertEqual(var1, var2)
26572657
self.assertEqual(mean1, mean2)
26582658

2659+
def test_std_mean_some_dims(self):
2660+
sizes = (4, 6, 7, 5, 3)
2661+
dims = len(sizes)
2662+
for device in torch.testing.get_all_device_types():
2663+
x = torch.rand(sizes, device=device)
2664+
for num_of_dims in range(2, dims):
2665+
dim_list = list(combinations(list(range(dims)), r=num_of_dims))
2666+
for dim in dim_list:
2667+
for unbiased in [False, True]:
2668+
for keepdim in [False, True]:
2669+
std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2670+
std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim)
2671+
mean2 = x.mean(dim=dim, keepdim=keepdim)
2672+
self.assertEqual(std1, std2)
2673+
self.assertEqual(mean1, mean2)
2674+
2675+
def test_var_mean_some_dims(self):
2676+
sizes = (4, 6, 7, 5, 3)
2677+
dims = len(sizes)
2678+
for device in torch.testing.get_all_device_types():
2679+
x = torch.rand(sizes, device=device)
2680+
for num_of_dims in range(2, dims):
2681+
dim_list = list(combinations(list(range(dims)), r=num_of_dims))
2682+
for dim in dim_list:
2683+
for unbiased in [False, True]:
2684+
for keepdim in [False, True]:
2685+
var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2686+
var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim)
2687+
mean2 = x.mean(dim=dim, keepdim=keepdim)
2688+
self.assertEqual(var1, var2)
2689+
self.assertEqual(mean1, mean2)
2690+
26592691
def test_zeros_like(self):
26602692
expected = torch.zeros(100, 100)
26612693

torch/tensor.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -425,18 +425,6 @@ def __rfloordiv__(self, other):
425425
__ge__ = _C._TensorBase.ge
426426
__abs__ = _C._TensorBase.abs
427427

428-
def __std_mean__(self, dim=None, unbiased=True, keepdim=False):
429-
if dim is None:
430-
return _C._VariableFunctions.std_mean(self, unbiased)
431-
else:
432-
return _C._VariableFunctions.std_mean(self, dim, unbiased, keepdim)
433-
434-
def __var_mean__(self, dim=None, unbiased=True, keepdim=False):
435-
if dim is None:
436-
return _C._VariableFunctions.var_mean(self, unbiased)
437-
else:
438-
return _C._VariableFunctions.var_mean(self, dim, unbiased, keepdim)
439-
440428
def __len__(self):
441429
if self.dim() == 0:
442430
raise TypeError("len() of a 0-d tensor")

0 commit comments

Comments
 (0)