1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements the TensorFlow Lite's broadcast gradient argument
17 // operator.
18
19 #include <algorithm>
20 #include <array>
21 #include <cmath>
22
23 #include "tensorflow/lite/c/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/kernels/internal/common.h"
26 #include "tensorflow/lite/kernels/internal/tensor.h"
27 #include "tensorflow/lite/kernels/internal/types.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 #include "tensorflow/lite/kernels/padding.h"
30
31 namespace tflite {
32 namespace ops {
33 namespace custom {
34 namespace {
35
36 static const int kInputOneTensor = 0;
37 static const int kInputTwoTensor = 1;
38 static const int kOutputOneTensor = 0;
39 static const int kOutputTwoTensor = 1;
40
Prepare(TfLiteContext * context,TfLiteNode * node)41 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42 // Check inputs and output.
43 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
44 const TfLiteTensor* input1 = GetInput(context, node, kInputOneTensor);
45 TF_LITE_ENSURE(context, input1 != nullptr);
46 const RuntimeShape input1_shape = GetTensorShape(input1);
47 TF_LITE_ENSURE(context,
48 input1->type == kTfLiteInt32 || input1->type == kTfLiteInt64);
49 TF_LITE_ENSURE_EQ(context, input1_shape.DimensionsCount(), 1);
50
51 const TfLiteTensor* input2 = GetInput(context, node, kInputTwoTensor);
52 TF_LITE_ENSURE(context, input2 != nullptr);
53 const RuntimeShape input2_shape = GetTensorShape(input2);
54 TF_LITE_ENSURE_TYPES_EQ(context, input2->type, input1->type);
55 TF_LITE_ENSURE_EQ(context, input2_shape.DimensionsCount(), 1);
56
57 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
58 TfLiteTensor* output1 = GetOutput(context, node, kOutputOneTensor);
59 TF_LITE_ENSURE(context, output1 != nullptr);
60 TF_LITE_ENSURE_TYPES_EQ(context, output1->type, input1->type);
61 TfLiteTensor* output2 = GetOutput(context, node, kOutputTwoTensor);
62 TF_LITE_ENSURE(context, output2 != nullptr);
63 TF_LITE_ENSURE_TYPES_EQ(context, output2->type, input1->type);
64 SetTensorToDynamic(output1);
65 SetTensorToDynamic(output2);
66 return kTfLiteOk;
67 }
68
Invoke(TfLiteContext * context,TfLiteNode * node)69 TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
70 const TfLiteTensor* input1 = GetInput(context, node, kInputOneTensor);
71 TF_LITE_ENSURE(context, input1 != nullptr);
72 const RuntimeShape input1_shape = GetTensorShape(input1);
73
74 const TfLiteTensor* input2 = GetInput(context, node, kInputTwoTensor);
75 TF_LITE_ENSURE(context, input2 != nullptr);
76 const RuntimeShape input2_shape = GetTensorShape(input2);
77
78 TfLiteTensor* output1 = GetOutput(context, node, kOutputOneTensor);
79 TF_LITE_ENSURE(context, output1 != nullptr);
80 TfLiteTensor* output2 = GetOutput(context, node, kOutputTwoTensor);
81 TF_LITE_ENSURE(context, output2 != nullptr);
82
83 std::vector<int64_t> input1_vec;
84 std::vector<int64_t> input2_vec;
85 if (input1->type == kTfLiteInt32) {
86 input1_vec = std::vector<int64_t>(input1->data.i32,
87 input1->data.i32 + input1_shape.Dims(0));
88 } else {
89 input1_vec = std::vector<int64_t>(input1->data.i64,
90 input1->data.i64 + input1_shape.Dims(0));
91 }
92 if (input2->type == kTfLiteInt32) {
93 input2_vec = std::vector<int64_t>(input2->data.i32,
94 input2->data.i32 + input2_shape.Dims(0));
95 } else {
96 input2_vec = std::vector<int64_t>(input2->data.i64,
97 input2->data.i64 + input2_shape.Dims(0));
98 }
99
100 if (input1_vec == input2_vec) {
101 // All equals.
102 TfLiteIntArray* output1_shape = TfLiteIntArrayCreate(1);
103 output1_shape->data[0] = 0;
104 TF_LITE_ENSURE_OK(context,
105 context->ResizeTensor(context, output1, output1_shape));
106
107 TfLiteIntArray* output2_shape = TfLiteIntArrayCreate(1);
108 output2_shape->data[0] = 0;
109 TF_LITE_ENSURE_OK(context,
110 context->ResizeTensor(context, output2, output2_shape));
111 return kTfLiteOk;
112 }
113
114 size_t largest_rank = std::max(input1_vec.size(), input2_vec.size());
115
116 // Reverse all the shapes for convenience
117 // After the reverse, 0-th is the inner-most dimension.
118 std::vector<int64_t> copy[2];
119 copy[0] = std::vector<int64_t>(input1_vec.rbegin(), input1_vec.rend());
120 copy[1] = std::vector<int64_t>(input2_vec.rbegin(), input2_vec.rend());
121
122 // 1-extend and align all vectors.
123 for (int i = 0; i < 2; ++i) {
124 if (copy[i].size() < largest_rank) {
125 copy[i].resize(largest_rank, 1);
126 }
127 }
128 // Going through each dimension starting from the inner-most
129 // dimension, compares dimension of x and y. They are compatible if
130 // they are equal or either is 1.
131
132 // indices of j-th component of each input.
133 std::array<bool, 2> prev_is_one = {false, false};
134 std::array<bool, 2> current_is_one = {false, false};
135 bool set_one = false;
136 // indices of gradient reduction of each input.
137 std::vector<int64_t> grad_reduce_idx[2];
138
139 for (int j = 0; j < largest_rank; ++j) {
140 int output_dim = -1;
141 int output_dim_set = false;
142 bool none_is_one = true;
143 // Find which indices are 1.
144 for (int i = 0; i < 2; ++i) {
145 // Keep track of which indices are 1.
146 if (copy[i][j] == 1) {
147 current_is_one[i] = true;
148 none_is_one = false;
149 } else {
150 current_is_one[i] = false;
151 if (!output_dim_set || copy[i][j] == output_dim) {
152 output_dim = copy[i][j];
153 output_dim_set = true;
154 } else {
155 // Not broadcastable shapes.
156 return kTfLiteError;
157 }
158 }
159 }
160 // All dimensions are 1.
161 if (!output_dim_set) {
162 for (int i = 0; i < 2; ++i) {
163 grad_reduce_idx[i].push_back(largest_rank - 1 - j);
164 }
165 continue;
166 } else if (current_is_one == prev_is_one && set_one) {
167 // It is a run of the same broadcasting case as last time.
168 // We can reshape the input so that fewer dimensions
169 // are involved in the intermediate computation.
170 for (int i = 0; i < 2; ++i) {
171 if (current_is_one[i] && !none_is_one) {
172 grad_reduce_idx[i].push_back(largest_rank - 1 - j);
173 }
174 }
175 } else {
176 for (int i = 0; i < 2; ++i) {
177 if (current_is_one[i] && !none_is_one) {
178 grad_reduce_idx[i].push_back(largest_rank - 1 - j);
179 }
180 }
181 }
182 set_one = true;
183 for (int i = 0; i < 2; ++i) {
184 prev_is_one[i] = current_is_one[i];
185 }
186 }
187 for (int i = 0; i < 2; ++i) {
188 std::reverse(grad_reduce_idx[i].begin(), grad_reduce_idx[i].end());
189 }
190 TfLiteIntArray* output1_shape = TfLiteIntArrayCreate(1);
191 output1_shape->data[0] = grad_reduce_idx[0].size();
192 TF_LITE_ENSURE_OK(context,
193 context->ResizeTensor(context, output1, output1_shape));
194 if (output1->type == kTfLiteInt32) {
195 for (int i = 0; i < grad_reduce_idx[0].size(); ++i) {
196 output1->data.i32[i] = grad_reduce_idx[0][i];
197 }
198 } else if (output1->type == kTfLiteInt64) {
199 for (int i = 0; i < grad_reduce_idx[0].size(); ++i) {
200 output1->data.i64[i] = grad_reduce_idx[0][i];
201 }
202 }
203
204 TfLiteIntArray* output2_shape = TfLiteIntArrayCreate(1);
205 output2_shape->data[0] = grad_reduce_idx[1].size();
206 TF_LITE_ENSURE_OK(context,
207 context->ResizeTensor(context, output2, output2_shape));
208 if (output2->type == kTfLiteInt32) {
209 for (int i = 0; i < grad_reduce_idx[1].size(); ++i) {
210 output2->data.i32[i] = grad_reduce_idx[1][i];
211 }
212 } else if (output2->type == kTfLiteInt64) {
213 for (int i = 0; i < grad_reduce_idx[1].size(); ++i) {
214 output2->data.i64[i] = grad_reduce_idx[1][i];
215 }
216 }
217 return kTfLiteOk;
218 }
219
220 } // namespace
221
Register_BROADCAST_GRADIENT_ARGS()222 TfLiteRegistration* Register_BROADCAST_GRADIENT_ARGS() {
223 static TfLiteRegistration reg = {/*init=*/nullptr,
224 /*free=*/nullptr,
225 /*prepare=*/Prepare,
226 /*invoke=*/Invoke};
227 return ®
228 }
229
230 } // namespace custom
231 } // namespace ops
232 } // namespace tflite
233