Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
bd0238a
Small comment fix
nouiz Sep 27, 2019
1542db8
Add tf_xla_supported_nodes flags to limits nodes XLA consider.
nouiz Sep 27, 2019
e194f62
More Ops in the whitelist category.
nouiz Sep 30, 2019
08af421
Add more XLA whitelist shortcut
nouiz Oct 1, 2019
c44dc1f
Better user error detection, less verbose and better error message.
nouiz Oct 1, 2019
727edfd
Add an EXTRA category
nouiz Oct 1, 2019
65f614b
Add the flag value TF_XLA_FLAGS=--tf_xla_auto_jit=fusible to enable X…
nouiz Nov 1, 2019
85ed59b
Update docmentation of the new flag
nouiz Nov 5, 2019
7cd6479
Use the absl containers.
nouiz Nov 5, 2019
4785944
Move code to a function to make the code more clear.
nouiz Nov 5, 2019
03250e3
Rename the new flags to tf_xla_supported_ops
nouiz Nov 5, 2019
771393e
Typo and add comment.
nouiz Nov 14, 2019
eb9081b
Replace a series of if with a table. Make code simpler to understand.
nouiz Nov 15, 2019
aee8d90
Small update (typo, clang-format, const)
nouiz Nov 18, 2019
5599e95
XLALite, put BiasAddGrad in the right section and add the missing Bia…
nouiz Nov 19, 2019
902c64a
Code formating
nouiz Nov 19, 2019
177ffc7
Repair the XLALite flag shortcut since the rebase
nouiz Nov 27, 2019
4576f25
clang-format
nouiz Nov 27, 2019
edf5517
Add TopKV2 to XLALite whitelist. XLA version is much faster the TF.
nouiz Nov 28, 2019
7123fa0
Remove some duplicate names.
nouiz Dec 2, 2019
cec299a
Better documentation of the new parameter
nouiz Dec 2, 2019
0bdd35e
Rename tf_xla_supported_ops to tf_xla_ops_to_cluster
nouiz Dec 2, 2019
4bf1930
Added warning about an experimental feature and code simplification.
nouiz Dec 2, 2019
1f2a4cb
Remove the unique ptr to simplify the code.
nouiz Dec 2, 2019
2654215
Small code simplification.
nouiz Dec 2, 2019
f370243
Convert non-trivial global destructors to local static variable.
nouiz Dec 3, 2019
4dcdcad
Rename a category
nouiz Dec 3, 2019
75e8b01
Add test to make sure all XLA support operation are in the XLALite wh…
nouiz Dec 3, 2019
9ff205f
Move the static inside the function to be safer.
nouiz Dec 9, 2019
0f45e20
Fix compilation crash for the new XLA test since the interface change.
nouiz Dec 13, 2019
cd2e988
Print the name that should be used to enable it. This
nouiz Dec 16, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion tensorflow/compiler/jit/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ bool SetterForXlaAutoJitFlag(const string& value) {
return true;
}

if (value == "fusible") {
mark_for_compilation_flags->xla_auto_jit_flag
.optimization_level_single_gpu = 1;
mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general =
1;
mark_for_compilation_flags->tf_xla_ops_to_cluster = "FUSIBLE";
return true;
}

absl::string_view value_sv(value);
if (!absl::ConsumePrefix(&value_sv, "single-gpu(") ||
!absl::ConsumeSuffix(&value_sv, ")") ||
Expand All @@ -65,7 +74,9 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
Flag("tf_xla_auto_jit", SetterForXlaAutoJitFlag, "0",
"Control compilation of operators into XLA computations on CPU and "
"GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
"things very likely to be improved; 2 = on for everything. "
"things very likely to be improved; 2 = on for everything; "
"(experimental) fusible = only for Tensorflow operations that XLA "
"knows how to fuse. "
"If set to single-gpu(<N>) then this resolves to <N> for single-GPU "
"graphs (graphs that have at least one node placed on a GPU and no "
"more than one GPU is in use through the entire graph) and 0 "
Expand All @@ -78,6 +89,22 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
Flag("tf_xla_max_cluster_size",
&mark_for_compilation_flags->tf_xla_max_cluster_size,
"Maximum number of operators in an XLA compilation."),
Flag("tf_xla_ops_to_cluster",
&mark_for_compilation_flags->tf_xla_ops_to_cluster,
"(experimental) "
"Limit the operations clustered by XLA to these operations. "
"If multiple, separate them with commas. Shortcuts: "
" PW: All point-wise operations."
" RED: All reduction operations."
" MISC: Mixed operations."
" PWRED: TF operations that get converted to PW+RED operation in XLA."
" REDUCEWINDOW: TF operations like MaxPool/AvgPool that get "
"converted to ReduceWindow in XLA."
" REDUCEWINDOWPW: Operation that get converted to ReduceWindow + PW "
"(LRN, LRNGrad)."
" BN: TF FusedBatchNorm* operations."
" FUSIBLE: All TF operations that XLA can fuse (All the above). "
"You can also put any TF operation name, e.g. 'FUSIBLE,Matmul'."),
Flag("tf_xla_clustering_debug",
&mark_for_compilation_flags->tf_xla_clustering_debug,
"Dump graphs during XLA compilation."),
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/jit/flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ struct MarkForCompilationPassFlags {
// Maximum number of operators in an XLA compilation.
int32 tf_xla_max_cluster_size;

// If non-empty, limit XLA clustering to the following TF operations.
string tf_xla_ops_to_cluster;

// Dump graphs during XLA compilation.
bool tf_xla_clustering_debug;

Expand Down
Loading