Skip to content

Commit 88dbe55

Browse files
authored
Defer image/audio/histogram v2 summary preprocessing using LazyTensorCreator (#2899)
* Add LazyTensorCreator * Defer summary preprocessing using LazyTensorCreator * Add guard against reentrant callables to avoid deadlock * Improve separation of concerns between lazy creation and tensor conversion routine * Make LazyTensorCreatorTest pass against TF 1.x * CR: remove unused functools import and nested_call parameter * Defer registration of tensor conversion function until after module import time
1 parent da9ca84 commit 88dbe55

File tree

9 files changed

+290
-37
lines changed

9 files changed

+290
-37
lines changed

tensorboard/plugins/audio/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ py_library(
9292
deps = [
9393
":metadata",
9494
"//tensorboard/compat",
95+
"//tensorboard/util:lazy_tensor_creator",
9596
],
9697
)
9798

tensorboard/plugins/audio/summary_v2.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from tensorboard.compat import tf2 as tf
3131
from tensorboard.plugins.audio import metadata
32+
from tensorboard.util import lazy_tensor_creator
3233

3334

3435
def audio(name,
@@ -91,19 +92,27 @@ def audio(name,
9192
tf.summary.summary_scope)
9293
with summary_scope(
9394
name, 'audio_summary', values=inputs) as (tag, _):
94-
tf.debugging.assert_rank(data, 3)
95-
tf.debugging.assert_non_negative(max_outputs)
96-
limited_audio = data[:max_outputs]
97-
encode_fn = functools.partial(audio_ops.encode_wav,
98-
sample_rate=sample_rate)
99-
encoded_audio = tf.map_fn(encode_fn, limited_audio,
100-
dtype=tf.string,
101-
name='encode_each_audio')
102-
# Workaround for map_fn returning float dtype for an empty elems input.
103-
encoded_audio = tf.cond(
104-
tf.shape(input=encoded_audio)[0] > 0,
105-
lambda: encoded_audio, lambda: tf.constant([], tf.string))
106-
limited_labels = tf.tile([''], tf.shape(input=limited_audio)[:1])
107-
tensor = tf.transpose(a=tf.stack([encoded_audio, limited_labels]))
95+
# Defer audio encoding preprocessing by passing it as a callable to write(),
96+
# wrapped in a LazyTensorCreator for backwards compatibility, so that we
97+
# only do this work when summaries are actually written.
98+
@lazy_tensor_creator.LazyTensorCreator
99+
def lazy_tensor():
100+
tf.debugging.assert_rank(data, 3)
101+
tf.debugging.assert_non_negative(max_outputs)
102+
limited_audio = data[:max_outputs]
103+
encode_fn = functools.partial(audio_ops.encode_wav,
104+
sample_rate=sample_rate)
105+
encoded_audio = tf.map_fn(encode_fn, limited_audio,
106+
dtype=tf.string,
107+
name='encode_each_audio')
108+
# Workaround for map_fn returning float dtype for an empty elems input.
109+
encoded_audio = tf.cond(
110+
tf.shape(input=encoded_audio)[0] > 0,
111+
lambda: encoded_audio, lambda: tf.constant([], tf.string))
112+
limited_labels = tf.tile([''], tf.shape(input=limited_audio)[:1])
113+
return tf.transpose(a=tf.stack([encoded_audio, limited_labels]))
114+
115+
# To ensure that audio encoding logic is only executed when summaries
116+
# are written, we pass callable to `tensor` parameter.
108117
return tf.summary.write(
109-
tag=tag, tensor=tensor, step=step, metadata=summary_metadata)
118+
tag=tag, tensor=lazy_tensor, step=step, metadata=summary_metadata)

tensorboard/plugins/histogram/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ py_library(
109109
"//tensorboard:expect_numpy_installed",
110110
"//tensorboard/compat",
111111
"//tensorboard/compat/proto:protos_all_py_pb2",
112+
"//tensorboard/util:lazy_tensor_creator",
112113
"//tensorboard/util:tensor_util",
113114
],
114115
)

tensorboard/plugins/histogram/summary_v2.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from tensorboard.compat import tf2 as tf
3535
from tensorboard.compat.proto import summary_pb2
3636
from tensorboard.plugins.histogram import metadata
37+
from tensorboard.util import lazy_tensor_creator
3738
from tensorboard.util import tensor_util
3839

3940

@@ -76,9 +77,14 @@ def histogram(name, data, step=None, buckets=None, description=None):
7677
def histogram_summary(data, buckets, histogram_metadata, step):
7778
with summary_scope(
7879
name, 'histogram_summary', values=[data, buckets, step]) as (tag, _):
79-
tensor = _buckets(data, bucket_count=buckets)
80+
# Defer histogram bucketing logic by passing it as a callable to write(),
81+
# wrapped in a LazyTensorCreator for backwards compatibility, so that we
82+
# only do this work when summaries are actually written.
83+
@lazy_tensor_creator.LazyTensorCreator
84+
def lazy_tensor():
85+
return _buckets(data, buckets)
8086
return tf.summary.write(
81-
tag=tag, tensor=tensor, step=step, metadata=histogram_metadata)
87+
tag=tag, tensor=lazy_tensor, step=step, metadata=summary_metadata)
8288

8389
# `_buckets()` has dynamic output shapes which is not supported on TPU's. As so, place
8490
# the bucketing ops on outside compilation cluster so that the function in executed on CPU.

tensorboard/plugins/image/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ py_library(
107107
":metadata",
108108
"//tensorboard/compat",
109109
"//tensorboard/compat/proto:protos_all_py_pb2",
110-
"//tensorboard/util:tensor_util",
110+
"//tensorboard/util:lazy_tensor_creator",
111111
],
112112
)
113113

tensorboard/plugins/image/summary_v2.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from tensorboard.compat import tf2 as tf
2626
from tensorboard.plugins.image import metadata
27+
from tensorboard.util import lazy_tensor_creator
2728

2829

2930
def image(name,
@@ -68,21 +69,29 @@ def image(name,
6869
tf.summary.summary_scope)
6970
with summary_scope(
7071
name, 'image_summary', values=[data, max_outputs, step]) as (tag, _):
71-
tf.debugging.assert_rank(data, 4)
72-
tf.debugging.assert_non_negative(max_outputs)
73-
images = tf.image.convert_image_dtype(data, tf.uint8, saturate=True)
74-
limited_images = images[:max_outputs]
75-
encoded_images = tf.map_fn(tf.image.encode_png, limited_images,
76-
dtype=tf.string,
77-
name='encode_each_image')
78-
# Workaround for map_fn returning float dtype for an empty elems input.
79-
encoded_images = tf.cond(
80-
tf.shape(input=encoded_images)[0] > 0,
81-
lambda: encoded_images, lambda: tf.constant([], tf.string))
82-
image_shape = tf.shape(input=images)
83-
dimensions = tf.stack([tf.as_string(image_shape[2], name='width'),
84-
tf.as_string(image_shape[1], name='height')],
85-
name='dimensions')
86-
tensor = tf.concat([dimensions, encoded_images], axis=0)
72+
# Defer image encoding preprocessing by passing it as a callable to write(),
73+
# wrapped in a LazyTensorCreator for backwards compatibility, so that we
74+
# only do this work when summaries are actually written.
75+
@lazy_tensor_creator.LazyTensorCreator
76+
def lazy_tensor():
77+
tf.debugging.assert_rank(data, 4)
78+
tf.debugging.assert_non_negative(max_outputs)
79+
images = tf.image.convert_image_dtype(data, tf.uint8, saturate=True)
80+
limited_images = images[:max_outputs]
81+
encoded_images = tf.map_fn(tf.image.encode_png, limited_images,
82+
dtype=tf.string,
83+
name='encode_each_image')
84+
# Workaround for map_fn returning float dtype for an empty elems input.
85+
encoded_images = tf.cond(
86+
tf.shape(input=encoded_images)[0] > 0,
87+
lambda: encoded_images, lambda: tf.constant([], tf.string))
88+
image_shape = tf.shape(input=images)
89+
dimensions = tf.stack([tf.as_string(image_shape[2], name='width'),
90+
tf.as_string(image_shape[1], name='height')],
91+
name='dimensions')
92+
return tf.concat([dimensions, encoded_images], axis=0)
93+
94+
# To ensure that image encoding logic is only executed when summaries
95+
# are written, we pass callable to `tensor` parameter.
8796
return tf.summary.write(
88-
tag=tag, tensor=tensor, step=step, metadata=summary_metadata)
97+
tag=tag, tensor=lazy_tensor, step=step, metadata=summary_metadata)

tensorboard/util/BUILD

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,29 @@ py_test(
7979

8080
tb_proto_library(
8181
name = "grpc_util_test_proto",
82-
has_services = True,
83-
srcs = ["grpc_util_test.proto"],
8482
testonly = True,
83+
srcs = ["grpc_util_test.proto"],
84+
has_services = True,
85+
)
86+
87+
py_library(
88+
name = "lazy_tensor_creator",
89+
srcs = ["lazy_tensor_creator.py"],
90+
srcs_version = "PY2AND3",
91+
deps = [
92+
"//tensorboard/compat",
93+
],
94+
)
95+
96+
py_test(
97+
name = "lazy_tensor_creator_test",
98+
size = "small",
99+
srcs = ["lazy_tensor_creator_test.py"],
100+
srcs_version = "PY2AND3",
101+
deps = [
102+
":lazy_tensor_creator",
103+
"//tensorboard:expect_tensorflow_installed",
104+
],
85105
)
86106

87107
py_library(
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Provides a lazy wrapper for deferring Tensor creation."""
16+
17+
import threading
18+
19+
from tensorboard.compat import tf2 as tf
20+
21+
22+
# Sentinel used for LazyTensorCreator._tensor to indicate that a value is
23+
# currently being computed, in order to fail hard on reentrancy.
24+
_CALL_IN_PROGRESS_SENTINEL = object()
25+
26+
27+
class LazyTensorCreator(object):
28+
"""Lazy auto-converting wrapper for a callable that returns a `tf.Tensor`.
29+
30+
This class wraps an arbitrary callable that returns a `Tensor` so that it
31+
will be automatically converted to a `Tensor` by any logic that calls
32+
`tf.convert_to_tensor()`. This also memoizes the callable so that it is
33+
called at most once.
34+
35+
The intended use of this class is to defer the construction of a `Tensor`
36+
(e.g. to avoid unnecessary wasted computation, or ensure any new ops are
37+
created in a context only available later on in execution), while remaining
38+
compatible with APIs that expect to be given an already materialized value
39+
that can be converted to a `Tensor`.
40+
41+
This class is thread-safe.
42+
"""
43+
44+
def __init__(self, tensor_callable):
45+
"""Initializes a LazyTensorCreator object.
46+
47+
Args:
48+
tensor_callable: A callable that returns a `tf.Tensor`.
49+
"""
50+
if not callable(tensor_callable):
51+
raise ValueError("Not a callable: %r" % tensor_callable)
52+
self._tensor_callable = tensor_callable
53+
self._tensor = None
54+
self._tensor_lock = threading.RLock()
55+
_register_conversion_function_once()
56+
57+
def __call__(self):
58+
if self._tensor is None or self._tensor is _CALL_IN_PROGRESS_SENTINEL:
59+
with self._tensor_lock:
60+
if self._tensor is _CALL_IN_PROGRESS_SENTINEL:
61+
raise RuntimeError("Cannot use LazyTensorCreator with reentrant callable")
62+
elif self._tensor is None:
63+
self._tensor = _CALL_IN_PROGRESS_SENTINEL
64+
self._tensor = self._tensor_callable()
65+
return self._tensor
66+
67+
68+
def _lazy_tensor_creator_converter(value, dtype=None, name=None, as_ref=False):
69+
del name # ignored
70+
if not isinstance(value, LazyTensorCreator):
71+
raise RuntimeError("Expected LazyTensorCreator, got %r" % value)
72+
if as_ref:
73+
raise RuntimeError("Cannot use LazyTensorCreator to create ref tensor")
74+
tensor = value()
75+
if dtype not in (None, tensor.dtype):
76+
raise RuntimeError(
77+
"Cannot convert LazyTensorCreator returning dtype %s to dtype %s" % (
78+
tensor.dtype, dtype))
79+
return tensor
80+
81+
82+
# Use module-level bit and lock to ensure that registration of the
83+
# LazyTensorCreator conversion function happens only once.
84+
_conversion_registered = False
85+
_conversion_registered_lock = threading.Lock()
86+
87+
88+
def _register_conversion_function_once():
89+
"""Performs one-time registration of `_lazy_tensor_creator_converter`.
90+
91+
This helper can be invoked multiple times but only registers the conversion
92+
function on the first invocation, making it suitable for calling when
93+
constructing a LazyTensorCreator.
94+
95+
Deferring the registration is necessary because doing it at at module import
96+
time would trigger the lazy TensorFlow import to resolve, and that in turn
97+
would break the delicate `tf.summary` import cycle avoidance scheme.
98+
"""
99+
global _conversion_registered
100+
if not _conversion_registered:
101+
with _conversion_registered_lock:
102+
if not _conversion_registered:
103+
_conversion_registered = True
104+
tf.register_tensor_conversion_function(
105+
base_type=LazyTensorCreator,
106+
conversion_func=_lazy_tensor_creator_converter,
107+
priority=0)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import tensorflow as tf
20+
21+
from tensorboard.util import lazy_tensor_creator
22+
23+
24+
tf.compat.v1.enable_eager_execution()
25+
26+
27+
class LazyTensorCreatorTest(tf.test.TestCase):
28+
29+
def assertEqualAsNumpy(self, a, b):
30+
# TODO(#2507): Remove after we no longer test against TF 1.x.
31+
self.assertEqual(a.numpy(), b.numpy())
32+
33+
def test_lazy_creation_with_memoization(self):
34+
boxed_count = [0]
35+
@lazy_tensor_creator.LazyTensorCreator
36+
def lazy_tensor():
37+
boxed_count[0] = boxed_count[0] + 1
38+
return tf.constant(1)
39+
self.assertEqual(0, boxed_count[0])
40+
real_tensor = lazy_tensor()
41+
self.assertEqual(1, boxed_count[0])
42+
lazy_tensor()
43+
self.assertEqual(1, boxed_count[0])
44+
45+
def test_conversion_explicit(self):
46+
@lazy_tensor_creator.LazyTensorCreator
47+
def lazy_tensor():
48+
return tf.constant(1)
49+
real_tensor = tf.convert_to_tensor(lazy_tensor)
50+
self.assertEqualAsNumpy(tf.constant(1), real_tensor)
51+
52+
def test_conversion_identity(self):
53+
@lazy_tensor_creator.LazyTensorCreator
54+
def lazy_tensor():
55+
return tf.constant(1)
56+
real_tensor = tf.identity(lazy_tensor)
57+
self.assertEqualAsNumpy(tf.constant(1), real_tensor)
58+
59+
def test_conversion_implicit(self):
60+
@lazy_tensor_creator.LazyTensorCreator
61+
def lazy_tensor():
62+
return tf.constant(1)
63+
real_tensor = lazy_tensor + tf.constant(1)
64+
self.assertEqualAsNumpy(tf.constant(2), real_tensor)
65+
66+
def test_explicit_dtype_okay_if_matches(self):
67+
@lazy_tensor_creator.LazyTensorCreator
68+
def lazy_tensor():
69+
return tf.constant(1, dtype=tf.int32)
70+
real_tensor = tf.convert_to_tensor(lazy_tensor, dtype=tf.int32)
71+
self.assertEqual(tf.int32, real_tensor.dtype)
72+
self.assertEqualAsNumpy(tf.constant(1, dtype=tf.int32), real_tensor)
73+
74+
def test_explicit_dtype_rejected_if_different(self):
75+
@lazy_tensor_creator.LazyTensorCreator
76+
def lazy_tensor():
77+
return tf.constant(1, dtype=tf.int32)
78+
with self.assertRaisesRegex(RuntimeError, "dtype"):
79+
tf.convert_to_tensor(lazy_tensor, dtype=tf.int64)
80+
81+
def test_as_ref_rejected(self):
82+
@lazy_tensor_creator.LazyTensorCreator
83+
def lazy_tensor():
84+
return tf.constant(1, dtype=tf.int32)
85+
with self.assertRaisesRegex(RuntimeError, "ref tensor"):
86+
# Call conversion routine manually since this isn't actually
87+
# exposed as an argument to tf.convert_to_tensor.
88+
lazy_tensor_creator._lazy_tensor_creator_converter(
89+
lazy_tensor, as_ref=True)
90+
91+
def test_reentrant_callable_does_not_deadlock(self):
92+
@lazy_tensor_creator.LazyTensorCreator
93+
def lazy_tensor():
94+
return lazy_tensor()
95+
with self.assertRaisesRegex(RuntimeError, "reentrant callable"):
96+
lazy_tensor()
97+
98+
99+
if __name__ == '__main__':
100+
tf.test.main()

0 commit comments

Comments
 (0)