Given data and mask tensors are there a pytorch-way to obtain masked aggregations of data (mean, max, min, etc.)?
x = torch.tensor([
[1, 2, -1, -1],
[10, 20, 30, -1]
])
mask = torch.tensor([
[True, True, False, False],
[True, True, True, False]
])
To compute a masked mean I can do the following, yet are there any pytorch built-in or commonly used package to do that?
n_mask = torch.sum(mask, axis=1)
x_mean = torch.sum(x * mask, axis=1) / n_mask
print(x_mean)
> tensor([ 1.50, 20.00])
torch.masked.sum,torch.masked.amax, andtorch.masked.amin. See pytorch docmasked_max = torch.max(x.where(mask, float('-inf')).float(), axis=1).valuesandmasked_min = torch.min(x.where(mask, float('inf')).float(), axis=1).values