Skip to content

Commit eba61a3

Browse files
brokentensorflower-gardener
authored andcommitted
Improve on the Templated TF Lite shim example. Previously, users would need to create their own wrapper class to delegate to the correctly typed op object. With this change, the OpWrapper is created for the user, and they only need to pass in the correct attribute types.
PiperOrigin-RevId: 503260524
1 parent e0c10b8 commit eba61a3

File tree

7 files changed

+1014
-127
lines changed

7 files changed

+1014
-127
lines changed

tensorflow/lite/kernels/shim/BUILD

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,33 @@ cc_library(
193193
],
194194
)
195195

196+
cc_library(
197+
name = "tflite_op_wrapper",
198+
hdrs = ["tflite_op_wrapper.h"],
199+
visibility = ["//visibility:public"],
200+
deps = [
201+
":op_kernel",
202+
":status_macros",
203+
"//tensorflow/lite:type_to_tflitetype",
204+
"//tensorflow/lite/c:common",
205+
],
206+
)
207+
208+
cc_test(
209+
name = "tflite_op_wrapper_test",
210+
srcs = ["tflite_op_wrapper_test.cc"],
211+
deps = [
212+
":op_kernel",
213+
":tflite_op_shim",
214+
":tflite_op_wrapper",
215+
"//tensorflow/core/platform:tstring",
216+
"@com_google_absl//absl/status",
217+
"@com_google_absl//absl/strings",
218+
"@com_google_googletest//:gtest_main",
219+
"@flatbuffers//:public_headers_lib",
220+
],
221+
)
222+
196223
cc_library(
197224
name = "status_macros",
198225
hdrs = ["status_macros.h"],

tensorflow/lite/kernels/shim/test_op/BUILD

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,11 @@ tf_kernel_library(
103103
hdrs = ["tmpl_tflite_op.h"],
104104
deps = [
105105
":tmpl_op",
106-
"//tensorflow/core:framework",
107-
"//tensorflow/core:protos_all_cc",
108106
"//tensorflow/lite:mutable_op_resolver",
109107
"//tensorflow/lite/c:common",
108+
"//tensorflow/lite/core/c:common",
110109
"//tensorflow/lite/kernels/shim:tflite_op_shim",
111-
"@com_google_absl//absl/types:variant",
110+
"//tensorflow/lite/kernels/shim:tflite_op_wrapper",
112111
],
113112
)
114113

@@ -125,7 +124,8 @@ tf_cc_test(
125124
":tmpl_tf_op",
126125
"//tensorflow/core:framework",
127126
"//tensorflow/core:protos_all_cc",
128-
"//tensorflow/core:testlib",
127+
"//tensorflow/core/framework:fake_input",
128+
"//tensorflow/core/framework:tensor_testutil",
129129
"//tensorflow/core/kernels:ops_testutil",
130130
"@com_google_googletest//:gtest_main",
131131
],

tensorflow/lite/kernels/shim/test_op/tmpl_tflite_op.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,24 @@ limitations under the License.
1717
#include "tensorflow/lite/c/common.h"
1818
#include "tensorflow/lite/kernels/shim/test_op/tmpl_op.h"
1919
#include "tensorflow/lite/kernels/shim/tflite_op_shim.h"
20+
#include "tensorflow/lite/kernels/shim/tflite_op_wrapper.h"
2021

2122
namespace tflite {
2223
namespace ops {
2324
namespace custom {
25+
namespace {
26+
const char a_type[]("AType"), b_type[]("BType");
27+
} // namespace
2428

25-
using OpKernel = ::tflite::shim::TfLiteOpKernel<
26-
tflite::ops::custom::tmpl_tflite_op::OpWrapper>;
29+
using ::tflite::shim::op_wrapper::Attr;
30+
using ::tflite::shim::op_wrapper::AttrName;
31+
using ::tflite::shim::op_wrapper::OpWrapper;
32+
33+
template <shim::Runtime Rt>
34+
using Op = OpWrapper<Rt, shim::TmplOp, Attr<AttrName<a_type>, int32_t, float>,
35+
Attr<AttrName<b_type>, int32_t, int64_t, bool>>;
36+
37+
using OpKernel = ::tflite::shim::TfLiteOpKernel<Op>;
2738

2839
void AddTmplOp(MutableOpResolver* resolver) { OpKernel::Add(resolver); }
2940

tensorflow/lite/kernels/shim/test_op/tmpl_tflite_op.h

Lines changed: 1 addition & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -15,118 +15,12 @@ limitations under the License.
1515
#ifndef TENSORFLOW_LITE_KERNELS_SHIM_TEST_OP_TMPL_TFLITE_OP_H_
1616
#define TENSORFLOW_LITE_KERNELS_SHIM_TEST_OP_TMPL_TFLITE_OP_H_
1717

18-
#include <memory>
19-
#include <string>
20-
#include <variant>
21-
#include <vector>
22-
23-
#include "absl/types/variant.h"
24-
#include "tensorflow/core/framework/types.pb.h"
25-
#include "tensorflow/lite/c/common.h"
26-
#include "tensorflow/lite/kernels/shim/op_kernel.h"
27-
#include "tensorflow/lite/kernels/shim/status_macros.h"
28-
#include "tensorflow/lite/kernels/shim/test_op/tmpl_op.h"
18+
#include "tensorflow/lite/core/c/common.h"
2919
#include "tensorflow/lite/mutable_op_resolver.h"
3020

3121
namespace tflite {
3222
namespace ops {
3323
namespace custom {
34-
namespace tmpl_tflite_op {
35-
36-
using ::tensorflow::DT_FLOAT;
37-
using ::tensorflow::DT_INT32;
38-
using ::tensorflow::DT_INT64;
39-
using ::tflite::shim::OpKernelShim;
40-
using ::tflite::shim::Runtime;
41-
using ::tflite::shim::TmplOp;
42-
43-
template <Runtime Rt>
44-
class OpWrapper : public OpKernelShim<OpWrapper, Rt> {
45-
// Atype: int32 or float
46-
// Btype: int32 or int64
47-
using TmplOpType =
48-
std::variant<TmplOp<Rt, int32_t, int32_t>, TmplOp<Rt, int32_t, int64_t>,
49-
TmplOp<Rt, float, int32_t>, TmplOp<Rt, float, int64_t>>;
50-
using TmplOpType0 = typename std::variant_alternative<0, TmplOpType>::type;
51-
52-
public:
53-
using typename OpKernelShim<OpWrapper, Rt>::InitContext;
54-
using typename OpKernelShim<OpWrapper, Rt>::InvokeContext;
55-
using typename OpKernelShim<OpWrapper, Rt>::ShapeInferenceContext;
56-
OpWrapper() = default;
57-
58-
// These two char*s should be copied from the wrapped op.
59-
static constexpr char kOpName[] = "TemplatizedOperation";
60-
static constexpr char kDoc[] = R"doc(
61-
Description:
62-
Templatized op for testing and demonstration purposes.
63-
64-
Attrs
65-
AType: The type for input0
66-
BType: The type for input1
67-
Inputs
68-
in0: AType, shape=[] - A scalar input
69-
in1: BType, shape=[] - A scalar input
70-
Outputs
71-
out0: int, shape=[] - first output
72-
)doc";
73-
74-
static const char* OpName() { return kOpName; }
75-
static const char* Doc() { return kDoc; }
76-
77-
// For the static methods, they shouldn't change based on the types.
78-
static std::vector<std::string> Attrs() { return TmplOpType0::Attrs(); }
79-
static std::vector<std::string> Inputs() { return TmplOpType0::Inputs(); }
80-
static std::vector<std::string> Outputs() { return TmplOpType0::Outputs(); }
81-
static absl::Status ShapeInference(ShapeInferenceContext* context) {
82-
return TmplOpType0::ShapeInference(context);
83-
}
84-
85-
// Init should create the correctly typed wrapped object.
86-
absl::Status Init(InitContext* context) {
87-
int64_t datatype_a, datatype_b;
88-
SH_RETURN_IF_ERROR(context->GetAttr("AType", &datatype_a));
89-
SH_RETURN_IF_ERROR(context->GetAttr("BType", &datatype_b));
90-
if (datatype_a == DT_INT32 && datatype_b == DT_INT32) {
91-
op_ = std::make_unique<TmplOpType>(TmplOp<Rt, int32_t, int32_t>());
92-
type_num_ = 0;
93-
return std::get<0>(*op_).Init(context);
94-
} else if (datatype_a == DT_INT32 && datatype_b == DT_INT64) {
95-
op_ = std::make_unique<TmplOpType>(TmplOp<Rt, int32_t, int64_t>());
96-
type_num_ = 1;
97-
return std::get<1>(*op_).Init(context);
98-
} else if (datatype_a == DT_FLOAT && datatype_b == DT_INT32) {
99-
op_ = std::make_unique<TmplOpType>(TmplOp<Rt, float, int32_t>());
100-
type_num_ = 2;
101-
return std::get<2>(*op_).Init(context);
102-
} else if (datatype_a == DT_FLOAT && datatype_b == DT_INT64) {
103-
op_ = std::make_unique<TmplOpType>(TmplOp<Rt, float, int64_t>());
104-
type_num_ = 3;
105-
return std::get<3>(*op_).Init(context);
106-
}
107-
return absl::InvalidArgumentError("Attribute is of wrong type.");
108-
}
109-
110-
// Call invoke on the created wrapped object.
111-
absl::Status Invoke(InvokeContext* context) {
112-
if (type_num_ == 0) {
113-
return std::get<0>(*op_).Invoke(context);
114-
} else if (type_num_ == 1) {
115-
return std::get<1>(*op_).Invoke(context);
116-
} else if (type_num_ == 2) {
117-
return std::get<2>(*op_).Invoke(context);
118-
} else if (type_num_ == 3) {
119-
return std::get<3>(*op_).Invoke(context);
120-
}
121-
return absl::InternalError("Unknown type.");
122-
}
123-
124-
protected:
125-
std::unique_ptr<TmplOpType> op_;
126-
int type_num_;
127-
};
128-
129-
} // namespace tmpl_tflite_op
13024

13125
// Add TmplOp to the resolver
13226
void AddTmplOp(MutableOpResolver* resolver);

tensorflow/lite/kernels/shim/test_op/tmpl_tflite_op_test.cc

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,12 @@ limitations under the License.
1919

2020
#include <gtest/gtest.h>
2121
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
22-
#include "tensorflow/core/framework/fake_input.h"
23-
#include "tensorflow/core/framework/node_def_builder.h"
24-
#include "tensorflow/core/framework/tensor_testutil.h"
25-
#include "tensorflow/core/framework/types.pb.h"
2622
#include "tensorflow/lite/kernels/test_util.h"
2723

2824
namespace tflite {
2925
namespace shim {
3026
namespace {
3127

32-
using ::tensorflow::DT_FLOAT;
33-
using ::tensorflow::DT_INT32;
34-
using ::tensorflow::DT_INT64;
35-
3628
template <typename AType, typename BType>
3729
class TmplOpModel : public SingleOpModel {
3830
public:
@@ -79,8 +71,8 @@ TEST(TmplOpModel, float_int32) {
7971
// Test input
8072
flexbuffers::Builder builder;
8173
builder.Map([&]() {
82-
builder.Int("AType", DT_FLOAT);
83-
builder.Int("BType", DT_INT32);
74+
builder.Int("AType", kTfLiteFloat32);
75+
builder.Int("BType", kTfLiteInt32);
8476
});
8577
builder.Finish();
8678
std::vector<std::vector<int>> input_shapes = {{}, {}};
@@ -94,16 +86,16 @@ TEST(TmplOpModel, float_int32) {
9486
/*op_options=*/builder.GetBuffer(), input_types, input_shapes, input0,
9587
input1, output_types);
9688
ASSERT_EQ(m.Invoke(), kTfLiteOk);
97-
// // Assertions
89+
// Assertions
9890
EXPECT_THAT(m.GetOutput<float>(0), testing::ElementsAre(8.6f));
9991
}
10092

10193
TEST(TmplOpModel, int32_int64) {
10294
// Test input
10395
flexbuffers::Builder builder;
10496
builder.Map([&]() {
105-
builder.Int("AType", DT_INT32);
106-
builder.Int("BType", DT_INT64);
97+
builder.Int("AType", kTfLiteInt32);
98+
builder.Int("BType", kTfLiteInt64);
10799
});
108100
builder.Finish();
109101
std::vector<std::vector<int>> input_shapes = {{}, {}};
@@ -117,10 +109,33 @@ TEST(TmplOpModel, int32_int64) {
117109
/*op_options=*/builder.GetBuffer(), input_types, input_shapes, input0,
118110
input1, output_types);
119111
ASSERT_EQ(m.Invoke(), kTfLiteOk);
120-
// // Assertions
112+
// Assertions
121113
EXPECT_THAT(m.GetOutput<float>(0), testing::ElementsAre(45.0f));
122114
}
123115

116+
TEST(TmplOpModel, int32_bool) {
117+
// Test input
118+
flexbuffers::Builder builder;
119+
builder.Map([&]() {
120+
builder.Int("AType", kTfLiteInt32);
121+
builder.Int("BType", kTfLiteBool);
122+
});
123+
builder.Finish();
124+
std::vector<std::vector<int>> input_shapes = {{}, {}};
125+
std::vector<tflite::TensorType> input_types = {tflite::TensorType_INT32,
126+
tflite::TensorType_BOOL};
127+
std::vector<tflite::TensorType> output_types = {tflite::TensorType_FLOAT32};
128+
const std::vector<int32_t> input0 = {12};
129+
const std::vector<bool> input1 = {true};
130+
// Run the op
131+
TmplOpModel<int32_t, bool> m(
132+
/*op_options=*/builder.GetBuffer(), input_types, input_shapes, input0,
133+
input1, output_types);
134+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
135+
// Assertions
136+
EXPECT_THAT(m.GetOutput<float>(0), testing::ElementsAre(13.0f));
137+
}
138+
124139
} // namespace
125140
} // namespace shim
126141
} // namespace tflite

0 commit comments

Comments
 (0)