Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 33 additions & 55 deletions lib/matplotlib/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2804,7 +2804,8 @@ def halfrange(self, halfrange):
self.vmax = self.vcenter + abs(halfrange)


def make_norm_from_scale(scale_cls, base_norm_cls=None, *, init=None):
def make_norm_from_scale(scale_cls, base_norm_cls=None, *, init=None,
norm_before_trf=False):
"""
Decorator for building a `.Normalize` subclass from a `~.scale.ScaleBase`
subclass.
Expand Down Expand Up @@ -2836,7 +2837,8 @@ class norm_cls(Normalize):
"""

if base_norm_cls is None:
return functools.partial(make_norm_from_scale, scale_cls, init=init)
return functools.partial(make_norm_from_scale, scale_cls, init=init,
norm_before_trf=norm_before_trf)

if isinstance(scale_cls, functools.partial):
scale_args = scale_cls.args
Expand All @@ -2850,13 +2852,13 @@ def init(vmin=None, vmax=None, clip=False): pass

return _make_norm_from_scale(
scale_cls, scale_args, scale_kwargs_items,
base_norm_cls, inspect.signature(init))
base_norm_cls, inspect.signature(init), norm_before_trf)


@functools.cache
def _make_norm_from_scale(
scale_cls, scale_args, scale_kwargs_items,
base_norm_cls, bound_init_signature,
base_norm_cls, bound_init_signature, norm_before_trf
):
"""
Helper for `make_norm_from_scale`.
Expand Down Expand Up @@ -2888,7 +2890,7 @@ def __reduce__(self):
pass
return (_picklable_norm_constructor,
(scale_cls, scale_args, scale_kwargs_items,
base_norm_cls, bound_init_signature),
base_norm_cls, bound_init_signature, norm_before_trf),
vars(self))

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -2917,6 +2919,14 @@ def __call__(self, value, clip=None):
clip = self.clip
if clip:
value = np.clip(value, self.vmin, self.vmax)

if norm_before_trf:
value -= self.vmin
value /= (self.vmax - self.vmin)
t_value = self._trf.transform(value).reshape(np.shape(value))
t_value = np.ma.masked_invalid(t_value, copy=False)
return t_value[0] if is_scalar else t_value

t_value = self._trf.transform(value).reshape(np.shape(value))
t_vmin, t_vmax = self._trf.transform([self.vmin, self.vmax])
if not np.isfinite([t_vmin, t_vmax]).all():
Expand All @@ -2931,10 +2941,17 @@ def inverse(self, value):
raise ValueError("Not invertible until scaled")
if self.vmin > self.vmax:
raise ValueError("vmin must be less or equal to vmax")
value, is_scalar = self.process_value(value)

if norm_before_trf:
value = self._trf.inverted().transform(value).reshape(np.shape(value))
rescaled = value * (self.vmax - self.vmin)
rescaled += self.vmin
return rescaled[0] if is_scalar else rescaled

t_vmin, t_vmax = self._trf.transform([self.vmin, self.vmax])
if not np.isfinite([t_vmin, t_vmax]).all():
raise ValueError("Invalid vmin or vmax")
value, is_scalar = self.process_value(value)
rescaled = value * (t_vmax - t_vmin)
rescaled += t_vmin
value = (self._trf
Expand Down Expand Up @@ -3083,6 +3100,10 @@ def linear_width(self, value):
self._scale.linear_width = value


@make_norm_from_scale(
scale.PowerScale,
init=lambda gamma=0.5, vmin=None, vmax=None, clip=False: None,
norm_before_trf=True)
class PowerNorm(Normalize):
r"""
Linearly map a given value to the 0-1 range and then apply
Expand Down Expand Up @@ -3119,56 +3140,13 @@ class PowerNorm(Normalize):

For input values below *vmin*, gamma is set to one.
"""
def __init__(self, gamma, vmin=None, vmax=None, clip=False):
super().__init__(vmin, vmax, clip)
self.gamma = gamma

def __call__(self, value, clip=None):
if clip is None:
clip = self.clip

result, is_scalar = self.process_value(value)

self.autoscale_None(result)
gamma = self.gamma
vmin, vmax = self.vmin, self.vmax
if vmin > vmax:
raise ValueError("minvalue must be less than or equal to maxvalue")
elif vmin == vmax:
result.fill(0)
else:
if clip:
mask = np.ma.getmask(result)
result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax),
mask=mask)
resdat = result.data
resdat -= vmin
resdat /= (vmax - vmin)
resdat[resdat > 0] = np.power(resdat[resdat > 0], gamma)

result = np.ma.array(resdat, mask=result.mask, copy=False)
if is_scalar:
result = result[0]
return result

def inverse(self, value):
if not self.scaled():
raise ValueError("Not invertible until scaled")

result, is_scalar = self.process_value(value)

gamma = self.gamma
vmin, vmax = self.vmin, self.vmax

resdat = result.data
resdat[resdat > 0] = np.power(resdat[resdat > 0], 1 / gamma)
resdat *= (vmax - vmin)
resdat += vmin
@property
def gamma(self):
return self._scale.gamma

result = np.ma.array(resdat, mask=result.mask, copy=False)
if is_scalar:
result = result[0]
return result
@gamma.setter
def gamma(self, value):
self._scale.gamma = value


class BoundaryNorm(Normalize):
Expand Down
13 changes: 9 additions & 4 deletions lib/matplotlib/colors.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -339,14 +339,16 @@ def make_norm_from_scale(
scale_cls: type[scale.ScaleBase],
base_norm_cls: type[Normalize],
*,
init: Callable | None = ...
init: Callable | None = ...,
norm_before_trf: bool = ...,
) -> type[Normalize]: ...
@overload
def make_norm_from_scale(
scale_cls: type[scale.ScaleBase],
base_norm_cls: None = ...,
*,
init: Callable | None = ...
init: Callable | None = ...,
norm_before_trf: bool = ...,
) -> Callable[[type[Normalize]], type[Normalize]]: ...

class FuncNorm(Normalize):
Expand Down Expand Up @@ -389,14 +391,17 @@ class AsinhNorm(Normalize):
def linear_width(self, value: float) -> None: ...

class PowerNorm(Normalize):
gamma: float
def __init__(
self,
gamma: float,
gamma: float = ...,
vmin: float | None = ...,
vmax: float | None = ...,
clip: bool = ...,
) -> None: ...
@property
def gamma(self) -> float: ...
@gamma.setter
def gamma(self, value: float) -> None: ...

class BoundaryNorm(Normalize):
boundaries: np.ndarray
Expand Down
107 changes: 107 additions & 0 deletions lib/matplotlib/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"log" `LogScale` `LogTransform` `InvertedLogTransform`
"logit" `LogitScale` `LogitTransform` `LogisticTransform`
"symlog" `SymmetricalLogScale` `SymmetricalLogTransform` `InvertedSymmetricalLogTransform`
"power" `PowerScale` `PowerTransform` `InvertedPowerTransform`
============= ===================== ================================ =================================

A user will often only use the scale name, e.g. when setting the scale through
Expand Down Expand Up @@ -282,6 +283,111 @@ def set_default_locators_and_formatters(self, axis):
axis.set_minor_locator(NullLocator())


class PowerTransform(Transform):
"""
A simple power transformation used by `.PowerScale`.

This transformation applies a power-law scaling to positive values, while
nonpositive values remain unchanged.
"""
input_dims = output_dims = 1

def __init__(self, gamma):
"""
Parameters
----------
gamma : float
Power law exponent.
"""
super().__init__()
self.gamma = gamma

def __str__(self):
return "{}(gamma={})".format(
type(self).__name__, self.gamma)

def transform_non_affine(self, a):
with np.errstate(divide="ignore", invalid="ignore"):
mask = np.ma.getmask(a)
d = np.asarray(a.data)
out = np.where(d > 0, np.power(d, self.gamma), d)
mout = np.ma.masked_array(out, mask=mask)
return mout

def inverted(self):
return InvertedPowerTransform(self.gamma)


class InvertedPowerTransform(Transform):
"""
The inverse of the `.PowerTransform`.

This transformation applies an inverse power-law scaling to positive values,
while nonpositive values remain unchanged.
"""
input_dims = output_dims = 1

def __init__(self, gamma):
"""
Parameters
----------
gamma : float
Power law exponent.
"""
super().__init__()
if gamma == 0:
raise ValueError('gamma cannot be 0')
self.gamma = gamma

def transform_non_affine(self, a):
with np.errstate(divide="ignore", invalid="ignore"):
input_mask = np.ma.getmask(a)
d = np.asarray(a.data)
out = np.where(d > 0, np.power(d, 1. / self.gamma), d)
mout = np.ma.array(out, mask=input_mask)
return mout

def inverted(self):
return PowerTransform(self.gamma)


class PowerScale(ScaleBase):
"""
A standard power scale
"""
name = 'power'

@_make_axis_parameter_optional
def __init__(self, axis=None, *, gamma=0.5):
"""
Parameters
----------
axis : `~matplotlib.axis.Axis`
The axis for the scale.
gamma : float, default: 0.5
Power law exponent.
"""
self._transform = PowerTransform(gamma)

gamma = property(lambda self: self._transform.gamma)

def get_transform(self):
"""Return the `.PowerTransform` associated with this scale."""
return self._transform

def set_default_locators_and_formatters(self, axis):
# docstring inherited
axis.set_major_locator(AutoLocator())
axis.set_major_formatter(ScalarFormatter())
axis.set_minor_formatter(NullFormatter())
# update the minor locator for x and y axis based on rcParams
if (axis.axis_name == 'x' and mpl.rcParams['xtick.minor.visible'] or
axis.axis_name == 'y' and mpl.rcParams['ytick.minor.visible']):
axis.set_minor_locator(AutoMinorLocator())
else:
axis.set_minor_locator(NullLocator())


class LogTransform(Transform):
input_dims = output_dims = 1

Expand Down Expand Up @@ -807,6 +913,7 @@ def limit_range_for_scale(self, vmin, vmax, minpos):
'logit': LogitScale,
'function': FuncScale,
'functionlog': FuncScaleLog,
'power': PowerScale,
}

# caching of signature info
Expand Down
23 changes: 23 additions & 0 deletions lib/matplotlib/scale.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,29 @@ class FuncScale(ScaleBase):
],
) -> None: ...

class PowerTransform(Transform):
def __init__(self, gamma: float) -> None: ...
def __str__(self) -> str: ...
def transform_non_affine(self, a: ArrayLike) -> ArrayLike: ...
def inverted(self) -> InvertedPowerTransform: ...

class InvertedPowerTransform(Transform):
def __init__(self, gamma: float) -> None: ...
def transform_non_affine(self, a: ArrayLike) -> ArrayLike: ...
def inverted(self) -> PowerTransform: ...

class PowerScale(ScaleBase):
name: str
def __init__(
self,
axis: Axis | None = ...,
*,
gamma: float = ...,
) -> None: ...
@property
def gamma(self) -> float: ...
def get_transform(self) -> PowerTransform: ...

class LogTransform(Transform):
input_dims: int
output_dims: int
Expand Down
Loading