Skip to content

Commit cce41d3

Browse files
kevemantensorflower-gardener
authored andcommitted
Support negative values in the reduction_indices argument of reduce_*
functions. Fixes tensorflow#2426 Change: 122735328
1 parent 144855b commit cce41d3

3 files changed

Lines changed: 41 additions & 26 deletions

File tree

tensorflow/core/kernels/reduction_ops_common.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
6161
gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
6262
auto axis_vec = axis.flat<int32>();
6363
for (int64 i = 0; i < axis.NumElements(); ++i) {
64-
const int32 index = axis_vec(i);
65-
if (index < 0 || index >= data.dims()) {
64+
int32 index = axis_vec(i);
65+
if (index < -data.dims() || index >= data.dims()) {
6666
return errors::InvalidArgument("Invalid reduction dimension (", index,
6767
" for input with ", data.dims(),
6868
" dimension(s)");
6969
}
70+
index = (index + data.dims()) % data.dims();
7071
bitmap[index] = true;
7172
}
7273

tensorflow/python/kernel_tests/reduction_ops_test.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,34 +27,40 @@
2727

2828
class ReducedShapeTest(tf.test.TestCase):
2929

30+
def _check(self, shape, axes, result):
31+
output = math_ops.reduced_shape(shape, axes=axes)
32+
self.assertAllEqual(output.eval(), result)
33+
3034
def testSimple(self):
3135
with self.test_session():
32-
def check(shape, axes, result):
33-
output = math_ops.reduced_shape(shape, axes=axes)
34-
self.assertAllEqual(output.eval(), result)
35-
check([3], [], [3])
36-
check([3], [0], [1])
37-
check([5, 3], [], [5, 3])
38-
check([5, 3], [0], [1, 3])
39-
check([5, 3], [1], [5, 1])
40-
check([5, 3], [0, 1], [1, 1])
36+
self._check([3], [], [3])
37+
self._check([3], [0], [1])
38+
self._check([5, 3], [], [5, 3])
39+
self._check([5, 3], [0], [1, 3])
40+
self._check([5, 3], [1], [5, 1])
41+
self._check([5, 3], [0, 1], [1, 1])
4142

4243
def testZeros(self):
4344
"""Check that reduced_shape does the right thing with zero dimensions."""
4445
with self.test_session():
45-
def check(shape, axes, result):
46-
output = math_ops.reduced_shape(shape, axes=axes)
47-
self.assertAllEqual(output.eval(), result)
48-
check([0], [], [0])
49-
check([0], [0], [1])
50-
check([0, 3], [], [0, 3])
51-
check([0, 3], [0], [1, 3])
52-
check([0, 3], [1], [0, 1])
53-
check([0, 3], [0, 1], [1, 1])
54-
check([3, 0], [], [3, 0])
55-
check([3, 0], [0], [1, 0])
56-
check([3, 0], [1], [3, 1])
57-
check([3, 0], [0, 1], [1, 1])
46+
self._check([0], [], [0])
47+
self._check([0], [0], [1])
48+
self._check([0, 3], [], [0, 3])
49+
self._check([0, 3], [0], [1, 3])
50+
self._check([0, 3], [1], [0, 1])
51+
self._check([0, 3], [0, 1], [1, 1])
52+
self._check([3, 0], [], [3, 0])
53+
self._check([3, 0], [0], [1, 0])
54+
self._check([3, 0], [1], [3, 1])
55+
self._check([3, 0], [0, 1], [1, 1])
56+
57+
def testNegAxes(self):
58+
with self.test_session():
59+
self._check([10, 10, 10], [-1], [10, 10, 1])
60+
self._check([10, 10, 10], [-1, 2], [10, 10, 1])
61+
self._check([10, 10, 10], [-1, -1], [10, 10, 1])
62+
self._check([10, 10, 10], [-1, 0], [1, 10, 1])
63+
self._check([10, 10, 10], [-3], [1, 10, 10])
5864

5965

6066
class SumReductionTest(tf.test.TestCase):
@@ -110,6 +116,9 @@ def testFloatReduce3D(self):
110116
self._compareAll(np_arr, [1, 2])
111117
self._compareAll(np_arr, [0, 2])
112118
self._compareAll(np_arr, [0, 1, 2])
119+
self._compareAll(np_arr, [-1])
120+
self._compareAll(np_arr, [-1, -3])
121+
self._compareAll(np_arr, [-1, 1])
113122

114123
def testFloatReduce4D(self):
115124
# Create a 4D array of floats and reduce across some
@@ -167,7 +176,7 @@ def testInvalidIndex(self):
167176
input_tensor = tf.convert_to_tensor(np_arr)
168177
with self.assertRaisesWithPredicateMatch(
169178
ValueError, lambda e: "Invalid reduction dimension" in str(e)):
170-
tf.reduce_sum(input_tensor, [-1])
179+
tf.reduce_sum(input_tensor, [-3])
171180
with self.assertRaisesWithPredicateMatch(
172181
ValueError, lambda e: "Invalid reduction dimension" in str(e)):
173182
tf.reduce_sum(input_tensor, [2])

tensorflow/python/ops/math_ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1527,10 +1527,14 @@ def _ReductionShape(op):
15271527
reduction_indices = np.ravel(reduction_indices)
15281528

15291529
for reduction_index in reduction_indices:
1530-
if reduction_index < 0 or reduction_index >= input_shape.ndims:
1530+
if (reduction_index < -input_shape.ndims or
1531+
reduction_index >= input_shape.ndims):
15311532
raise ValueError("Invalid reduction dimension %d for input with %d "
15321533
"dimensions" % (reduction_index, input_shape.ndims))
15331534

1535+
reduction_indices = set([(x + input_shape.ndims) % input_shape.ndims
1536+
for x in reduction_indices])
1537+
15341538
returned_dims = []
15351539
if keep_dims:
15361540
for i, dim in enumerate(input_shape.dims):
@@ -1624,6 +1628,7 @@ def reduced_shape(input_shape, axes):
16241628
axes = to_int32(axes) # [1, 2]
16251629

16261630
input_rank = array_ops.size(input_shape) # 4
1631+
axes = (axes + input_rank) % input_rank
16271632
axes_shape = array_ops.shape(axes) # [2]
16281633
return gen_data_flow_ops.dynamic_stitch( # [2, 1, 1, 7]
16291634
[range(input_rank), # [0, 1, 2, 3]

0 commit comments

Comments
 (0)