xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/gradient/bcast_grad_args.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements the TensorFlow Lite's broadcast gradient argument
17 // operator.
18 
19 #include <algorithm>
20 #include <array>
21 #include <cmath>
22 
23 #include "tensorflow/lite/c/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/kernels/internal/common.h"
26 #include "tensorflow/lite/kernels/internal/tensor.h"
27 #include "tensorflow/lite/kernels/internal/types.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 #include "tensorflow/lite/kernels/padding.h"
30 
31 namespace tflite {
32 namespace ops {
33 namespace custom {
34 namespace {
35 
36 static const int kInputOneTensor = 0;
37 static const int kInputTwoTensor = 1;
38 static const int kOutputOneTensor = 0;
39 static const int kOutputTwoTensor = 1;
40 
Prepare(TfLiteContext * context,TfLiteNode * node)41 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42   // Check inputs and output.
43   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
44   const TfLiteTensor* input1 = GetInput(context, node, kInputOneTensor);
45   TF_LITE_ENSURE(context, input1 != nullptr);
46   const RuntimeShape input1_shape = GetTensorShape(input1);
47   TF_LITE_ENSURE(context,
48                  input1->type == kTfLiteInt32 || input1->type == kTfLiteInt64);
49   TF_LITE_ENSURE_EQ(context, input1_shape.DimensionsCount(), 1);
50 
51   const TfLiteTensor* input2 = GetInput(context, node, kInputTwoTensor);
52   TF_LITE_ENSURE(context, input2 != nullptr);
53   const RuntimeShape input2_shape = GetTensorShape(input2);
54   TF_LITE_ENSURE_TYPES_EQ(context, input2->type, input1->type);
55   TF_LITE_ENSURE_EQ(context, input2_shape.DimensionsCount(), 1);
56 
57   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
58   TfLiteTensor* output1 = GetOutput(context, node, kOutputOneTensor);
59   TF_LITE_ENSURE(context, output1 != nullptr);
60   TF_LITE_ENSURE_TYPES_EQ(context, output1->type, input1->type);
61   TfLiteTensor* output2 = GetOutput(context, node, kOutputTwoTensor);
62   TF_LITE_ENSURE(context, output2 != nullptr);
63   TF_LITE_ENSURE_TYPES_EQ(context, output2->type, input1->type);
64   SetTensorToDynamic(output1);
65   SetTensorToDynamic(output2);
66   return kTfLiteOk;
67 }
68 
Invoke(TfLiteContext * context,TfLiteNode * node)69 TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
70   const TfLiteTensor* input1 = GetInput(context, node, kInputOneTensor);
71   TF_LITE_ENSURE(context, input1 != nullptr);
72   const RuntimeShape input1_shape = GetTensorShape(input1);
73 
74   const TfLiteTensor* input2 = GetInput(context, node, kInputTwoTensor);
75   TF_LITE_ENSURE(context, input2 != nullptr);
76   const RuntimeShape input2_shape = GetTensorShape(input2);
77 
78   TfLiteTensor* output1 = GetOutput(context, node, kOutputOneTensor);
79   TF_LITE_ENSURE(context, output1 != nullptr);
80   TfLiteTensor* output2 = GetOutput(context, node, kOutputTwoTensor);
81   TF_LITE_ENSURE(context, output2 != nullptr);
82 
83   std::vector<int64_t> input1_vec;
84   std::vector<int64_t> input2_vec;
85   if (input1->type == kTfLiteInt32) {
86     input1_vec = std::vector<int64_t>(input1->data.i32,
87                                       input1->data.i32 + input1_shape.Dims(0));
88   } else {
89     input1_vec = std::vector<int64_t>(input1->data.i64,
90                                       input1->data.i64 + input1_shape.Dims(0));
91   }
92   if (input2->type == kTfLiteInt32) {
93     input2_vec = std::vector<int64_t>(input2->data.i32,
94                                       input2->data.i32 + input2_shape.Dims(0));
95   } else {
96     input2_vec = std::vector<int64_t>(input2->data.i64,
97                                       input2->data.i64 + input2_shape.Dims(0));
98   }
99 
100   if (input1_vec == input2_vec) {
101     // All equals.
102     TfLiteIntArray* output1_shape = TfLiteIntArrayCreate(1);
103     output1_shape->data[0] = 0;
104     TF_LITE_ENSURE_OK(context,
105                       context->ResizeTensor(context, output1, output1_shape));
106 
107     TfLiteIntArray* output2_shape = TfLiteIntArrayCreate(1);
108     output2_shape->data[0] = 0;
109     TF_LITE_ENSURE_OK(context,
110                       context->ResizeTensor(context, output2, output2_shape));
111     return kTfLiteOk;
112   }
113 
114   size_t largest_rank = std::max(input1_vec.size(), input2_vec.size());
115 
116   // Reverse all the shapes for convenience
117   // After the reverse, 0-th is the inner-most dimension.
118   std::vector<int64_t> copy[2];
119   copy[0] = std::vector<int64_t>(input1_vec.rbegin(), input1_vec.rend());
120   copy[1] = std::vector<int64_t>(input2_vec.rbegin(), input2_vec.rend());
121 
122   // 1-extend and align all vectors.
123   for (int i = 0; i < 2; ++i) {
124     if (copy[i].size() < largest_rank) {
125       copy[i].resize(largest_rank, 1);
126     }
127   }
128   // Going through each dimension starting from the inner-most
129   // dimension, compares dimension of x and y. They are compatible if
130   // they are equal or either is 1.
131 
132   // indices of j-th component of each input.
133   std::array<bool, 2> prev_is_one = {false, false};
134   std::array<bool, 2> current_is_one = {false, false};
135   bool set_one = false;
136   // indices of gradient reduction of each input.
137   std::vector<int64_t> grad_reduce_idx[2];
138 
139   for (int j = 0; j < largest_rank; ++j) {
140     int output_dim = -1;
141     int output_dim_set = false;
142     bool none_is_one = true;
143     // Find which indices are 1.
144     for (int i = 0; i < 2; ++i) {
145       // Keep track of which indices are 1.
146       if (copy[i][j] == 1) {
147         current_is_one[i] = true;
148         none_is_one = false;
149       } else {
150         current_is_one[i] = false;
151         if (!output_dim_set || copy[i][j] == output_dim) {
152           output_dim = copy[i][j];
153           output_dim_set = true;
154         } else {
155           // Not broadcastable shapes.
156           return kTfLiteError;
157         }
158       }
159     }
160     // All dimensions are 1.
161     if (!output_dim_set) {
162       for (int i = 0; i < 2; ++i) {
163         grad_reduce_idx[i].push_back(largest_rank - 1 - j);
164       }
165       continue;
166     } else if (current_is_one == prev_is_one && set_one) {
167       // It is a run of the same broadcasting case as last time.
168       // We can reshape the input so that fewer dimensions
169       // are involved in the intermediate computation.
170       for (int i = 0; i < 2; ++i) {
171         if (current_is_one[i] && !none_is_one) {
172           grad_reduce_idx[i].push_back(largest_rank - 1 - j);
173         }
174       }
175     } else {
176       for (int i = 0; i < 2; ++i) {
177         if (current_is_one[i] && !none_is_one) {
178           grad_reduce_idx[i].push_back(largest_rank - 1 - j);
179         }
180       }
181     }
182     set_one = true;
183     for (int i = 0; i < 2; ++i) {
184       prev_is_one[i] = current_is_one[i];
185     }
186   }
187   for (int i = 0; i < 2; ++i) {
188     std::reverse(grad_reduce_idx[i].begin(), grad_reduce_idx[i].end());
189   }
190   TfLiteIntArray* output1_shape = TfLiteIntArrayCreate(1);
191   output1_shape->data[0] = grad_reduce_idx[0].size();
192   TF_LITE_ENSURE_OK(context,
193                     context->ResizeTensor(context, output1, output1_shape));
194   if (output1->type == kTfLiteInt32) {
195     for (int i = 0; i < grad_reduce_idx[0].size(); ++i) {
196       output1->data.i32[i] = grad_reduce_idx[0][i];
197     }
198   } else if (output1->type == kTfLiteInt64) {
199     for (int i = 0; i < grad_reduce_idx[0].size(); ++i) {
200       output1->data.i64[i] = grad_reduce_idx[0][i];
201     }
202   }
203 
204   TfLiteIntArray* output2_shape = TfLiteIntArrayCreate(1);
205   output2_shape->data[0] = grad_reduce_idx[1].size();
206   TF_LITE_ENSURE_OK(context,
207                     context->ResizeTensor(context, output2, output2_shape));
208   if (output2->type == kTfLiteInt32) {
209     for (int i = 0; i < grad_reduce_idx[1].size(); ++i) {
210       output2->data.i32[i] = grad_reduce_idx[1][i];
211     }
212   } else if (output2->type == kTfLiteInt64) {
213     for (int i = 0; i < grad_reduce_idx[1].size(); ++i) {
214       output2->data.i64[i] = grad_reduce_idx[1][i];
215     }
216   }
217   return kTfLiteOk;
218 }
219 
220 }  // namespace
221 
Register_BROADCAST_GRADIENT_ARGS()222 TfLiteRegistration* Register_BROADCAST_GRADIENT_ARGS() {
223   static TfLiteRegistration reg = {/*init=*/nullptr,
224                                    /*free=*/nullptr,
225                                    /*prepare=*/Prepare,
226                                    /*invoke=*/Invoke};
227   return &reg;
228 }
229 
230 }  // namespace custom
231 }  // namespace ops
232 }  // namespace tflite
233