Skip to content

Commit 169124c

Browse files
Merge pull request #34655 from nouiz:xlalite_pr
PiperOrigin-RevId: 285730750 Change-Id: Ib53f29df2e956b8c4904d08af3d6f33f1c419a9f
2 parents 6e17132 + 9ff205f commit 169124c

File tree

5 files changed

+409
-3
lines changed

5 files changed

+409
-3
lines changed

tensorflow/compiler/jit/flags.cc

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ bool SetterForXlaAutoJitFlag(const string& value) {
4848
return true;
4949
}
5050

51+
if (value == "fusible") {
52+
mark_for_compilation_flags->xla_auto_jit_flag
53+
.optimization_level_single_gpu = 1;
54+
mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general =
55+
1;
56+
mark_for_compilation_flags->tf_xla_ops_to_cluster = "FUSIBLE";
57+
return true;
58+
}
59+
5160
absl::string_view value_sv(value);
5261
if (!absl::ConsumePrefix(&value_sv, "single-gpu(") ||
5362
!absl::ConsumeSuffix(&value_sv, ")") ||
@@ -65,7 +74,9 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
6574
Flag("tf_xla_auto_jit", SetterForXlaAutoJitFlag, "0",
6675
"Control compilation of operators into XLA computations on CPU and "
6776
"GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
68-
"things very likely to be improved; 2 = on for everything. "
77+
"things very likely to be improved; 2 = on for everything; "
78+
"(experimental) fusible = only for Tensorflow operations that XLA "
79+
"knows how to fuse. "
6980
"If set to single-gpu(<N>) then this resolves to <N> for single-GPU "
7081
"graphs (graphs that have at least one node placed on a GPU and no "
7182
"more than one GPU is in use through the entire graph) and 0 "
@@ -78,6 +89,23 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
7889
Flag("tf_xla_max_cluster_size",
7990
&mark_for_compilation_flags->tf_xla_max_cluster_size,
8091
"Maximum number of operators in an XLA compilation."),
92+
Flag(
93+
"tf_xla_ops_to_cluster",
94+
&mark_for_compilation_flags->tf_xla_ops_to_cluster,
95+
"(experimental) "
96+
"Limit the operations clustered by XLA to these operations. "
97+
"If multiple, separate them with commas. Shortcuts: "
98+
" PW: All point-wise operations."
99+
" RED: All reduction operations."
100+
" MISC: Mixed operations."
101+
" PWRED: TF operations that get converted to PW+RED operation in XLA."
102+
" REDUCEWINDOW: TF operations like MaxPool/AvgPool that get "
103+
"converted to ReduceWindow in XLA."
104+
" REDUCEWINDOWPW: Operation that get converted to ReduceWindow + PW "
105+
"(LRN, LRNGrad)."
106+
" BN: TF FusedBatchNorm* operations."
107+
" FUSIBLE: All TF operations that XLA can fuse (All the above). "
108+
"You can also put any TF operation name, e.g. 'FUSIBLE,Matmul'."),
81109
Flag("tf_xla_clustering_debug",
82110
&mark_for_compilation_flags->tf_xla_clustering_debug,
83111
"Dump graphs during XLA compilation."),

tensorflow/compiler/jit/flags.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ struct MarkForCompilationPassFlags {
5555
// Maximum number of operators in an XLA compilation.
5656
int32 tf_xla_max_cluster_size;
5757

58+
// If non-empty, limit XLA clustering to the following TF operations.
59+
string tf_xla_ops_to_cluster;
60+
5861
// Dump graphs during XLA compilation.
5962
bool tf_xla_clustering_debug;
6063

0 commit comments

Comments
 (0)