Skip to content

Commit ebc1a2f

Browse files
Make IsSimplifiableReshape return Status instead of bool.
This is to allow remove `CHECK`-fails in subsequent commits. PiperOrigin-RevId: 409160987 Change-Id: I3f050218a3832271395c4372a0b8ea05f1c03d80
1 parent 71399f9 commit ebc1a2f

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

tensorflow/core/grappler/optimizers/constant_folding.cc

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,15 +1684,17 @@ Status ConstantFolding::FoldGraph(
16841684
return Status::OK();
16851685
}
16861686

1687-
bool ConstantFolding::IsSimplifiableReshape(
1687+
Status ConstantFolding::IsSimplifiableReshape(
16881688
const NodeDef& node, const GraphProperties& properties) const {
16891689
if (!IsReshape(node)) {
1690-
return false;
1690+
return errors::Internal("Node ", node.name(), " is not a Reshape node");
16911691
}
16921692
CHECK_LE(2, node.input_size());
16931693
const NodeDef* new_shape = node_map_->GetNode(node.input(1));
16941694
if (!IsReallyConstant(*new_shape)) {
1695-
return false;
1695+
return errors::Internal("Node ", node.name(), " has shape ",
1696+
new_shape->DebugString(),
1697+
" which is not a constant");
16961698
}
16971699
TensorVector outputs;
16981700
auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
@@ -1703,22 +1705,25 @@ bool ConstantFolding::IsSimplifiableReshape(
17031705

17041706
Status s = EvaluateNode(*new_shape, TensorVector(), &outputs);
17051707
if (!s.ok()) {
1706-
return false;
1708+
return errors::Internal("Could not evaluate node ", node.name());
17071709
}
17081710
CHECK_EQ(1, outputs.size());
17091711

17101712
const std::vector<OpInfo::TensorProperties>& props =
17111713
properties.GetInputProperties(node.name());
17121714
if (props.empty()) {
1713-
return false;
1715+
return errors::Internal("Node ", node.name(), " has no properties");
17141716
}
17151717
const OpInfo::TensorProperties& prop = props[0];
17161718
if (prop.dtype() == DT_INVALID) {
1717-
return false;
1719+
return errors::Internal("Node ", node.name(), " has property ",
1720+
prop.DebugString(), " with invalid dtype");
17181721
}
17191722
const PartialTensorShape shape(prop.shape());
17201723
if (!shape.IsFullyDefined()) {
1721-
return false;
1724+
return errors::Internal("Node ", node.name(), " has property ",
1725+
prop.DebugString(), " with shape ",
1726+
shape.DebugString(), " which is not fully defined");
17221727
}
17231728

17241729
PartialTensorShape new_dims;
@@ -1738,7 +1743,12 @@ bool ConstantFolding::IsSimplifiableReshape(
17381743
TF_CHECK_OK(TensorShapeUtils::MakeShape(shp, &new_dims));
17391744
}
17401745

1741-
return shape.IsCompatibleWith(new_dims);
1746+
if (!shape.IsCompatibleWith(new_dims)) {
1747+
return errors::Internal("Expected shape ", shape.DebugString(),
1748+
"to be compatible with ", new_dims.DebugString());
1749+
}
1750+
1751+
return Status::OK();
17421752
}
17431753

17441754
#define IS_VALUE_CASE(DTYPE, VALUE) \
@@ -2925,7 +2935,7 @@ bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
29252935
bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
29262936
bool use_shape_info, NodeDef* node) {
29272937
if (!use_shape_info || node->attr().count("T") == 0 ||
2928-
!IsSimplifiableReshape(*node, properties)) {
2938+
!IsSimplifiableReshape(*node, properties).ok()) {
29292939
return false;
29302940
}
29312941
DataType output_type = node->attr().at("T").type();

tensorflow/core/grappler/optimizers/constant_folding.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ class ConstantFolding : public GraphOptimizer {
129129
Status FoldGraph(const GraphProperties& properties, GraphDef* output,
130130
absl::flat_hash_set<string>* nodes_to_not_simplify);
131131

132-
bool IsSimplifiableReshape(const NodeDef& node,
133-
const GraphProperties& properties) const;
132+
Status IsSimplifiableReshape(const NodeDef& node,
133+
const GraphProperties& properties) const;
134134
Status SimplifyGraph(GraphDef* optimized_graph, GraphProperties* properties,
135135
absl::flat_hash_set<string>* nodes_to_not_simplify);
136136
Status SimplifyNode(NodeDef* node, GraphDef* optimized_graph,

0 commit comments

Comments
 (0)