-
Notifications
You must be signed in to change notification settings - Fork 247
Expand file tree
/
Copy pathgen_squeeze.py
More file actions
99 lines (80 loc) · 3.43 KB
/
gen_squeeze.py
File metadata and controls
99 lines (80 loc) · 3.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import tensorflow as tf
import numpy as np
import jinja2
output_file = "test_squeeze.cpp"
const_file = "constants_squeeze.hpp"
const_str = """
static const {{ dtype }} s_in_{{ test_name }}[{{ input_size }}] = { {% for x in ref_in %} {{ x }}{{ "," if not loop.last }} {% endfor %} };
static const {{ dtype }} s_ref_out_{{ test_name }}[{{ out_size }}] = { {% for x in ref_out %} {{ x }}{{ "," if not loop.last }} {% endfor %} };
"""
test_str = """
TEST(Squeeze, random_inputs_{{ test_name }}) {
localCircularArenaAllocator<1024> meta_allocator;
localCircularArenaAllocator<{{ out_size }}*2*sizeof({{ dtype }}), uint32_t> ram_allocator;
Context::get_default_context()->set_metadata_allocator(&meta_allocator);
Context::get_default_context()->set_ram_data_allocator(&ram_allocator);
Tensor io = new RamTensor({ {% for x in in_shape %}{{ x }}{{ "," if not loop.last }}{% endfor %} }, {{ u_dtype }});
for(int i = 0; i < {{ input_size }}; i++) {
io(i) = s_in_{{ test_name }}[i];
}
SqueezeOperator<{{ dtype }}> Squeeze;
Squeeze
.set_inputs({ {SqueezeOperator<{{ dtype }}>::x, io}})
.eval();
{{ dtype }} tmp;
for(int i = 0; i < {{ out_size }}; i++) {
tmp = io(i);
EXPECT_NEAR( (float)(tmp - s_ref_out_{{ test_name }}[i]), 0, 0.0001);
}
}
"""
container_str = """
#include <cstring>
#include <iostream>
#include "uTensor.h"
#include "gtest/gtest.h"
#include "{{ constants_header }}"
using std::cout;
using std::endl;
using namespace uTensor;
{% for test in tests %}
/*********************************************
* Generated Test number {{ loop.counter }}
*********************************************/
{{ test }}
{% endfor %}
"""
const_container_str = """
#ifndef {{ constants_header | replace(".", "_") }}
#define {{ constants_header | replace(".", "_") }}
{% for constant_snippet in constants %}
{{ constant_snippet }}
{% endfor %}
#endif
"""
test_Template = jinja2.Template(test_str)
const_Template = jinja2.Template(const_str)
container_Template = jinja2.Template(container_str)
const_container_Template = jinja2.Template(const_container_str)
num_tests = 5
tests=[]
constants=[]
for test_type in [(np.float32, "float", "flt"), (np.int8, "int8_t", "i8"), (np.int16, "int16_t", "i16"), (np.int32, "int32_t", "i32")]:
for i in range(num_tests):
np_dtype, dtype, u_dtype = test_type
in_1 = tf.Variable(tf.random.normal([1, 8, 8, 1]))
out_1 = tf.squeeze(in_1)
in_flat = in_1.numpy().astype(np_dtype).flatten()
out_flat = out_1.numpy().astype(np_dtype).flatten()
test_name = "%s_%d" % (dtype, i)
test_str_rendered = test_Template.render(test_name=test_name, input_size=in_flat.shape[0], out_size=out_flat.shape[0], ref_in=in_flat, ref_out=out_flat, in_shape=in_1.shape, out_shape=out_1.shape, dtype=dtype, u_dtype=u_dtype)
const_str_rendered = const_Template.render(test_name=test_name, input_size=in_flat.shape[0], out_size=out_flat.shape[0], ref_in=in_flat, ref_out=out_flat, in_shape=in_1.shape, out_shape=out_1.shape, dtype=dtype)
tests.append(test_str_rendered)
constants.append(const_str_rendered)
container_rendered = container_Template.render(tests=tests, constants_header=const_file)
consts_container_rendered = const_container_Template.render(constants=constants, constants_header=const_file)
with open(output_file, "w") as fp:
fp.write(container_rendered)
with open(const_file, "w") as fp:
fp.write(consts_container_rendered)
print("Complete");