xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/while.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stddef.h>
17 
18 #include <cstring>
19 #include <vector>
20 
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/context_util.h"
24 #include "tensorflow/lite/core/subgraph.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 
27 namespace tflite {
28 namespace ops {
29 namespace builtin {
30 namespace while_kernel {
31 
32 struct OpData {
33   int cond_subgraph_index;
34   int body_subgraph_index;
35   bool cond_has_dynamic_output_tensors;
36   bool body_has_dynamic_output_tensors;
37   bool body_use_shallow_copy;
38   // set when Prepare_impl() is called.
39   bool subgraphs_prepared;
40 };
41 
42 namespace {
43 
44 // Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph`
45 // to `dst_tensor_indices` in `dst_subgraph`.
46 //
47 // When `resize_subgraph_inputs` is true, the function calls subgraphs's
48 // `ResizeInputTensor` function, and it may trigger the memory planner to
49 // reallocate memory.
50 // When `resize_subgraph_inputs` is false, it implies `context` belongs to
51 // `dst_subgraph`. The function calls `context->ResizeTensor`. This happens
52 // when resizing `While` op's outputs.
53 template <typename SrcVector, typename DstVector>
CopyTensorsShapeAndType(TfLiteContext * context,Subgraph * src_subgraph,const SrcVector & src_tensor_indices,Subgraph * dst_subgraph,const DstVector & dst_tensor_indices,bool resize_subgraph_inputs)54 TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context,
55                                      Subgraph* src_subgraph,
56                                      const SrcVector& src_tensor_indices,
57                                      Subgraph* dst_subgraph,
58                                      const DstVector& dst_tensor_indices,
59                                      bool resize_subgraph_inputs) {
60   TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
61                     dst_tensor_indices.size());
62   for (int i = 0; i < src_tensor_indices.size(); ++i) {
63     // Skip copying unused destination tensors.
64     if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue;
65 
66     const TfLiteTensor* src_tensor =
67         src_subgraph->tensor(src_tensor_indices[i]);
68 
69     TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
70     if (resize_subgraph_inputs) {
71       std::vector<int> dims(src_tensor->dims->data,
72                             src_tensor->dims->data + src_tensor->dims->size);
73       dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims);
74     } else {
75       TF_LITE_ENSURE_OK(
76           context, context->ResizeTensor(context, dst_tensor,
77                                          TfLiteIntArrayCopy(src_tensor->dims)));
78     }
79     dst_tensor->type = src_tensor->type;
80   }
81   return kTfLiteOk;
82 }
83 
84 // Copy the tensors data from tensors `src_tensor_indices` in `src_subgraph`
85 // to `dst_tensor_indices` in `dst_subgraph`.
86 template <typename SrcVector, typename DstVector>
CopyTensorsData(TfLiteContext * context,Subgraph * src_subgraph,const SrcVector & src_tensor_indices,Subgraph * dst_subgraph,const DstVector & dst_tensor_indices)87 TfLiteStatus CopyTensorsData(TfLiteContext* context, Subgraph* src_subgraph,
88                              const SrcVector& src_tensor_indices,
89                              Subgraph* dst_subgraph,
90                              const DstVector& dst_tensor_indices) {
91   TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
92                     dst_tensor_indices.size());
93   for (int i = 0; i < src_tensor_indices.size(); ++i) {
94     // Skip copying unused destination tensors.
95     if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue;
96 
97     const TfLiteTensor* src_tensor =
98         src_subgraph->tensor(src_tensor_indices[i]);
99     TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
100     if (IsDynamicTensor(dst_tensor)) {
101       TfLiteTensorRealloc(src_tensor->bytes, dst_tensor);
102     }
103     TF_LITE_ENSURE_EQ(context, src_tensor->bytes, dst_tensor->bytes);
104     TfLiteTensorCopy(src_tensor, dst_tensor);
105   }
106   return kTfLiteOk;
107 }
108 
109 // Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph`
110 // to `dst_tensor_indices` in `dst_subgraph` and copy data deeply.
111 template <typename SrcVector, typename DstVector>
DeepCopyTensorsShapeTypeData(TfLiteContext * context,TfLiteNode * node,Subgraph * src_subgraph,const SrcVector & src_tensor_indices,Subgraph * dst_subgraph,const DstVector & dst_tensor_indices)112 TfLiteStatus DeepCopyTensorsShapeTypeData(TfLiteContext* context,
113                                           TfLiteNode* node,
114                                           Subgraph* src_subgraph,
115                                           const SrcVector& src_tensor_indices,
116                                           Subgraph* dst_subgraph,
117                                           const DstVector& dst_tensor_indices) {
118   const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
119 
120   if (op_data->body_has_dynamic_output_tensors) {
121     Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
122     bool resize_subgraph_inputs = (dst_subgraph != this_subgraph);
123     TF_LITE_ENSURE_OK(
124         context, CopyTensorsShapeAndType(
125                      context, src_subgraph, src_tensor_indices, dst_subgraph,
126                      dst_tensor_indices, resize_subgraph_inputs));
127     if (resize_subgraph_inputs) {
128       TF_LITE_ENSURE_OK(context, dst_subgraph->AllocateTensors());
129     }
130   }
131   TF_LITE_ENSURE_OK(context,
132                     CopyTensorsData(context, src_subgraph, src_tensor_indices,
133                                     dst_subgraph, dst_tensor_indices));
134   return kTfLiteOk;
135 }
136 
137 // Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph`
138 // to `dst_tensor_indices` in `dst_subgraph` and copy data shallowly.
139 template <typename SrcVector, typename DstVector>
ShallowCopyTensorsShapeTypeData(TfLiteContext * context,TfLiteNode * node,Subgraph * src_subgraph,const SrcVector & src_tensor_indices,Subgraph * dst_subgraph,const DstVector & dst_tensor_indices)140 TfLiteStatus ShallowCopyTensorsShapeTypeData(
141     TfLiteContext* context, TfLiteNode* node, Subgraph* src_subgraph,
142     const SrcVector& src_tensor_indices, Subgraph* dst_subgraph,
143     const DstVector& dst_tensor_indices) {
144   const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
145   Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
146   TF_LITE_ENSURE_EQ(context, op_data->body_has_dynamic_output_tensors, true);
147   // Only allow shallow copy from main node input.
148   TF_LITE_ENSURE_EQ(context, src_subgraph, this_subgraph);
149 
150   TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
151                     dst_tensor_indices.size());
152   bool reallocation_needed = false;
153   for (int i = 0; i < src_tensor_indices.size(); ++i) {
154     // Skip copying unused destination tensors.
155     if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue;
156 
157     const TfLiteTensor* src_tensor =
158         src_subgraph->tensor(src_tensor_indices[i]);
159     TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
160 
161     if (!TfLiteIntArrayEqual(src_tensor->dims, dst_tensor->dims)) {
162       reallocation_needed = true;
163       TfLiteIntArrayFree(dst_tensor->dims);
164       dst_tensor->dims = TfLiteIntArrayCopy(src_tensor->dims);
165     }
166     dst_tensor->type = src_tensor->type;
167     dst_tensor->bytes = 0;  // Don't allocate memory with AllocateTensors().
168     dst_tensor->data.raw = nullptr;
169   }
170 
171   if (reallocation_needed && dst_subgraph != this_subgraph) {
172     TF_LITE_ENSURE_OK(context, dst_subgraph->AllocateTensors());
173   }
174 
175   for (int i = 0; i < src_tensor_indices.size(); ++i) {
176     // Skip copying unused destination tensors.
177     if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue;
178 
179     const TfLiteTensor* src_tensor =
180         src_subgraph->tensor(src_tensor_indices[i]);
181     TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
182 
183     dst_tensor->bytes = src_tensor->bytes;
184     dst_tensor->data.raw = src_tensor->data.raw;
185   }
186 
187   return kTfLiteOk;
188 }
189 
CheckCondOutput(TfLiteContext * context,const TfLiteTensor * cond_output)190 TfLiteStatus CheckCondOutput(TfLiteContext* context,
191                              const TfLiteTensor* cond_output) {
192   // The condition output must be a single boolean value.
193   TF_LITE_ENSURE_TYPES_EQ(context, cond_output->type, kTfLiteBool);
194   if (cond_output->dims->size == 0) {
195     // It's okay if it's a 0D scalar.
196     return kTfLiteOk;
197   }
198   // Otherwise it must be 1D with shape [1].
199   TF_LITE_ENSURE_EQ(context, cond_output->dims->size, 1);
200   TF_LITE_ENSURE_EQ(context, cond_output->dims->data[0], 1);
201   return kTfLiteOk;
202 }
203 
204 }  // namespace
205 
Init(TfLiteContext * context,const char * buffer,size_t length)206 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
207   auto* op_data = new OpData;
208   const auto* params = reinterpret_cast<const TfLiteWhileParams*>(buffer);
209   op_data->cond_subgraph_index = params->cond_subgraph_index;
210   op_data->body_subgraph_index = params->body_subgraph_index;
211   op_data->cond_has_dynamic_output_tensors = false;
212   op_data->body_has_dynamic_output_tensors = false;
213   op_data->body_use_shallow_copy = false;
214   op_data->subgraphs_prepared = false;
215   return op_data;
216 }
217 
Free(TfLiteContext * context,void * buffer)218 void Free(TfLiteContext* context, void* buffer) {
219   delete reinterpret_cast<OpData*>(buffer);
220 }
221 
Prepare_impl(TfLiteContext * context,TfLiteNode * node)222 TfLiteStatus Prepare_impl(TfLiteContext* context, TfLiteNode* node) {
223   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
224   int num_inputs = node->inputs->size;
225   // The number of outputs should be the same as number of inputs.
226   TF_LITE_ENSURE_EQ(context, node->outputs->size, num_inputs);
227 
228   // Check subgraph indices and get subgraphs.
229   Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
230   auto* subgraphs = this_subgraph->GetSubgraphs();
231   TF_LITE_ENSURE(context, op_data->cond_subgraph_index < subgraphs->size());
232   TF_LITE_ENSURE(context, op_data->body_subgraph_index < subgraphs->size());
233   TF_LITE_ENSURE(context,
234                  op_data->cond_subgraph_index != op_data->body_subgraph_index);
235 
236   Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
237   Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
238 
239   // Check input & output count of the condition subgraph.
240   TF_LITE_ENSURE_EQ(context, cond_subgraph->inputs().size(), num_inputs);
241   TF_LITE_ENSURE_EQ(context, cond_subgraph->outputs().size(), 1);
242 
243   // Check input & output count of the body subgraph.
244   TF_LITE_ENSURE_EQ(context, body_subgraph->inputs().size(), num_inputs);
245   TF_LITE_ENSURE_EQ(context, body_subgraph->outputs().size(), num_inputs);
246 
247   // Remove unused inputs of the condition subgraph to skip copying unnecessary
248   // inputs.
249   cond_subgraph->RemoveUnusedInputs();
250 
251   // Prepare and check the condition subgraph.
252   TF_LITE_ENSURE_OK(
253       context, CopyTensorsShapeAndType(
254                    context, this_subgraph, TfLiteIntArrayView(node->inputs),
255                    cond_subgraph, cond_subgraph->inputs(), true));
256   TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
257   TfLiteTensor* cond_output =
258       cond_subgraph->tensor(cond_subgraph->outputs()[0]);
259   // This should rarely happens. In most cases the output is static with shape
260   // [1]. However theoretically intermediate tensors in the cond subgraph
261   // can be dynamic.
262   if (IsDynamicTensor(cond_output)) {
263     op_data->cond_has_dynamic_output_tensors = true;
264   } else {
265     TF_LITE_ENSURE_STATUS(CheckCondOutput(context, cond_output));
266   }
267 
268   // Prepare and check the body subgraph.
269   TF_LITE_ENSURE_OK(
270       context, CopyTensorsShapeAndType(
271                    context, this_subgraph, TfLiteIntArrayView(node->inputs),
272                    body_subgraph, body_subgraph->inputs(), true));
273 
274   bool input_has_resource_or_variant_tensor = false;
275   for (int i = 0; i < num_inputs; ++i) {
276     if (IsResourceOrVariant(
277             body_subgraph->tensor(body_subgraph->inputs()[i]))) {
278       input_has_resource_or_variant_tensor = true;
279       break;
280     }
281   }
282   if (this_subgraph->ShouldOptimizeMemoryForLargeTensors() &&
283       !input_has_resource_or_variant_tensor) {
284     // The current shallow copy requires to use dynamic tensors which introduces
285     // additional overheads. Therefore, use the method only if dynamic
286     // allocation is enabled.
287     op_data->body_use_shallow_copy = true;
288     op_data->body_has_dynamic_output_tensors = true;
289     // Make body inputs dynamic to use shallow copy with Eval_dynamic().
290     for (int i = 0; i < num_inputs; ++i) {
291       TfLiteTensor* body_input =
292           body_subgraph->tensor(body_subgraph->inputs()[i]);
293       SetTensorToDynamic(body_input);
294       body_input->bytes = 0;
295     }
296   }
297 
298   TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
299   if (body_subgraph->HasDynamicTensors()) {
300     op_data->body_has_dynamic_output_tensors = true;
301   } else {
302     for (int i = 0; i < num_inputs; ++i) {
303       TfLiteTensor* body_input =
304           body_subgraph->tensor(body_subgraph->inputs()[i]);
305       TfLiteTensor* body_output =
306           body_subgraph->tensor(body_subgraph->outputs()[i]);
307       TF_LITE_ENSURE_TYPES_EQ(context, body_input->type, body_output->type);
308 
309       TF_LITE_ENSURE(context, !IsDynamicTensor(body_output));
310       if (!TfLiteIntArrayEqual(body_input->dims, body_output->dims)) {
311         // If the output shape of the body subgraph is static w.r.t. a fixed
312         // input size, but it's different from input size, it's still considered
313         // dynamic. For example: If a subgraph keeps padding its input with a
314         // fixed padding, the output shape is static w.r.t the input shape and
315         // padding, but running it in a loop will keep bloating the tensor.
316         op_data->body_has_dynamic_output_tensors = true;
317         break;
318       }
319     }
320   }
321   for (int i = 0; i < num_inputs; ++i) {
322     TfLiteTensor* output;
323     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
324     if (op_data->body_has_dynamic_output_tensors) {
325       SetTensorToDynamic(output);
326     } else {
327       TfLiteTensor* body_output =
328           body_subgraph->tensor(body_subgraph->outputs()[i]);
329       TfLiteIntArray* output_size = TfLiteIntArrayCopy(body_output->dims);
330       TF_LITE_ENSURE_OK(context,
331                         context->ResizeTensor(context, output, output_size));
332     }
333   }
334   op_data->subgraphs_prepared = true;
335   return kTfLiteOk;
336 }
337 
Prepare(TfLiteContext * context,TfLiteNode * node)338 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
339   Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
340   if (this_subgraph->ShouldOptimizeMemoryForLargeTensors()) {
341     // Apply lazy initialization of WHILE kernel.
342     // Just make node output tensors dynamic.
343     int num_outputs = node->outputs->size;
344     for (int i = 0; i < num_outputs; ++i) {
345       TfLiteTensor* output;
346       TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
347       SetTensorToDynamic(output);
348     }
349     return kTfLiteOk;
350   }
351   return Prepare_impl(context, node);
352 }
353 
Prepare_lazy(TfLiteContext * context,TfLiteNode * node)354 TfLiteStatus Prepare_lazy(TfLiteContext* context, TfLiteNode* node) {
355   return Prepare_impl(context, node);
356 }
357 
358 // Evaluate cond subgraph and set the result.
Eval_cond_subgraph(TfLiteContext * context,Subgraph * cond_subgraph,bool cond_has_dynamic_output_tensors,bool * cond_subgraph_output)359 TfLiteStatus Eval_cond_subgraph(TfLiteContext* context, Subgraph* cond_subgraph,
360                                 bool cond_has_dynamic_output_tensors,
361                                 bool* cond_subgraph_output) {
362   TF_LITE_ENSURE_OK(context, cond_subgraph->Invoke());
363   int cond_subgraph_output_index = cond_subgraph->outputs()[0];
364   cond_subgraph->EnsureTensorDataIsReadable(cond_subgraph_output_index);
365   TfLiteTensor* cond_output = cond_subgraph->tensor(cond_subgraph_output_index);
366   if (cond_has_dynamic_output_tensors) {
367     TF_LITE_ENSURE_STATUS(CheckCondOutput(context, cond_output));
368   }
369 
370   *cond_subgraph_output = (cond_output->data.b[0]);
371   return kTfLiteOk;
372 }
373 
374 // Evaluate WHILE op when body subgraph has dynamic outputs.
Eval_dynamic(TfLiteContext * context,TfLiteNode * node)375 TfLiteStatus Eval_dynamic(TfLiteContext* context, TfLiteNode* node) {
376   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
377   Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
378   auto* subgraphs = this_subgraph->GetSubgraphs();
379   Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
380   Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
381 
382   // The follow graph illustrates the current implementation.
383   //
384   // This Subgraph          Cond Subgraph         Body Subgraph
385   // +-----------+   (1)   +------------+         +------------+
386   // |   WHILE   |-------->|  SUBGRAPH  |         |  SUBGRAPH  |
387   // |   INPUT   |         |   INPUT    |         |   INPUT    |
388   // |           |         |     ---------------->|            |
389   // |           |         |   /        | <----   |            |
390   // +-----------+         +--/---------+      \  +------------+
391   //      |                 /    |              \       |
392   //      | (2)       (4) /      | (3)       (6) \      | (5)
393   //      v             /        v                \     v
394   // +-----------+    /    +------------+         +------------+
395   // |   WHILE   |--/      |  SUBGRAPH  |         |  SUBGRAPH  |
396   // |   OUTPUT  |    (7)  |   OUTPUT   |         |   OUTPUT   |
397   // |           |<-------------------------------|            |
398   // +-----------+         +------------+         +------------+
399   //
400   // (1) Copy the inputs of WHILE op to the inputs of condition subgraph.
401   // (2) Copy the inputs of WHILE op to the outputs of WHILE op
402   // (3) Invoke condition subgraph.
403   //     Exit the loop if the result is false.
404   // (4) Copy the outputs of WHILE op to the inputs of body subgraph.
405   // (5) Invoke body subgraph.
406   // (6) Copy the outputs of body subgraph to the inputs condition subgraph.
407   // (7) Copy the outputs of body subgraph to the outputs of WHILE op.
408   //     Jump back to step 3!
409   //
410   // If the body subgraph has dynamic sized outputs, it's required to resize the
411   // tensor before copying in step 1, 2, 4, 6 and 7.
412   //
413   // Note the flow is carefully designed to handle the dynamic sized output
414   // case. The loop invariant is: The newest value is in the inputs of condition
415   // subgraph. This is always true before step 3.
416 
417   // Step 1. node->inputs -> cond->inputs (fast)
418   TF_LITE_ENSURE_OK(context, DeepCopyTensorsShapeTypeData(
419                                  context, node, this_subgraph,
420                                  TfLiteIntArrayView(node->inputs),
421                                  cond_subgraph, cond_subgraph->inputs()));
422 
423   // Step 2. node->inputs -> node->outputs
424   TF_LITE_ENSURE_OK(
425       context, DeepCopyTensorsShapeTypeData(context, node, this_subgraph,
426                                             TfLiteIntArrayView(node->inputs),
427                                             this_subgraph,
428                                             TfLiteIntArrayView(node->outputs)));
429 
430   while (true) {
431     // Step 3. Eval cond subgraph
432     bool cond_subgraph_output;
433     TF_LITE_ENSURE_OK(
434         context, Eval_cond_subgraph(context, cond_subgraph,
435                                     op_data->cond_has_dynamic_output_tensors,
436                                     &cond_subgraph_output));
437     if (!cond_subgraph_output) {
438       break;
439     }
440 
441     // Step 4. node->outputs -> body->inputs
442     if (op_data->body_use_shallow_copy) {
443       TF_LITE_ENSURE_OK(context, ShallowCopyTensorsShapeTypeData(
444                                      context, node, this_subgraph,
445                                      TfLiteIntArrayView(node->outputs),
446                                      body_subgraph, body_subgraph->inputs()));
447     } else {
448       TF_LITE_ENSURE_OK(context, DeepCopyTensorsShapeTypeData(
449                                      context, node, this_subgraph,
450                                      TfLiteIntArrayView(node->outputs),
451                                      body_subgraph, body_subgraph->inputs()));
452     }
453 
454     // Step 5. Invoke body subgraph
455     TF_LITE_ENSURE_OK(context, body_subgraph->Invoke());
456     for (int tensor_index : body_subgraph->outputs()) {
457       body_subgraph->EnsureTensorDataIsReadable(tensor_index);
458     }
459 
460     // Step 6. body->outputs -> cond->inputs (fast)
461     TF_LITE_ENSURE_OK(
462         context, DeepCopyTensorsShapeTypeData(
463                      context, node, body_subgraph, body_subgraph->outputs(),
464                      cond_subgraph, cond_subgraph->inputs()));
465 
466     // Step 7. body->outputs -> node->outputs
467     TF_LITE_ENSURE_OK(
468         context, DeepCopyTensorsShapeTypeData(
469                      context, node, body_subgraph, body_subgraph->outputs(),
470                      this_subgraph, TfLiteIntArrayView(node->outputs)));
471   }
472 
473   if (op_data->body_use_shallow_copy) {
474     // Clean up shallow copied pointer of body inputs.
475     for (int i = 0; i < body_subgraph->inputs().size(); ++i) {
476       TfLiteTensor* body_input =
477           body_subgraph->tensor(body_subgraph->inputs()[i]);
478       body_input->data.raw = nullptr;
479     }
480   }
481 
482   return kTfLiteOk;
483 }
484 
485 // Evaluate WHILE op when body subgraph has static outputs.
Eval_static(TfLiteContext * context,TfLiteNode * node)486 TfLiteStatus Eval_static(TfLiteContext* context, TfLiteNode* node) {
487   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
488   Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
489   auto* subgraphs = this_subgraph->GetSubgraphs();
490   Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
491   Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
492 
493   // The follow graph illustrates the current implementation.
494   //
495   // This Subgraph          Cond Subgraph         Body Subgraph
496   // +-----------+   (1)   +------------+         +------------+
497   // |   WHILE   |-------->|  SUBGRAPH  |         |  SUBGRAPH  |
498   // |   INPUT   |  (3-1) /|   INPUT    |         |   INPUT    |
499   // |           |------------------------------->|            |
500   // |           |         |            | <----   |            |
501   // +-----------+         +------------+      \  +------------+
502   //                             |              \       |     ^
503   //                             | (2)       (5) \      | (4) | (3-2)
504   //                             v                \     v     |
505   // +-----------+         +------------+         +------------+
506   // |   WHILE   |         |  SUBGRAPH  |         |  SUBGRAPH  |
507   // |   OUTPUT  |    (6)  |   OUTPUT   |         |   OUTPUT   |
508   // |           |<-------------------------------|            |
509   // +-----------+         +------------+         +------------+
510   //
511   // (1) Copy the inputs of WHILE op to the inputs of condition subgraph.
512   // (2) Invoke condition subgraph.
513   //     Jump to step 6 if the result is false.
514   // (3) If body is never invoked, run the step 3-1, else run the step 3-2.
515   // (3-1) Copy the inputs of WHILE op to the inputs of body subgraph.
516   // (3-2) Copy the outputs of body subgraph to the inputs of body subgraph.
517   // (4) Invoke body subgraph.
518   // (5) Copy the outputs of body subgraph to the inputs condition subgraph.
519   //     Jump back to step 2!
520   // (6) Copy the outputs of body subgraph to the outputs of WHILE op.
521   //
522   // The body subgraph shouldn't have dynamic sized outputs.
523 
524   // Step 1. node->inputs -> cond->inputs (fast)
525   TF_LITE_ENSURE_OK(
526       context,
527       CopyTensorsData(context, this_subgraph, TfLiteIntArrayView(node->inputs),
528                       cond_subgraph, cond_subgraph->inputs()));
529 
530   bool body_invoked = false;
531   while (true) {
532     // Step 2. Eval cond subgraph
533     bool cond_subgraph_output;
534     TF_LITE_ENSURE_OK(
535         context, Eval_cond_subgraph(context, cond_subgraph,
536                                     op_data->cond_has_dynamic_output_tensors,
537                                     &cond_subgraph_output));
538     if (!cond_subgraph_output) {
539       break;
540     }
541 
542     if (body_invoked) {
543       // Step 3-2. body->output -> body->inputs
544       TF_LITE_ENSURE_OK(
545           context,
546           CopyTensorsData(context, body_subgraph, body_subgraph->outputs(),
547                           body_subgraph, body_subgraph->inputs()));
548     } else {
549       // Step 3-1. node->inputs -> body->inputs
550       TF_LITE_ENSURE_OK(
551           context, CopyTensorsData(context, this_subgraph,
552                                    TfLiteIntArrayView(node->inputs),
553                                    body_subgraph, body_subgraph->inputs()));
554     }
555 
556     // Step 4. Invoke body subgraph
557     TF_LITE_ENSURE_OK(context, body_subgraph->Invoke());
558     body_invoked = true;
559     for (int tensor_index : body_subgraph->outputs()) {
560       body_subgraph->EnsureTensorDataIsReadable(tensor_index);
561     }
562 
563     // Step 5. body->output -> cond->inputs (fast)
564     TF_LITE_ENSURE_OK(
565         context,
566         CopyTensorsData(context, body_subgraph, body_subgraph->outputs(),
567                         cond_subgraph, cond_subgraph->inputs()));
568   }
569 
570   if (body_invoked) {
571     // Step 6. Copy body->output -> node->outputs
572     TF_LITE_ENSURE_OK(
573         context,
574         CopyTensorsData(context, body_subgraph, body_subgraph->outputs(),
575                         this_subgraph, TfLiteIntArrayView(node->outputs)));
576   } else {
577     // Copy node->inputs if body is never invoked.
578     TF_LITE_ENSURE_OK(
579         context, CopyTensorsData(
580                      context, this_subgraph, TfLiteIntArrayView(node->inputs),
581                      this_subgraph, TfLiteIntArrayView(node->outputs)));
582   }
583 
584   return kTfLiteOk;
585 }
586 
Eval(TfLiteContext * context,TfLiteNode * node)587 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
588   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
589   Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
590   auto* subgraphs = this_subgraph->GetSubgraphs();
591   Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
592   Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
593 
594   if (op_data->subgraphs_prepared == false) {
595     TF_LITE_ENSURE_OK(context, Prepare_lazy(context, node));
596   } else {
597     TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
598     TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
599   }
600 
601   if (op_data->body_has_dynamic_output_tensors) {
602     TF_LITE_ENSURE_OK(context, Eval_dynamic(context, node));
603   } else {
604     TF_LITE_ENSURE_OK(context, Eval_static(context, node));
605   }
606 
607   if (!this_subgraph->ShouldPreserveAllTensors()) {
608     TF_LITE_ENSURE_OK(context, cond_subgraph->ReleaseMemory());
609     TF_LITE_ENSURE_OK(context, body_subgraph->ReleaseMemory());
610   }
611 
612   return kTfLiteOk;
613 }
614 
615 }  // namespace while_kernel
616 
Register_WHILE()617 TfLiteRegistration* Register_WHILE() {
618   static TfLiteRegistration r = {while_kernel::Init, while_kernel::Free,
619                                  while_kernel::Prepare, while_kernel::Eval};
620   return &r;
621 }
622 
623 }  // namespace builtin
624 }  // namespace ops
625 }  // namespace tflite
626