1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/coreml/builders/pooling_layer_builder.h"
16
17 #include <string>
18
19 #include "tensorflow/lite/builtin_ops.h"
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/delegates/coreml/builders/op_factory.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25
26 namespace tflite {
27 namespace delegates {
28 namespace coreml {
29
DebugName()30 const std::string& PoolingLayerBuilder::DebugName() {
31 if (!debug_name_.empty()) return debug_name_;
32 switch (pooling_type_) {
33 case kTfLiteBuiltinAveragePool2d:
34 SetDebugName("PoolingLayerBuilder (AVERAGE)", node_id_);
35 break;
36 case kTfLiteBuiltinMaxPool2d:
37 SetDebugName("PoolingLayerBuilder (MAX)", node_id_);
38 break;
39 case kTfLiteBuiltinL2Pool2d:
40 SetDebugName("PoolingLayerBuilder (L2, unsupported)", node_id_);
41 break;
42 case kTfLiteBuiltinMean:
43 SetDebugName("PoolingLayerBuilder (MEAN)", node_id_);
44 break;
45 default:
46 SetDebugName("PoolingLayerBuilder (ERROR)", node_id_);
47 }
48 return debug_name_;
49 }
50
Build()51 CoreML::Specification::NeuralNetworkLayer* PoolingLayerBuilder::Build() {
52 layer_->set_name(DebugName());
53 auto* pooling_params = layer_->mutable_pooling();
54
55 if (pooling_type_ == kTfLiteBuiltinMean) {
56 pooling_params->set_type(
57 CoreML::Specification::PoolingLayerParams::AVERAGE);
58 pooling_params->set_globalpooling(true);
59 return layer_.release();
60 }
61
62 const TfLitePoolParams* params =
63 reinterpret_cast<const TfLitePoolParams*>(builtin_data_);
64 pooling_params->mutable_stride()->Add(params->stride_height);
65 pooling_params->mutable_stride()->Add(params->stride_width);
66 pooling_params->mutable_kernelsize()->Add(params->filter_height);
67 pooling_params->mutable_kernelsize()->Add(params->filter_width);
68
69 if (params->padding == kTfLitePaddingSame) {
70 pooling_params->mutable_same();
71 } else {
72 pooling_params->mutable_valid();
73 }
74
75 switch (pooling_type_) {
76 case kTfLiteBuiltinAveragePool2d:
77 pooling_params->set_type(
78 CoreML::Specification::PoolingLayerParams::AVERAGE);
79 pooling_params->set_avgpoolexcludepadding(true);
80 break;
81 case kTfLiteBuiltinMaxPool2d:
82 pooling_params->set_type(CoreML::Specification::PoolingLayerParams::MAX);
83 break;
84 case kTfLiteBuiltinL2Pool2d:
85 // TODO(b/145873272) implement L2 pooling
86 // NOLINTNEXTLINE: minimize absl usage
87 fprintf(stderr, "L2 pooling is not supported yet.\n");
88 return nullptr;
89 default:
90 // NOLINTNEXTLINE: minimize absl usage
91 fprintf(stderr, "Unexpected pooling type.\n"); // Should not reach here.
92 return nullptr;
93 }
94
95 // TODO(b/145582958): Add padding values.
96 // TODO(b/145582958): Handle fused activation function.
97 return layer_.release();
98 }
99
RegisterInputs(const TfLiteIntArray * inputs,TfLiteContext * context)100 TfLiteStatus PoolingLayerBuilder::RegisterInputs(const TfLiteIntArray* inputs,
101 TfLiteContext* context) {
102 if (pooling_type_ == kTfLiteBuiltinMean) {
103 if (inputs->size != 2) {
104 TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to Mean!.");
105 return kTfLiteError;
106 }
107 } else if (inputs->size != 1) {
108 TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to Pooling!.");
109 return kTfLiteError;
110 }
111 AddInput(inputs->data[0]);
112 return kTfLiteOk;
113 }
114
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)115 TfLiteStatus PoolingLayerBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
116 TfLiteContext* context) {
117 if (outputs->size != 1) {
118 TF_LITE_KERNEL_LOG(context, "Wrong # of outputs to Pooling!.");
119 return kTfLiteError;
120 }
121 graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context));
122 return kTfLiteOk;
123 }
124
CreateAveragePool2dOpBuilder(GraphBuilder * graph_builder)125 OpBuilder* CreateAveragePool2dOpBuilder(GraphBuilder* graph_builder) {
126 return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinAveragePool2d);
127 }
128
CreateMaxPool2dOpBuilder(GraphBuilder * graph_builder)129 OpBuilder* CreateMaxPool2dOpBuilder(GraphBuilder* graph_builder) {
130 return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinMaxPool2d);
131 }
132
CreateMeanOpBuilder(GraphBuilder * graph_builder)133 OpBuilder* CreateMeanOpBuilder(GraphBuilder* graph_builder) {
134 return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinMean);
135 }
136
137 // Only supports averaging over H and W dimensions, as
IsMeanOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)138 bool IsMeanOpSupported(const TfLiteRegistration* registration,
139 const TfLiteNode* node, TfLiteContext* context) {
140 const TfLiteTensor* input = GetInput(context, node, 0);
141 const TfLiteTensor* axis = GetInput(context, node, 1);
142 const auto* params =
143 reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
144
145 if (!params->keep_dims) {
146 TF_LITE_KERNEL_LOG(context, "keep_dims should be true for Mean op.");
147 return false;
148 }
149 if (input->dims->size != 4) {
150 TF_LITE_KERNEL_LOG(context, "Mean op is only supported for 4D input.");
151 return false;
152 }
153 const int* axis_data = GetTensorData<int>(axis);
154 std::vector<bool> axis_mask = {false, true, true, false};
155 for (int i = 0; i < axis->dims->data[0]; ++i) {
156 if (!axis_mask[(axis_data[i] + 4) % 4]) {
157 TF_LITE_KERNEL_LOG(context,
158 "Mean op should reduce for H and W dimensions.");
159 return false;
160 }
161 }
162 return true;
163 }
164
165 } // namespace coreml
166 } // namespace delegates
167 } // namespace tflite
168