forked from grst/spatialdata-plot
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbasic.py
More file actions
299 lines (245 loc) · 11.5 KB
/
basic.py
File metadata and controls
299 lines (245 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
from collections import OrderedDict
from typing import Union
import spatialdata as sd
from anndata import AnnData
from dask.dataframe.core import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
from spatial_image import SpatialImage
from spatialdata_plot._accessor import register_spatial_data_accessor
from spatialdata_plot.pp.utils import (
_get_coordinate_system_mapping,
_get_region_key,
_verify_plotting_tree,
)
# from .colorize import _colorize
@register_spatial_data_accessor("pp")
class PreprocessingAccessor:
"""
Preprocessing functions for SpatialData objects.
Parameters
----------
sdata :
A spatial data object.
"""
@property
def sdata(self) -> sd.SpatialData:
"""The `SpatialData` object to provide preprocessing functions for."""
return self._sdata
@sdata.setter
def sdata(self, sdata: sd.SpatialData) -> None:
self._sdata = sdata
def __init__(self, sdata: sd.SpatialData) -> None:
self._sdata = sdata
def _copy(
self,
images: Union[None, dict[str, Union[SpatialImage, MultiscaleSpatialImage]]] = None,
labels: Union[None, dict[str, Union[SpatialImage, MultiscaleSpatialImage]]] = None,
points: Union[None, dict[str, DaskDataFrame]] = None,
shapes: Union[None, dict[str, GeoDataFrame]] = None,
table: Union[None, AnnData] = None,
) -> sd.SpatialData:
"""Copies the references from the original to the new SpatialData object."""
sdata = sd.SpatialData(
images=self._sdata.images if images is None else images,
labels=self._sdata.labels if labels is None else labels,
points=self._sdata.points if points is None else points,
shapes=self._sdata.shapes if shapes is None else shapes,
table=self._sdata.table if table is None else table,
)
sdata.plotting_tree = self._sdata.plotting_tree if hasattr(self._sdata, "plotting_tree") else OrderedDict()
return sdata
def _verify_plotting_tree_exists(self) -> None:
if not hasattr(self._sdata, "plotting_tree"):
self._sdata.plotting_tree = OrderedDict()
def get_elements(self, elements: Union[str, list[str]]) -> sd.SpatialData:
"""
Get a subset of the spatial data object by specifying elements to keep.
Parameters
----------
elements :
A string or a list of strings specifying the elements to keep.
Valid element types are:
- 'coordinate_systems'
- 'images'
- 'labels'
- 'shapes'
Returns
-------
sd.SpatialData
A new spatial data object containing only the specified elements.
Raises
------
TypeError
If `elements` is not a string or a list of strings.
If `elements` is a list of strings but one or more of the strings
are not valid element types.
ValueError
If any of the specified elements is not present in the original
spatialdata object.
AssertionError
If `label_keys` is not an empty list but the spatial data object
does not have a table or the table does not have 'uns' or 'obs'
attributes.
Notes
-----
If the original spatialdata object has a table, and `elements`
includes label keys, the returned spatialdata object will have a
subset of the original table with only the rows corresponding to the
specified label keys. The `region` attribute of the returned spatial
data object's table will be set to the list of specified label keys.
If the original spatial data object has no table, or if `elements` does
not include label keys, the returned spatialdata object will have no
table.
"""
if not isinstance(elements, (str, list)):
raise TypeError("Parameter 'elements' must be a string or a list of strings.")
if not all([isinstance(e, str) for e in elements]):
raise TypeError("When parameter 'elements' is a list, all elements must be strings.")
if isinstance(elements, str):
elements = [elements]
coord_keys = []
image_keys = []
label_keys = []
shape_keys = []
# prepare list of valid keys to sort elements on
valid_coord_keys = self._sdata.coordinate_systems if hasattr(self._sdata, "coordinate_systems") else None
valid_image_keys = list(self._sdata.images.keys()) if hasattr(self._sdata, "images") else None
valid_label_keys = list(self._sdata.labels.keys()) if hasattr(self._sdata, "labels") else None
valid_shape_keys = list(self._sdata.shapes.keys()) if hasattr(self._sdata, "shapes") else None
# for key_dict in [coord_keys, image_keys, label_keys, shape_keys]:
# key_dict = []
# key_dict["implicit"] = []
# first, extract coordinate system keys becasuse they generate implicit keys
mapping = _get_coordinate_system_mapping(self._sdata)
implicit_keys = []
for e in elements:
if (valid_coord_keys is not None) and (e in valid_coord_keys):
coord_keys.append(e)
implicit_keys += mapping[e]
for e in elements + implicit_keys:
if (valid_coord_keys is not None) and (e in valid_coord_keys):
coord_keys.append(e)
elif (valid_image_keys is not None) and (e in valid_image_keys):
image_keys.append(e)
elif (valid_label_keys is not None) and (e in valid_label_keys):
label_keys.append(e)
elif (valid_shape_keys is not None) and (e in valid_shape_keys):
shape_keys.append(e)
else:
msg = f"Element '{e}' not found. Valid choices are:"
if valid_coord_keys is not None:
msg += "\n\ncoordinate_systems\n├ "
msg += "\n├ ".join(valid_coord_keys)
if valid_image_keys is not None:
msg += "\n\nimages\n├ "
msg += "\n├ ".join(valid_image_keys)
if valid_label_keys is not None:
msg += "\n\nlabels\n├ "
msg += "\n├ ".join(valid_label_keys)
if valid_shape_keys is not None:
msg += "\n\nshapes\n├ "
msg += "\n├ ".join(valid_shape_keys)
raise ValueError(msg)
# copy that we hard-modify
sdata = self._copy()
if (valid_coord_keys is not None) and (len(coord_keys) > 0):
sdata = sdata.filter_by_coordinate_system(coord_keys)
elif len(coord_keys) == 0:
if valid_image_keys is not None:
if len(image_keys) == 0:
for valid_image_key in valid_image_keys:
del sdata.images[valid_image_key]
elif len(image_keys) > 0:
for valid_image_key in valid_image_keys:
if valid_image_key not in image_keys:
del sdata.images[valid_image_key]
if valid_label_keys is not None:
if len(label_keys) == 0:
for valid_label_key in valid_label_keys:
del sdata.labels[valid_label_key]
elif len(label_keys) > 0:
for valid_label_key in valid_label_keys:
if valid_label_key not in label_keys:
del sdata.labels[valid_label_key]
if valid_shape_keys is not None:
if len(shape_keys) == 0:
for valid_shape_key in valid_shape_keys:
del sdata.shapes[valid_shape_key]
elif len(shape_keys) > 0:
for valid_shape_key in valid_shape_keys:
if valid_shape_key not in shape_keys:
del sdata.shapes[valid_shape_key]
# subset table if it is present and the region key is a valid column
if sdata.table is not None and len(shape_keys + label_keys) > 0:
assert hasattr(sdata, "table"), "SpatialData object does not have a table."
assert hasattr(sdata.table, "uns"), "Table in SpatialData object does not have 'uns'."
assert hasattr(sdata.table, "obs"), "Table in SpatialData object does not have 'obs'."
# create mask of used keys
mask = sdata.table.obs[_get_region_key(sdata)]
mask = list(mask.str.contains("|".join(shape_keys + label_keys)))
# create copy and delete original so we can reuse slot
old_table = sdata.table.copy()
new_table = old_table[mask, :].copy()
new_table.uns["spatialdata_attrs"]["region"] = list(set(new_table.obs[_get_region_key(sdata)]))
del sdata.table
sdata.table = new_table
else:
del sdata.table
return sdata
def get_bb(
self,
x: Union[slice, list[int], tuple[int, int]] = (0, 0),
y: Union[slice, list[int], tuple[int, int]] = (0, 0),
) -> sd.SpatialData:
"""Get bounding box around a point.
Parameters
----------
x :
x range of the bounding box. Stepsize will be ignored if slice
y :
y range of the bounding box. Stepsize will be ignored if slice
Returns
-------
sd.SpatialData
subsetted SpatialData object
"""
if not isinstance(x, (slice, list, tuple)):
raise TypeError("Parameter 'x' must be one of 'slice', 'list', 'tuple'.")
if isinstance(x, (list, tuple)) and len(x) == 2:
if x[1] <= x[0]:
raise ValueError("The current choice of 'x' would result in an empty slice.")
x = slice(x[0], x[1])
elif isinstance(x, slice):
if x.stop <= x.start:
raise ValueError("The current choice of 'x' would result in an empty slice.")
else:
raise ValueError("Parameter 'x' must be of length 2.")
if not isinstance(y, (slice, list, tuple)):
raise TypeError("Parameter 'y' must be one of 'slice', 'list', 'tuple'.")
if isinstance(y, (list, tuple)):
if len(y) != 2:
raise ValueError("Parameter 'y' must be of length 2.")
if y[1] <= y[0]:
raise ValueError("The current choice of 'y' would result in an empty slice.")
# y is clean
y = slice(y[0], y[1])
elif isinstance(y, slice):
if y.stop <= y.start:
raise ValueError("The current choice of 'x' would result in an empty slice.")
selection = {"x": x, "y": y} # makes use of xarray sel method
# TODO: error handling if selection is out of bounds
cropped_images = {key: img.sel(selection) for key, img in self._sdata.images.items()}
cropped_labels = {key: img.sel(selection) for key, img in self._sdata.labels.items()}
sdata = self._copy(
images=cropped_images,
labels=cropped_labels,
)
self._sdata = _verify_plotting_tree(self._sdata)
# get current number of steps to create a unique key
n_steps = len(self._sdata.plotting_tree.keys())
sdata.plotting_tree[f"{n_steps+1}_get_bb"] = {
"x": x,
"y": y,
}
return sdata