xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/hexagon/builders/batch_seq_builder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/hexagon/builders/batch_seq_builder.h"
16 
17 namespace tflite {
18 namespace delegates {
19 namespace hexagon {
20 
PopulateSubGraph(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs,TfLiteContext * context)21 TfLiteStatus BatchSeqBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
22                                                const TfLiteIntArray* outputs,
23                                                TfLiteContext* context) {
24   // Add config input.
25   static const int config_shape[] = {1, 1, 1, 3};
26   // TODO(b/152562126): Allow custom setting for BQ (preferred batch multiple),
27   // and Options.
28   // BQ is preferred batch multiple
29   // Options is currently 0 or 1, 0 is default and batches
30   // will run in increasing order, this behavior can be disabled by setting 1.
31   // Refer to Hexagon NN docs for more details.
32   int config[] = {max_size_for_batch_, 1, 0};
33 
34   auto* input_config = graph_builder_->AddConstNodeWithData(
35       config_shape, reinterpret_cast<char*>(&config), sizeof(int) * 3);
36   AddInput(TensorID(input_config->GetID(), 0));
37 
38   // Add Input batch details.
39   const int input_batch_dims_shape[] = {1, 1, 1, input_batch_dims_->size};
40   auto* input_batch_dims_node = graph_builder_->AddConstNodeWithData(
41       input_batch_dims_shape, reinterpret_cast<char*>(input_batch_dims_->data),
42       sizeof(input_batch_dims_[0]) * input_batch_dims_->size);
43   AddInput(TensorID(input_batch_dims_node->GetID(), 0));
44 
45   // Add Output batch details.
46   const int output_batch_dims_shape[] = {1, 1, 1, output_batch_dims_->size};
47   auto* output_batch_dims_node = graph_builder_->AddConstNodeWithData(
48       output_batch_dims_shape,
49       reinterpret_cast<char*>(output_batch_dims_->data),
50       sizeof(output_batch_dims_[0]) * output_batch_dims_->size);
51   AddInput(TensorID(output_batch_dims_node->GetID(), 0));
52 
53   return kTfLiteOk;
54 }
55 
CreateBatchSeqBuilder(GraphBuilder * graph_builder,int op_type,int max_size_for_batch,TfLiteIntArray * input_batch_dimensions,TfLiteIntArray * output_batch_dimensions)56 OpBuilder* CreateBatchSeqBuilder(GraphBuilder* graph_builder, int op_type,
57                                  int max_size_for_batch,
58                                  TfLiteIntArray* input_batch_dimensions,
59                                  TfLiteIntArray* output_batch_dimensions) {
60   auto* builder = new BatchSeqBuilder(graph_builder, op_type);
61   builder->SetMaxSizeForBatch(max_size_for_batch);
62   builder->SetInputBatchDimensions(input_batch_dimensions);
63   builder->SetOutputBatchDimensions(output_batch_dimensions);
64   return builder;
65 }
66 
67 }  // namespace hexagon
68 }  // namespace delegates
69 }  // namespace tflite
70