Skip to content

Commit dfdae84

Browse files
committed
Use scipy implementation of trapz if available
1 parent 45dfd60 commit dfdae84

File tree

1 file changed

+35
-28
lines changed

1 file changed

+35
-28
lines changed

quantities/umath.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -200,37 +200,44 @@ def trapz(y, x=None, dx=1.0, axis=-1):
200200
else:
201201
ret = _trapz(y.magnitude , x.magnitude, dx.magnitude, axis)
202202
return Quantity ( ret, y.units * x.units)
203-
203+
204204
def _trapz(y, x, dx, axis):
205205
"""ported from numpy 1.26 since it will be deprecated and removed"""
206-
from numpy.core.numeric import asanyarray
207-
from numpy.core.umath import add
208-
y = asanyarray(y)
209-
if x is None:
210-
d = dx
211-
else:
212-
x = asanyarray(x)
213-
if x.ndim == 1:
214-
d = diff(x)
215-
# reshape to correct shape
216-
shape = [1]*y.ndim
217-
shape[axis] = d.shape[0]
218-
d = d.reshape(shape)
219-
else:
220-
d = diff(x, axis=axis)
221-
nd = y.ndim
222-
slice1 = [slice(None)]*nd
223-
slice2 = [slice(None)]*nd
224-
slice1[axis] = slice(1, None)
225-
slice2[axis] = slice(None, -1)
226206
try:
227-
ret = (d * (y[tuple(slice1)] + y[tuple(slice2)]) / 2.0).sum(axis)
228-
except ValueError:
229-
# Operations didn't work, cast to ndarray
230-
d = np.asarray(d)
231-
y = np.asarray(y)
232-
ret = add.reduce(d * (y[tuple(slice1)]+y[tuple(slice2)])/2.0, axis)
233-
return ret
207+
# if scipy is available, we use it
208+
from scipy.integrate import trapezoid # type: ignore
209+
except ImportError:
210+
# otherwise we use the implementation ported from numpy 1.26
211+
from numpy.core.numeric import asanyarray
212+
from numpy.core.umath import add
213+
y = asanyarray(y)
214+
if x is None:
215+
d = dx
216+
else:
217+
x = asanyarray(x)
218+
if x.ndim == 1:
219+
d = diff(x)
220+
# reshape to correct shape
221+
shape = [1]*y.ndim
222+
shape[axis] = d.shape[0]
223+
d = d.reshape(shape)
224+
else:
225+
d = diff(x, axis=axis)
226+
nd = y.ndim
227+
slice1 = [slice(None)]*nd
228+
slice2 = [slice(None)]*nd
229+
slice1[axis] = slice(1, None)
230+
slice2[axis] = slice(None, -1)
231+
try:
232+
ret = (d * (y[tuple(slice1)] + y[tuple(slice2)]) / 2.0).sum(axis)
233+
except ValueError:
234+
# Operations didn't work, cast to ndarray
235+
d = np.asarray(d)
236+
y = np.asarray(y)
237+
ret = add.reduce(d * (y[tuple(slice1)]+y[tuple(slice2)])/2.0, axis)
238+
return ret
239+
else:
240+
return trapezoid(y, x=x, dx=dx, axis=axis)
234241

235242
@with_doc(np.sin)
236243
def sin(x, out=None):

0 commit comments

Comments
 (0)