-
Notifications
You must be signed in to change notification settings - Fork 238
Expand file tree
/
Copy pathtorch_backend.py
More file actions
321 lines (269 loc) · 11.9 KB
/
torch_backend.py
File metadata and controls
321 lines (269 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
import numpy as np
from docarray.computation.abstract_comp_backend import AbstractComputationalBackend
from docarray.utils._internal.misc import import_library
if TYPE_CHECKING:
import torch
else:
torch = import_library('torch', raise_error=True)
def _unsqueeze_if_single_axis(*matrices: torch.Tensor) -> List[torch.Tensor]:
"""Unsqueezes tensors that only have one axis, at dim 0.
This ensures that all outputs can be treated as matrices, not vectors.
:param matrices: Matrices to be unsqueezed
:return: List of the input matrices,
where single axis matrices are unsqueezed at dim 0.
"""
unsqueezed = []
for m in matrices:
if len(m.shape) == 1:
unsqueezed.append(m.unsqueeze(0))
else:
unsqueezed.append(m)
return unsqueezed
def _unsqueeze_if_scalar(t: torch.Tensor):
if len(t.shape) == 0: # avoid scalar output
t = t.unsqueeze(0)
return t
class TorchCompBackend(AbstractComputationalBackend[torch.Tensor]):
"""
Computational backend for PyTorch.
"""
@classmethod
def stack(
cls, tensors: Union[List['torch.Tensor'], Tuple['torch.Tensor']], dim: int = 0
) -> 'torch.Tensor':
return torch.stack(tensors, dim=dim)
@classmethod
def copy(cls, tensor: 'torch.Tensor') -> 'torch.Tensor':
"""return a copy/clone of the tensor"""
return tensor.clone()
@classmethod
def to_device(cls, tensor: 'torch.Tensor', device: str) -> 'torch.Tensor':
"""Move the tensor to the specified device."""
return tensor.to(device)
@classmethod
def device(cls, tensor: 'torch.Tensor') -> Optional[str]:
"""Return device on which the tensor is allocated."""
return str(tensor.device)
@classmethod
def empty(
cls,
shape: Tuple[int, ...],
dtype: Optional[Any] = None,
device: Optional[Any] = None,
) -> torch.Tensor:
extra_param = {}
if dtype is not None:
extra_param['dtype'] = dtype
if device is not None:
extra_param['device'] = device
return torch.empty(shape, **extra_param)
@classmethod
def n_dim(cls, array: 'torch.Tensor') -> int:
return array.ndim
@classmethod
def squeeze(cls, tensor: 'torch.Tensor') -> 'torch.Tensor':
"""
Returns a tensor with all the dimensions of tensor of size 1 removed.
"""
return torch.squeeze(tensor)
@classmethod
def to_numpy(cls, array: 'torch.Tensor') -> 'np.ndarray':
return array.cpu().detach().numpy()
@classmethod
def none_value(
cls,
) -> Any:
"""Provide a compatible value that represents None in torch."""
return torch.tensor(float('nan'))
@classmethod
def shape(cls, tensor: 'torch.Tensor') -> Tuple[int, ...]:
return tuple(tensor.shape)
@classmethod
def reshape(cls, tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tensor':
"""
Gives a new shape to tensor without changing its data.
:param tensor: tensor to be reshaped
:param shape: the new shape
:return: a tensor with the same data and number of elements as tensor
but with the specified shape.
"""
return tensor.reshape(shape)
@classmethod
def equal(cls, tensor1: 'torch.Tensor', tensor2: 'torch.Tensor') -> bool:
"""
Check if two tensors are equal.
:param tensor1: the first tensor
:param tensor2: the second tensor
:return: True if two tensors are equal, False otherwise.
If one or more of the inputs is not a torch.Tensor, return False.
"""
are_torch = isinstance(tensor1, torch.Tensor) and isinstance(
tensor2, torch.Tensor
)
return are_torch and torch.equal(tensor1, tensor2)
@classmethod
def detach(cls, tensor: 'torch.Tensor') -> 'torch.Tensor':
"""
Returns the tensor detached from its current graph.
:param tensor: tensor to be detached
:return: a detached tensor with the same data.
"""
return tensor.detach()
@classmethod
def dtype(cls, tensor: 'torch.Tensor') -> torch.dtype:
"""Get the data type of the tensor."""
return tensor.dtype
@classmethod
def isnan(cls, tensor: 'torch.Tensor') -> 'torch.Tensor':
"""Check element-wise for nan and return result as a boolean array"""
return torch.isnan(tensor)
@classmethod
def minmax_normalize(
cls,
tensor: 'torch.Tensor',
t_range: Tuple = (0, 1),
x_range: Optional[Tuple] = None,
eps: float = 1e-7,
) -> 'torch.Tensor':
"""
Normalize values in `tensor` into `t_range`.
`tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then
normalization is row-based.
!!! note
- with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1;
- with `t_range=(1, 0)` will normalize the min-value of data to 1, max value
of the data to 0.
:param tensor: the data to be normalized
:param t_range: a tuple represents the target range.
:param x_range: a tuple represents tensors range.
:param eps: a small jitter to avoid divide by zero
:return: normalized data in `t_range`
"""
a, b = t_range
min_d = (
x_range[0] if x_range else torch.min(tensor, dim=-1, keepdim=True).values
)
max_d = (
x_range[1] if x_range else torch.max(tensor, dim=-1, keepdim=True).values
)
r = (b - a) * (tensor - min_d) / (max_d - min_d + eps) + a
normalized = torch.clip(r, *((a, b) if a < b else (b, a)))
return normalized.to(tensor.dtype)
class Retrieval(AbstractComputationalBackend.Retrieval[torch.Tensor]):
"""
Abstract class for retrieval and ranking functionalities
"""
@staticmethod
def top_k(
values: 'torch.Tensor',
k: int,
descending: bool = False,
device: Optional[str] = None,
) -> Tuple['torch.Tensor', 'torch.Tensor']:
"""
Retrieves the top k smallest values in `values`,
and returns them alongside their indices in the input `values`.
Can also be used to retrieve the top k largest values,
by setting the `descending` flag.
:param values: Torch tensor of values to rank.
Should be of shape (n_queries, n_values_per_query).
Inputs of shape (n_values_per_query,) will be expanded
to (1, n_values_per_query).
:param k: number of values to retrieve
:param descending: retrieve largest values instead of smallest values
:param device: the computational device to use,
can be either `cpu` or a `cuda` device.
:return: Tuple containing the retrieved values, and their indices.
Both ar of shape (n_queries, k)
"""
if device is not None:
values = values.to(device)
if len(values.shape) <= 1:
values = values.view(1, -1)
len_values = values.shape[-1] if len(values.shape) > 1 else len(values)
k = min(k, len_values)
return torch.topk(
input=values, k=k, largest=descending, sorted=True, dim=-1
)
class Metrics(AbstractComputationalBackend.Metrics[torch.Tensor]):
"""
Abstract base class for metrics (distances and similarities).
"""
@staticmethod
def cosine_sim(
x_mat: torch.Tensor,
y_mat: torch.Tensor,
eps: float = 1e-7,
device: Optional[str] = None,
) -> torch.Tensor:
"""Pairwise cosine similarities between all vectors in x_mat and y_mat.
:param x_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the
number of vectors and n_dim is the number of dimensions of each example.
:param y_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the
number of vectors and n_dim is the number of dimensions of each example.
:param eps: a small jitter to avoid divde by zero
:param device: the device to use for pytorch computations.
Either 'cpu' or a 'cuda' device.
If not provided, the devices of x_mat and y_mat are used.
:return: torch Tensor of shape (n_vectors, n_vectors) containing all
pairwise cosine distances.
The index [i_x, i_y] contains the cosine distance between
x_mat[i_x] and y_mat[i_y].
"""
if device is not None:
x_mat = x_mat.to(device)
y_mat = y_mat.to(device)
x_mat, y_mat = _unsqueeze_if_single_axis(x_mat, y_mat)
a_n, b_n = x_mat.norm(dim=1)[:, None], y_mat.norm(dim=1)[:, None]
a_norm = x_mat / torch.clamp(a_n, min=eps)
b_norm = y_mat / torch.clamp(b_n, min=eps)
sims = torch.mm(a_norm, b_norm.transpose(0, 1)).squeeze()
return _unsqueeze_if_scalar(sims)
@staticmethod
def euclidean_dist(
x_mat: torch.Tensor, y_mat: torch.Tensor, device: Optional[str] = None
) -> torch.Tensor:
"""Pairwise Euclidian distances between all vectors in x_mat and y_mat.
:param x_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the
number of vectors and n_dim is the number of dimensions of each example.
:param y_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the
number of vectors and n_dim is the number of dimensions of each example.
:param device: the device to use for pytorch computations.
Either 'cpu' or a 'cuda' device.
If not provided, the devices of x_mat and y_mat are used.
:return: torch Tensor of shape (n_vectors, n_vectors) containing all
pairwise euclidian distances.
The index [i_x, i_y] contains the euclidian distance between
x_mat[i_x] and y_mat[i_y].
"""
if device is not None:
x_mat = x_mat.to(device)
y_mat = y_mat.to(device)
x_mat, y_mat = _unsqueeze_if_single_axis(x_mat, y_mat)
dists = torch.cdist(x_mat, y_mat).squeeze()
return _unsqueeze_if_scalar(dists)
@staticmethod
def sqeuclidean_dist(
x_mat: torch.Tensor, y_mat: torch.Tensor, device: Optional[str] = None
) -> torch.Tensor:
"""Pairwise Squared Euclidian distances between all vectors in
x_mat and y_mat.
:param x_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the
number of vectors and n_dim is the number of dimensions of each example.
:param y_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the
number of vectors and n_dim is the number of dimensions of each example.
:param eps: a small jitter to avoid divde by zero
:param device: the device to use for pytorch computations.
Either 'cpu' or a 'cuda' device.
If not provided, the devices of x_mat and y_mat are used.
:return: torch Tensor of shape (n_vectors, n_vectors) containing all
pairwise Squared Euclidian distances.
The index [i_x, i_y] contains the cosine Squared Euclidian between
x_mat[i_x] and y_mat[i_y].
"""
if device is not None:
x_mat = x_mat.to(device)
y_mat = y_mat.to(device)
x_mat, y_mat = _unsqueeze_if_single_axis(x_mat, y_mat)
return _unsqueeze_if_scalar((torch.cdist(x_mat, y_mat) ** 2).squeeze())