xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/coreml/builders/pooling_layer_builder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/coreml/builders/pooling_layer_builder.h"
16 
17 #include <string>
18 
19 #include "tensorflow/lite/builtin_ops.h"
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/delegates/coreml/builders/op_factory.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 
26 namespace tflite {
27 namespace delegates {
28 namespace coreml {
29 
DebugName()30 const std::string& PoolingLayerBuilder::DebugName() {
31   if (!debug_name_.empty()) return debug_name_;
32   switch (pooling_type_) {
33     case kTfLiteBuiltinAveragePool2d:
34       SetDebugName("PoolingLayerBuilder (AVERAGE)", node_id_);
35       break;
36     case kTfLiteBuiltinMaxPool2d:
37       SetDebugName("PoolingLayerBuilder (MAX)", node_id_);
38       break;
39     case kTfLiteBuiltinL2Pool2d:
40       SetDebugName("PoolingLayerBuilder (L2, unsupported)", node_id_);
41       break;
42     case kTfLiteBuiltinMean:
43       SetDebugName("PoolingLayerBuilder (MEAN)", node_id_);
44       break;
45     default:
46       SetDebugName("PoolingLayerBuilder (ERROR)", node_id_);
47   }
48   return debug_name_;
49 }
50 
Build()51 CoreML::Specification::NeuralNetworkLayer* PoolingLayerBuilder::Build() {
52   layer_->set_name(DebugName());
53   auto* pooling_params = layer_->mutable_pooling();
54 
55   if (pooling_type_ == kTfLiteBuiltinMean) {
56     pooling_params->set_type(
57         CoreML::Specification::PoolingLayerParams::AVERAGE);
58     pooling_params->set_globalpooling(true);
59     return layer_.release();
60   }
61 
62   const TfLitePoolParams* params =
63       reinterpret_cast<const TfLitePoolParams*>(builtin_data_);
64   pooling_params->mutable_stride()->Add(params->stride_height);
65   pooling_params->mutable_stride()->Add(params->stride_width);
66   pooling_params->mutable_kernelsize()->Add(params->filter_height);
67   pooling_params->mutable_kernelsize()->Add(params->filter_width);
68 
69   if (params->padding == kTfLitePaddingSame) {
70     pooling_params->mutable_same();
71   } else {
72     pooling_params->mutable_valid();
73   }
74 
75   switch (pooling_type_) {
76     case kTfLiteBuiltinAveragePool2d:
77       pooling_params->set_type(
78           CoreML::Specification::PoolingLayerParams::AVERAGE);
79       pooling_params->set_avgpoolexcludepadding(true);
80       break;
81     case kTfLiteBuiltinMaxPool2d:
82       pooling_params->set_type(CoreML::Specification::PoolingLayerParams::MAX);
83       break;
84     case kTfLiteBuiltinL2Pool2d:
85       // TODO(b/145873272) implement L2 pooling
86       // NOLINTNEXTLINE: minimize absl usage
87       fprintf(stderr, "L2 pooling is not supported yet.\n");
88       return nullptr;
89     default:
90       // NOLINTNEXTLINE: minimize absl usage
91       fprintf(stderr, "Unexpected pooling type.\n");  // Should not reach here.
92       return nullptr;
93   }
94 
95   // TODO(b/145582958): Add padding values.
96   // TODO(b/145582958): Handle fused activation function.
97   return layer_.release();
98 }
99 
RegisterInputs(const TfLiteIntArray * inputs,TfLiteContext * context)100 TfLiteStatus PoolingLayerBuilder::RegisterInputs(const TfLiteIntArray* inputs,
101                                                  TfLiteContext* context) {
102   if (pooling_type_ == kTfLiteBuiltinMean) {
103     if (inputs->size != 2) {
104       TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to Mean!.");
105       return kTfLiteError;
106     }
107   } else if (inputs->size != 1) {
108     TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to Pooling!.");
109     return kTfLiteError;
110   }
111   AddInput(inputs->data[0]);
112   return kTfLiteOk;
113 }
114 
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)115 TfLiteStatus PoolingLayerBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
116                                                   TfLiteContext* context) {
117   if (outputs->size != 1) {
118     TF_LITE_KERNEL_LOG(context, "Wrong # of outputs to Pooling!.");
119     return kTfLiteError;
120   }
121   graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context));
122   return kTfLiteOk;
123 }
124 
CreateAveragePool2dOpBuilder(GraphBuilder * graph_builder)125 OpBuilder* CreateAveragePool2dOpBuilder(GraphBuilder* graph_builder) {
126   return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinAveragePool2d);
127 }
128 
CreateMaxPool2dOpBuilder(GraphBuilder * graph_builder)129 OpBuilder* CreateMaxPool2dOpBuilder(GraphBuilder* graph_builder) {
130   return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinMaxPool2d);
131 }
132 
CreateMeanOpBuilder(GraphBuilder * graph_builder)133 OpBuilder* CreateMeanOpBuilder(GraphBuilder* graph_builder) {
134   return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinMean);
135 }
136 
137 // Only supports averaging over H and W dimensions, as
IsMeanOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)138 bool IsMeanOpSupported(const TfLiteRegistration* registration,
139                        const TfLiteNode* node, TfLiteContext* context) {
140   const TfLiteTensor* input = GetInput(context, node, 0);
141   const TfLiteTensor* axis = GetInput(context, node, 1);
142   const auto* params =
143       reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
144 
145   if (!params->keep_dims) {
146     TF_LITE_KERNEL_LOG(context, "keep_dims should be true for Mean op.");
147     return false;
148   }
149   if (input->dims->size != 4) {
150     TF_LITE_KERNEL_LOG(context, "Mean op is only supported for 4D input.");
151     return false;
152   }
153   const int* axis_data = GetTensorData<int>(axis);
154   std::vector<bool> axis_mask = {false, true, true, false};
155   for (int i = 0; i < axis->dims->data[0]; ++i) {
156     if (!axis_mask[(axis_data[i] + 4) % 4]) {
157       TF_LITE_KERNEL_LOG(context,
158                          "Mean op should reduce for H and W dimensions.");
159       return false;
160     }
161   }
162   return true;
163 }
164 
165 }  // namespace coreml
166 }  // namespace delegates
167 }  // namespace tflite
168