@@ -48,6 +48,15 @@ bool SetterForXlaAutoJitFlag(const string& value) {
4848 return true ;
4949 }
5050
51+ if (value == " fusible" ) {
52+ mark_for_compilation_flags->xla_auto_jit_flag
53+ .optimization_level_single_gpu = 1 ;
54+ mark_for_compilation_flags->xla_auto_jit_flag .optimization_level_general =
55+ 1 ;
56+ mark_for_compilation_flags->tf_xla_ops_to_cluster = " FUSIBLE" ;
57+ return true ;
58+ }
59+
5160 absl::string_view value_sv (value);
5261 if (!absl::ConsumePrefix (&value_sv, " single-gpu(" ) ||
5362 !absl::ConsumeSuffix (&value_sv, " )" ) ||
@@ -65,7 +74,9 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
6574 Flag (" tf_xla_auto_jit" , SetterForXlaAutoJitFlag, " 0" ,
6675 " Control compilation of operators into XLA computations on CPU and "
6776 " GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
68- " things very likely to be improved; 2 = on for everything. "
77+ " things very likely to be improved; 2 = on for everything; "
78+ " (experimental) fusible = only for Tensorflow operations that XLA "
79+ " knows how to fuse. "
6980 " If set to single-gpu(<N>) then this resolves to <N> for single-GPU "
7081 " graphs (graphs that have at least one node placed on a GPU and no "
7182 " more than one GPU is in use through the entire graph) and 0 "
@@ -78,6 +89,23 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
7889 Flag (" tf_xla_max_cluster_size" ,
7990 &mark_for_compilation_flags->tf_xla_max_cluster_size ,
8091 " Maximum number of operators in an XLA compilation." ),
92+ Flag (
93+ " tf_xla_ops_to_cluster" ,
94+ &mark_for_compilation_flags->tf_xla_ops_to_cluster ,
95+ " (experimental) "
96+ " Limit the operations clustered by XLA to these operations. "
97+ " If multiple, separate them with commas. Shortcuts: "
98+ " PW: All point-wise operations."
99+ " RED: All reduction operations."
100+ " MISC: Mixed operations."
101+ " PWRED: TF operations that get converted to PW+RED operation in XLA."
102+ " REDUCEWINDOW: TF operations like MaxPool/AvgPool that get "
103+ " converted to ReduceWindow in XLA."
104+ " REDUCEWINDOWPW: Operation that get converted to ReduceWindow + PW "
105+ " (LRN, LRNGrad)."
106+ " BN: TF FusedBatchNorm* operations."
107+ " FUSIBLE: All TF operations that XLA can fuse (All the above). "
108+ " You can also put any TF operation name, e.g. 'FUSIBLE,Matmul'." ),
81109 Flag (" tf_xla_clustering_debug" ,
82110 &mark_for_compilation_flags->tf_xla_clustering_debug ,
83111 " Dump graphs during XLA compilation." ),
0 commit comments