forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcontrol_flow_assert.py
More file actions
130 lines (113 loc) · 4.65 KB
/
Copy pathcontrol_flow_assert.py
File metadata and controls
130 lines (113 loc) · 4.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Assert functions for Control Flow Operations."""
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_logging_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
def _summarize_eager(tensor, summarize=None):
"""Returns a summarized string representation of eager `tensor`.
Args:
tensor: EagerTensor to summarize
summarize: Include these many first elements of `array`
"""
# Emulate the behavior of Tensor::SummarizeValue()
if summarize is None:
summarize = 3
elif summarize < 0:
summarize = array_ops.size(tensor)
# reshape((-1,)) is the fastest way to get a flat array view
if tensor._rank(): # pylint: disable=protected-access
flat = tensor.numpy().reshape((-1,))
lst = [str(x) for x in flat[:summarize]]
if len(lst) < flat.size:
lst.append("...")
else:
# tensor.numpy() returns a scalar for zero dimensional arrays
if gen_math_ops.not_equal(summarize, 0):
lst = [str(tensor.numpy())]
else:
lst = []
return ", ".join(lst)
# Assert and Print are special symbols in python, so we must
# use an upper-case version of them.
@tf_export("debugging.Assert", "Assert")
@dispatch.add_dispatch_support
@tf_should_use.should_use_result
def Assert(condition, data, summarize=None, name=None):
"""Asserts that the given condition is true.
If `condition` evaluates to false, print the list of tensors in `data`.
`summarize` determines how many entries of the tensors to print.
Args:
condition: The condition to evaluate.
data: The tensors to print out when condition is false.
summarize: Print this many entries of each tensor.
name: A name for this operation (optional).
Returns:
assert_op: An `Operation` that, when executed, raises a
`tf.errors.InvalidArgumentError` if `condition` is not true.
@compatibility(eager)
returns None
@end_compatibility
Raises:
@compatibility(TF1)
When in TF V1 mode (that is, outside `tf.function`) Assert needs a control
dependency on the output to ensure the assertion executes:
```python
# Ensure maximum element of x is smaller or equal to 1
assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x])
with tf.control_dependencies([assert_op]):
... code using x ...
```
@end_compatibility
"""
if context.executing_eagerly():
if not condition:
xs = ops.convert_n_to_tensor(data)
data_str = [_summarize_eager(x, summarize) for x in xs]
raise errors.InvalidArgumentError(
node_def=None,
op=None,
message="Expected '%s' to be true. Summarized data: %s" %
(condition, "\n".join(data_str)))
return
with ops.name_scope(name, "Assert", [condition, data]) as name:
xs = ops.convert_n_to_tensor(data)
if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs):
# As a simple heuristic, we assume that string and int32 are
# on host to avoid the need to use cond. If it is not case,
# we will pay the price copying the tensor to host memory.
return gen_logging_ops._assert(condition, data, summarize, name="Assert") # pylint: disable=protected-access
else:
condition = ops.convert_to_tensor(condition, name="Condition")
def true_assert():
return gen_logging_ops._assert( # pylint: disable=protected-access
condition, data, summarize, name="Assert")
guarded_assert = cond.cond(
condition,
gen_control_flow_ops.no_op,
true_assert,
name="AssertGuard")
if context.executing_eagerly():
return
return guarded_assert.op