|
27 | 27 |
|
28 | 28 | class ReducedShapeTest(tf.test.TestCase): |
29 | 29 |
|
| 30 | + def _check(self, shape, axes, result): |
| 31 | + output = math_ops.reduced_shape(shape, axes=axes) |
| 32 | + self.assertAllEqual(output.eval(), result) |
| 33 | + |
30 | 34 | def testSimple(self): |
31 | 35 | with self.test_session(): |
32 | | - def check(shape, axes, result): |
33 | | - output = math_ops.reduced_shape(shape, axes=axes) |
34 | | - self.assertAllEqual(output.eval(), result) |
35 | | - check([3], [], [3]) |
36 | | - check([3], [0], [1]) |
37 | | - check([5, 3], [], [5, 3]) |
38 | | - check([5, 3], [0], [1, 3]) |
39 | | - check([5, 3], [1], [5, 1]) |
40 | | - check([5, 3], [0, 1], [1, 1]) |
| 36 | + self._check([3], [], [3]) |
| 37 | + self._check([3], [0], [1]) |
| 38 | + self._check([5, 3], [], [5, 3]) |
| 39 | + self._check([5, 3], [0], [1, 3]) |
| 40 | + self._check([5, 3], [1], [5, 1]) |
| 41 | + self._check([5, 3], [0, 1], [1, 1]) |
41 | 42 |
|
42 | 43 | def testZeros(self): |
43 | 44 | """Check that reduced_shape does the right thing with zero dimensions.""" |
44 | 45 | with self.test_session(): |
45 | | - def check(shape, axes, result): |
46 | | - output = math_ops.reduced_shape(shape, axes=axes) |
47 | | - self.assertAllEqual(output.eval(), result) |
48 | | - check([0], [], [0]) |
49 | | - check([0], [0], [1]) |
50 | | - check([0, 3], [], [0, 3]) |
51 | | - check([0, 3], [0], [1, 3]) |
52 | | - check([0, 3], [1], [0, 1]) |
53 | | - check([0, 3], [0, 1], [1, 1]) |
54 | | - check([3, 0], [], [3, 0]) |
55 | | - check([3, 0], [0], [1, 0]) |
56 | | - check([3, 0], [1], [3, 1]) |
57 | | - check([3, 0], [0, 1], [1, 1]) |
| 46 | + self._check([0], [], [0]) |
| 47 | + self._check([0], [0], [1]) |
| 48 | + self._check([0, 3], [], [0, 3]) |
| 49 | + self._check([0, 3], [0], [1, 3]) |
| 50 | + self._check([0, 3], [1], [0, 1]) |
| 51 | + self._check([0, 3], [0, 1], [1, 1]) |
| 52 | + self._check([3, 0], [], [3, 0]) |
| 53 | + self._check([3, 0], [0], [1, 0]) |
| 54 | + self._check([3, 0], [1], [3, 1]) |
| 55 | + self._check([3, 0], [0, 1], [1, 1]) |
| 56 | + |
| 57 | + def testNegAxes(self): |
| 58 | + with self.test_session(): |
| 59 | + self._check([10, 10, 10], [-1], [10, 10, 1]) |
| 60 | + self._check([10, 10, 10], [-1, 2], [10, 10, 1]) |
| 61 | + self._check([10, 10, 10], [-1, -1], [10, 10, 1]) |
| 62 | + self._check([10, 10, 10], [-1, 0], [1, 10, 1]) |
| 63 | + self._check([10, 10, 10], [-3], [1, 10, 10]) |
58 | 64 |
|
59 | 65 |
|
60 | 66 | class SumReductionTest(tf.test.TestCase): |
@@ -110,6 +116,9 @@ def testFloatReduce3D(self): |
110 | 116 | self._compareAll(np_arr, [1, 2]) |
111 | 117 | self._compareAll(np_arr, [0, 2]) |
112 | 118 | self._compareAll(np_arr, [0, 1, 2]) |
| 119 | + self._compareAll(np_arr, [-1]) |
| 120 | + self._compareAll(np_arr, [-1, -3]) |
| 121 | + self._compareAll(np_arr, [-1, 1]) |
113 | 122 |
|
114 | 123 | def testFloatReduce4D(self): |
115 | 124 | # Create a 4D array of floats and reduce across some |
@@ -167,7 +176,7 @@ def testInvalidIndex(self): |
167 | 176 | input_tensor = tf.convert_to_tensor(np_arr) |
168 | 177 | with self.assertRaisesWithPredicateMatch( |
169 | 178 | ValueError, lambda e: "Invalid reduction dimension" in str(e)): |
170 | | - tf.reduce_sum(input_tensor, [-1]) |
| 179 | + tf.reduce_sum(input_tensor, [-3]) |
171 | 180 | with self.assertRaisesWithPredicateMatch( |
172 | 181 | ValueError, lambda e: "Invalid reduction dimension" in str(e)): |
173 | 182 | tf.reduce_sum(input_tensor, [2]) |
|
0 commit comments