Skip to content

Commit 0882871

Browse files
committed
Merge pull request python-quantities#32 from phippo/master
allows Quantity to be subclassed
2 parents c2df012 + 8e5504e commit 0882871

File tree

2 files changed

+43
-14
lines changed

2 files changed

+43
-14
lines changed

quantities/quantity.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def simplified(self):
152152

153153
@property
154154
def units(self):
155-
return Quantity(1.0, (self.dimensionality))
155+
return self.__class__(1.0, (self.dimensionality))
156156
@units.setter
157157
def units(self, units):
158158
try:
@@ -195,7 +195,7 @@ def rescale(self, units):
195195
'Unable to convert between units of "%s" and "%s"'
196196
%(from_u._dimensionality, to_u._dimensionality)
197197
)
198-
return Quantity(cf*self.magnitude, to_u)
198+
return self.__class__(cf*self.magnitude, to_u)
199199

200200
@with_doc(np.ndarray.astype)
201201
def astype(self, dtype=None):
@@ -349,7 +349,7 @@ def __getitem__(self, key):
349349
if isinstance(ret, Quantity):
350350
return ret
351351
else:
352-
return Quantity(ret, self._dimensionality)
352+
return self.__class__(ret, self._dimensionality)
353353

354354
@with_doc(np.ndarray.__setitem__)
355355
def __setitem__(self, key, value):
@@ -422,7 +422,7 @@ def _tolist(self, work_list):
422422

423423
@with_doc(np.ndarray.sum)
424424
def sum(self, axis=None, dtype=None, out=None):
425-
return Quantity(
425+
return self.__class__(
426426
self.magnitude.sum(axis, dtype, out),
427427
self.dimensionality,
428428
copy=False
@@ -471,15 +471,15 @@ def nonzero(self):
471471

472472
@with_doc(np.ndarray.max)
473473
def max(self, axis=None, out=None):
474-
return Quantity(
474+
return self.__class__(
475475
self.magnitude.max(),
476476
self.dimensionality,
477477
copy=False
478478
)
479479

480480
@with_doc(np.ndarray.min)
481481
def min(self, axis=None, out=None):
482-
return Quantity(
482+
return self.__class__(
483483
self.magnitude.min(),
484484
self.dimensionality,
485485
copy=False
@@ -491,7 +491,7 @@ def argmin(self,axis=None, out=None):
491491

492492
@with_doc(np.ndarray.ptp)
493493
def ptp(self, axis=None, out=None):
494-
return Quantity(
494+
return self.__class__(
495495
self.magnitude.ptp(),
496496
self.dimensionality,
497497
copy=False
@@ -516,42 +516,42 @@ def clip(self, min=None, max=None, out=None):
516516
max.rescale(self._dimensionality).magnitude,
517517
out
518518
)
519-
return Quantity(clipped, self.dimensionality, copy=False)
519+
return self.__class__(clipped, self.dimensionality, copy=False)
520520

521521
@with_doc(np.ndarray.round)
522522
def round(self, decimals=0, out=None):
523-
return Quantity(
523+
return self.__class__(
524524
self.magnitude.round(decimals, out),
525525
self.dimensionality,
526526
copy=False
527527
)
528528

529529
@with_doc(np.ndarray.trace)
530530
def trace(self, offset=0, axis1=0, axis2=1, dtype=None, out=None):
531-
return Quantity(
531+
return self.__class__(
532532
self.magnitude.trace(offset, axis1, axis2, dtype, out),
533533
self.dimensionality,
534534
copy=False
535535
)
536536

537537
@with_doc(np.ndarray.mean)
538538
def mean(self, axis=None, dtype=None, out=None):
539-
return Quantity(
539+
return self.__class__(
540540
self.magnitude.mean(axis, dtype, out),
541541
self.dimensionality,
542542
copy=False)
543543

544544
@with_doc(np.ndarray.var)
545545
def var(self, axis=None, dtype=None, out=None, ddof=0):
546-
return Quantity(
546+
return self.__class__(
547547
self.magnitude.var(axis, dtype, out, ddof),
548548
self._dimensionality**2,
549549
copy=False
550550
)
551551

552552
@with_doc(np.ndarray.std)
553553
def std(self, axis=None, dtype=None, out=None, ddof=0):
554-
return Quantity(
554+
return self.__class__(
555555
self.magnitude.std(axis, dtype, out, ddof),
556556
self._dimensionality,
557557
copy=False
@@ -564,7 +564,7 @@ def prod(self, axis=None, dtype=None, out=None):
564564
else:
565565
power = self.shape[axis]
566566

567-
return Quantity(
567+
return self.__class__(
568568
self.magnitude.prod(axis, dtype, out),
569569
self._dimensionality**power,
570570
copy=False
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from ..quantity import Quantity
4+
from .test_methods import TestQuantityMethods, TestCase
5+
6+
class ChildQuantity(Quantity):
7+
def __new__(cls, data, units='', dtype=None, copy=True):
8+
obj = Quantity.__new__(cls, data, units, dtype, copy).view(cls)
9+
return obj
10+
11+
class TestQuantityInheritance(TestCase):
12+
13+
def setUp(self):
14+
self.cq = ChildQuantity([1,5], '')
15+
16+
def test_resulting_type(self):
17+
self.assertTrue (isinstance(self.cq, ChildQuantity))
18+
self.assertTrue (isinstance(self.cq + self.cq, ChildQuantity))
19+
self.assertTrue (isinstance(self.cq * self.cq, ChildQuantity))
20+
self.assertTrue (isinstance(self.cq / self.cq, ChildQuantity))
21+
self.assertTrue (isinstance(self.cq - self.cq, ChildQuantity))
22+
self.assertTrue (isinstance(self.cq.max(), ChildQuantity))
23+
self.assertTrue (isinstance(self.cq.min(), ChildQuantity))
24+
self.assertTrue (isinstance(self.cq.mean(), ChildQuantity))
25+
self.assertTrue (isinstance(self.cq.var(), ChildQuantity))
26+
self.assertTrue (isinstance(self.cq.std(), ChildQuantity))
27+
self.assertTrue (isinstance(self.cq.prod(), ChildQuantity))
28+
self.assertTrue (isinstance(self.cq.cumsum(), ChildQuantity))
29+
self.assertTrue (isinstance(self.cq.cumprod(), ChildQuantity))

0 commit comments

Comments
 (0)