Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions quantities/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def simplified(self):

@property
def units(self):
return Quantity(1.0, (self.dimensionality))
return self.__class__(1.0, (self.dimensionality))
@units.setter
def units(self, units):
try:
Expand Down Expand Up @@ -195,7 +195,7 @@ def rescale(self, units):
'Unable to convert between units of "%s" and "%s"'
%(from_u._dimensionality, to_u._dimensionality)
)
return Quantity(cf*self.magnitude, to_u)
return self.__class__(cf*self.magnitude, to_u)

@with_doc(np.ndarray.astype)
def astype(self, dtype=None):
Expand Down Expand Up @@ -349,7 +349,7 @@ def __getitem__(self, key):
if isinstance(ret, Quantity):
return ret
else:
return Quantity(ret, self._dimensionality)
return self.__class__(ret, self._dimensionality)

@with_doc(np.ndarray.__setitem__)
def __setitem__(self, key, value):
Expand Down Expand Up @@ -422,7 +422,7 @@ def _tolist(self, work_list):

@with_doc(np.ndarray.sum)
def sum(self, axis=None, dtype=None, out=None):
return Quantity(
return self.__class__(
self.magnitude.sum(axis, dtype, out),
self.dimensionality,
copy=False
Expand Down Expand Up @@ -471,15 +471,15 @@ def nonzero(self):

@with_doc(np.ndarray.max)
def max(self, axis=None, out=None):
return Quantity(
return self.__class__(
self.magnitude.max(),
self.dimensionality,
copy=False
)

@with_doc(np.ndarray.min)
def min(self, axis=None, out=None):
return Quantity(
return self.__class__(
self.magnitude.min(),
self.dimensionality,
copy=False
Expand All @@ -491,7 +491,7 @@ def argmin(self,axis=None, out=None):

@with_doc(np.ndarray.ptp)
def ptp(self, axis=None, out=None):
return Quantity(
return self.__class__(
self.magnitude.ptp(),
self.dimensionality,
copy=False
Expand All @@ -516,42 +516,42 @@ def clip(self, min=None, max=None, out=None):
max.rescale(self._dimensionality).magnitude,
out
)
return Quantity(clipped, self.dimensionality, copy=False)
return self.__class__(clipped, self.dimensionality, copy=False)

@with_doc(np.ndarray.round)
def round(self, decimals=0, out=None):
return Quantity(
return self.__class__(
self.magnitude.round(decimals, out),
self.dimensionality,
copy=False
)

@with_doc(np.ndarray.trace)
def trace(self, offset=0, axis1=0, axis2=1, dtype=None, out=None):
return Quantity(
return self.__class__(
self.magnitude.trace(offset, axis1, axis2, dtype, out),
self.dimensionality,
copy=False
)

@with_doc(np.ndarray.mean)
def mean(self, axis=None, dtype=None, out=None):
return Quantity(
return self.__class__(
self.magnitude.mean(axis, dtype, out),
self.dimensionality,
copy=False)

@with_doc(np.ndarray.var)
def var(self, axis=None, dtype=None, out=None, ddof=0):
return Quantity(
return self.__class__(
self.magnitude.var(axis, dtype, out, ddof),
self._dimensionality**2,
copy=False
)

@with_doc(np.ndarray.std)
def std(self, axis=None, dtype=None, out=None, ddof=0):
return Quantity(
return self.__class__(
self.magnitude.std(axis, dtype, out, ddof),
self._dimensionality,
copy=False
Expand All @@ -564,7 +564,7 @@ def prod(self, axis=None, dtype=None, out=None):
else:
power = self.shape[axis]

return Quantity(
return self.__class__(
self.magnitude.prod(axis, dtype, out),
self._dimensionality**power,
copy=False
Expand Down
29 changes: 29 additions & 0 deletions quantities/tests/test_inheritance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-

from ..quantity import Quantity
from .test_methods import TestQuantityMethods, TestCase

class ChildQuantity(Quantity):
def __new__(cls, data, units='', dtype=None, copy=True):
obj = Quantity.__new__(cls, data, units, dtype, copy).view(cls)
return obj

class TestQuantityInheritance(TestCase):

def setUp(self):
self.cq = ChildQuantity([1,5], '')

def test_resulting_type(self):
self.assertTrue (isinstance(self.cq, ChildQuantity))
self.assertTrue (isinstance(self.cq + self.cq, ChildQuantity))
self.assertTrue (isinstance(self.cq * self.cq, ChildQuantity))
self.assertTrue (isinstance(self.cq / self.cq, ChildQuantity))
self.assertTrue (isinstance(self.cq - self.cq, ChildQuantity))
self.assertTrue (isinstance(self.cq.max(), ChildQuantity))
self.assertTrue (isinstance(self.cq.min(), ChildQuantity))
self.assertTrue (isinstance(self.cq.mean(), ChildQuantity))
self.assertTrue (isinstance(self.cq.var(), ChildQuantity))
self.assertTrue (isinstance(self.cq.std(), ChildQuantity))
self.assertTrue (isinstance(self.cq.prod(), ChildQuantity))
self.assertTrue (isinstance(self.cq.cumsum(), ChildQuantity))
self.assertTrue (isinstance(self.cq.cumprod(), ChildQuantity))