xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/spacetobatch_functor.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Specialization of SpaceToBatchFunctor for a CPUDevice.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/spacetobatch_functor.h"
21 
22 #include "tensorflow/core/framework/register_types.h"
23 
24 namespace tensorflow {
25 
26 typedef Eigen::ThreadPoolDevice CPUDevice;
27 
28 namespace functor {
29 
30 namespace {
31 
32 // Implementation of nested loops for SpaceToBatchOpFunctor.
33 //
34 // To simplify template implementation given lack of constexpr if, both the
35 // input and output pointers are non-const.
36 template <int N, bool B2S>
37 struct SpaceToBatchHelper {
38   template <typename T>
runtensorflow::functor::__anonda5dc4cb0111::SpaceToBatchHelper39   static void run(T* space_tensor_ptr, const int64_t* space_tensor_shape,
40                   const int64_t* space_tensor_strides,
41                   const int64_t* block_shape, const int64_t* pad_start,
42                   const int64_t* block_offsets,
43                   const int64_t* batch_tensor_shape,
44                   const int64_t* batch_tensor_strides, T* batch_tensor_ptr) {
45     for (int64_t batch_tensor_pos = 0; batch_tensor_pos < batch_tensor_shape[0];
46          ++batch_tensor_pos) {
47       const int64_t space_tensor_pos =
48           batch_tensor_pos * block_shape[0] + block_offsets[0] - pad_start[0];
49       if (space_tensor_pos >= 0 && space_tensor_pos < space_tensor_shape[0]) {
50         SpaceToBatchHelper<N - 1, B2S>::run(
51             space_tensor_ptr + space_tensor_pos * space_tensor_strides[0],
52             space_tensor_shape + 1, space_tensor_strides + 1, block_shape + 1,
53             pad_start + 1, block_offsets + 1, batch_tensor_shape + 1,
54             batch_tensor_strides + 1, batch_tensor_ptr);
55       } else {
56         if (B2S == false) {
57           // Copy in padding.
58           for (int64_t i = 0; i < batch_tensor_strides[0]; ++i) {
59             batch_tensor_ptr[i] = static_cast<T>(0);
60           }
61         }
62       }
63       batch_tensor_ptr += batch_tensor_strides[0];
64     }
65   }
66 };
67 
68 template <bool B2S>
69 struct SpaceToBatchHelper<0, B2S> {
70   template <typename T>
runtensorflow::functor::__anonda5dc4cb0111::SpaceToBatchHelper71   static void run(T* space_tensor_ptr, const int64_t* space_tensor_shape,
72                   const int64_t* space_tensor_strides,
73                   const int64_t* block_shape, const int64_t* pad_start,
74                   const int64_t* block_offsets,
75                   const int64_t* batch_tensor_shape,
76                   const int64_t* batch_tensor_strides, T* batch_tensor_ptr) {
77     for (int64_t i = 0; i < batch_tensor_strides[-1]; ++i) {
78       if (B2S == false) {
79         batch_tensor_ptr[i] = space_tensor_ptr[i];
80       } else {
81         space_tensor_ptr[i] = batch_tensor_ptr[i];
82       }
83     }
84   }
85 };
86 
87 }  // namespace
88 
89 template <typename T, int NUM_BLOCK_DIMS, bool B2S>
90 struct SpaceToBatchFunctor<CPUDevice, T, NUM_BLOCK_DIMS, B2S> {
91   using SpaceT = typename std::conditional<B2S, T, const T>::type;
92   using BatchT = typename std::conditional<B2S, const T, T>::type;
operator ()tensorflow::functor::SpaceToBatchFunctor93   Status operator()(
94       const CPUDevice& d,
95       typename TTypes<SpaceT, NUM_BLOCK_DIMS + 2>::Tensor space_tensor,
96       const int64_t block_shape_tensor[NUM_BLOCK_DIMS],
97       const int64_t paddings_tensor[NUM_BLOCK_DIMS * 2],
98       typename TTypes<BatchT, NUM_BLOCK_DIMS + 2>::Tensor batch_tensor) {
99     const int64_t batch_tensor_batch = batch_tensor.dimension(0);
100 
101     const int64_t space_tensor_batch = space_tensor.dimension(0);
102 
103     // Copy into local array so that the compiler is free to place in a
104     // register.
105     int64_t pad_start[NUM_BLOCK_DIMS];
106     int64_t block_shape[NUM_BLOCK_DIMS];
107     int64_t space_tensor_shape[NUM_BLOCK_DIMS],
108         batch_tensor_shape[NUM_BLOCK_DIMS];
109     for (int block_dim = 0; block_dim < NUM_BLOCK_DIMS; ++block_dim) {
110       pad_start[block_dim] = paddings_tensor[block_dim * 2];
111       block_shape[block_dim] = block_shape_tensor[block_dim];
112       space_tensor_shape[block_dim] = space_tensor.dimension(block_dim + 1);
113       batch_tensor_shape[block_dim] = batch_tensor.dimension(block_dim + 1);
114     }
115 
116     int64_t space_tensor_strides[NUM_BLOCK_DIMS + 2],
117         batch_tensor_strides[NUM_BLOCK_DIMS + 2];
118     space_tensor_strides[NUM_BLOCK_DIMS + 1] =
119         batch_tensor_strides[NUM_BLOCK_DIMS + 1] = 1;
120     for (int dim = NUM_BLOCK_DIMS; dim >= 0; --dim) {
121       space_tensor_strides[dim] =
122           space_tensor_strides[dim + 1] * space_tensor.dimension(dim + 1);
123       batch_tensor_strides[dim] =
124           batch_tensor_strides[dim + 1] * batch_tensor.dimension(dim + 1);
125     }
126 
127     // Use non-const pointers for both input and output to simplify template
128     // implementation given lack of constexpr if.
129     T* space_tensor_ptr = const_cast<T*>(space_tensor.data());
130     T* batch_tensor_ptr = const_cast<T*>(batch_tensor.data());
131 
132     for (int64_t batch_tensor_b = 0; batch_tensor_b < batch_tensor_batch;
133          ++batch_tensor_b) {
134       const int64_t space_tensor_b = batch_tensor_b % space_tensor_batch;
135       int64_t block_index = batch_tensor_b / space_tensor_batch;
136       int64_t block_offsets[NUM_BLOCK_DIMS];
137       for (int block_dim = NUM_BLOCK_DIMS - 1; block_dim >= 0; --block_dim) {
138         // Skip unnecessary remainder operation for block_dim == 0.
139         block_offsets[block_dim] =
140             block_dim > 0 ? block_index % block_shape[block_dim] : block_index;
141         block_index /= block_shape[block_dim];
142       }
143 
144       // The compiler should inline the nested loops generated by this template.
145       SpaceToBatchHelper<NUM_BLOCK_DIMS, B2S>::run(
146           space_tensor_ptr + space_tensor_b * space_tensor_strides[0],
147           space_tensor_shape, &space_tensor_strides[1], block_shape, pad_start,
148           block_offsets, batch_tensor_shape, &batch_tensor_strides[1],
149           batch_tensor_ptr + batch_tensor_b * batch_tensor_strides[0]);
150     }
151     return OkStatus();
152   }
153 };
154 
155 // Instantiate.
156 #define INSTANTIATE(NUM_BLOCK_DIMS, T)                                      \
157   template struct SpaceToBatchFunctor<CPUDevice, T, NUM_BLOCK_DIMS, false>; \
158   template struct SpaceToBatchFunctor<CPUDevice, T, NUM_BLOCK_DIMS, true>;  \
159   /**/
160 
161 #define INSTANTIATE_FOR_T(T) \
162   TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(INSTANTIATE, T)
163 
164 TF_CALL_REAL_NUMBER_TYPES(INSTANTIATE_FOR_T)
165 
166 #undef INSTANTIATE_FOR_T
167 #undef INSTANTIATE
168 
169 }  // namespace functor
170 }  // end namespace tensorflow
171