xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/coreml/builders/pad_op_builder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/coreml/builders/pad_op_builder.h"
16 
17 #include <string>
18 
19 #include "tensorflow/lite/builtin_ops.h"
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/delegates/coreml/builders/op_factory.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 
26 namespace tflite {
27 namespace delegates {
28 namespace coreml {
29 
DebugName()30 const std::string& PadOpBuilder::DebugName() {
31   if (!debug_name_.empty()) return debug_name_;
32   SetDebugName(padding_type_ == PadType::kPad ? "PadOpBuilder (PAD)"
33                                               : "PadOpBuilder (MIRROR_PAD)",
34                node_id_);
35   return debug_name_;
36 }
37 
Build()38 CoreML::Specification::NeuralNetworkLayer* PadOpBuilder::Build() {
39   layer_->set_name(DebugName());
40   if (padding_type_ == PadType::kPad) {
41     layer_->mutable_padding()->mutable_constant();
42   } else if (padding_type_ == PadType::kMirrorPad) {
43     layer_->mutable_padding()->mutable_reflection();
44   }
45   return layer_.release();
46 }
47 
48 // padding is d x 2 tensor, where d is the dimension of input.
49 // only paddings for width and height are considered.
SetPadding(const TfLiteTensor * padding)50 void PadOpBuilder::SetPadding(const TfLiteTensor* padding) {
51   const int32_t* padding_data = GetTensorData<int32_t>(padding);
52   for (int i = 1; i <= 2; ++i) {
53     auto* borderamount = layer_->mutable_padding()
54                              ->mutable_paddingamounts()
55                              ->add_borderamounts();
56     borderamount->set_startedgesize(padding_data[i * 2]);
57     borderamount->set_endedgesize(padding_data[i * 2 + 1]);
58   }
59 }
60 
SetConstantValue(const TfLiteTensor * constant_value)61 void PadOpBuilder::SetConstantValue(const TfLiteTensor* constant_value) {
62   layer_->mutable_padding()->mutable_constant()->set_value(
63       GetTensorData<float>(constant_value)[0]);
64 }
65 
RegisterInputs(const TfLiteIntArray * inputs,TfLiteContext * context)66 TfLiteStatus PadOpBuilder::RegisterInputs(const TfLiteIntArray* inputs,
67                                           TfLiteContext* context) {
68   if (!(inputs->size == 2 || inputs->size == 3)) {
69     TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to Padding!.");
70     return kTfLiteError;
71   }
72   AddInput(inputs->data[0]);
73   SetPadding(GetInput(context, tflite_node_, 1));
74   if (inputs->size == 3) {
75     SetConstantValue(GetInput(context, tflite_node_, 2));
76   }
77 
78   return kTfLiteOk;
79 }
80 
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)81 TfLiteStatus PadOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
82                                            TfLiteContext* context) {
83   if (outputs->size != 1) {
84     TF_LITE_KERNEL_LOG(context, "Wrong # of outputs to Padding!.");
85     return kTfLiteError;
86   }
87   graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context));
88   return kTfLiteOk;
89 }
90 
CreatePadOpBuilder(GraphBuilder * graph_builder)91 OpBuilder* CreatePadOpBuilder(GraphBuilder* graph_builder) {
92   return new PadOpBuilder(graph_builder, PadType::kPad);
93 }
94 
CreateMirrorPadOpBuilder(GraphBuilder * graph_builder)95 OpBuilder* CreateMirrorPadOpBuilder(GraphBuilder* graph_builder) {
96   return new PadOpBuilder(graph_builder, PadType::kMirrorPad);
97 }
98 
IsPadOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)99 bool IsPadOpSupported(const TfLiteRegistration* registration,
100                       const TfLiteNode* node, TfLiteContext* context) {
101   // padding is d x 2 tensor, where d is the dimension of input.
102   const TfLiteTensor* padding;
103   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &padding));
104   if (!IsConstantTensor(padding)) {
105     TF_LITE_KERNEL_LOG(context,
106                        "%s: Only constant padding is supported for PAD.",
107                        padding->name);
108     return false;
109   }
110   if (padding->dims->data[0] != 4 || padding->dims->data[1] != 2) {
111     TF_LITE_KERNEL_LOG(context, "%s: Only 4D inputs are supported for PAD.",
112                        padding->name);
113     return false;
114   }
115   const int32_t* padding_data = GetTensorData<int32_t>(padding);
116   if (!(padding_data[0] == 0 && padding_data[1] == 0)) {
117     TF_LITE_KERNEL_LOG(
118         context, "%s: Padding for batch dimension is not supported in PAD.",
119         padding->name);
120     return false;
121   }
122 
123   if (!(padding_data[6] == 0 && padding_data[7] == 0)) {
124     TF_LITE_KERNEL_LOG(
125         context, "%s: Padding for channel dimension is not supported in PAD.",
126         padding->name);
127     return false;
128   }
129   return true;
130 }
131 
IsMirrorPadOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)132 bool IsMirrorPadOpSupported(const TfLiteRegistration* registration,
133                             const TfLiteNode* node, TfLiteContext* context) {
134   auto* params =
135       reinterpret_cast<TfLiteMirrorPaddingParams*>(node->builtin_data);
136   if (params->mode != kTfLiteMirrorPaddingReflect) {
137     TF_LITE_KERNEL_LOG(context,
138                        "Only REFLECT mode is supported for MIRROR_PAD.");
139     return false;
140   }
141   return IsPadOpSupported(registration, node, context);
142 }
143 
144 }  // namespace coreml
145 }  // namespace delegates
146 }  // namespace tflite
147