@@ -24,6 +24,7 @@ limitations under the License.
2424#include " tensorflow/core/framework/tensor.h"
2525#include " tensorflow/core/framework/tensor_shape.h"
2626#include " tensorflow/core/framework/types.h"
27+ #include " tensorflow/core/kernels/bounds_check.h"
2728#include " tensorflow/core/lib/core/status.h"
2829#include " tensorflow/core/platform/logging.h"
2930
@@ -32,6 +33,19 @@ namespace tensorflow {
3233typedef Eigen::ThreadPoolDevice CPUDevice;
3334typedef Eigen::GpuDevice GPUDevice;
3435
36+ template <typename Device, typename T, int NDIMS >
37+ void HandleReverseCase (OpKernelContext* context,
38+ typename TTypes<bool , 1 >::ConstTensor dims,
39+ Tensor* result) {
40+ typename Eigen::array<bool , NDIMS > axes_di;
41+ for (int i = 0 ; i < NDIMS ; i++) {
42+ axes_di[i] = dims (i);
43+ }
44+ functor::Reverse<Device, T, NDIMS >()(context->eigen_device <Device>(),
45+ context->input (0 ).tensor <T, NDIMS >(),
46+ axes_di, result->tensor <T, NDIMS >());
47+ }
48+
3549template <typename Device, typename T>
3650class ReverseOp : public OpKernel {
3751 public:
@@ -67,11 +81,9 @@ class ReverseOp : public OpKernel {
6781 OP_REQUIRES_OK (context,
6882 context->allocate_output (0 , input.shape (), &output));
6983
70- #define HANDLE_REVERSE (NDIMS ) \
71- case NDIMS : \
72- functor::Reverse<Device, T, NDIMS >()( \
73- context->eigen_device <Device>(), input.tensor <T, NDIMS >(), \
74- dims.vec <bool >(), output->tensor <T, NDIMS >()); \
84+ #define HANDLE_REVERSE (NDIMS ) \
85+ case NDIMS : \
86+ HandleReverseCase<Device, T, NDIMS >(context, dims.vec <bool >(), output); \
7587 return ;
7688
7789 switch (input_dims) {
@@ -90,15 +102,97 @@ class ReverseOp : public OpKernel {
90102 }
91103};
92104
93- #define REGISTER_KERNEL (T ) \
94- REGISTER_KERNEL_BUILDER (Name(" Reverse" ) \
95- .Device(DEVICE_CPU ) \
96- .TypeConstraint<T>(" T" ) \
97- .HostMemory(" dims" ), \
98- ReverseOp<CPUDevice, T>)
105+ template <typename Device, typename T, int NDIMS >
106+ void HandleReverseV2Case (OpKernelContext* context,
107+ const gtl::ArraySlice<bool >& axes, Tensor* result) {
108+ typename Eigen::array<bool , NDIMS > axes_di;
109+ for (int i = 0 ; i < NDIMS ; i++) {
110+ axes_di[i] = axes[i];
111+ }
112+ functor::Reverse<Device, T, NDIMS >()(context->eigen_device <Device>(),
113+ context->input (0 ).tensor <T, NDIMS >(),
114+ axes_di, result->tensor <T, NDIMS >());
115+ }
116+
117+ template <typename Device, typename T>
118+ class ReverseV2Op : public OpKernel {
119+ public:
120+ explicit ReverseV2Op (OpKernelConstruction* context) : OpKernel(context) {}
121+
122+ void Compute (OpKernelContext* context) override {
123+ const Tensor& input = context->input (0 );
124+ const Tensor& sparse_dims = context->input (1 );
125+
126+ if (TensorShapeUtils::IsScalar (input.shape ())) {
127+ Tensor* output = nullptr ;
128+ OP_REQUIRES_OK (context,
129+ context->allocate_output (0 , input.shape (), &output));
130+ output->scalar <T>() = input.scalar <T>();
131+ } else {
132+ const int input_dims = input.dims ();
133+ const TensorShape& sparse_dims_shape = sparse_dims.shape ();
134+ const auto & axes_sparse_flat = sparse_dims.flat <int32>();
135+
136+ OP_REQUIRES (context, TensorShapeUtils::IsVector (sparse_dims_shape),
137+ errors::InvalidArgument (" 'dims' must be 1-dimension, not " ,
138+ sparse_dims.dims ()));
139+ gtl::InlinedVector<bool , 8 > axes_dense (input_dims, false );
140+ for (int dummy = 0 ; dummy < axes_sparse_flat.size (); dummy++) {
141+ int32 axis = internal::SubtleMustCopy<int32>(axes_sparse_flat (dummy));
142+ int32 canonical_axis = axis < 0 ? input_dims + axis : axis;
143+ OP_REQUIRES (context, canonical_axis >= 0 && canonical_axis < input_dims,
144+ errors::InvalidArgument (" 'axis'[" , dummy, " ] = " , axis,
145+ " is out of valid range [" , 0 , " , " ,
146+ input_dims - 1 ));
147+ OP_REQUIRES (context, !axes_dense[canonical_axis],
148+ errors::InvalidArgument (" axis " , canonical_axis,
149+ " specified more than once." ));
150+ axes_dense[canonical_axis] = true ;
151+ }
99152
100- TF_CALL_POD_TYPES (REGISTER_KERNEL );
101- #undef REGISTER_KERNEL
153+ OP_REQUIRES (context, input_dims <= 8 ,
154+ errors::Unimplemented (
155+ " reverse is not implemented for tensors of rank > 8." ));
156+
157+ Tensor* output = nullptr ;
158+ OP_REQUIRES_OK (context,
159+ context->allocate_output (0 , input.shape (), &output));
160+
161+ #define HANDLE_REVERSE (NDIMS ) \
162+ case NDIMS : \
163+ HandleReverseV2Case<Device, T, NDIMS >(context, axes_dense, output); \
164+ return ;
165+
166+ switch (input_dims) {
167+ HANDLE_REVERSE (0 );
168+ HANDLE_REVERSE (1 );
169+ HANDLE_REVERSE (2 );
170+ HANDLE_REVERSE (3 );
171+ HANDLE_REVERSE (4 );
172+ HANDLE_REVERSE (5 );
173+ HANDLE_REVERSE (6 );
174+ HANDLE_REVERSE (7 );
175+ HANDLE_REVERSE (8 );
176+ }
177+ #undef HANDLE_REVERSE
178+ }
179+ }
180+ };
181+
182+ #define REGISTER_KERNELS (T ) \
183+ REGISTER_KERNEL_BUILDER (Name(" Reverse" ) \
184+ .Device(DEVICE_CPU ) \
185+ .TypeConstraint<T>(" T" ) \
186+ .HostMemory(" dims" ), \
187+ ReverseOp<CPUDevice, T>) \
188+ REGISTER_KERNEL_BUILDER (Name(" ReverseV2" ) \
189+ .Device(DEVICE_CPU ) \
190+ .TypeConstraint<T>(" T" ) \
191+ .TypeConstraint<int32>(" Tidx" ) \
192+ .HostMemory(" axis" ), \
193+ ReverseV2Op<CPUDevice, T>)
194+ TF_CALL_POD_TYPES (REGISTER_KERNELS );
195+ #undef REGISTER_KERNELS
102196
103197#if GOOGLE_CUDA
104198
@@ -109,7 +203,7 @@ namespace functor {
109203 template <> \
110204 void Reverse<GPUDevice, T, DIM >::operator ()( \
111205 const GPUDevice& d, typename TTypes<T, DIM >::ConstTensor input, \
112- typename TTypes <bool , 1 >::ConstTensor dims, \
206+ const Eigen::array <bool , DIM >& reverse_dims, \
113207 typename TTypes<T, DIM >::Tensor output); \
114208 extern template struct Reverse <GPUDevice, T, DIM >;
115209#define DECLARE_GPU_SPEC (T ) \
@@ -136,21 +230,27 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
136230} // namespace functor
137231
138232// Registration of the GPU implementations.
139- #define REGISTER_GPU_KERNEL (T ) \
140- REGISTER_KERNEL_BUILDER (Name(" Reverse" ) \
141- .Device(DEVICE_GPU ) \
142- .TypeConstraint<T>(" T" ) \
143- .HostMemory(" dims" ), \
144- ReverseOp<GPUDevice, T>)
145- TF_CALL_uint8(REGISTER_GPU_KERNEL );
146- TF_CALL_int8 (REGISTER_GPU_KERNEL );
233+ #define REGISTER_GPU_KERNELS (T ) \
234+ REGISTER_KERNEL_BUILDER (Name(" Reverse" ) \
235+ .Device(DEVICE_GPU ) \
236+ .TypeConstraint<T>(" T" ) \
237+ .HostMemory(" dims" ), \
238+ ReverseOp<GPUDevice, T>) \
239+ REGISTER_KERNEL_BUILDER(Name(" ReverseV2" ) \
240+ .Device(DEVICE_GPU ) \
241+ .TypeConstraint<T>(" T" ) \
242+ .TypeConstraint<int32>(" Tidx" ) \
243+ .HostMemory(" axis" ), \
244+ ReverseV2Op<GPUDevice, T>)
245+ TF_CALL_uint8(REGISTER_GPU_KERNELS );
246+ TF_CALL_int8 (REGISTER_GPU_KERNELS );
147247// TODO decide whether we want to enable the bool kernel.
148- // TF_CALL_bool(REGISTER_GPU_KERNEL );
149- TF_CALL_half (REGISTER_GPU_KERNEL );
150- TF_CALL_float (REGISTER_GPU_KERNEL );
151- TF_CALL_double (REGISTER_GPU_KERNEL );
152- TF_CALL_complex64 (REGISTER_GPU_KERNEL );
153- TF_CALL_complex128 (REGISTER_GPU_KERNEL );
248+ // TF_CALL_bool(REGISTER_GPU_KERNELS );
249+ TF_CALL_half (REGISTER_GPU_KERNELS );
250+ TF_CALL_float (REGISTER_GPU_KERNELS );
251+ TF_CALL_double (REGISTER_GPU_KERNELS );
252+ TF_CALL_complex64 (REGISTER_GPU_KERNELS );
253+ TF_CALL_complex128 (REGISTER_GPU_KERNELS );
154254#undef REGISTER_GPU_KERNEL
155255
156256// A special GPU kernel for int32.
@@ -163,7 +263,14 @@ REGISTER_KERNEL_BUILDER(Name("Reverse")
163263 .HostMemory(" dims" )
164264 .HostMemory(" output" ),
165265 ReverseOp<CPUDevice, int32>);
166-
266+ REGISTER_KERNEL_BUILDER (Name(" ReverseV2" )
267+ .Device(DEVICE_GPU )
268+ .TypeConstraint<int32>(" T" )
269+ .TypeConstraint<int32>(" Tidx" )
270+ .HostMemory(" tensor" )
271+ .HostMemory(" axis" )
272+ .HostMemory(" output" ),
273+ ReverseV2Op<CPUDevice, int32>);
167274#endif // GOOGLE_CUDA
168275
169276} // namespace tensorflow
0 commit comments