Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tensorboard/plugins/histogram/summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,24 @@ def graph_fn():
graph_fn()
writer.close()

def test_no_gradient_error_xla(self):
@tf2.function(jit_compile=True)
def graph_fn():
x = tf.constant(1.0)
with tf2.GradientTape() as tape1:
with tf2.GradientTape() as tape2:
tape1.watch(x)
tape2.watch(x)
summary.histogram(name="loss", step=0, data=x, buckets=10)

# Note that XLA CPU/GPU has no outside compilation support, so summaries
# won't actually run in a jit_compiled function. TPUs do, and follow
# some similar codepaths, so this test stops at graph building to
# exercise those paths without a TPU available.
writer = tf2.summary.create_file_writer(self.get_temp_dir())
with writer.as_default():
graph_fn.get_concrete_function()


if __name__ == "__main__":
tf.test.main()
4 changes: 4 additions & 0 deletions tensorboard/plugins/histogram/summary_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def histogram(name, data, step=None, buckets=None, description=None):
ValueError: if a default writer exists, but no step was provided and
`tf.summary.experimental.get_step()` is None.
"""
# Avoid building unused gradient graphs for conds below. This works around
# an error building second-order gradient graphs when XlaDynamicUpdateSlice
# is used, and will generally speed up graph building slightly.
data = tf.stop_gradient(data)
summary_metadata = metadata.create_summary_metadata(
display_name=None, description=description
)
Expand Down