|
| 1 | +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Provides a lazy wrapper for deferring Tensor creation.""" |
| 16 | + |
| 17 | +import threading |
| 18 | + |
| 19 | +from tensorboard.compat import tf2 as tf |
| 20 | + |
| 21 | + |
| 22 | +# Sentinel used for LazyTensorCreator._tensor to indicate that a value is |
| 23 | +# currently being computed, in order to fail hard on reentrancy. |
| 24 | +_CALL_IN_PROGRESS_SENTINEL = object() |
| 25 | + |
| 26 | + |
| 27 | +class LazyTensorCreator(object): |
| 28 | + """Lazy auto-converting wrapper for a callable that returns a `tf.Tensor`. |
| 29 | +
|
| 30 | + This class wraps an arbitrary callable that returns a `Tensor` so that it |
| 31 | + will be automatically converted to a `Tensor` by any logic that calls |
| 32 | + `tf.convert_to_tensor()`. This also memoizes the callable so that it is |
| 33 | + called at most once. |
| 34 | +
|
| 35 | + The intended use of this class is to defer the construction of a `Tensor` |
| 36 | + (e.g. to avoid unnecessary wasted computation, or ensure any new ops are |
| 37 | + created in a context only available later on in execution), while remaining |
| 38 | + compatible with APIs that expect to be given an already materialized value |
| 39 | + that can be converted to a `Tensor`. |
| 40 | +
|
| 41 | + This class is thread-safe. |
| 42 | + """ |
| 43 | + |
| 44 | + def __init__(self, tensor_callable): |
| 45 | + """Initializes a LazyTensorCreator object. |
| 46 | +
|
| 47 | + Args: |
| 48 | + tensor_callable: A callable that returns a `tf.Tensor`. |
| 49 | + """ |
| 50 | + if not callable(tensor_callable): |
| 51 | + raise ValueError("Not a callable: %r" % tensor_callable) |
| 52 | + self._tensor_callable = tensor_callable |
| 53 | + self._tensor = None |
| 54 | + self._tensor_lock = threading.RLock() |
| 55 | + _register_conversion_function_once() |
| 56 | + |
| 57 | + def __call__(self): |
| 58 | + if self._tensor is None or self._tensor is _CALL_IN_PROGRESS_SENTINEL: |
| 59 | + with self._tensor_lock: |
| 60 | + if self._tensor is _CALL_IN_PROGRESS_SENTINEL: |
| 61 | + raise RuntimeError("Cannot use LazyTensorCreator with reentrant callable") |
| 62 | + elif self._tensor is None: |
| 63 | + self._tensor = _CALL_IN_PROGRESS_SENTINEL |
| 64 | + self._tensor = self._tensor_callable() |
| 65 | + return self._tensor |
| 66 | + |
| 67 | + |
| 68 | +def _lazy_tensor_creator_converter(value, dtype=None, name=None, as_ref=False): |
| 69 | + del name # ignored |
| 70 | + if not isinstance(value, LazyTensorCreator): |
| 71 | + raise RuntimeError("Expected LazyTensorCreator, got %r" % value) |
| 72 | + if as_ref: |
| 73 | + raise RuntimeError("Cannot use LazyTensorCreator to create ref tensor") |
| 74 | + tensor = value() |
| 75 | + if dtype not in (None, tensor.dtype): |
| 76 | + raise RuntimeError( |
| 77 | + "Cannot convert LazyTensorCreator returning dtype %s to dtype %s" % ( |
| 78 | + tensor.dtype, dtype)) |
| 79 | + return tensor |
| 80 | + |
| 81 | + |
| 82 | +# Use module-level bit and lock to ensure that registration of the |
| 83 | +# LazyTensorCreator conversion function happens only once. |
| 84 | +_conversion_registered = False |
| 85 | +_conversion_registered_lock = threading.Lock() |
| 86 | + |
| 87 | + |
| 88 | +def _register_conversion_function_once(): |
| 89 | + """Performs one-time registration of `_lazy_tensor_creator_converter`. |
| 90 | +
|
| 91 | + This helper can be invoked multiple times but only registers the conversion |
| 92 | + function on the first invocation, making it suitable for calling when |
| 93 | + constructing a LazyTensorCreator. |
| 94 | +
|
| 95 | + Deferring the registration is necessary because doing it at at module import |
| 96 | + time would trigger the lazy TensorFlow import to resolve, and that in turn |
| 97 | + would break the delicate `tf.summary` import cycle avoidance scheme. |
| 98 | + """ |
| 99 | + global _conversion_registered |
| 100 | + if not _conversion_registered: |
| 101 | + with _conversion_registered_lock: |
| 102 | + if not _conversion_registered: |
| 103 | + _conversion_registered = True |
| 104 | + tf.register_tensor_conversion_function( |
| 105 | + base_type=LazyTensorCreator, |
| 106 | + conversion_func=_lazy_tensor_creator_converter, |
| 107 | + priority=0) |
0 commit comments