0

Given data and mask tensors are there a pytorch-way to obtain masked aggregations of data (mean, max, min, etc.)?

x = torch.tensor([
    [1, 2, -1, -1],
    [10, 20, 30, -1]
])

mask = torch.tensor([
    [True, True, False, False],
    [True, True, True, False]
])

To compute a masked mean I can do the following, yet are there any pytorch built-in or commonly used package to do that?

n_mask = torch.sum(mask, axis=1)
x_mean = torch.sum(x * mask, axis=1) / n_mask

print(x_mean)
> tensor([ 1.50, 20.00])
3
  • 2
    You didn't research on google. Yes PyTorch also provides torch.masked.sum , torch.masked.amax, and torch.masked.amin. See pytorch doc Commented Mar 2 at 12:01
  • would you recommend it? it gives a warning that masked tensors are still in prototype stage Commented Mar 2 at 12:08
  • 2
    You can use alternative masked_max = torch.max(x.where(mask, float('-inf')).float(), axis=1).values and masked_min = torch.min(x.where(mask, float('inf')).float(), axis=1).values Commented Mar 2 at 12:55

1 Answer 1

2

If you don't want to use torch.masked due to it being in prototype stage, you can use scatter_reduce to aggregate based on sum, prod, mean, amax and amin.

x = torch.tensor([
    [1, 2, -1, -1],
    [10, 20, 30, -1]
]).float() # note you'll need to cast to float for this to work

mask = torch.tensor([
    [True, True, False, False],
    [True, True, True, False]
])

rows, cols = mask.nonzero().T

for reduction in ['mean', 'sum', 'prod', 'amax', 'amin']:
    output = torch.zeros(x.shape[0], device=x.device, dtype=x.dtype)
    output = output.scatter_reduce(0, rows, x[rows, cols], reduce=reduction, include_self=False)
    print(f"{reduction}\t{output}")
    

# # printed output:
# mean  tensor([ 1.5000, 20.0000])
# sum   tensor([ 3., 60.])
# prod  tensor([2.0000e+00, 6.0000e+03])
# amax  tensor([ 2., 30.])
# amin  tensor([ 1., 10.])
Sign up to request clarification or add additional context in comments.

2 Comments

I am trying to access the documentation for the scatter_reduce, Don't know why this is empty: pytorch.org/docs/stable/generated/torch.scatter_reduce.html
click the link on that page for the in-place version (pytorch.org/docs/stable/generated/…)

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.