Skip to content

Commit e32259b

Browse files
committed
port numpy trapz into umath
1 parent 6b745f4 commit e32259b

File tree

1 file changed

+92
-4
lines changed

1 file changed

+92
-4
lines changed

quantities/umath.py

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,73 @@ def cross (a, b , axisa=-1, axisb=-1, axisc=-1, axis=None):
119119
copy=False
120120
)
121121

122-
@with_doc(np.trapz)
122+
123123
def trapz(y, x=None, dx=1.0, axis=-1):
124+
"""
125+
Integrate along the given axis using the composite trapezoidal rule.
126+
127+
If `x` is provided, the integration happens in sequence along its
128+
elements - they are not sorted.
129+
130+
Integrate `y` (`x`) along each 1d slice on the given axis, compute
131+
:math:`\int y(x) dx`.
132+
When `x` is specified, this integrates along the parametric curve,
133+
computing :math:`\int_t y(t) dt =
134+
\int_t y(t) \left.\frac{dx}{dt}\right|_{x=x(t)} dt`.
135+
136+
Parameters
137+
----------
138+
y : array_like
139+
Input array to integrate.
140+
x : array_like, optional
141+
The sample points corresponding to the `y` values. If `x` is None,
142+
the sample points are assumed to be evenly spaced `dx` apart. The
143+
default is None.
144+
dx : scalar, optional
145+
The spacing between sample points when `x` is None. The default is 1.
146+
axis : int, optional
147+
The axis along which to integrate.
148+
149+
Returns
150+
-------
151+
trapz : float or ndarray
152+
Definite integral of `y` = n-dimensional array as approximated along
153+
a single axis by the trapezoidal rule. If `y` is a 1-dimensional array,
154+
then the result is a float. If `n` is greater than 1, then the result
155+
is an `n`-1 dimensional array.
156+
157+
See Also
158+
--------
159+
sum, cumsum
160+
161+
Notes
162+
-----
163+
Image [2]_ illustrates trapezoidal rule -- y-axis locations of points
164+
will be taken from `y` array, by default x-axis distances between
165+
points will be 1.0, alternatively they can be provided with `x` array
166+
or with `dx` scalar. Return value will be equal to combined area under
167+
the red lines.
168+
169+
Docstring is from the numpy 1.26 code base
170+
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/function_base.py#L4857-L4984
171+
172+
173+
References
174+
----------
175+
.. [1] Wikipedia page: https://en.wikipedia.org/wiki/Trapezoidal_rule
176+
177+
.. [2] Illustration image:
178+
https://en.wikipedia.org/wiki/File:Composite_trapezoidal_rule_illustration.png
179+
180+
"""
124181
# this function has a weird input structure, so it is tricky to wrap it
125182
# perhaps there is a simpler way to do this
126183
if (
127184
not isinstance(y, Quantity)
128185
and not isinstance(x, Quantity)
129186
and not isinstance(dx, Quantity)
130187
):
131-
return np.trapz(y, x, dx, axis)
188+
return _trapz(y, x, dx, axis)
132189

133190
if not isinstance(y, Quantity):
134191
y = Quantity(y, copy = False)
@@ -138,11 +195,42 @@ def trapz(y, x=None, dx=1.0, axis=-1):
138195
dx = Quantity(dx, copy = False)
139196

140197
if x is None:
141-
ret = np.trapz(y.magnitude , x, dx.magnitude, axis)
198+
ret = _trapz(y.magnitude , x, dx.magnitude, axis)
142199
return Quantity ( ret, y.units * dx.units)
143200
else:
144-
ret = np.trapz(y.magnitude , x.magnitude, dx.magnitude, axis)
201+
ret = _trapz(y.magnitude , x.magnitude, dx.magnitude, axis)
145202
return Quantity ( ret, y.units * x.units)
203+
204+
def _trapz(y, x, dx, axis):
205+
"""ported from numpy 1.26 since it will be deprecated and removed"""
206+
from numpy.core.numeric import asanyarray
207+
from numpy.core.umath import add
208+
y = asanyarray(y)
209+
if x is None:
210+
d = dx
211+
else:
212+
x = asanyarray(x)
213+
if x.ndim == 1:
214+
d = diff(x)
215+
# reshape to correct shape
216+
shape = [1]*y.ndim
217+
shape[axis] = d.shape[0]
218+
d = d.reshape(shape)
219+
else:
220+
d = diff(x, axis=axis)
221+
nd = y.ndim
222+
slice1 = [slice(None)]*nd
223+
slice2 = [slice(None)]*nd
224+
slice1[axis] = slice(1, None)
225+
slice2[axis] = slice(None, -1)
226+
try:
227+
ret = (d * (y[tuple(slice1)] + y[tuple(slice2)]) / 2.0).sum(axis)
228+
except ValueError:
229+
# Operations didn't work, cast to ndarray
230+
d = np.asarray(d)
231+
y = np.asarray(y)
232+
ret = add.reduce(d * (y[tuple(slice1)]+y[tuple(slice2)])/2.0, axis)
233+
return ret
146234

147235
@with_doc(np.sin)
148236
def sin(x, out=None):

0 commit comments

Comments
 (0)