xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/if.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stddef.h>
17 
18 #include <cstring>
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/lite/c/builtin_op_data.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/core/subgraph.h"
25 #include "tensorflow/lite/kernels/internal/compatibility.h"
26 #include "tensorflow/lite/kernels/kernel_util.h"
27 
28 namespace tflite {
29 namespace ops {
30 namespace builtin {
31 namespace if_kernel {
32 
33 struct OpData {
34   int then_subgraph_index;
35   int else_subgraph_index;
36 };
37 
Init(TfLiteContext * context,const char * buffer,size_t length)38 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
39   auto* op_data = new OpData;
40   const auto* params = reinterpret_cast<const TfLiteIfParams*>(buffer);
41   op_data->then_subgraph_index = params->then_subgraph_index;
42   op_data->else_subgraph_index = params->else_subgraph_index;
43   return op_data;
44 }
45 
Free(TfLiteContext * context,void * buffer)46 void Free(TfLiteContext* context, void* buffer) {
47   delete reinterpret_cast<OpData*>(buffer);
48 }
49 
Prepare(TfLiteContext * context,TfLiteNode * node)50 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
51   const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
52 
53   TF_LITE_ENSURE(context, node->inputs->size > 0);
54 
55   // The first input is the condition.
56   const TfLiteTensor* cond;
57   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
58   // Currently only bool is supported.
59   // TODO(ycling): Support other types since TensorFlow also support
60   // non-bool types as condition.
61   TF_LITE_ENSURE_EQ(context, cond->type, kTfLiteBool);
62   TF_LITE_ENSURE_EQ(context, NumElements(cond), 1);
63 
64   // The first input of the node is the condition. The rest of inputs are
65   // passed to the branch subgraphs. Therefore, the number of subgraph inputs
66   // will be the number of node inputs - 1.
67   int num_inputs = node->inputs->size - 1;
68   int num_outputs = node->outputs->size;
69 
70   Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
71   auto* subgraphs = this_subgraph->GetSubgraphs();
72   TF_LITE_ENSURE(context, op_data->then_subgraph_index < subgraphs->size());
73   TF_LITE_ENSURE(context, op_data->else_subgraph_index < subgraphs->size());
74 
75   Subgraph* then_subgraph = (*subgraphs)[op_data->then_subgraph_index].get();
76   Subgraph* else_subgraph = (*subgraphs)[op_data->else_subgraph_index].get();
77 
78   for (auto* subgraph : {then_subgraph, else_subgraph}) {
79     TF_LITE_ENSURE_EQ(context, num_inputs, subgraph->inputs().size());
80     TF_LITE_ENSURE_EQ(context, num_outputs, subgraph->outputs().size());
81   }
82 
83   bool has_dynamic_output_tensors = false;
84   for (auto* subgraph : {then_subgraph, else_subgraph}) {
85     for (int i = 0; i < num_inputs; ++i) {
86       // The first input of the node is the condition. The indices of the inputs
87       // passed to the subgraphs are offset by 1.
88       const TfLiteTensor* input;
89       TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input));
90       std::vector<int> dims(input->dims->data,
91                             input->dims->data + input->dims->size);
92       subgraph->ResizeInputTensor(i, dims);
93       TfLiteTensor* subgraph_input = subgraph->tensor(subgraph->inputs()[i]);
94       if (IsDynamicTensor(input)) {
95         SetTensorToDynamic(subgraph_input);
96       }
97       TF_LITE_ENSURE_TYPES_EQ(context, input->type, subgraph_input->type);
98     }
99     // Note: The `Prepare` function is responsible to run `AllocateTensors` on
100     // both subgraphs. It's intentionally not to break out of the loop when
101     // finding a dynamic output tensor.
102     TF_LITE_ENSURE_OK(context, subgraph->AllocateTensors());
103     has_dynamic_output_tensors |= subgraph->HasDynamicTensors();
104   }
105 
106   if (!has_dynamic_output_tensors) {
107     for (int i = 0; i < num_outputs; ++i) {
108       TfLiteTensor* then_output =
109           then_subgraph->tensor(then_subgraph->outputs()[i]);
110       TfLiteTensor* else_output =
111           else_subgraph->tensor(else_subgraph->outputs()[i]);
112       // If the 2 subgraphs have static but different output shapes, the output
113       // tensors of the IF op have dynamic sizes.
114       if (!TfLiteIntArrayEqual(then_output->dims, else_output->dims)) {
115         has_dynamic_output_tensors = true;
116         break;
117       }
118     }
119   }
120 
121   for (int i = 0; i < num_outputs; ++i) {
122     TfLiteTensor* output;
123     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
124     if (has_dynamic_output_tensors) {
125       SetTensorToDynamic(output);
126     } else {
127       // When there's no dynamic output tensors, the 2 subgraph has exactly
128       // the same static sized outputs.
129       TfLiteTensor* then_output =
130           then_subgraph->tensor(then_subgraph->outputs()[i]);
131       TfLiteIntArray* output_size = TfLiteIntArrayCopy(then_output->dims);
132       TF_LITE_ENSURE_OK(context,
133                         context->ResizeTensor(context, output, output_size));
134     }
135   }
136 
137   return kTfLiteOk;
138 }
139 
Eval(TfLiteContext * context,TfLiteNode * node)140 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
141   const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
142 
143   const TfLiteTensor* cond;
144   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond));
145   bool cond_value = cond->data.b[0];
146 
147   Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
148   auto* subgraphs = this_subgraph->GetSubgraphs();
149 
150   // Currently we copy the input / output between the subgraphs. This isn't
151   // optimized yet.
152   // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
153   int active_branch_subgraph_index =
154       cond_value ? op_data->then_subgraph_index : op_data->else_subgraph_index;
155   Subgraph& active_branch_subgraph =
156       *(*subgraphs)[active_branch_subgraph_index];
157 
158   // We release memory of the subgraph at the end of evaluation to save memory.
159   // So it's required to call AllocateTensors() for the second run.
160   TF_LITE_ENSURE_OK(context, active_branch_subgraph.AllocateTensors());
161 
162   for (int i = 0; i < active_branch_subgraph.inputs().size(); ++i) {
163     const TfLiteTensor* input;
164     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input));
165     TfLiteTensor* subgraph_input =
166         active_branch_subgraph.tensor(active_branch_subgraph.inputs()[i]);
167 
168     if (IsDynamicTensor(subgraph_input)) {
169       TfLiteTensorRealloc(input->bytes, subgraph_input);
170     }
171 
172     TF_LITE_ENSURE_EQ(context, input->bytes, subgraph_input->bytes);
173     TfLiteTensorCopy(input, subgraph_input);
174   }
175 
176   TF_LITE_ENSURE_OK(context, active_branch_subgraph.Invoke());
177 
178   for (int tensor_index : active_branch_subgraph.outputs()) {
179     active_branch_subgraph.EnsureTensorDataIsReadable(tensor_index);
180   }
181 
182   bool has_dynamic_output_tensors = false;
183   for (int i = 0; i < node->outputs->size; ++i) {
184     TfLiteTensor* output;
185     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
186     if (IsDynamicTensor(output)) {
187       has_dynamic_output_tensors = true;
188       break;
189     }
190   }
191 
192   if (has_dynamic_output_tensors) {
193     for (int i = 0; i < node->outputs->size; ++i) {
194       TfLiteTensor* output;
195       TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
196       TfLiteTensor* subgraph_output =
197           active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
198       TfLiteIntArray* output_size = TfLiteIntArrayCopy(subgraph_output->dims);
199       TF_LITE_ENSURE_OK(context,
200                         context->ResizeTensor(context, output, output_size));
201     }
202   }
203 
204   for (int i = 0; i < active_branch_subgraph.outputs().size(); ++i) {
205     const TfLiteTensor* subgraph_output =
206         active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
207     TfLiteTensor* output;
208     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
209 
210     if (IsDynamicTensor(output)) {
211       TfLiteTensorRealloc(subgraph_output->bytes, output);
212     }
213 
214     TF_LITE_ENSURE_EQ(context, output->bytes, subgraph_output->bytes);
215     TfLiteTensorCopy(subgraph_output, output);
216   }
217 
218   // Release memory of subgraphs to save the memory. Though it impacts latency,
219   // actual impacts looks very little, so no additional option is introduced for
220   // the feature until we find a different case.
221   Subgraph* then_subgraph = (*subgraphs)[op_data->then_subgraph_index].get();
222   Subgraph* else_subgraph = (*subgraphs)[op_data->else_subgraph_index].get();
223   TF_LITE_ENSURE_OK(context, then_subgraph->ReleaseMemory());
224   TF_LITE_ENSURE_OK(context, else_subgraph->ReleaseMemory());
225 
226   return kTfLiteOk;
227 }
228 
229 }  // namespace if_kernel
230 
Register_IF()231 TfLiteRegistration* Register_IF() {
232   static TfLiteRegistration r = {if_kernel::Init, if_kernel::Free,
233                                  if_kernel::Prepare, if_kernel::Eval};
234   return &r;
235 }
236 
237 }  // namespace builtin
238 }  // namespace ops
239 }  // namespace tflite
240