xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/strided_slice_op_impl.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
17 #define TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
18 
19 // Functor definition for StridedSliceOp, must be compilable by nvcc.
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/register_types_traits.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/variant.h"
28 #include "tensorflow/core/framework/variant_encode_decode.h"
29 #include "tensorflow/core/kernels/dense_update_functor.h"
30 #include "tensorflow/core/kernels/ops_util.h"
31 #include "tensorflow/core/kernels/slice_op.h"
32 #include "tensorflow/core/kernels/strided_slice_op.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/lib/gtl/array_slice.h"
35 #include "tensorflow/core/platform/mem.h"
36 
37 namespace tensorflow {
38 
39 template <typename Device, typename T, int NDIM>
40 void HandleStridedSliceCase(OpKernelContext* context,
41                             const gtl::ArraySlice<int64_t>& begin,
42                             const gtl::ArraySlice<int64_t>& end,
43                             const gtl::ArraySlice<int64_t>& strides,
44                             const TensorShape& processing_shape,
45                             bool is_simple_slice, Tensor* result);
46 
47 template <typename Device, typename T, int NDIM>
48 void HandleStridedSliceGradCase(OpKernelContext* context,
49                                 const gtl::ArraySlice<int64_t>& begin,
50                                 const gtl::ArraySlice<int64_t>& end,
51                                 const gtl::ArraySlice<int64_t>& strides,
52                                 const TensorShape& processing_shape,
53                                 bool is_simple_slice, Tensor* result);
54 
55 template <typename Device, typename T, int NDIM>
56 class HandleStridedSliceAssignCase {
57  public:
58   void operator()(OpKernelContext* context,
59                   const gtl::ArraySlice<int64_t>& begin,
60                   const gtl::ArraySlice<int64_t>& end,
61                   const gtl::ArraySlice<int64_t>& strides,
62                   const StridedSliceAssignBCast& bcast, Tensor* result);
63 };
64 }  // namespace tensorflow
65 
66 // The actual implementation. This is designed so multiple
67 // translation units can include this file in the form
68 //
69 // #define STRIDED_SLICE_INSTANTIATE_DIM 1
70 // #include <thisfile>
71 // #undef STRIDED_SLICE_INSTANTIATE_DIM
72 //
73 #ifdef STRIDED_SLICE_INSTANTIATE_DIM
74 
75 namespace tensorflow {
76 
77 template <typename Device, typename T, int NDIM>
HandleStridedSliceCase(OpKernelContext * context,const gtl::ArraySlice<int64_t> & begin,const gtl::ArraySlice<int64_t> & end,const gtl::ArraySlice<int64_t> & strides,const TensorShape & processing_shape,bool is_simple_slice,Tensor * result)78 void HandleStridedSliceCase(OpKernelContext* context,
79                             const gtl::ArraySlice<int64_t>& begin,
80                             const gtl::ArraySlice<int64_t>& end,
81                             const gtl::ArraySlice<int64_t>& strides,
82                             const TensorShape& processing_shape,
83                             bool is_simple_slice, Tensor* result) {
84   typedef typename proxy_type<Device, T>::type Proxy;
85 
86   gtl::InlinedVector<int64_t, 4> processing_dims = processing_shape.dim_sizes();
87   if (is_simple_slice) {
88     Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
89     Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes_di;
90     for (int i = 0; i < NDIM; ++i) {
91       begin_di[i] = begin[i];
92       sizes_di[i] = end[i] - begin[i];
93     }
94     functor::Slice<Device, Proxy, NDIM>()(
95         context->eigen_device<Device>(),
96         result->bit_casted_shaped<Proxy, NDIM>(processing_dims),
97         context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, sizes_di);
98   } else {
99     Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
100     Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
101     Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
102     for (int i = 0; i < NDIM; ++i) {
103       begin_di[i] = begin[i];
104       end_di[i] = end[i];
105       strides_di[i] = strides[i];
106     }
107     functor::StridedSlice<Device, Proxy, NDIM>()(
108         context->eigen_device<Device>(),
109         result->bit_casted_shaped<Proxy, NDIM>(processing_dims),
110         context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, end_di,
111         strides_di);
112   }
113 }
114 
115 template <typename Device, typename T, int NDIM>
HandleStridedSliceGradCase(OpKernelContext * context,const gtl::ArraySlice<int64_t> & begin,const gtl::ArraySlice<int64_t> & end,const gtl::ArraySlice<int64_t> & strides,const TensorShape & processing_shape,bool is_simple_slice,Tensor * result)116 void HandleStridedSliceGradCase(OpKernelContext* context,
117                                 const gtl::ArraySlice<int64_t>& begin,
118                                 const gtl::ArraySlice<int64_t>& end,
119                                 const gtl::ArraySlice<int64_t>& strides,
120                                 const TensorShape& processing_shape,
121                                 bool is_simple_slice, Tensor* result) {
122   gtl::InlinedVector<int64_t, 4> processing_dims = processing_shape.dim_sizes();
123 
124   Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
125   Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
126   Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
127   for (int i = 0; i < NDIM; ++i) {
128     begin_di[i] = begin[i];
129     end_di[i] = end[i];
130     strides_di[i] = strides[i];
131   }
132 
133   typedef typename proxy_type<Device, T>::type Proxy;
134   functor::StridedSliceGrad<Device, Proxy, NDIM>()(
135       context->eigen_device<Device>(), result->bit_casted_tensor<Proxy, NDIM>(),
136       context->input(4).bit_casted_shaped<Proxy, NDIM>(processing_dims),
137       begin_di, end_di, strides_di);
138 }
139 
140 template <typename Device, typename T, int NDIM>
operator()141 void HandleStridedSliceAssignCase<Device, T, NDIM>::operator()(
142     OpKernelContext* context, const gtl::ArraySlice<int64_t>& begin,
143     const gtl::ArraySlice<int64_t>& end,
144     const gtl::ArraySlice<int64_t>& strides,
145     const StridedSliceAssignBCast& bcast, Tensor* result) {
146   typedef typename proxy_type<Device, T>::type Proxy;
147   Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
148   Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
149   Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
150   for (int i = 0; i < NDIM; ++i) {
151     begin_di[i] = begin[i];
152     end_di[i] = end[i];
153     strides_di[i] = strides[i];
154   }
155 
156   constexpr int kRhsInput = 4;
157   const Tensor& input = context->input(kRhsInput);
158   functor::StridedSliceAssign<Device, Proxy, NDIM>()(
159       context->eigen_device<Device>(), result->bit_casted_tensor<Proxy, NDIM>(),
160       input.bit_casted_shaped<Proxy, NDIM>(bcast.reshape()), begin_di, end_di,
161       strides_di, bcast);
162 }
163 
164 template <typename Device, typename T>
165 class HandleStridedSliceAssignCase<Device, T, 0> {
166  public:
167   enum { NDIM_PROXY = 1 };
operator()168   void operator()(OpKernelContext* context,
169                   const gtl::ArraySlice<int64_t>& begin,
170                   const gtl::ArraySlice<int64_t>& end,
171                   const gtl::ArraySlice<int64_t>& strides,
172                   const StridedSliceAssignBCast& bcast, Tensor* result) {
173     gtl::InlinedVector<int64_t, 1> processing_dims(1);
174     processing_dims[0] = 1;
175 
176     typedef typename proxy_type<Device, T>::type Proxy;
177     functor::StridedSliceAssignScalar<Device, Proxy>()(
178         context->eigen_device<Device>(),
179         result->bit_casted_shaped<Proxy, 1>(processing_dims),
180         context->input(4).bit_casted_shaped<Proxy, 1>(processing_dims));
181   }
182 };
183 
184 // NOTE(aselle): according to bsteiner, we need this because otherwise
185 // nvcc instantiates templates that are invalid. strided_slice_op_gpu.cu
186 // handles instantiates externally. It is important that this is done
187 // before the HandleXXCase's are instantiated to avoid duplicate
188 // specialization errors.
189 
190 #define PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM)                   \
191   namespace functor {                                              \
192   template <>                                                      \
193   void StridedSlice<GPUDevice, T, NDIM>::operator()(               \
194       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
195       typename TTypes<T, NDIM>::ConstTensor input,                 \
196       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start,         \
197       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop,          \
198       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides);      \
199   extern template struct StridedSlice<GPUDevice, T, NDIM>;         \
200   template <>                                                      \
201   void Slice<GPUDevice, T, NDIM>::operator()(                      \
202       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
203       typename TTypes<T, NDIM>::ConstTensor input,                 \
204       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,       \
205       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes);        \
206   extern template struct Slice<GPUDevice, T, NDIM>;                \
207   template <>                                                      \
208   void StridedSliceGrad<GPUDevice, T, NDIM>::operator()(           \
209       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
210       typename TTypes<T, NDIM>::ConstTensor input,                 \
211       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start,         \
212       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop,          \
213       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides);      \
214   extern template struct StridedSliceGrad<GPUDevice, T, NDIM>;     \
215   template <>                                                      \
216   void StridedSliceAssign<GPUDevice, T, NDIM>::operator()(         \
217       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
218       typename TTypes<T, NDIM>::ConstTensor input,                 \
219       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start,         \
220       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop,          \
221       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides,       \
222       const StridedSliceAssignBCast& bcast);                       \
223   extern template struct StridedSliceAssign<GPUDevice, T, NDIM>;   \
224   }  // namespace functor
225 #define PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM)                   \
226   namespace functor {                                            \
227   template <>                                                    \
228   void StridedSliceAssignScalar<GPUDevice, T>::operator()(       \
229       const GPUDevice& d, typename TTypes<T, 1>::Tensor output,  \
230       typename TTypes<T, 1>::ConstTensor input);                 \
231   extern template struct StridedSliceAssignScalar<GPUDevice, T>; \
232   }  // namespace functor
233 
234 // Dimension 0 only instantiates some functors. So we only need
235 // to prevent ones defined by PREVENT_INSTANTIATE_DIM0_ONLY
236 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
237 #if STRIDED_SLICE_INSTANTIATE_DIM == 0
238 #define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM)
239 #else
240 #define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM)
241 #endif
242 #else
243 #define PREVENT_INSTANTIATE(T, NDIM)
244 #endif
245 
246 #define INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM)                \
247   template void HandleStridedSliceCase<DEVICE, T, DIM>(                 \
248       OpKernelContext * context, const gtl::ArraySlice<int64_t>& begin, \
249       const gtl::ArraySlice<int64_t>& end,                              \
250       const gtl::ArraySlice<int64_t>& strides,                          \
251       const TensorShape& processing_shape, bool is_simple_slice,        \
252       Tensor* result);                                                  \
253   template void HandleStridedSliceGradCase<DEVICE, T, DIM>(             \
254       OpKernelContext * context, const gtl::ArraySlice<int64_t>& begin, \
255       const gtl::ArraySlice<int64_t>& end,                              \
256       const gtl::ArraySlice<int64_t>& strides,                          \
257       const TensorShape& processing_shape, bool is_simple_slice,        \
258       Tensor* result);
259 
260 #define INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \
261   template class HandleStridedSliceAssignCase<DEVICE, T, DIM>;
262 
263 // Only some kernels need to be instantiated on dim 0.
264 #if STRIDED_SLICE_INSTANTIATE_DIM == 0
265 #define INSTANTIATE(DEVICE, T, DIM) \
266   INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM)
267 #else
268 #define INSTANTIATE(DEVICE, T, DIM)                \
269   INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \
270   INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM)
271 #endif
272 
273 #define DECLARE_FOR_N_CPU(T) \
274   INSTANTIATE(CPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM)
275 
276 #define PREVENT_FOR_N_GPU(T) \
277   PREVENT_INSTANTIATE(T, STRIDED_SLICE_INSTANTIATE_DIM)
278 
279 #define DECLARE_FOR_N_GPU(T) \
280   INSTANTIATE(GPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM)
281 
282 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
283 TF_CALL_GPU_PROXY_TYPES(PREVENT_FOR_N_GPU);
284 TF_CALL_COMPLEX_TYPES(PREVENT_FOR_N_GPU);
285 
286 TF_CALL_uint8(DECLARE_FOR_N_GPU);
287 TF_CALL_int8(DECLARE_FOR_N_GPU);
288 TF_CALL_int32(DECLARE_FOR_N_GPU);
289 TF_CALL_int64(DECLARE_FOR_N_GPU);
290 TF_CALL_uint32(DECLARE_FOR_N_GPU);
291 TF_CALL_GPU_ALL_TYPES(DECLARE_FOR_N_GPU);
292 #endif  // END GOOGLE_CUDA || TENSORFLOW_USE_ROCM
293 
294 TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
295 TF_CALL_QUANTIZED_TYPES(DECLARE_FOR_N_CPU);
296 
297 #undef INSTANTIATE
298 #undef DECLARE_FOR_N_CPU
299 #undef DECLARE_FOR_N_GPU
300 
301 }  // end namespace tensorflow
302 
303 #endif  // END STRIDED_SLICE_INSTANTIATE_DIM
304 #endif  // TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
305