Skip to content

Commit e3c630d

Browse files
authored
Merge branch 'r1.7' into update_readme
2 parents e6c83df + 6485bb7 commit e3c630d

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

tensorflow/core/graph/mkl_layout_pass.cc

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2492,10 +2492,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
24922492
mkl_op_registry::GetMklOpName(csinfo_.identity),
24932493
CopyAttrsDataType, AlwaysRewrite});
24942494
rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
2495-
CopyAttrsLRN, AlwaysRewrite});
2495+
CopyAttrsLRN, LrnRewrite});
24962496
rinfo_.push_back({csinfo_.lrn_grad,
24972497
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
2498-
CopyAttrsLRN, AlwaysRewrite});
2498+
CopyAttrsLRN, LrnRewrite});
24992499
rinfo_.push_back({csinfo_.max_pool,
25002500
mkl_op_registry::GetMklOpName(csinfo_.max_pool),
25012501
CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
@@ -2865,6 +2865,28 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
28652865
return false;
28662866
}
28672867

2868+
// If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
2869+
// path. The unoptimized path is slow. Thus we dont rewrite the node
2870+
// and use default Eigen. But for depth_radius=2, MKL DNN optimized
2871+
// path is taken, i.e., eigen node is rewritten by MKl DNN node.
2872+
static bool LrnRewrite(const Node* n) {
2873+
CHECK_NOTNULL(n);
2874+
2875+
int depth_radius;
2876+
CHECK_EQ(GetNodeAttr(n->def(), "depth_radius", &depth_radius).ok(), true);
2877+
2878+
// if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN
2879+
// and use eigen node instead
2880+
if (depth_radius == 2) {
2881+
return true;
2882+
}
2883+
VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which"
2884+
<< "case is not optimized by Intel MKL, thus using Eigen op"
2885+
<< "for LRN " ;
2886+
2887+
return false;
2888+
}
2889+
28682890
static bool AddNRewrite(const Node* n) {
28692891
CHECK_NOTNULL(n);
28702892

0 commit comments

Comments
 (0)