-
Notifications
You must be signed in to change notification settings - Fork 181
Expand file tree
/
Copy pathint8_fprop.cpp
More file actions
102 lines (77 loc) · 4.01 KB
/
int8_fprop.cpp
File metadata and controls
102 lines (77 loc) · 4.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*/
#include <catch2/catch_test_macros.hpp>
#include "../utils/helpers.h"
#include <cudnn_frontend.h>
TEST_CASE("Conv with Int8 datatypes", "[conv][graph][caching]") {
namespace fe = cudnn_frontend;
int64_t n = 1, c = 64, h = 32, w = 32, k = 4, r = 3, s = 3;
bool const include_identity = true;
auto build_new_graph = [=](cudnnHandle_t handle) {
auto graph = std::make_shared<fe::graph::Graph>();
graph->set_io_data_type(fe::DataType_t::INT8)
.set_intermediate_data_type(fe::DataType_t::INT32)
.set_compute_data_type(fe::DataType_t::INT32);
auto X = graph->tensor(fe::graph::Tensor_attributes()
.set_name("image")
.set_dim({n, c, h, w})
.set_stride({c * h * w, 1, c * w, c}));
auto W = graph->tensor(fe::graph::Tensor_attributes()
.set_name("filter")
.set_dim({k, c, r, s})
.set_stride({c * r * s, 1, c * s, c}));
auto conv_options =
fe::graph::Conv_fprop_attributes().set_padding({1, 1}).set_stride({1, 1}).set_dilation({1, 1});
auto conv_output = graph->conv_fprop(X, W, conv_options);
auto Y = conv_output;
if (include_identity) {
auto identity = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::IDENTITY);
Y = graph->pointwise(conv_output, conv_output, identity);
}
Y->set_output(true).set_data_type(fe::DataType_t::INT32);
REQUIRE(graph->validate().is_good());
REQUIRE(graph->build_operation_graph(handle).is_good());
REQUIRE(graph->create_execution_plans({fe::HeurMode_t::A}).is_good());
REQUIRE(graph->check_support().is_good());
REQUIRE(graph->build_plans().is_good());
return std::make_tuple(graph, X, W, Y);
};
// Create a unique_ptr for the cuDNN handle
auto handle_ptr = create_cudnn_handle();
auto handle = *handle_ptr;
#if (CUDNN_VERSION < 8600)
SKIP("Conv Int8 requires cudnn 8.6 and up");
#endif
if (check_device_arch_newer_than("ampere") == false) {
SKIP("Int8 datatype convolutions require Ampere and later architectures");
}
auto [graph, X, W, Y] = build_new_graph(handle);
Surface<int8_t> x_tensor(n * c * h * w);
Surface<int8_t> w_tensor(k * c * r * s);
Surface<int32_t> y_tensor(n * k * h * w); // Should be p, q.
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{X, x_tensor.devPtr}, {W, w_tensor.devPtr}, {Y, y_tensor.devPtr}};
int64_t workspace_size = 0;
REQUIRE(graph->get_workspace_size(workspace_size).is_good());
Surface<int8_t> workspace(workspace_size);
REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good());
}