forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsparse_tensor.py
More file actions
68 lines (56 loc) · 2.36 KB
/
sparse_tensor.py
File metadata and controls
68 lines (56 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Implementation of a compressed sparse tensor. Similar in
functionality to TensorFlow's IndexedSlices implementation.
"""
import torch
class SparseTensor(object):
""" Compressed Sparse Tensor """
def __init__(self, dense_tensor=None):
self.orig_dense_tensor = dense_tensor
self.is_sparse = dense_tensor.is_sparse
if dense_tensor is not None:
if dense_tensor.is_sparse:
dense_tensor = dense_tensor.coalesce()
self.indices = dense_tensor.indices().flatten()
self.values = dense_tensor.values()
else:
result = torch.sum(dense_tensor, dim=1)
self.indices = result.nonzero().flatten()
self.values = dense_tensor[self.indices]
self.dense_size = list(dense_tensor.size())
else:
self.indices = None
self.values = None
self.dense_size = None
def to_coo_tensor(self):
return torch.sparse_coo_tensor(self.indices.unsqueeze(0), self.values, self.dense_size)
@staticmethod
def type():
return "deepspeed.SparseTensor"
def to_dense(self):
it = self.indices.unsqueeze(1)
full_indices = torch.cat([it for _ in range(self.dense_size[1])], dim=1)
return self.values.new_zeros(self.dense_size).scatter_add_(0, full_indices, self.values)
def sparse_size(self):
index_size = list(self.indices.size())
index_size = index_size[0]
value_size = list(self.values.size())
value_size = value_size[0] * value_size[1]
dense_size = self.dense_size[0] * self.dense_size[1]
return index_size + value_size, dense_size
def add(self, b):
assert self.dense_size == b.dense_size
self.indices = torch.cat([self.indices, b.indices])
self.values = torch.cat([self.values, b.values])
def __str__(self):
sparse_size, dense_size = self.sparse_size()
return "DeepSpeed.SparseTensor(indices_size={}, values_size={}, " \
"dense_size={}, device={}, reduction_factor={})".format(
self.indices.size(), self.values.size(), self.dense_size,
self.indices.get_device(), dense_size / sparse_size
)
def __repr__(self):
return self.__str__()