Skip to content

Commit bb96cbb

Browse files
authored
Merge pull request python-quantities#235 from zm711/numpy-2.0-compat
Add numpy 2.0 compatibility
2 parents b47419a + b6efa33 commit bb96cbb

File tree

9 files changed

+109
-94
lines changed

9 files changed

+109
-94
lines changed

.github/workflows/test.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ jobs:
3434
- python-version: "3.12"
3535
numpy-version: "1.26"
3636
os: ubuntu-latest
37+
- python-version: "3.12"
38+
numpy-version: "2.0"
39+
os: ubuntu-latest
3740
steps:
3841
- uses: actions/checkout@v2
3942

@@ -128,4 +131,4 @@ jobs:
128131
129132
- name: Check type information
130133
run: |
131-
mypy quantities
134+
mypy quantities

quantities/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@
265265
266266
"""
267267

268+
class QuantitiesDeprecationWarning(DeprecationWarning):
269+
pass
268270

269271
from ._version import __version__
270272

quantities/dimensionality.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from .registry import unit_registry
1010
from .decorators import memoize
1111

12+
_np_version = tuple(map(int, np.__version__.split('.')))
13+
1214
def assert_isinstance(obj, types):
1315
try:
1416
assert isinstance(obj, types)
@@ -329,10 +331,11 @@ def _d_copy(q1, out=None):
329331

330332
def _d_clip(a1, a2, a3, q):
331333
return q.dimensionality
332-
try:
334+
335+
if _np_version < (2, 0, 0):
333336
p_dict[np.core.umath.clip] = _d_clip
334-
except AttributeError:
335-
pass # For compatibility with Numpy < 1.17 when clip wasn't a ufunc yet
337+
else:
338+
p_dict[np.clip] = _d_clip
336339

337340
def _d_sqrt(q1, out=None):
338341
return q1._dimensionality**0.5

quantities/quantity.py

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33

44
import copy
55
from functools import wraps
6+
import warnings
67

78
import numpy as np
89

9-
from . import markup
10+
from . import markup, QuantitiesDeprecationWarning
1011
from .dimensionality import Dimensionality, p_dict
1112
from .registry import unit_registry
1213
from .decorators import with_doc
@@ -114,15 +115,19 @@ class Quantity(np.ndarray):
114115
# TODO: what is an appropriate value?
115116
__array_priority__ = 21
116117

117-
def __new__(cls, data, units='', dtype=None, copy=True):
118+
def __new__(cls, data, units='', dtype=None, copy=None):
119+
if copy is not None:
120+
warnings.warn(("The 'copy' argument in Quantity is deprecated and will be removed in the future. "
121+
"The argument has no effect since quantities-0.16.0 (to aid numpy-2.0 support)."),
122+
QuantitiesDeprecationWarning, stacklevel=2)
118123
if isinstance(data, Quantity):
119124
if units:
120125
data = data.rescale(units)
121126
if isinstance(data, unit_registry['UnitQuantity']):
122127
return 1*data
123-
return np.array(data, dtype=dtype, copy=copy, subok=True).view(cls)
128+
return np.asanyarray(data, dtype=dtype).view(cls)
124129

125-
ret = np.array(data, dtype=dtype, copy=copy).view(cls)
130+
ret = np.asarray(data, dtype=dtype).view(cls)
126131
ret._dimensionality.update(validate_dimensionality(units))
127132
return ret
128133

@@ -210,15 +215,17 @@ def rescale(self, units=None, dtype=None):
210215
dtype = self.dtype
211216
if self.dimensionality == to_dims:
212217
return self.astype(dtype)
213-
to_u = Quantity(1.0, to_dims)
214-
from_u = Quantity(1.0, self.dimensionality)
218+
to_u = Quantity(1.0, to_dims, dtype=dtype)
219+
from_u = Quantity(1.0, self.dimensionality, dtype=dtype)
215220
try:
216221
cf = get_conversion_factor(from_u, to_u)
217222
except AssertionError:
218223
raise ValueError(
219224
'Unable to convert between units of "%s" and "%s"'
220225
%(from_u._dimensionality, to_u._dimensionality)
221226
)
227+
if np.dtype(dtype).kind in 'fc':
228+
cf = np.array(cf, dtype=dtype)
222229
new_magnitude = cf*self.magnitude
223230
dtype = np.result_type(dtype, new_magnitude)
224231
return Quantity(new_magnitude, to_u, dtype=dtype)
@@ -272,7 +279,7 @@ def __array_prepare__(self, obj, context=None):
272279
uf, objs, huh = context
273280
if uf.__name__.startswith('is'):
274281
return obj
275-
#print self, obj, res, uf, objs
282+
276283
try:
277284
res._dimensionality = p_dict[uf](*objs)
278285
except KeyError:
@@ -283,11 +290,21 @@ def __array_prepare__(self, obj, context=None):
283290
)
284291
return res
285292

286-
def __array_wrap__(self, obj, context=None):
287-
if not isinstance(obj, Quantity):
288-
# backwards compatibility with numpy-1.3
289-
obj = self.__array_prepare__(obj, context)
290-
return obj
293+
def __array_wrap__(self, obj, context=None, return_scalar=False):
294+
_np_version = tuple(map(int, np.__version__.split('.')))
295+
# For NumPy < 2.0 we do old behavior
296+
if _np_version < (2, 0, 0):
297+
if not isinstance(obj, Quantity):
298+
return self.__array_prepare__(obj, context)
299+
else:
300+
return obj
301+
# For NumPy > 2.0 we either do the prepare or the wrap
302+
else:
303+
if not isinstance(obj, Quantity):
304+
return self.__array_prepare__(obj, context)
305+
else:
306+
return super().__array_wrap__(obj, context, return_scalar)
307+
291308

292309
@with_doc(np.ndarray.__add__)
293310
@scale_other_units
@@ -476,7 +493,7 @@ def sum(self, axis=None, dtype=None, out=None):
476493
ret = self.magnitude.sum(axis, dtype, None if out is None else out.magnitude)
477494
dim = self.dimensionality
478495
if out is None:
479-
return Quantity(ret, dim, copy=False)
496+
return Quantity(ret, dim)
480497
if not isinstance(out, Quantity):
481498
raise TypeError("out parameter must be a Quantity")
482499
out._dimensionality = dim
@@ -487,8 +504,7 @@ def nansum(self, axis=None, dtype=None, out=None):
487504
import numpy as np
488505
return Quantity(
489506
np.nansum(self.magnitude, axis, dtype, out),
490-
self.dimensionality,
491-
copy=False
507+
self.dimensionality
492508
)
493509

494510
@with_doc(np.ndarray.fill)
@@ -523,7 +539,7 @@ def argsort(self, axis=-1, kind='quick', order=None):
523539
@with_doc(np.ndarray.searchsorted)
524540
def searchsorted(self,values, side='left'):
525541
if not isinstance (values, Quantity):
526-
values = Quantity(values, copy=False)
542+
values = Quantity(values)
527543

528544
if values._dimensionality != self._dimensionality:
529545
raise ValueError("values does not have the same units as self")
@@ -539,7 +555,7 @@ def max(self, axis=None, out=None):
539555
ret = self.magnitude.max(axis, None if out is None else out.magnitude)
540556
dim = self.dimensionality
541557
if out is None:
542-
return Quantity(ret, dim, copy=False)
558+
return Quantity(ret, dim)
543559
if not isinstance(out, Quantity):
544560
raise TypeError("out parameter must be a Quantity")
545561
out._dimensionality = dim
@@ -553,16 +569,15 @@ def argmax(self, axis=None, out=None):
553569
def nanmax(self, axis=None, out=None):
554570
return Quantity(
555571
np.nanmax(self.magnitude),
556-
self.dimensionality,
557-
copy=False
572+
self.dimensionality
558573
)
559574

560575
@with_doc(np.ndarray.min)
561576
def min(self, axis=None, out=None):
562577
ret = self.magnitude.min(axis, None if out is None else out.magnitude)
563578
dim = self.dimensionality
564579
if out is None:
565-
return Quantity(ret, dim, copy=False)
580+
return Quantity(ret, dim)
566581
if not isinstance(out, Quantity):
567582
raise TypeError("out parameter must be a Quantity")
568583
out._dimensionality = dim
@@ -572,8 +587,7 @@ def min(self, axis=None, out=None):
572587
def nanmin(self, axis=None, out=None):
573588
return Quantity(
574589
np.nanmin(self.magnitude),
575-
self.dimensionality,
576-
copy=False
590+
self.dimensionality
577591
)
578592

579593
@with_doc(np.ndarray.argmin)
@@ -590,10 +604,10 @@ def nanargmax(self,axis=None, out=None):
590604

591605
@with_doc(np.ndarray.ptp)
592606
def ptp(self, axis=None, out=None):
593-
ret = self.magnitude.ptp(axis, None if out is None else out.magnitude)
607+
ret = np.ptp(self.magnitude, axis, None if out is None else out.magnitude)
594608
dim = self.dimensionality
595609
if out is None:
596-
return Quantity(ret, dim, copy=False)
610+
return Quantity(ret, dim)
597611
if not isinstance(out, Quantity):
598612
raise TypeError("out parameter must be a Quantity")
599613
out._dimensionality = dim
@@ -620,7 +634,7 @@ def clip(self, min=None, max=None, out=None):
620634
)
621635
dim = self.dimensionality
622636
if out is None:
623-
return Quantity(clipped, dim, copy=False)
637+
return Quantity(clipped, dim)
624638
if not isinstance(out, Quantity):
625639
raise TypeError("out parameter must be a Quantity")
626640
out._dimensionality = dim
@@ -631,7 +645,7 @@ def round(self, decimals=0, out=None):
631645
ret = self.magnitude.round(decimals, None if out is None else out.magnitude)
632646
dim = self.dimensionality
633647
if out is None:
634-
return Quantity(ret, dim, copy=False)
648+
return Quantity(ret, dim)
635649
if not isinstance(out, Quantity):
636650
raise TypeError("out parameter must be a Quantity")
637651
out._dimensionality = dim
@@ -642,7 +656,7 @@ def trace(self, offset=0, axis1=0, axis2=1, dtype=None, out=None):
642656
ret = self.magnitude.trace(offset, axis1, axis2, dtype, None if out is None else out.magnitude)
643657
dim = self.dimensionality
644658
if out is None:
645-
return Quantity(ret, dim, copy=False)
659+
return Quantity(ret, dim)
646660
if not isinstance(out, Quantity):
647661
raise TypeError("out parameter must be a Quantity")
648662
out._dimensionality = dim
@@ -652,16 +666,15 @@ def trace(self, offset=0, axis1=0, axis2=1, dtype=None, out=None):
652666
def squeeze(self, axis=None):
653667
return Quantity(
654668
self.magnitude.squeeze(axis),
655-
self.dimensionality,
656-
copy=False
669+
self.dimensionality
657670
)
658671

659672
@with_doc(np.ndarray.mean)
660673
def mean(self, axis=None, dtype=None, out=None):
661674
ret = self.magnitude.mean(axis, dtype, None if out is None else out.magnitude)
662675
dim = self.dimensionality
663676
if out is None:
664-
return Quantity(ret, dim, copy=False)
677+
return Quantity(ret, dim)
665678
if not isinstance(out, Quantity):
666679
raise TypeError("out parameter must be a Quantity")
667680
out._dimensionality = dim
@@ -672,15 +685,14 @@ def nanmean(self, axis=None, dtype=None, out=None):
672685
import numpy as np
673686
return Quantity(
674687
np.nanmean(self.magnitude, axis, dtype, out),
675-
self.dimensionality,
676-
copy=False)
688+
self.dimensionality)
677689

678690
@with_doc(np.ndarray.var)
679691
def var(self, axis=None, dtype=None, out=None, ddof=0):
680692
ret = self.magnitude.var(axis, dtype, out, ddof)
681693
dim = self._dimensionality**2
682694
if out is None:
683-
return Quantity(ret, dim, copy=False)
695+
return Quantity(ret, dim)
684696
if not isinstance(out, Quantity):
685697
raise TypeError("out parameter must be a Quantity")
686698
out._dimensionality = dim
@@ -691,7 +703,7 @@ def std(self, axis=None, dtype=None, out=None, ddof=0):
691703
ret = self.magnitude.std(axis, dtype, out, ddof)
692704
dim = self.dimensionality
693705
if out is None:
694-
return Quantity(ret, dim, copy=False)
706+
return Quantity(ret, dim)
695707
if not isinstance(out, Quantity):
696708
raise TypeError("out parameter must be a Quantity")
697709
out._dimensionality = dim
@@ -701,8 +713,7 @@ def std(self, axis=None, dtype=None, out=None, ddof=0):
701713
def nanstd(self, axis=None, dtype=None, out=None, ddof=0):
702714
return Quantity(
703715
np.nanstd(self.magnitude, axis, dtype, out, ddof),
704-
self._dimensionality,
705-
copy=False
716+
self._dimensionality
706717
)
707718

708719
@with_doc(np.ndarray.prod)
@@ -715,7 +726,7 @@ def prod(self, axis=None, dtype=None, out=None):
715726
ret = self.magnitude.prod(axis, dtype, None if out is None else out.magnitude)
716727
dim = self._dimensionality**power
717728
if out is None:
718-
return Quantity(ret, dim, copy=False)
729+
return Quantity(ret, dim)
719730
if not isinstance(out, Quantity):
720731
raise TypeError("out parameter must be a Quantity")
721732
out._dimensionality = dim
@@ -726,7 +737,7 @@ def cumsum(self, axis=None, dtype=None, out=None):
726737
ret = self.magnitude.cumsum(axis, dtype, None if out is None else out.magnitude)
727738
dim = self.dimensionality
728739
if out is None:
729-
return Quantity(ret, dim, copy=False)
740+
return Quantity(ret, dim)
730741
if not isinstance(out, Quantity):
731742
raise TypeError("out parameter must be a Quantity")
732743
out._dimensionality = dim
@@ -743,7 +754,7 @@ def cumprod(self, axis=None, dtype=None, out=None):
743754
ret = self.magnitude.cumprod(axis, dtype, out)
744755
dim = self.dimensionality
745756
if out is None:
746-
return Quantity(ret, dim, copy=False)
757+
return Quantity(ret, dim)
747758
if isinstance(out, Quantity):
748759
out._dimensionality = dim
749760
return out

quantities/tests/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ def assertQuantityEqual(self, q1, q2, msg=None, delta=None):
1919
Make sure q1 and q2 are the same quantities to within the given
2020
precision.
2121
"""
22-
delta = 1e-5 if delta is None else delta
22+
if delta is None:
23+
# NumPy 2 introduced float16, so we base tolerance on machine epsilon
24+
delta1 = np.finfo(q1.dtype).eps if isinstance(q1, np.ndarray) and q1.dtype.kind in 'fc' else 1e-15
25+
delta2 = np.finfo(q2.dtype).eps if isinstance(q2, np.ndarray) and q2.dtype.kind in 'fc' else 1e-15
26+
delta = max(delta1, delta2)**0.3
2327
msg = '' if msg is None else ' (%s)' % msg
2428

2529
q1 = Quantity(q1)
@@ -28,6 +32,7 @@ def assertQuantityEqual(self, q1, q2, msg=None, delta=None):
2832
raise self.failureException(
2933
f"Shape mismatch ({q1.shape} vs {q2.shape}){msg}"
3034
)
35+
3136
if not np.all(np.abs(q1.magnitude - q2.magnitude) < delta):
3237
raise self.failureException(
3338
"Magnitudes differ by more than %g (%s vs %s)%s"

quantities/tests/test_arithmetic.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,32 +48,9 @@ def check(f, *args, **kwargs):
4848
return (new, )
4949

5050

51-
class iter_dtypes:
52-
53-
def __init__(self):
54-
self._i = 1
55-
self._typeDict = np.sctypeDict.copy()
56-
self._typeDict[17] = int
57-
self._typeDict[18] = long
58-
self._typeDict[19] = float
59-
self._typeDict[20] = complex
60-
61-
def __iter__(self):
62-
return self
63-
64-
def __next__(self):
65-
if self._i > 20:
66-
raise StopIteration
67-
68-
i = self._i
69-
self._i += 1
70-
return self._typeDict[i]
71-
72-
def next(self):
73-
return self.__next__()
74-
7551
def get_dtypes():
76-
return list(iter_dtypes())
52+
numeric_dtypes = 'iufc' # https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html
53+
return [v for v in np.sctypeDict.values() if np.dtype(v).kind in numeric_dtypes] + [int, long, float, complex]
7754

7855

7956
class iter_types:

0 commit comments

Comments
 (0)