Skip to content

Commit 49ccc41

Browse files
manuelcandalespytorchmergebot
authored andcommitted
[Vulkan] Enable QInt8 and QInt32 quantization (#89788)
Summary: Enabled Vulkan quantization for dtypes QInt8 and QInt32 Test Plan: On Mac ``` cd ~/fbsource buck1 run -c pt.vulkan_full_precision=1 //xplat/caffe2:pt_vulkan_quantized_api_test_binAppleMac\#macosx-arm64 ``` On Android ``` cd ~/fbsource buck1 build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 -c pt.vulkan_full_precision=1 //xplat/caffe2:pt_vulkan_quantized_api_test_binAndroid\#android-arm64 --show-output adb push buck-out/gen/xplat/caffe2/pt_vulkan_quantized_api_test_binAndroid\#android-arm64 /data/local/tmp/vulkan_quantized_api_test adb shell "/data/local/tmp/vulkan_quantized_api_test" ``` Differential Revision: D41561661 Pull Request resolved: #89788 Approved by: https://github.com/digantdesai
1 parent 45b40be commit 49ccc41

File tree

7 files changed

+191
-45
lines changed

7 files changed

+191
-45
lines changed

aten/src/ATen/native/vulkan/api/Resource.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ VkFormat vk_format(const at::ScalarType dtype) {
3636
#endif /* USE_VULKAN_FP16_INFERENCE */
3737
case c10::kQUInt8:
3838
return VK_FORMAT_R8G8B8A8_UINT;
39+
case c10::kQInt8:
40+
return VK_FORMAT_R8G8B8A8_SINT;
41+
case c10::kQInt32:
42+
return VK_FORMAT_R32G32B32A32_SINT;
3943

4044
default:
4145
TORCH_CHECK(
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#version 450 core
2+
#define PRECISION $precision
3+
#define FORMAT $format
4+
5+
layout(std430) buffer;
6+
7+
/* Qualifiers: layout - storage - precision - memory */
8+
9+
layout(set = 0, binding = 0, rgba32i) uniform PRECISION restrict writeonly iimage3D uOutput;
10+
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; //input
11+
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
12+
ivec4 size;
13+
vec2 scale;
14+
ivec2 zero_point;
15+
} uBlock;
16+
17+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
18+
19+
void main() {
20+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
21+
if (all(lessThan(pos, uBlock.size.xyz))) {
22+
vec4 q_res = roundEven(texelFetch(uInput, pos, 0) / uBlock.scale.x) + uBlock.zero_point.x;
23+
24+
ivec4 ret = ivec4(q_res);
25+
26+
imageStore(
27+
uOutput,
28+
pos,
29+
ret);
30+
}
31+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#version 450 core
2+
#define PRECISION $precision
3+
#define FORMAT $format
4+
5+
layout(std430) buffer;
6+
7+
/* Qualifiers: layout - storage - precision - memory */
8+
9+
layout(set = 0, binding = 0, rgba8i) uniform PRECISION restrict writeonly iimage3D uOutput;
10+
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; //input
11+
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
12+
ivec4 size;
13+
vec2 scale;
14+
ivec2 zero_point;
15+
} uBlock;
16+
17+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
18+
19+
void main() {
20+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
21+
if (all(lessThan(pos, uBlock.size.xyz))) {
22+
vec4 q_res = roundEven(texelFetch(uInput, pos, 0) / uBlock.scale.x) + uBlock.zero_point.x;
23+
24+
ivec4 ret = ivec4(q_res);
25+
26+
imageStore(
27+
uOutput,
28+
pos,
29+
ret);
30+
}
31+
}

aten/src/ATen/native/vulkan/glsl/quantize_per_tensor.glsl renamed to aten/src/ATen/native/vulkan/glsl/quantize_per_tensor_quint8.glsl

File renamed without changes.

aten/src/ATen/native/vulkan/ops/Copy.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@ void memcpy_to_mapping(const Tensor& src, api::MemoryMap& dst_mapping) {
1818
memcpy_to_mapping_impl<c10::Half>(src, dst_mapping);
1919
} else if (src.dtype() == c10::kQUInt8) {
2020
memcpy_to_mapping_impl<c10::quint8>(src, dst_mapping);
21+
} else if (src.dtype() == c10::kQInt8) {
22+
memcpy_to_mapping_impl<c10::qint8>(src, dst_mapping);
23+
} else if (src.dtype() == c10::kQInt32) {
24+
memcpy_to_mapping_impl<c10::qint32>(src, dst_mapping);
2125
} else {
2226
TORCH_CHECK(
2327
false,
24-
"Invalid Data Type: expected c10::QUint8, at::kHalf or at::Float but got ",
28+
"Invalid Data Type: expected c10::kQInt32, c10::kQInt8, c10::kQUInt8,",
29+
" at::kHalf or at::Float but got ",
2530
src.dtype());
2631
}
2732
}
@@ -33,10 +38,15 @@ void memcpy_from_mapping(api::MemoryMap& src_mapping, Tensor& dst) {
3338
memcpy_from_mapping_impl<c10::Half>(src_mapping, dst);
3439
} else if (dst.dtype() == c10::kQUInt8) {
3540
memcpy_from_mapping_impl<c10::quint8>(src_mapping, dst);
41+
} else if (dst.dtype() == c10::kQInt8) {
42+
memcpy_from_mapping_impl<c10::qint8>(src_mapping, dst);
43+
} else if (dst.dtype() == c10::kQInt32) {
44+
memcpy_from_mapping_impl<c10::qint32>(src_mapping, dst);
3645
} else {
3746
TORCH_CHECK(
3847
false,
39-
"Invalid Data Type: expected c10::QUint8, at::kHalf or Float but got ",
48+
"Invalid Data Type: expected c10::kQInt32, c10::kQInt8, c10::kQUInt8,",
49+
" at::kHalf or at::Float but got ",
4050
dst.dtype());
4151
}
4252
}

aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,37 @@ namespace ops {
1010

1111
using namespace api::utils;
1212

13+
static api::ShaderSource get_quantize_per_tensor_shader(
14+
const c10::ScalarType dtype) {
15+
switch (dtype) {
16+
case c10::ScalarType::QUInt8:
17+
return VK_KERNEL(quantize_per_tensor_quint8);
18+
case c10::ScalarType::QInt8:
19+
return VK_KERNEL(quantize_per_tensor_qint8);
20+
case c10::ScalarType::QInt32:
21+
return VK_KERNEL(quantize_per_tensor_qint32);
22+
default:
23+
TORCH_CHECK(
24+
false,
25+
"Vulkan quantization currently not supported for dtype ",
26+
dtype);
27+
}
28+
}
29+
1330
Tensor quantize_per_tensor(
1431
const at::Tensor& input_arg,
1532
const double scale,
1633
const int64_t zero_point,
1734
const c10::ScalarType dtype) {
18-
TORCH_CHECK(dtype == c10::ScalarType::QUInt8, "Expected type c10::kQUint8");
35+
api::ShaderSource compute_shader = get_quantize_per_tensor_shader(dtype);
1936

2037
api::Context* const context = api::context();
2138

2239
const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
2340
const vTensor& v_input = convert(input);
2441

2542
vTensor v_output{
26-
context,
27-
input.sizes(),
28-
input.options().dtype(c10::kQUInt8),
29-
scale,
30-
zero_point};
43+
context, input.sizes(), input.options().dtype(dtype), scale, zero_point};
3144

3245
const struct Block final {
3346
uvec3 extents;
@@ -50,7 +63,7 @@ Tensor quantize_per_tensor(
5063

5164
context->submit_compute_job(
5265
// shader descriptor
53-
VK_KERNEL(quantize_per_tensor),
66+
compute_shader,
5467
// barrier
5568
pipeline_barrier,
5669
// global work group size

aten/src/ATen/test/vulkan_quantized_api_test.cpp

Lines changed: 93 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -449,66 +449,123 @@ void test_quantize_per_tensor_and_dequantize(
449449
const at::IntArrayRef input_shape,
450450
const double input_scale,
451451
const int input_zero_point,
452-
const float tolerance = 0) {
453-
at::Tensor input = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat));
452+
const c10::ScalarType dtype = c10::ScalarType::QUInt8) {
453+
at::Tensor input = produce_random_tensor(input_shape);
454454

455455
// quantize tensors
456456
at::Tensor out_q_cpu = at::quantize_per_tensor(
457-
input, input_scale, input_zero_point, c10::ScalarType::QUInt8);
457+
input, input_scale, input_zero_point, dtype);
458458
at::Tensor out_q_vk = at::quantize_per_tensor(
459-
input.vulkan(), input_scale, input_zero_point, c10::ScalarType::QUInt8);
459+
input.vulkan(), input_scale, input_zero_point, dtype);
460460

461461
// dequantize tensors
462462
const auto out_cpu_deq = at::dequantize(out_q_cpu);
463463
const auto out_vk_deq = at::dequantize(out_q_vk);
464+
const auto out_vk_deq_cpu = out_vk_deq.cpu();
464465

465466
// check dequantized tensor are equal
466-
const auto check = almostEqual(out_cpu_deq, out_vk_deq.cpu(), tolerance);
467+
const float tolerance = input_scale;
468+
// tolerated error = scale, to allow for precision differences after dividing
469+
// by random scale, which could result on a difference of 1 unit in the
470+
// quantized result.
471+
const auto check = almostEqual(out_cpu_deq, out_vk_deq_cpu, tolerance);
467472

468473
if (!check) {
474+
const auto error = at::abs(out_vk_deq_cpu - out_cpu_deq).max().item<float>();
469475
std::cout
470476
<< "Quantize and Dequantize failed with input shape: " << input_shape
471477
<< " scale: " << input_scale << " and zero point: " << input_zero_point
472478
<< std::endl;
479+
std::cout << "Error: " << error << std::endl;
473480
}
474481
ASSERT_TRUE(check);
475482
}
476483

477-
void test_quantize_per_tensor_and_dequantize_random() {
478-
const double scale = 0.0001 + (double)rand() / (double)RAND_MAX;
479-
const int zero_point = int((double)rand() / (double)RAND_MAX * 255);
480-
const int n = 1 + int((double)rand() / (double)RAND_MAX * 30);
481-
const int c = 1 + int((double)rand() / (double)RAND_MAX * 30);
482-
const int h = 1 + int((double)rand() / (double)RAND_MAX * 100);
483-
const int w = 1 + int((double)rand() / (double)RAND_MAX * 100);
484-
// tolerated error = scale, to allow for precision differences after dividing
485-
// by random scale, which could result on a difference of 1 unit in the
486-
// quantized result.
487-
test_quantize_per_tensor_and_dequantize({n, c, h, w}, scale, zero_point, scale);
484+
void test_quantize_per_tensor_and_dequantize_random(
485+
const c10::ScalarType dtype) {
486+
const double scale = produce_random_scale();
487+
const int64_t zero_point = produce_random_zero_point(dtype);
488+
const at::IntArrayRef tensor_shape =
489+
{rand_pos_int(30), rand_pos_int(30), rand_pos_int(100), rand_pos_int(100)};
490+
test_quantize_per_tensor_and_dequantize(
491+
tensor_shape, scale, zero_point, dtype);
492+
}
493+
494+
TEST_F(VulkanAPITest, quantize_per_tensor_and_dequantize_quint8) {
495+
const c10::ScalarType dtype = c10::ScalarType::QUInt8;
496+
test_quantize_per_tensor_and_dequantize({1, 1, 1, 1}, 0.13, 21, dtype);
497+
test_quantize_per_tensor_and_dequantize({1, 1, 1, 4}, 0.3, 87, dtype);
498+
test_quantize_per_tensor_and_dequantize({1, 1, 4, 1}, 0.2, 120, dtype);
499+
test_quantize_per_tensor_and_dequantize({1, 1, 7, 7}, 0.3, 87, dtype);
500+
test_quantize_per_tensor_and_dequantize({1, 1, 8, 8}, 0.1, 10, dtype);
501+
test_quantize_per_tensor_and_dequantize({3, 5, 8, 8}, 0.04, 97, dtype);
502+
test_quantize_per_tensor_and_dequantize({1, 1, 11, 17}, 0.07, 15, dtype);
503+
test_quantize_per_tensor_and_dequantize({1, 1, 12, 17}, 0.1, 10, dtype);
504+
test_quantize_per_tensor_and_dequantize({3, 5, 12, 17}, 0.1, 10, dtype);
505+
test_quantize_per_tensor_and_dequantize({1, 1, 17, 12}, 0.1, 10, dtype);
506+
test_quantize_per_tensor_and_dequantize({2, 4, 17, 12}, 0.1, 10, dtype);
507+
test_quantize_per_tensor_and_dequantize({1, 1, 10, 14}, 0.001, 101, dtype);
508+
test_quantize_per_tensor_and_dequantize({3, 5, 10, 14}, 0.009, 43, dtype);
509+
test_quantize_per_tensor_and_dequantize({3, 5, 10, 15}, 0.1, 19, dtype);
510+
test_quantize_per_tensor_and_dequantize({4, 4, 9, 17}, 0.1, 19, dtype);
511+
test_quantize_per_tensor_and_dequantize({3, 5, 25, 29}, 0.1, 19, dtype);
512+
test_quantize_per_tensor_and_dequantize({4, 4, 25, 29}, 0.1, 19, dtype);
513+
test_quantize_per_tensor_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype);
514+
515+
for (int i = 0; i < 20; i += 1) {
516+
test_quantize_per_tensor_and_dequantize_random(dtype);
517+
}
518+
}
519+
520+
TEST_F(VulkanAPITest, quantize_per_tensor_and_dequantize_qint8) {
521+
const c10::ScalarType dtype = c10::ScalarType::QInt8;
522+
test_quantize_per_tensor_and_dequantize({1, 1, 1, 1}, 0.13, -21, dtype);
523+
test_quantize_per_tensor_and_dequantize({1, 1, 1, 4}, 0.3, 87, dtype);
524+
test_quantize_per_tensor_and_dequantize({1, 1, 4, 1}, 0.2, -120, dtype);
525+
test_quantize_per_tensor_and_dequantize({1, 1, 7, 7}, 0.3, 87, dtype);
526+
test_quantize_per_tensor_and_dequantize({1, 1, 8, 8}, 0.1, -10, dtype);
527+
test_quantize_per_tensor_and_dequantize({3, 5, 8, 8}, 0.04, 97, dtype);
528+
test_quantize_per_tensor_and_dequantize({1, 1, 11, 17}, 0.07, -15, dtype);
529+
test_quantize_per_tensor_and_dequantize({1, 1, 12, 17}, 0.1, 10, dtype);
530+
test_quantize_per_tensor_and_dequantize({3, 5, 12, 17}, 0.1, -10, dtype);
531+
test_quantize_per_tensor_and_dequantize({1, 1, 17, 12}, 0.1, 10, dtype);
532+
test_quantize_per_tensor_and_dequantize({2, 4, 17, 12}, 0.1, -10, dtype);
533+
test_quantize_per_tensor_and_dequantize({1, 1, 10, 14}, 0.001, 101, dtype);
534+
test_quantize_per_tensor_and_dequantize({3, 5, 10, 14}, 0.009, -43, dtype);
535+
test_quantize_per_tensor_and_dequantize({3, 5, 10, 15}, 0.1, 19, dtype);
536+
test_quantize_per_tensor_and_dequantize({4, 4, 9, 17}, 0.1, -19, dtype);
537+
test_quantize_per_tensor_and_dequantize({3, 5, 25, 29}, 0.1, 19, dtype);
538+
test_quantize_per_tensor_and_dequantize({4, 4, 25, 29}, 0.1, -19, dtype);
539+
test_quantize_per_tensor_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype);
540+
541+
for (int i = 0; i < 20; i += 1) {
542+
test_quantize_per_tensor_and_dequantize_random(dtype);
543+
}
488544
}
489545

490-
TEST_F(VulkanAPITest, quantize_per_tensor_and_dequantize) {
491-
test_quantize_per_tensor_and_dequantize({1, 1, 1, 1}, 0.13, 21);
492-
test_quantize_per_tensor_and_dequantize({1, 1, 1, 4}, 0.3, 87);
493-
test_quantize_per_tensor_and_dequantize({1, 1, 4, 1}, 0.2, 120);
494-
test_quantize_per_tensor_and_dequantize({1, 1, 7, 7}, 0.3, 87);
495-
test_quantize_per_tensor_and_dequantize({1, 1, 8, 8}, 0.1, 10);
496-
test_quantize_per_tensor_and_dequantize({3, 5, 8, 8}, 0.04, 97);
497-
test_quantize_per_tensor_and_dequantize({1, 1, 11, 17}, 0.07, 15);
498-
test_quantize_per_tensor_and_dequantize({1, 1, 12, 17}, 0.1, 10);
499-
test_quantize_per_tensor_and_dequantize({3, 5, 12, 17}, 0.1, 10);
500-
test_quantize_per_tensor_and_dequantize({1, 1, 17, 12}, 0.1, 10);
501-
test_quantize_per_tensor_and_dequantize({2, 4, 17, 12}, 0.1, 10);
502-
test_quantize_per_tensor_and_dequantize({1, 1, 10, 14}, 0.0001, 101);
503-
test_quantize_per_tensor_and_dequantize({3, 5, 10, 14}, 0.009, 43);
504-
test_quantize_per_tensor_and_dequantize({3, 5, 10, 15}, 0.1, 19);
505-
test_quantize_per_tensor_and_dequantize({4, 4, 9, 17}, 0.1, 19);
506-
test_quantize_per_tensor_and_dequantize({3, 5, 25, 29}, 0.1, 19);
507-
test_quantize_per_tensor_and_dequantize({4, 4, 25, 29}, 0.1, 19);
508-
test_quantize_per_tensor_and_dequantize({11, 17, 25, 29}, 0.027, 89);
546+
TEST_F(VulkanAPITest, quantize_per_tensor_and_dequantize_qint32) {
547+
const c10::ScalarType dtype = c10::ScalarType::QInt32;
548+
test_quantize_per_tensor_and_dequantize({1, 1, 1, 1}, 0.13, -21123, dtype);
549+
test_quantize_per_tensor_and_dequantize({1, 1, 1, 4}, 0.339, 8734, dtype);
550+
test_quantize_per_tensor_and_dequantize({1, 1, 4, 1}, 0.228, -12023, dtype);
551+
test_quantize_per_tensor_and_dequantize({1, 1, 7, 7}, 0.338, 8723, dtype);
552+
test_quantize_per_tensor_and_dequantize({1, 1, 8, 8}, 0.193, -1023, dtype);
553+
test_quantize_per_tensor_and_dequantize({3, 5, 8, 8}, 0.0449, 972, dtype);
554+
test_quantize_per_tensor_and_dequantize({1, 1, 11, 17}, 0.073, -15, dtype);
555+
test_quantize_per_tensor_and_dequantize({1, 1, 12, 17}, 0.1572, 102, dtype);
556+
test_quantize_per_tensor_and_dequantize({3, 5, 12, 17}, 0.147, -156, dtype);
557+
test_quantize_per_tensor_and_dequantize({1, 1, 17, 12}, 0.129, 10448, dtype);
558+
test_quantize_per_tensor_and_dequantize({2, 4, 17, 12}, 0.137, -10, dtype);
559+
test_quantize_per_tensor_and_dequantize({1, 1, 10, 14}, 0.001, 101, dtype);
560+
test_quantize_per_tensor_and_dequantize({3, 5, 10, 14}, 0.009, -43267, dtype);
561+
test_quantize_per_tensor_and_dequantize({3, 5, 10, 15}, 0.1243, 19, dtype);
562+
test_quantize_per_tensor_and_dequantize({4, 4, 9, 17}, 0.1889, -19784, dtype);
563+
test_quantize_per_tensor_and_dequantize({3, 5, 25, 29}, 0.1345, 196, dtype);
564+
test_quantize_per_tensor_and_dequantize({4, 4, 25, 29}, 0.129, -19489, dtype);
565+
test_quantize_per_tensor_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype);
509566

510567
for (int i = 0; i < 20; i += 1) {
511-
test_quantize_per_tensor_and_dequantize_random();
568+
test_quantize_per_tensor_and_dequantize_random(dtype);
512569
}
513570
}
514571

0 commit comments

Comments
 (0)