Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions aten/src/ATen/native/Integration.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/DimVector.h>
#include <c10/util/Exception.h>

namespace at {
namespace native {
namespace {

// The estimated integral of a function y of x,
// sampled at points (y_1, ..., y_n) that are separated by distance (dx_1, ..., dx_{n-1}),
// is given by the trapezoid rule:
//
// \sum_{i=1}^{n-1} dx_i * (y_i + y_{i+1}) / 2
//
// TODO: if we extend TensorIterator to accept 3 inputs,
// we can probably make this a bit more performant.
Tensor do_trapz(const Tensor& y, const Tensor& dx, int64_t dim) {
Tensor left = y.slice(dim, 0, -1);
Tensor right = y.slice(dim, 1);

return ((left + right) * dx).sum(dim) / 2.;
}

// When dx is constant, the above formula simplifies
// to dx * [(\sum_{i=1}^n y_i) - (y_1 + y_n)/2]
Tensor do_trapz(const Tensor& y, double dx, int64_t dim) {
return (y.sum(dim) - (y.select(dim, 0) + y.select(dim, -1)) * (0.5)) * dx;
}

Tensor zeros_like_except(const Tensor& y, int64_t dim) {
auto sizes = y.sizes().vec();
dim = maybe_wrap_dim(dim, y.dim());
sizes.erase(sizes.begin() + dim);
return at::zeros(sizes, y.options());
}

}

Tensor trapz(const Tensor& y, const Tensor& x, int64_t dim) {
dim = maybe_wrap_dim(dim, y);
// asking for the integral with zero samples is a bit nonsensical,
// but we'll return "0" to match numpy behavior.
if (y.size(dim) == 0) {
return zeros_like_except(y, dim);
}
Tensor x_viewed;
if (x.dim() == 1) {
TORCH_CHECK(x.size(0) == y.size(dim), "trapz: There must be one `x` value for each sample point");
DimVector sizes(y.dim(), 1);
sizes[dim] = x.size(0);
x_viewed = x.view(sizes);
} else {
x_viewed = x;
}
Tensor x_left = x_viewed.slice(dim, 0, -1);
Tensor x_right = x_viewed.slice(dim, 1);

Tensor dx = x_right - x_left;
return do_trapz(y, dx, dim);
}

Tensor trapz(const Tensor& y, double dx, int64_t dim) {
// see above
if (y.size(dim) == 0) {
return zeros_like_except(y, dim);
}
return do_trapz(y, dx, dim);
}

}} // namespace at::native
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1987,6 +1987,10 @@
- func: rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor
variants: function, method

- func: trapz(Tensor y, Tensor x, *, int dim=-1) -> Tensor

- func: trapz(Tensor y, *, float dx=1, int dim=-1) -> Tensor

- func: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor

- func: triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ BLAS and LAPACK Operations
.. autofunction:: solve
.. autofunction:: svd
.. autofunction:: symeig
.. autofunction:: trapz
.. autofunction:: triangular_solve
.. autofunction:: trtrs

Expand Down
8 changes: 8 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2189,6 +2189,14 @@ def test_cat_empty(self):
lambda a, b: torch.cat((a, b)),
True, f_args_variable, f_args_tensor)

def test_trapz(self):
f_args_variable = (torch.randn(2, 3, requires_grad=True),
torch.tensor([[1.0, 2.0, 5.5], [2.3, 0.5, 6.2]], requires_grad=True))
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
run_functional_checks(self, "test_trapz", "trapz",
lambda y, x: torch.trapz(y, x),
True, f_args_variable, f_args_tensor)

def test_cdist(self):
for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
f_args_variable = (torch.randn(S, S, requires_grad=True),
Expand Down
38 changes: 38 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11013,6 +11013,44 @@ def test_multiplication_numpy_scalar(self):
self.assertTrue(r2.dtype == t_dtype)
self.assertTrue(r2.requires_grad)

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_trapz(self):
def test_dx(sizes, dim, dx, device):
t = torch.randn(sizes, device=device)
actual = torch.trapz(t, dx=dx, dim=dim)
expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim)
self.assertEqual(expected.shape, actual.shape)
self.assertTrue(np.allclose(expected, actual.cpu().numpy()))

def test_x(sizes, dim, x, device):
t = torch.randn(sizes, device=device)
actual = torch.trapz(t, x=torch.tensor(x, device=device), dim=dim)
expected = np.trapz(t.cpu().numpy(), x=x, axis=dim)
self.assertEqual(expected.shape, actual.shape)
self.assertTrue(np.allclose(expected, actual.cpu().numpy()))

for device in torch.testing.get_all_device_types():
test_dx((2, 3, 4), 1, 1, device)
test_dx((10, 2), 0, 0.1, device)
test_dx((1, 10), 0, 2.3, device)
test_dx((0, 2), 0, 1.0, device)
test_dx((0, 2), 1, 1.0, device)
test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device)
test_x((10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device)
test_x((1, 10), 0, [1.0], device)
test_x((0, 2), 0, [], device)
test_x((0, 2), 1, [1.0, 2.0], device)
with self.assertRaisesRegex(
IndexError,
'Dimension out of range'):
test_x((2, 3), 2, [], device)
test_dx((2, 3), 2, 1.0, device)
with self.assertRaisesRegex(
RuntimeError,
'There must be one `x` value for each sample point'):
test_x((2, 3), 1, [1.0, 2.0], device)
test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device)

def test_error_msg_type_translation(self):
with self.assertRaisesRegex(
RuntimeError,
Expand Down
45 changes: 45 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6556,6 +6556,51 @@ def merge_dicts(*dicts):
[3, 3]])
""")

add_docstr(torch.trapz,
r"""
.. function:: trapz(y, x, *, dim=-1) -> Tensor

Estimate :math:`\int y\,dx` along `dim`, using the trapezoid rule.

Arguments:
y (Tensor): The values of the function to integrate
x (Tensor): The points at which the function `y` is sampled.
If `x` is not in ascending order, intervals on which it is decreasing
contribute negatively to the estimated integral (i.e., the convention
:math:`\int_a^b f = -\int_b^a f` is followed).
dim (int): The dimension along which to integrate.
By default, use the last dimension.

Returns:
A Tensor with the same shape as the input, except with `dim` removed.
Each element of the returned tensor represents the estimated integral
:math:`\int y\,dx` along `dim`.

Example::

>>> y = torch.randn((2, 3))
>>> y
tensor([[-2.1156, 0.6857, -0.2700],
[-1.2145, 0.5540, 2.0431]])
>>> x = torch.tensor([[1, 3, 4], [1, 2, 3]])
>>> torch.trapz(y, x)
tensor([-1.2220, 0.9683])

.. function:: trapz(y, *, dx=1, dim=-1) -> Tensor

As above, but the sample points are spaced uniformly at a distance of `dx`.

Arguments:
y (Tensor): The values of the function to integrate
dx (float): The distance between points at which `y` is sampled.
dim (int): The dimension along which to integrate.
By default, use the last dimension.

Returns:
A Tensor with the same shape as the input, except with `dim` removed.
Each element of the returned tensor represents the estimated integral
:math:`\int y\,dx` along `dim`.
""")

add_docstr(torch.repeat_interleave,
r"""
Expand Down