forked from scverse/spatialdata-plot
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbasic.py
More file actions
170 lines (139 loc) · 5.52 KB
/
basic.py
File metadata and controls
170 lines (139 loc) · 5.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from typing import Callable, List, Union
import spatialdata as sd
from anndata import AnnData
from ..accessor import register_spatial_data_accessor
from .colorize import _colorize
from .render import _render_label
from .utils import _get_listed_colormap
@register_spatial_data_accessor("pp")
class PreprocessingAccessor:
def __init__(self, sdata):
self._sdata = sdata
def _copy(
self,
images: Union[None, dict] = None,
labels: Union[None, dict] = None,
points: Union[None, dict] = None,
polygons: Union[None, dict] = None,
shapes: Union[None, dict] = None,
table: Union[dict, AnnData] = None,
) -> sd.SpatialData:
"""
Helper function to copies the references from the original SpatialData
object to the subsetted SpatialData object.
"""
return sd.SpatialData(
images=self._sdata.images if images is None else images,
labels=self._sdata.labels if labels is None else labels,
points=self._sdata.points if points is None else points,
polygons=self._sdata.polygons if polygons is None else polygons,
shapes=self._sdata.shapes if shapes is None else shapes,
table=self._sdata.table if table is None else table,
)
def get_bb(self, x: Union[slice, list, tuple], y: Union[slice, list, tuple]) -> sd.SpatialData:
"""Get bounding box around a point.
Parameters
----------
x : Union[slice, list, tuple]
x range of the bounding box
y : Union[slice, list, tuple]
y range of the bounding box
Returns
-------
sd.SpatialData
subsetted SpatialData object
"""
# TODO: add support for list and tuple inputs ? (currently only slice is supported)
selection = dict(x=x, y=y) # makes use of xarray sel method
# TODO: error handling if selection is out of bounds
cropped_images = {key: img.sel(selection) for key, img in self._sdata.images.items()}
cropped_labels = {key: img.sel(selection) for key, img in self._sdata.labels.items()}
sdata = self._copy(
images=cropped_images,
labels=cropped_labels,
)
return sdata
def get_images(self, keys: Union[list, str], label_func: Callable = lambda x: x) -> sd.SpatialData:
"""Get images from a list of keys.
Parameters
----------
keys : list
list of keys to select
Returns
-------
sd.SpatialData
subsetted SpatialData object
"""
# TODO: error handling if keys are not in images
if isinstance(keys, str):
keys = [keys]
selected_images = {key: img for key, img in self._sdata.images.items() if key in keys}
# TODO: how to handle labels ? there might be multiple labels per image (e.g. nuclei and cell segmentation masks)
selected_labels = {key: img for key, img in self._sdata.labels.items() if label_func(key) in keys}
return self._copy(images=selected_images, labels=selected_labels)
def get_channels(self, keys: Union[list, slice]) -> sd.SpatialData:
"""Get channels from a list of keys.
Parameters
----------
keys : list
list of keys to select
Returns
-------
sd.SpatialData
subsetted SpatialData object
"""
selection = dict(c=keys)
# TODO: error handling if selection is out of bounds
channels_images = {key: img.sel(selection) for key, img in self._sdata.images.items()}
return self._copy(images=channels_images)
def colorize(
self,
colors: List[str] = ["C0", "C1", "C2", "C3"],
background: str = "black",
normalize: bool = True,
merge=True,
) -> sd.SpatialData:
"""Colorizes a stack of images.
Parameters
----------
colors: List[str]
A list of strings that denote the color of each channel.
background: float
Background color of the colorized image.
normalize: bool
Normalizes the image prior to colorizing it.
merge: True
Merge the channel dimension.
Returns
-------
xr.Dataset
The image container with the colorized image stored in Layers.PLOT.
"""
rendered = {}
for key, img in self._sdata.images.items():
colored_image = _colorize(
img,
colors=colors,
background=background,
normalize=normalize,
).sum(0)
rendered[key] = sd.Image2DModel.parse(colored_image.swapaxes(0, 2))
return self._copy(images=rendered)
def render_labels(self, alpha=0, alpha_boundary=1, mode="inner", label_func=lambda x: x):
color_dict = {1: "white"}
cmap = _get_listed_colormap(color_dict)
# mask = _label_segmentation_mask(segmentation, cells_dict)
rendered = {}
for key, img in self._sdata.images.items():
labels = self._sdata.labels[label_func(key)]
rendered_image = _render_label(
labels.values,
cmap,
img.values.T,
alpha=alpha,
alpha_boundary=alpha_boundary,
mode=mode,
)
# print(rendered.swapaxes(0, 2).shape)
rendered[key] = sd.Image2DModel.parse(rendered_image.swapaxes(0, 2))
return self._copy(images=rendered)