forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpt_binding.cpp
More file actions
161 lines (139 loc) · 5.65 KB
/
pt_binding.cpp
File metadata and controls
161 lines (139 loc) · 5.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <cassert>
#include <vector>
#include "quantization.h"
template <typename T>
at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if ((((size / groups) - 1) / 4096 + 1) <= 256) {
launch_fake_quantize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}
template <typename T>
at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_fake_quantize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}
template <typename T>
at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if ((((size / groups) - 1) / 4096 + 1) <= 256) {
launch_fake_quantize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}
template <typename T>
at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_fake_quantize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}
std::vector<at::Tensor> quantize_kernel(at::Tensor& input_vals,
int groups,
int numBits,
quantize::Type quantType)
{
auto dtype = at::kFloat;
auto params_options = at::TensorOptions()
.dtype(dtype)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int param_elems = (quantize::requires_offset(quantType)) ? 2 : 1;
auto params = torch::empty({groups, param_elems}, params_options);
auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output_sizes = input_vals.sizes().vec();
output_sizes[output_sizes.size() - 1] /= numBits == 8 ? 1 : 2;
auto output = torch::empty(output_sizes, output_options);
const int elems_per_group = at::numel(input_vals) / groups;
launch_quant((int8_t*)output.data_ptr(),
(float*)params.data_ptr(),
(__half*)input_vals.data_ptr(),
groups,
elems_per_group,
numBits,
quantType,
at::cuda::getCurrentCUDAStream());
return {output, params};
}
template <typename T>
at::Tensor dequantize(at::Tensor& quantized_data,
at::Tensor& params,
int groups,
int num_bits,
quantize::Type quant_type)
{
auto dtype = (std::is_same<T, float>::value) ? torch::kFloat32 : torch::kFloat16;
auto output_options = at::TensorOptions()
.dtype(dtype)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output_sizes = quantized_data.sizes().vec();
output_sizes[output_sizes.size() - 1] *= num_bits == 8 ? 1 : 2;
auto output = torch::empty(output_sizes, output_options);
const int total_elems = at::numel(output);
const int elems_per_group = total_elems / groups;
launch_dequantize_kernel((T*)output.data_ptr(),
(const int8_t*)quantized_data.data_ptr(),
(const float*)params.data_ptr(),
quant_type,
num_bits,
elems_per_group,
total_elems,
at::cuda::getCurrentCUDAStream());
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_quantize_fp16", &ds_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_sr_quantize_fp32", &ds_sr_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_sr_quantize_fp16", &ds_sr_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_quantize_asym_fp32", &ds_quantize_asym<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def(
"ds_quantize_asym_fp16", &ds_quantize_asym<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_sr_quantize_asym_fp32",
&ds_sr_quantize_asym<float>,
"DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_sr_quantize_asym_fp16",
&ds_sr_quantize_asym<__half>,
"DeepSpeed Quantize with fp16 (CUDA)");
pybind11::enum_<quantize::Type>(m, "QuantizationType")
.value("Symmetric", quantize::Type::Symmetric)
.value("Asymmetric", quantize::Type::Asymmetric)
.export_values();
m.def("quantize", &quantize_kernel);
m.def("dequantize", &dequantize<__half>);
m.def("dequantize_fp32", &dequantize<float>);
}