Skip to content

Commit 72cd0b0

Browse files
In C++ shape inference, support registering a shape inference function in OP
registration. Change op registration signature to return Status and return the registration data as an out parameter. Add a shape inference function and test for AddN. Support in function library needs to be expanded in a future change. Change: 124871850
1 parent 406c7d9 commit 72cd0b0

31 files changed

Lines changed: 962 additions & 385 deletions

tensorflow/core/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ cc_library(
543543
"common_runtime/kernel_benchmark_testlib.h",
544544
"framework/fake_input.h",
545545
"framework/function_testlib.h",
546+
"framework/shape_inference_testutil.h",
546547
"framework/tensor_testutil.h",
547548
"graph/testlib.h",
548549
# TODO(josh11b): Drop this once users are depending on
@@ -559,6 +560,7 @@ cc_library(
559560
":lib",
560561
":proto_text",
561562
":protos_all_cc",
563+
":shape_inference_testutil",
562564
":tensor_testutil",
563565
":test",
564566
"//tensorflow/core/kernels:constant_op",
@@ -748,6 +750,8 @@ filegroup(
748750
srcs = [
749751
"//tensorflow/core:framework/fake_input.cc",
750752
"//tensorflow/core:framework/fake_input.h",
753+
"//tensorflow/core:framework/shape_inference_testutil.cc",
754+
"//tensorflow/core:framework/shape_inference_testutil.h",
751755
"//tensorflow/core:framework/tensor_testutil.cc",
752756
"//tensorflow/core:framework/tensor_testutil.h",
753757
"//tensorflow/core:platform/test.h",
@@ -1212,6 +1216,19 @@ cc_library(
12121216
],
12131217
)
12141218

1219+
cc_library(
1220+
name = "shape_inference_testutil",
1221+
testonly = 1,
1222+
srcs = ["framework/shape_inference_testutil.cc"],
1223+
hdrs = ["framework/shape_inference_testutil.h"],
1224+
copts = tf_copts(),
1225+
deps = [
1226+
":framework",
1227+
":lib",
1228+
":test",
1229+
],
1230+
)
1231+
12151232
# Main program for tests
12161233
cc_library(
12171234
name = "test_main",

tensorflow/core/common_runtime/function.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
301301
lib_def_(lib_def),
302302
optimizer_(optimizer_options) {
303303
get_func_sig_ = [this](const string& op, const OpDef** sig) {
304-
Status s;
305-
*sig = lib_def_->LookUp(op, &s);
306-
return s;
304+
return lib_def_->LookUpOpDef(op, sig);
307305
};
308306
create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) {
309307
return CreateKernel(ndef, kernel);
@@ -689,9 +687,9 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
689687
}
690688

691689
bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
692-
Status s;
693-
auto sig = lib_def_->LookUp(func, &s);
694-
return s.ok() && sig->is_stateful();
690+
const OpDef* op_def;
691+
const Status s = lib_def_->LookUpOpDef(func, &op_def);
692+
return s.ok() && op_def->is_stateful();
695693
}
696694

697695
FunctionLibraryRuntime* NewFunctionLibraryRuntime(

tensorflow/core/common_runtime/function_test.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ namespace tensorflow {
3838
typedef FunctionDefHelper FDH;
3939

4040
Status GetOpSig(const string& op, const OpDef** sig) {
41-
Status s;
42-
*sig = OpRegistry::Global()->LookUp(op, &s);
43-
return s;
41+
return OpRegistry::Global()->LookUpOpDef(op, sig);
4442
}
4543

4644
void FunctionTestSchedClosure(std::function<void()> fn) {

tensorflow/core/framework/function.cc

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -708,14 +708,19 @@ Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
708708

709709
FunctionLibraryDefinition::FunctionLibraryDefinition(
710710
const FunctionLibraryDefinition& other)
711-
: function_defs_(other.function_defs_), func_grad_(other.func_grad_) {}
711+
: func_grad_(other.func_grad_) {
712+
for (const auto& it : other.function_defs_) {
713+
TF_CHECK_OK(AddFunctionDef(it.second->fdef));
714+
}
715+
}
712716

713717
FunctionLibraryDefinition::FunctionLibraryDefinition(
714718
const FunctionDefLibrary& def_lib)
715719
: function_defs_(def_lib.function_size()) {
716720
for (const auto& fdef : def_lib.function()) {
717721
// The latter function definition wins.
718-
function_defs_[fdef.signature().name()] = fdef;
722+
auto& ptr = function_defs_[fdef.signature().name()];
723+
ptr.reset(new FunctionDefAndOpRegistration(fdef));
719724
}
720725
for (const auto& grad : def_lib.gradient()) {
721726
func_grad_[grad.function_name()] = grad.gradient_func();
@@ -729,36 +734,39 @@ const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
729734
if (iter == function_defs_.end()) {
730735
return nullptr;
731736
} else {
732-
return &iter->second;
737+
return &iter->second->fdef;
733738
}
734739
}
735740

736741
Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
737-
if (!function_defs_.insert({fdef.signature().name(), fdef}).second) {
742+
auto& ptr = function_defs_[fdef.signature().name()];
743+
if (ptr != nullptr) {
738744
return errors::InvalidArgument("Function with name: ",
739745
fdef.signature().name(),
740746
" already exists in function library.");
741747
}
748+
ptr.reset(new FunctionDefAndOpRegistration(fdef));
742749
return Status::OK();
743750
}
744751

745752
string FunctionLibraryDefinition::FindGradient(const string& func) const {
746753
return gtl::FindWithDefault(func_grad_, func, "");
747754
}
748755

749-
const OpDef* FunctionLibraryDefinition::LookUp(const string& op,
750-
Status* status) const {
751-
auto fdef = Find(op);
752-
if (fdef != nullptr) {
753-
return &(fdef->signature());
756+
Status FunctionLibraryDefinition::LookUp(
757+
const string& op, const OpRegistrationData** op_reg_data) const {
758+
auto iter = function_defs_.find(op);
759+
if (iter != function_defs_.end()) {
760+
*op_reg_data = &iter->second->op_registration_data;
761+
return Status::OK();
754762
}
755-
return OpRegistry::Global()->LookUp(op, status);
763+
return OpRegistry::Global()->LookUp(op, op_reg_data);
756764
}
757765

758766
FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
759767
FunctionDefLibrary lib;
760768
for (const auto& f : function_defs_) {
761-
*lib.add_function() = f.second;
769+
*lib.add_function() = f.second->fdef;
762770
}
763771
for (const auto& g : func_grad_) {
764772
GradientDef* gd = lib.add_gradient();
@@ -845,7 +853,10 @@ FunctionDef FunctionDefHelper::Define(const string& name,
845853
for (const auto& a : arg_def) b.Input(a);
846854
for (const auto& r : ret_def) b.Output(r);
847855
for (const auto& a : attr_def) b.Attr(a);
848-
TF_CHECK_OK(b.Finalize(fdef.mutable_signature()));
856+
857+
OpRegistrationData op_reg_data;
858+
TF_CHECK_OK(b.Finalize(&op_reg_data));
859+
fdef.mutable_signature()->Swap(&op_reg_data.op_def);
849860
for (const auto& n : node_def) {
850861
*(fdef.add_node()) = n.ToProto();
851862
}

tensorflow/core/framework/function.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,14 +277,25 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
277277
//
278278
// If "op" is defined in the library, returns its signature.
279279
// Otherwise, assume "op" is a primitive op and returns its op
280-
// signature.
281-
const OpDef* LookUp(const string& op, Status* status) const override;
280+
// signature and shape inference function.
281+
Status LookUp(const string& op_type_name,
282+
const OpRegistrationData** op_reg_data) const override;
282283

283284
// Returns a proto representation of the state of this function library.
284285
FunctionDefLibrary ToProto() const;
285286

286287
private:
287-
std::unordered_map<string, FunctionDef> function_defs_;
288+
// TODO(cwhipkey): support shape functions in FunctionDefLibrary.
289+
struct FunctionDefAndOpRegistration {
290+
FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
291+
: fdef(fdef_in), op_registration_data(fdef.signature()) {}
292+
293+
FunctionDef fdef;
294+
OpRegistrationData op_registration_data;
295+
};
296+
297+
std::unordered_map<string, std::unique_ptr<FunctionDefAndOpRegistration>>
298+
function_defs_;
288299
std::unordered_map<string, string> func_grad_;
289300
};
290301

tensorflow/core/framework/function_test.cc

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ namespace tensorflow {
3333
typedef FunctionDefHelper FDH;
3434

3535
Status GetOpSig(const string& op, const OpDef** sig) {
36-
Status s;
37-
*sig = OpRegistry::Global()->LookUp(op, &s);
38-
return s;
36+
return OpRegistry::Global()->LookUpOpDef(op, sig);
3937
}
4038

4139
REGISTER_OP("One")
@@ -643,12 +641,12 @@ TEST(FunctionLibraryDefinitionTest, LookUp) {
643641
*proto.add_function() = test::function::XTimesTwo();
644642
FunctionLibraryDefinition lib_def(proto);
645643

646-
Status s;
647-
EXPECT_EQ(lib_def.LookUp("XTimes16", &s), nullptr);
644+
const OpDef* op_def;
645+
EXPECT_TRUE(!lib_def.LookUpOpDef("XTimes16", &op_def).ok());
648646

649-
auto found = lib_def.LookUp("XTimesTwo", &s);
650-
ASSERT_NE(found, nullptr);
651-
EXPECT_EQ(found->DebugString(),
647+
TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &op_def));
648+
ASSERT_NE(op_def, nullptr);
649+
EXPECT_EQ(op_def->DebugString(),
652650
test::function::XTimesTwo().signature().DebugString());
653651
}
654652

@@ -662,14 +660,15 @@ TEST(FunctionLibraryDefinitionTest, AddFunctionDef) {
662660
TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB()));
663661

664662
// Test lookup of first function.
665-
Status s;
666-
auto first = lib_def.LookUp("XTimesTwo", &s);
663+
const OpDef* first;
664+
TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &first));
667665
ASSERT_NE(first, nullptr);
668666
EXPECT_EQ(first->DebugString(),
669667
test::function::XTimesTwo().signature().DebugString());
670668

671669
// Test lookup of second function.
672-
auto second = lib_def.LookUp("WXPlusB", &s);
670+
const OpDef* second;
671+
TF_EXPECT_OK(lib_def.LookUpOpDef("WXPlusB", &second));
673672
ASSERT_NE(second, nullptr);
674673
EXPECT_EQ(second->DebugString(),
675674
test::function::WXPlusB().signature().DebugString());
@@ -689,18 +688,14 @@ TEST(FunctionLibraryDefinitionTest, ToProto) {
689688
FunctionLibraryDefinition lib_def2(proto2);
690689

691690
// Test that the first function exists in both libraries.
692-
Status s;
693-
auto f1 = lib_def1.LookUp("XTimesTwo", &s);
694-
TF_EXPECT_OK(s);
695-
auto f2 = lib_def1.LookUp("XTimesTwo", &s);
696-
TF_EXPECT_OK(s);
691+
const OpDef *f1, *f2, *f3, *f4;
692+
TF_EXPECT_OK(lib_def1.LookUpOpDef("XTimesTwo", &f1));
693+
TF_EXPECT_OK(lib_def2.LookUpOpDef("XTimesTwo", &f2));
697694
EXPECT_EQ(f1->DebugString(), f2->DebugString());
698695

699696
// Test that the second function exists in both libraries.
700-
auto f3 = lib_def1.LookUp("WXPlusB", &s);
701-
TF_EXPECT_OK(s);
702-
auto f4 = lib_def1.LookUp("WXPlusB", &s);
703-
TF_EXPECT_OK(s);
697+
TF_EXPECT_OK(lib_def1.LookUpOpDef("WXPlusB", &f3));
698+
TF_EXPECT_OK(lib_def2.LookUpOpDef("WXPlusB", &f4));
704699
EXPECT_EQ(f3->DebugString(), f4->DebugString());
705700
}
706701

tensorflow/core/framework/graph_def_util.cc

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,14 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
5656
node_offset, " with total nodes in graph: ", graph_def->node_size());
5757
}
5858

59-
Status s;
6059
for (int i = node_offset; i < graph_def->node_size(); ++i) {
6160
NodeDef* node_def = graph_def->mutable_node(i);
62-
const OpDef* op_def = op_registry.LookUp(node_def->op(), &s);
63-
if (!s.ok()) {
64-
return s;
65-
}
61+
const OpDef* op_def;
62+
TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(node_def->op(), &op_def));
6663
AddDefaultsToNodeDef(*op_def, node_def);
6764
}
6865

69-
return s;
66+
return Status::OK();
7067
}
7168

7269
Status RemoveNewDefaultAttrsFromGraphDef(
@@ -77,12 +74,13 @@ Status RemoveNewDefaultAttrsFromGraphDef(
7774
std::vector<string> to_remove;
7875
for (int n = 0; n < graph_def->node_size(); ++n) {
7976
NodeDef* node_def = graph_def->mutable_node(n);
80-
const OpDef* producer_op_def =
81-
producer_op_registry.LookUp(node_def->op(), &s);
82-
if (!s.ok()) return s;
83-
const OpDef* consumer_op_def =
84-
consumer_op_registry.LookUp(node_def->op(), &s);
85-
if (!s.ok()) return s;
77+
const OpDef* producer_op_def;
78+
const OpDef* consumer_op_def;
79+
80+
TF_RETURN_IF_ERROR(
81+
producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def));
82+
TF_RETURN_IF_ERROR(
83+
consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def));
8684

8785
for (const auto& attr : node_def->attr()) {
8886
// If the attr is not in consumer_op_def and doesn't start with '_'...
@@ -172,13 +170,12 @@ Status StrippedOpListForGraph(const GraphDef& graph_def,
172170
OpsUsedByGraph(graph_def, &used_ops);
173171

174172
// Build the stripped op list in sorted order, ignoring functions.
175-
Status status;
176173
stripped_op_list->clear_op();
177174
for (const string& op_name : used_ops) {
178-
const OpDef* op = op_registry.LookUp(op_name, &status);
179-
if (!op) return status;
175+
const OpDef* op_def;
176+
TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def));
180177
OpDef* stripped_op = stripped_op_list->add_op();
181-
stripped_op->CopyFrom(*op);
178+
stripped_op->CopyFrom(*op_def);
182179
RemoveDescriptionsFromOpDef(stripped_op);
183180
}
184181
return Status::OK();

0 commit comments

Comments
 (0)