Skip to content

Commit 7aa9563

Browse files
aselletensorflower-gardener
authored andcommitted
Create op ReverseV2 that takes indices for reversing rather than a bool array
This does not implement the Python API change yet for forward compatibility. e.g. eventually tf.reverse(a, [0,1,-1]) will be the same as the old API of tf.reverse(a, [True, True, False, False, False, True]) for a 6 dimensional tensor a. Change: 136675570
1 parent bd82fd0 commit 7aa9563

6 files changed

Lines changed: 306 additions & 37 deletions

File tree

tensorflow/core/kernels/reverse_op.cc

Lines changed: 136 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include "tensorflow/core/framework/tensor.h"
2525
#include "tensorflow/core/framework/tensor_shape.h"
2626
#include "tensorflow/core/framework/types.h"
27+
#include "tensorflow/core/kernels/bounds_check.h"
2728
#include "tensorflow/core/lib/core/status.h"
2829
#include "tensorflow/core/platform/logging.h"
2930

@@ -32,6 +33,19 @@ namespace tensorflow {
3233
typedef Eigen::ThreadPoolDevice CPUDevice;
3334
typedef Eigen::GpuDevice GPUDevice;
3435

36+
template <typename Device, typename T, int NDIMS>
37+
void HandleReverseCase(OpKernelContext* context,
38+
typename TTypes<bool, 1>::ConstTensor dims,
39+
Tensor* result) {
40+
typename Eigen::array<bool, NDIMS> axes_di;
41+
for (int i = 0; i < NDIMS; i++) {
42+
axes_di[i] = dims(i);
43+
}
44+
functor::Reverse<Device, T, NDIMS>()(context->eigen_device<Device>(),
45+
context->input(0).tensor<T, NDIMS>(),
46+
axes_di, result->tensor<T, NDIMS>());
47+
}
48+
3549
template <typename Device, typename T>
3650
class ReverseOp : public OpKernel {
3751
public:
@@ -67,11 +81,9 @@ class ReverseOp : public OpKernel {
6781
OP_REQUIRES_OK(context,
6882
context->allocate_output(0, input.shape(), &output));
6983

70-
#define HANDLE_REVERSE(NDIMS) \
71-
case NDIMS: \
72-
functor::Reverse<Device, T, NDIMS>()( \
73-
context->eigen_device<Device>(), input.tensor<T, NDIMS>(), \
74-
dims.vec<bool>(), output->tensor<T, NDIMS>()); \
84+
#define HANDLE_REVERSE(NDIMS) \
85+
case NDIMS: \
86+
HandleReverseCase<Device, T, NDIMS>(context, dims.vec<bool>(), output); \
7587
return;
7688

7789
switch (input_dims) {
@@ -90,15 +102,97 @@ class ReverseOp : public OpKernel {
90102
}
91103
};
92104

93-
#define REGISTER_KERNEL(T) \
94-
REGISTER_KERNEL_BUILDER(Name("Reverse") \
95-
.Device(DEVICE_CPU) \
96-
.TypeConstraint<T>("T") \
97-
.HostMemory("dims"), \
98-
ReverseOp<CPUDevice, T>)
105+
template <typename Device, typename T, int NDIMS>
106+
void HandleReverseV2Case(OpKernelContext* context,
107+
const gtl::ArraySlice<bool>& axes, Tensor* result) {
108+
typename Eigen::array<bool, NDIMS> axes_di;
109+
for (int i = 0; i < NDIMS; i++) {
110+
axes_di[i] = axes[i];
111+
}
112+
functor::Reverse<Device, T, NDIMS>()(context->eigen_device<Device>(),
113+
context->input(0).tensor<T, NDIMS>(),
114+
axes_di, result->tensor<T, NDIMS>());
115+
}
116+
117+
template <typename Device, typename T>
118+
class ReverseV2Op : public OpKernel {
119+
public:
120+
explicit ReverseV2Op(OpKernelConstruction* context) : OpKernel(context) {}
121+
122+
void Compute(OpKernelContext* context) override {
123+
const Tensor& input = context->input(0);
124+
const Tensor& sparse_dims = context->input(1);
125+
126+
if (TensorShapeUtils::IsScalar(input.shape())) {
127+
Tensor* output = nullptr;
128+
OP_REQUIRES_OK(context,
129+
context->allocate_output(0, input.shape(), &output));
130+
output->scalar<T>() = input.scalar<T>();
131+
} else {
132+
const int input_dims = input.dims();
133+
const TensorShape& sparse_dims_shape = sparse_dims.shape();
134+
const auto& axes_sparse_flat = sparse_dims.flat<int32>();
135+
136+
OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_dims_shape),
137+
errors::InvalidArgument("'dims' must be 1-dimension, not ",
138+
sparse_dims.dims()));
139+
gtl::InlinedVector<bool, 8> axes_dense(input_dims, false);
140+
for (int dummy = 0; dummy < axes_sparse_flat.size(); dummy++) {
141+
int32 axis = internal::SubtleMustCopy<int32>(axes_sparse_flat(dummy));
142+
int32 canonical_axis = axis < 0 ? input_dims + axis : axis;
143+
OP_REQUIRES(context, canonical_axis >= 0 && canonical_axis < input_dims,
144+
errors::InvalidArgument("'axis'[", dummy, "] = ", axis,
145+
" is out of valid range [", 0, ", ",
146+
input_dims - 1));
147+
OP_REQUIRES(context, !axes_dense[canonical_axis],
148+
errors::InvalidArgument("axis ", canonical_axis,
149+
" specified more than once."));
150+
axes_dense[canonical_axis] = true;
151+
}
99152

100-
TF_CALL_POD_TYPES(REGISTER_KERNEL);
101-
#undef REGISTER_KERNEL
153+
OP_REQUIRES(context, input_dims <= 8,
154+
errors::Unimplemented(
155+
"reverse is not implemented for tensors of rank > 8."));
156+
157+
Tensor* output = nullptr;
158+
OP_REQUIRES_OK(context,
159+
context->allocate_output(0, input.shape(), &output));
160+
161+
#define HANDLE_REVERSE(NDIMS) \
162+
case NDIMS: \
163+
HandleReverseV2Case<Device, T, NDIMS>(context, axes_dense, output); \
164+
return;
165+
166+
switch (input_dims) {
167+
HANDLE_REVERSE(0);
168+
HANDLE_REVERSE(1);
169+
HANDLE_REVERSE(2);
170+
HANDLE_REVERSE(3);
171+
HANDLE_REVERSE(4);
172+
HANDLE_REVERSE(5);
173+
HANDLE_REVERSE(6);
174+
HANDLE_REVERSE(7);
175+
HANDLE_REVERSE(8);
176+
}
177+
#undef HANDLE_REVERSE
178+
}
179+
}
180+
};
181+
182+
#define REGISTER_KERNELS(T) \
183+
REGISTER_KERNEL_BUILDER(Name("Reverse") \
184+
.Device(DEVICE_CPU) \
185+
.TypeConstraint<T>("T") \
186+
.HostMemory("dims"), \
187+
ReverseOp<CPUDevice, T>) \
188+
REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
189+
.Device(DEVICE_CPU) \
190+
.TypeConstraint<T>("T") \
191+
.TypeConstraint<int32>("Tidx") \
192+
.HostMemory("axis"), \
193+
ReverseV2Op<CPUDevice, T>)
194+
TF_CALL_POD_TYPES(REGISTER_KERNELS);
195+
#undef REGISTER_KERNELS
102196

103197
#if GOOGLE_CUDA
104198

@@ -109,7 +203,7 @@ namespace functor {
109203
template <> \
110204
void Reverse<GPUDevice, T, DIM>::operator()( \
111205
const GPUDevice& d, typename TTypes<T, DIM>::ConstTensor input, \
112-
typename TTypes<bool, 1>::ConstTensor dims, \
206+
const Eigen::array<bool, DIM>& reverse_dims, \
113207
typename TTypes<T, DIM>::Tensor output); \
114208
extern template struct Reverse<GPUDevice, T, DIM>;
115209
#define DECLARE_GPU_SPEC(T) \
@@ -136,21 +230,27 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
136230
} // namespace functor
137231

138232
// Registration of the GPU implementations.
139-
#define REGISTER_GPU_KERNEL(T) \
140-
REGISTER_KERNEL_BUILDER(Name("Reverse") \
141-
.Device(DEVICE_GPU) \
142-
.TypeConstraint<T>("T") \
143-
.HostMemory("dims"), \
144-
ReverseOp<GPUDevice, T>)
145-
TF_CALL_uint8(REGISTER_GPU_KERNEL);
146-
TF_CALL_int8(REGISTER_GPU_KERNEL);
233+
#define REGISTER_GPU_KERNELS(T) \
234+
REGISTER_KERNEL_BUILDER(Name("Reverse") \
235+
.Device(DEVICE_GPU) \
236+
.TypeConstraint<T>("T") \
237+
.HostMemory("dims"), \
238+
ReverseOp<GPUDevice, T>) \
239+
REGISTER_KERNEL_BUILDER(Name("ReverseV2") \
240+
.Device(DEVICE_GPU) \
241+
.TypeConstraint<T>("T") \
242+
.TypeConstraint<int32>("Tidx") \
243+
.HostMemory("axis"), \
244+
ReverseV2Op<GPUDevice, T>)
245+
TF_CALL_uint8(REGISTER_GPU_KERNELS);
246+
TF_CALL_int8(REGISTER_GPU_KERNELS);
147247
// TODO decide whether we want to enable the bool kernel.
148-
// TF_CALL_bool(REGISTER_GPU_KERNEL);
149-
TF_CALL_half(REGISTER_GPU_KERNEL);
150-
TF_CALL_float(REGISTER_GPU_KERNEL);
151-
TF_CALL_double(REGISTER_GPU_KERNEL);
152-
TF_CALL_complex64(REGISTER_GPU_KERNEL);
153-
TF_CALL_complex128(REGISTER_GPU_KERNEL);
248+
// TF_CALL_bool(REGISTER_GPU_KERNELS);
249+
TF_CALL_half(REGISTER_GPU_KERNELS);
250+
TF_CALL_float(REGISTER_GPU_KERNELS);
251+
TF_CALL_double(REGISTER_GPU_KERNELS);
252+
TF_CALL_complex64(REGISTER_GPU_KERNELS);
253+
TF_CALL_complex128(REGISTER_GPU_KERNELS);
154254
#undef REGISTER_GPU_KERNEL
155255

156256
// A special GPU kernel for int32.
@@ -163,7 +263,14 @@ REGISTER_KERNEL_BUILDER(Name("Reverse")
163263
.HostMemory("dims")
164264
.HostMemory("output"),
165265
ReverseOp<CPUDevice, int32>);
166-
266+
REGISTER_KERNEL_BUILDER(Name("ReverseV2")
267+
.Device(DEVICE_GPU)
268+
.TypeConstraint<int32>("T")
269+
.TypeConstraint<int32>("Tidx")
270+
.HostMemory("tensor")
271+
.HostMemory("axis")
272+
.HostMemory("output"),
273+
ReverseV2Op<CPUDevice, int32>);
167274
#endif // GOOGLE_CUDA
168275

169276
} // namespace tensorflow

tensorflow/core/kernels/reverse_op.h

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,20 @@ limitations under the License.
2222
namespace tensorflow {
2323
namespace functor {
2424

25-
// Functor used by MirrorOp to do the computations.
25+
// Functor used by ReverseOp to do the computations.
2626
template <typename Device, typename T, int Dims>
2727
struct Reverse {
2828
void operator()(const Device& d, typename TTypes<T, Dims>::ConstTensor input,
29-
typename TTypes<bool, 1>::ConstTensor dims,
29+
const Eigen::array<bool, Dims>& reverse_dims,
3030
typename TTypes<T, Dims>::Tensor output) {
31-
// mirror is in host memory
32-
Eigen::array<bool, Dims> reverse_dims;
33-
for (int i = 0; i < Dims; ++i) {
34-
reverse_dims[i] = dims(i);
35-
}
3631
output.device(d) = input.reverse(reverse_dims);
3732
}
3833
};
3934

4035
template <typename Device, typename T>
4136
struct Reverse<Device, T, 0> {
4237
void operator()(const Device& d, typename TTypes<T, 0>::ConstTensor input,
43-
typename TTypes<bool, 1>::ConstTensor,
38+
const Eigen::array<bool, 0>& reverse_dims,
4439
typename TTypes<T, 0>::Tensor output) {
4540
// Reversing a scalar is copying it.
4641
output.device(d) = input;

tensorflow/core/ops/array_ops.cc

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,79 @@ dims: 1-D. The dimensions to reverse.
791791
output: The same shape as `tensor`.
792792
)Doc");
793793

794+
// --------------------------------------------------------------------------
795+
REGISTER_OP("ReverseV2")
796+
.Input("tensor: T")
797+
.Input("axis: Tidx")
798+
.Output("output: T")
799+
.Attr("Tidx: {int32, int64} = DT_INT32")
800+
.Attr(
801+
"T: {uint8, int8, int32, int64, bool, half, float, double, complex64, "
802+
"complex128}")
803+
.SetShapeFn([](InferenceContext* c) {
804+
ShapeHandle input = c->input(0);
805+
ShapeHandle axis;
806+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &axis));
807+
// TODO(aselle): if input(0)'s dimension is known we could validate axis
808+
if (c->Rank(input) > 8) {
809+
return errors::InvalidArgument(
810+
"reverse does not work on tensors with more than 8 dimensions");
811+
}
812+
c->set_output(0, input);
813+
return Status::OK();
814+
})
815+
.Doc(R"Doc(
816+
Reverses specific dimensions of a tensor.
817+
818+
Given a `tensor`, and a `int32` tensor `axis` representing the set of
819+
dimensions of `tensor` to reverse. This operation reverses each dimension
820+
`i` for which there exists `j` s.t. `axis[j] == i`.
821+
822+
`tensor` can have up to 8 dimensions. The number of dimensions specified
823+
in `axis` may be 0 or more entries. If an index is specified more than
824+
once, a InvalidArgument error is raised.
825+
826+
For example:
827+
828+
```prettyprint
829+
# tensor 't' is [[[[ 0, 1, 2, 3],
830+
# [ 4, 5, 6, 7],
831+
# [ 8, 9, 10, 11]],
832+
# [[12, 13, 14, 15],
833+
# [16, 17, 18, 19],
834+
# [20, 21, 22, 23]]]]
835+
# tensor 't' shape is [1, 2, 3, 4]
836+
837+
# 'dims' is [3] or 'dims' is -1
838+
reverse(t, dims) ==> [[[[ 3, 2, 1, 0],
839+
[ 7, 6, 5, 4],
840+
[ 11, 10, 9, 8]],
841+
[[15, 14, 13, 12],
842+
[19, 18, 17, 16],
843+
[23, 22, 21, 20]]]]
844+
845+
# 'dims' is '[1]' (or 'dims' is '[-3]')
846+
reverse(t, dims) ==> [[[[12, 13, 14, 15],
847+
[16, 17, 18, 19],
848+
[20, 21, 22, 23]
849+
[[ 0, 1, 2, 3],
850+
[ 4, 5, 6, 7],
851+
[ 8, 9, 10, 11]]]]
852+
853+
# 'dims' is '[2]' (or 'dims' is '[-2]')
854+
reverse(t, dims) ==> [[[[8, 9, 10, 11],
855+
[4, 5, 6, 7],
856+
[0, 1, 2, 3]]
857+
[[20, 21, 22, 23],
858+
[16, 17, 18, 19],
859+
[12, 13, 14, 15]]]]
860+
```
861+
862+
tensor: Up to 8-D.
863+
axis: 1-D. The indices of the dimensions to reverse.
864+
output: The same shape as `tensor`.
865+
)Doc");
866+
794867
// --------------------------------------------------------------------------
795868
REGISTER_OP("EditDistance")
796869
.Input("hypothesis_indices: int64")

tensorflow/core/ops/array_ops_test.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,18 @@ TEST(ArrayOpsTest, Reverse_ShapeFn) {
203203
INFER_OK(op, "[1,2,3,?,5,6,7,8];[8]", "in0");
204204
}
205205

206+
TEST(ArrayOpsTest, ReverseV2_ShapeFn) {
207+
ShapeInferenceTestOp op("ReverseV2");
208+
INFER_OK(op, "?;?", "in0");
209+
INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
210+
INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[?,2]");
211+
INFER_OK(op, "[1,2,3];[2]", "in0");
212+
INFER_ERROR("reverse does not work on tensors with more than 8 dimensions",
213+
op, "[1,2,3,4,5,6,7,8,9];[9]");
214+
INFER_OK(op, "[1,2,3,?];[4]", "in0");
215+
INFER_OK(op, "[1,2,3,?,5,6,7,8];[8]", "in0");
216+
}
217+
206218
TEST(ArrayOpsTest, Fill_ShapeFn) {
207219
ShapeInferenceTestOp op("Fill");
208220
op.input_tensors.resize(2);

0 commit comments

Comments
 (0)