Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
739105b
Implement correction argument in torch.masked.{std,var}
peterbell10 Oct 17, 2022
77076b4
Update on "Implement correction argument in torch.masked.{std,var}"
peterbell10 Oct 17, 2022
bc303ba
Update on "Implement correction argument in torch.masked.{std,var}"
peterbell10 Oct 19, 2022
73fac76
Fix lint on "Implement correction argument in torch.masked.{std,var}"
peterbell10 Oct 19, 2022
c4e0b92
Update on "Implement correction argument in torch.masked.{std,var}"
peterbell10 Oct 20, 2022
904d8e2
Rebase and fix merge conflicts on "Implement correction argument in t…
peterbell10 Oct 21, 2022
f434ae1
Split std/var opinfos into two on "Implement correction argument in t…
peterbell10 Nov 1, 2022
d02106e
Update on "Implement correction argument in torch.masked.{std,var}"
peterbell10 Nov 1, 2022
4bcb6c7
Update on "Implement correction argument in torch.masked.{std,var}"
peterbell10 Nov 2, 2022
33936b5
Update on "Implement correction argument in torch.masked.{std,var}"
peterbell10 Nov 2, 2022
b2a9b6d
Update on "Implement correction argument in torch.masked.{std,var}"
peterbell10 Nov 3, 2022
5427e78
Update on "Implement correction argument in torch.masked.{std,var}"
peterbell10 Nov 3, 2022
bb924df
Update on "Implement correction argument in torch.masked.{std,var}"
peterbell10 Nov 7, 2022
d5b9c6b
Rebase on "Implement correction argument in torch.masked.{std,var}"
peterbell10 Dec 4, 2022
2f85a57
Update on "Implement correction argument in torch.masked.{std,var}"
peterbell10 Dec 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/inductor/test_torchinductor_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def process(device_type):
"linalg.pinv.singular": {f32, f64},
"masked.norm": {f16},
"masked.normalize": {f16},
"masked.var": {f16},
"masked_fill": {f16},
"masked_scatter": {f16, f32, f64},
"masked_select": {b8, f16, f32, f64, i32, i64},
Expand Down
32 changes: 22 additions & 10 deletions torch/masked/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,14 +1538,22 @@ def norm(

def _std_var(
input: Union[Tensor, MaskedTensor],
dim: DimOrDims = None,
unbiased: Optional[bool] = False,
dim: DimOrDims,
unbiased: Optional[bool],
*,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,
mask: Optional[Tensor] = None,
take_sqrt: Optional[bool] = False,
correction: Optional[int],
keepdim: Optional[bool],
dtype: Optional[DType],
mask: Optional[Tensor],
take_sqrt: Optional[bool],
) -> Tensor:
assert (unbiased is None or correction is None), "Only one of unbiased and correction may be given"
correction_int = 1
if unbiased is not None:
correction_int = 1 if unbiased else 0
if correction is not None:
correction_int = correction

if dtype is None:
dtype = input.dtype
if not (dtype.is_floating_point or dtype.is_complex):
Expand Down Expand Up @@ -1584,8 +1592,8 @@ def _std_var(
)
if not keepdim:
count = count.reshape(total.shape)
if unbiased:
count = torch.subtract(count, 1)
if correction_int != 0:
count = torch.subtract(count, correction_int)
count = torch.maximum(count, count.new_zeros([]))
output = torch.divide(total, count).to(dtype=dtype)
if take_sqrt:
Expand All @@ -1601,8 +1609,9 @@ def _std_var(
def var(
input: Union[Tensor, MaskedTensor],
dim: DimOrDims = None,
unbiased: Optional[bool] = False,
unbiased: Optional[bool] = None,
*,
correction: Optional[int] = None,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,
mask: Optional[Tensor] = None,
Expand All @@ -1619,6 +1628,7 @@ def var(
input=input,
dim=dim,
unbiased=unbiased,
correction=correction,
keepdim=keepdim,
dtype=dtype,
mask=mask,
Expand All @@ -1630,8 +1640,9 @@ def var(
def std(
input: Union[Tensor, MaskedTensor],
dim: DimOrDims = None,
unbiased: Optional[bool] = False,
unbiased: Optional[bool] = None,
*,
correction: Optional[int] = None,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,
mask: Optional[Tensor] = None,
Expand All @@ -1648,6 +1659,7 @@ def std(
input=input,
dim=dim,
unbiased=unbiased,
correction=correction,
keepdim=keepdim,
dtype=dtype,
mask=mask,
Expand Down
133 changes: 92 additions & 41 deletions torch/testing/_internal/opinfo/definitions/_masked.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from collections.abc import Sequence
from functools import partial
from typing import List

Expand Down Expand Up @@ -223,51 +224,101 @@ def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
)


def reference_masked_std_var(
numpy_fn,
):
ref = reference_reduction_numpy(numpy_fn)

# Translate unbiased or correction arguments into ddof
def func(
input,
dim=None,
unbiased=None,
*,
correction=None,
**kwargs,
):
ddof = 1
if unbiased is not None:
ddof = 1 if unbiased else 0
if correction is not None:
ddof = correction

if isinstance(dim, Sequence):
dim = tuple(dim)

return ref(input, dim, ddof=ddof, **kwargs)

return func


def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for masked std/var."""
for unbiased in [False, True]:
for sample_input in sample_inputs_masked_reduction(
kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims
from torch.testing._internal.common_methods_invocations import sample_inputs_std_var

def masked_samples():
for sample_input in sample_inputs_std_var(
op_info, device, dtype, requires_grad, **kwargs
):
if sample_input.args:
dim = sample_input.args[0]
sample_input_args = (
sample_input.args[:1] + (unbiased,) + sample_input.args[1:]
if len(sample_input.args) and isinstance(sample_input.args[0], bool):
continue # masked.{std, var} doesn't support `.var(unbiased)`

for mask in _generate_masked_op_mask(
sample_input.input.shape, device, **kwargs
):
sample_input_args, sample_input_kwargs = sample_input.args, dict(
mask=mask, **sample_input.kwargs
)
sample_input_kwargs = sample_input.kwargs.copy()
else:
dim = sample_input.kwargs.get("dim")
sample_input_args = sample_input.args
sample_input_kwargs = dict(sample_input.kwargs, unbiased=unbiased)
if requires_grad:
if sample_input_kwargs.get("mask") is None:
orig_count = torch.masked.sum(
torch.ones(sample_input.input.shape, dtype=torch.int64),
dim,
keepdim=True,
)
else:
inmask = torch.masked._input_mask(
sample_input.input, *sample_input_args, **sample_input_kwargs
)
orig_count = torch.masked.sum(
inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
dim,
keepdim=True,
mask=inmask,
)
if orig_count.min() <= int(unbiased) + 1:
# Skip samples that lead to singularities in var
# computation resulting nan values both in var and
# autograd output that test_grad_fn cannot handle
# correctly. Also, skip samples when the autograd output
# for std could not be handled correctly due to torch.sqrt
continue
yield SampleInput(
sample_input.input.detach().requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
yield SampleInput(
sample_input.input.detach().requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
)
if (
not requires_grad
and dtype.is_floating_point
and sample_input.input.ndim == 2
and mask is not None
and mask.shape == sample_input.input.shape
):
for v in [torch.inf, -torch.inf, torch.nan]:
t = sample_input.input.detach()
t.diagonal(0, -2, -1).fill_(v)
yield SampleInput(
t.requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
)

for sample_input in masked_samples():
correction = sample_input.kwargs.get("correction")
if correction is None:
correction = int(sample_input.kwargs.get("unbiased", True))

dim = sample_input.kwargs.get("dim", None)

if sample_input.kwargs.get("mask") is None:
orig_count = torch.masked.sum(
torch.ones(sample_input.input.shape, dtype=torch.int64),
dim,
keepdim=True,
)
else:
inmask = torch.masked._input_mask(
sample_input.input, *sample_input.args, **sample_input.kwargs
)
orig_count = torch.masked.sum(
inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
dim,
keepdim=True,
mask=inmask,
)
if orig_count.min() <= correction + 1:
# Skip samples that lead to nans in var computation
continue

yield sample_input


def sample_inputs_masked_softmax(
Expand Down Expand Up @@ -860,7 +911,7 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar
),
ReductionOpInfo(
"masked.var",
ref=reference_reduction_numpy(np.var)
ref=reference_masked_std_var(np.var)
if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
else None,
method_variant=None,
Expand Down Expand Up @@ -938,7 +989,7 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar
),
ReductionOpInfo(
"masked.std",
ref=reference_reduction_numpy(np.std)
ref=reference_masked_std_var(np.std)
if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
else None,
method_variant=None,
Expand Down
5 changes: 0 additions & 5 deletions torch/testing/_internal/opinfo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,6 @@ def wrapper(x: np.ndarray, *args, **kwargs):
identity = identity.cpu()
kwargs["initial"] = identity.numpy()

if "unbiased" in keys:
unbiased = kwargs.pop("unbiased")
if unbiased is not None:
kwargs["ddof"] = int(unbiased)

result = f(x, *args, **kwargs)

# Unsqueeze reduced dimensions if NumPy does not support keepdims
Expand Down