-
Notifications
You must be signed in to change notification settings - Fork 247
Expand file tree
/
Copy pathgen_fc.py
More file actions
58 lines (52 loc) · 2.45 KB
/
gen_fc.py
File metadata and controls
58 lines (52 loc) · 2.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from jinja_env import env2, Operator, Tensor, SingleOpTest
import tensorflow as tf
import numpy as np
test_group = "ReferenceFC"
num_tests = 5;
output_file = "test_fully_connected.cpp"
const_file = "constants_fully_connected.hpp"
def gen_test(test_number):
test_name = "random_gen_fc__%d" % ( test_number)
in0 = tf.constant(tf.random.uniform([1,2,2,64])).numpy()
w = tf.constant(tf.random.uniform([512,256])).numpy()
if test_number == 0:
b = np.zeros([1,512], dtype=np.float32).flatten()
else:
b = tf.constant(tf.random.uniform([1,512])).numpy().flatten()
# Combine ops to behave like final kernel
in1 = tf.reshape(in0, (1,-1)).numpy()
m = tf.linalg.matmul(in1, w, transpose_b=True)
print(m.shape)
out_1 = tf.math.add(m, b).numpy()
w = np.transpose(w)
in_ref_name = "s_ref_in_%d" % test_number
w_ref_name = "s_ref_w_%d" % test_number
b_ref_name = "s_ref_b_%d" % test_number
out_ref_name = "s_ref_out_%d" % test_number
in_t = Tensor("in", in1, ref_name=in_ref_name)
w_t = Tensor("w", w, ref_name=w_ref_name)
b_t = Tensor("b ", b, ref_name=b_ref_name)
out_ref = Tensor("out_ref", out_1, ref_name=out_ref_name) # Store the reference out values
out_t = Tensor("out", out_1)
#conv_param_str = "{%s}, %s" % (str(strides).lstrip('[').rstrip(']'), padding)
#convOp = Operator("Conv2dOperator", "op_0", dtypes=["float"], param_str=conv_param_str)
param_str = "Fuseable::NoActivation<float>"
op = Operator("FullyConnectedOperator", "fcOp", dtypes=["float"], param_str=param_str)
op.set_inputs({"input": in_t, "filter": w_t, "bias": b_t}).set_outputs({"output": out_t})
test = SingleOpTest(test_group, test_name, op)
test.add_tensor_comparison(out_t, out_ref)
test_rendered, const_snippets = test.render()
return test_rendered, const_snippets
if __name__ == '__main__':
tests = []
const_snippets =[]
for i in range(num_tests):
tr, cs = gen_test(i)
tests.append(tr)
const_snippets.extend(cs)
with open(const_file, "w") as fp:
c_r = env2.get_template("const_container.hpp").render(constants=const_snippets, constants_header=const_file)
fp.write(c_r)
with open(output_file, "w") as fp:
gt_r = env2.get_template("gtest_container.cpp").render(constants_header=const_file, using_directives=["using namespace uTensor::ReferenceOperators"], tests=tests)
fp.write(gt_r)