Skip to content

Commit 75bfbc3

Browse files
committed
port spmm_sum to pytorch and optimize it on CPU
[ghstack-poisoned]
1 parent 97c4f58 commit 75bfbc3

File tree

5 files changed

+200
-0
lines changed

5 files changed

+200
-0
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/NativeFunctions.h>
3+
#include <ATen/native/SpmmReduce.h>
4+
5+
namespace at { namespace native {
6+
7+
Tensor spmm_sum_cpu(
8+
const Tensor& rowptr,
9+
const Tensor& col,
10+
const c10::optional<Tensor>& optional_value,
11+
const Tensor& mat) {
12+
TORCH_CHECK(rowptr.dim() == 1);
13+
TORCH_CHECK(col.dim() == 1);
14+
if (optional_value.has_value()) {
15+
TORCH_CHECK(optional_value.value().dim() == 1);
16+
TORCH_CHECK(optional_value.value().size(0) == col.size(0));
17+
}
18+
TORCH_CHECK(mat.dim() >= 2);
19+
20+
Tensor other = mat.contiguous();
21+
22+
auto sizes = other.sizes().vec();
23+
sizes[other.dim() - 2] = rowptr.numel() - 1;
24+
Tensor result = at::empty(sizes, other.options());
25+
spmm_sum_stub(kCPU, result, rowptr, col, optional_value, other);
26+
27+
return result;
28+
}
29+
30+
DEFINE_DISPATCH(spmm_sum_stub);
31+
32+
}} // at::native

aten/src/ATen/native/SpmmReduce.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include <ATen/core/Tensor.h>
4+
#include <ATen/native/DispatchStub.h>
5+
6+
namespace at { namespace native {
7+
8+
using spmm_sum_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const c10::optional<Tensor>&, const Tensor&);
9+
DECLARE_DISPATCH(spmm_sum_fn, spmm_sum_stub);
10+
11+
}} // at::native
12+
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/core/Tensor.h>
3+
4+
#include <ATen/Dispatch.h>
5+
#include <ATen/native/SpmmReduce.h>
6+
#include <ATen/Parallel.h>
7+
#include <ATen/cpu/vec/functional.h>
8+
#include <ATen/cpu/vec/vec.h>
9+
#include <c10/util/irange.h>
10+
11+
namespace at { namespace native {
12+
13+
namespace {
14+
15+
template <typename scalar_t, bool has_optional_value>
16+
void spmm_sum_kernel_impl(
17+
const Tensor& result,
18+
const Tensor& rowptr,
19+
const Tensor& col,
20+
const c10::optional<Tensor>& optional_value,
21+
const Tensor& mat) {
22+
23+
scalar_t* result_data = result.data_ptr<scalar_t>();
24+
int64_t* rowptr_data = rowptr.data_ptr<int64_t>();
25+
int64_t* col_data = col.data_ptr<int64_t>();
26+
scalar_t* value_data = has_optional_value ? optional_value.value().data_ptr<scalar_t>() : nullptr;
27+
scalar_t* mat_data = mat.data_ptr<scalar_t>();
28+
29+
int64_t M = rowptr.numel() - 1;
30+
int64_t N = mat.size(-2);
31+
int64_t K = mat.size(-1);
32+
int64_t B = mat.numel() / (N * K);
33+
34+
// directly parallel on `B * M` may lead to load imbalance,
35+
// statically determine thread partition here to average payload
36+
// for each thread.
37+
int num_threads = at::get_num_threads();
38+
std::vector<int64_t> thread_splits(num_threads + 1, B * M);
39+
int64_t thread_averge_payload = (rowptr_data[M] - rowptr_data[0]) / num_threads;
40+
41+
thread_splits[0] = 0;
42+
int64_t sum = 0;
43+
int64_t t = 1;
44+
for (const auto m : c10::irange(M)) {
45+
int64_t row_start = rowptr_data[m];
46+
int64_t row_end = rowptr_data[m + 1];
47+
sum += row_end - row_start;
48+
if (sum > t * thread_averge_payload) {
49+
thread_splits[t] = B * m;
50+
t++;
51+
}
52+
}
53+
// need to restore the last index,
54+
// due to rounding error when calculating `thread_averge_payload`.
55+
thread_splits[num_threads] = B * M;
56+
57+
// TODO: add bfloat16 support here
58+
using Vec = vec::Vectorized<scalar_t>;
59+
at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) {
60+
int tid = at::get_thread_num();
61+
int64_t begin = thread_splits[tid];
62+
int64_t end = thread_splits[tid + 1];
63+
64+
int64_t row_start, row_end, b, m, c;
65+
for (const auto i : c10::irange(begin, end)) {
66+
b = i / M;
67+
m = i % M;
68+
row_start = rowptr_data[m];
69+
row_end = rowptr_data[m + 1];
70+
71+
scalar_t* result_ptr = result_data + i * K;
72+
73+
constexpr int64_t kVecSize = Vec::size();
74+
constexpr int64_t kVLEN = kVecSize * 4;
75+
constexpr int64_t CHUNK_SIZE = 16;
76+
77+
// init the output lane
78+
vec::map<scalar_t>([](Vec x) { return Vec(0); }, result_ptr, result_ptr, K);
79+
80+
// blocking on rowwise to reduce write memory bandwidth
81+
for (int64_t e0 = row_start; e0 < row_end; e0 += CHUNK_SIZE) {
82+
int64_t e1 = std::min(e0 + CHUNK_SIZE, row_end);
83+
84+
// unrolling by 4
85+
int64_t k = 0;
86+
for (; k < K - (K % kVLEN); k += kVLEN) {
87+
Vec out_vec0 = Vec::loadu(result_ptr + k);
88+
Vec out_vec1 = Vec::loadu(result_ptr + k + kVecSize);
89+
Vec out_vec2 = Vec::loadu(result_ptr + k + kVecSize * 2);
90+
Vec out_vec3 = Vec::loadu(result_ptr + k + kVecSize * 3);
91+
for (const auto e : c10::irange(e0, e1)) {
92+
c = col_data[e];
93+
scalar_t val = has_optional_value ? value_data[e] : scalar_t(1);
94+
scalar_t* mat_ptr = mat_data + b * N * K + c * K + k;
95+
96+
out_vec0 += Vec::loadu(mat_ptr) * Vec(val);
97+
out_vec1 += Vec::loadu(mat_ptr + kVecSize) * Vec(val);
98+
out_vec2 += Vec::loadu(mat_ptr + kVecSize * 2) * Vec(val);
99+
out_vec3 += Vec::loadu(mat_ptr + kVecSize * 3) * Vec(val);
100+
}
101+
out_vec0.store(result_ptr + k);
102+
out_vec1.store(result_ptr + k + kVecSize);
103+
out_vec2.store(result_ptr + k + kVecSize * 2);
104+
out_vec3.store(result_ptr + k + kVecSize * 3);
105+
}
106+
for (; k < K - (K % Vec::size()); k += Vec::size()) {
107+
Vec out_vec = Vec::loadu(result_ptr + k);
108+
for (const auto e : c10::irange(e0, e1)) {
109+
c = col_data[e];
110+
scalar_t val = has_optional_value ? value_data[e] : scalar_t(1);
111+
scalar_t* mat_ptr = mat_data + b * N * K + c * K;
112+
out_vec += Vec::loadu(mat_ptr + k) * Vec(val);
113+
}
114+
out_vec.store(result_ptr + k);
115+
}
116+
for (; k < K; k++) {
117+
scalar_t out_val = result_ptr[k];
118+
for (const auto e : c10::irange(e0, e1)) {
119+
c = col_data[e];
120+
scalar_t val = has_optional_value ? value_data[e] : scalar_t(1);
121+
scalar_t* mat_ptr = mat_data + b * N * K + c * K;
122+
out_val += mat_ptr[k] * val;
123+
}
124+
result_ptr[k] = out_val;
125+
}
126+
}
127+
}
128+
});
129+
}
130+
131+
void spmm_sum_kernel(
132+
const Tensor& result,
133+
const Tensor& rowptr,
134+
const Tensor& col,
135+
const c10::optional<Tensor>& optional_value,
136+
const Tensor& mat) {
137+
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "spmm_sum_kernel", [&]() {
138+
if (optional_value.has_value()) {
139+
spmm_sum_kernel_impl<scalar_t, true>(result, rowptr, col, optional_value, mat);
140+
} else {
141+
spmm_sum_kernel_impl<scalar_t, false>(result, rowptr, col, optional_value, mat);
142+
}
143+
});
144+
}
145+
146+
} // anonymous namespace
147+
148+
REGISTER_DISPATCH(spmm_sum_stub, &spmm_sum_kernel);
149+
150+
}} // at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3609,6 +3609,11 @@
36093609
SparseCUDA: sparse_mask_helper_cuda
36103610
autogen: _sparse_mask_helper.out
36113611

3612+
- func: spmm_sum(Tensor rowptr, Tensor col, Tensor? optional_value, Tensor mat) -> Tensor
3613+
variants: function
3614+
dispatch:
3615+
CPU: spmm_sum_cpu
3616+
36123617
- func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
36133618
variants: function, method
36143619
dispatch:

torch/overrides.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def get_ignored_functions() -> Set[Callable]:
176176
torch.sparse_csc_tensor,
177177
torch.sparse_bsr_tensor,
178178
torch.sparse_bsc_tensor,
179+
torch.spmm_sum,
179180
torch.tril_indices,
180181
torch.triu_indices,
181182
torch.vander,

0 commit comments

Comments
 (0)