xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/coreml/builders/activation_layer_builder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/coreml/builders/activation_layer_builder.h"
16 
17 #include <string>
18 
19 #include "tensorflow/lite/builtin_ops.h"
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/delegates/coreml/builders/op_factory.h"
23 #include "tensorflow/lite/delegates/coreml/builders/threshold_layer_builder.h"
24 
25 namespace tflite {
26 namespace delegates {
27 namespace coreml {
28 
DebugName()29 const std::string& ActivationLayerBuilder::DebugName() {
30   if (debug_name_.empty()) SetDebugName("ActivationLayerBuilder", node_id_);
31   return debug_name_;
32 }
33 
Build()34 CoreML::Specification::NeuralNetworkLayer* ActivationLayerBuilder::Build() {
35   layer_->set_name(DebugName());
36   switch (activation_) {
37     // ActNone is used for sclalar multiplication (linear activation)
38     case kTfLiteActNone:
39       layer_->mutable_activation()->mutable_linear()->set_alpha(alpha_);
40       break;
41     case kTfLiteActRelu:
42       layer_->mutable_activation()->mutable_relu();
43       break;
44     // Relu1 and Relu6 layers are fully composed in PopulateSubgraph().
45     case kTfLiteActReluN1To1:  // clip(-1, 1)
46       layer_->mutable_unary()->set_alpha(-1);
47       layer_->mutable_unary()->set_type(
48           CoreML::Specification::UnaryFunctionLayerParams::THRESHOLD);
49       break;
50     case kTfLiteActRelu6:  // clip(0, 6)
51       layer_->mutable_activation()->mutable_relu();
52       break;
53     case kTfLiteActTanh:
54       layer_->mutable_activation()->mutable_tanh();
55       break;
56     case kTfLiteActSigmoid:
57       layer_->mutable_activation()->mutable_sigmoid();
58       break;
59     // TODO(taeheej): signbit is not implemented.
60     default:
61       fprintf(stderr, "Activation %d is not supported.\n", activation_);
62       break;
63   }
64   return layer_.release();
65 }
66 
PopulateSubgraph(TfLiteContext * context)67 TfLiteStatus ActivationLayerBuilder::PopulateSubgraph(TfLiteContext* context) {
68   if (!(activation_ == kTfLiteActRelu6 || activation_ == kTfLiteActReluN1To1)) {
69     builder_output_ = AddOutput();
70     return kTfLiteOk;
71   }
72 
73   // Relu1: Threshold(-1) -> Threshold(-1) with scale: -1 -> Negation
74   // Relu6: ReLU -> Threshold(-6) with scale: -1 -> Negation
75   const int relu_threshold = activation_ == kTfLiteActRelu6 ? 6 : 1;
76   ThresholdLayerBuilder* threshold_builder =
77       reinterpret_cast<ThresholdLayerBuilder*>(
78           graph_builder_->AddBuilder(CreateThresholdLayerBuilder, nullptr));
79 
80   threshold_builder->SetAlpha(-relu_threshold);
81   threshold_builder->SetScale(-1);
82 
83   threshold_builder->AddInput(AddOutput());
84 
85   ActivationLayerBuilder* negation_builder =
86       reinterpret_cast<ActivationLayerBuilder*>(
87           graph_builder_->AddBuilder(CreateActivationLayerBuilder, nullptr));
88   negation_builder->SetActivation(kTfLiteActNone);
89   negation_builder->SetAlpha(-1);
90 
91   negation_builder->AddInput(threshold_builder->AddOutput());
92   builder_output_ = negation_builder->AddOutput();
93   return kTfLiteOk;
94 }
95 
RegisterInputs(const TfLiteIntArray * inputs,TfLiteContext * context)96 TfLiteStatus ActivationLayerBuilder::RegisterInputs(
97     const TfLiteIntArray* inputs, TfLiteContext* context) {
98   if (inputs->size != 1) {
99     TF_LITE_KERNEL_LOG(context, "Activation: Wrong # of inputs!.");
100     return kTfLiteError;
101   }
102   AddInput(inputs->data[0]);
103   return kTfLiteOk;
104 }
105 
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)106 TfLiteStatus ActivationLayerBuilder::RegisterOutputs(
107     const TfLiteIntArray* outputs, TfLiteContext* context) {
108   if (outputs->size != 1) {
109     TF_LITE_KERNEL_LOG(context, "Activation: Wrong # of outputs!.");
110     return kTfLiteError;
111   }
112   graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context));
113   return kTfLiteOk;
114 }
115 
CreateActivationLayerBuilder(GraphBuilder * graph_builder)116 OpBuilder* CreateActivationLayerBuilder(GraphBuilder* graph_builder) {
117   return new ActivationLayerBuilder(graph_builder);
118 }
119 
CreateLogisticOpBuilder(GraphBuilder * graph_builder)120 OpBuilder* CreateLogisticOpBuilder(GraphBuilder* graph_builder) {
121   return new ActivationLayerBuilder(graph_builder, kTfLiteActSigmoid);
122 }
123 
CreateReluOpBuilder(GraphBuilder * graph_builder)124 OpBuilder* CreateReluOpBuilder(GraphBuilder* graph_builder) {
125   return new ActivationLayerBuilder(graph_builder, kTfLiteActRelu);
126 }
127 
CreateReluN1To1OpBuilder(GraphBuilder * graph_builder)128 OpBuilder* CreateReluN1To1OpBuilder(GraphBuilder* graph_builder) {
129   return new ActivationLayerBuilder(graph_builder, kTfLiteActReluN1To1);
130 }
131 
CreateRelu6OpBuilder(GraphBuilder * graph_builder)132 OpBuilder* CreateRelu6OpBuilder(GraphBuilder* graph_builder) {
133   return new ActivationLayerBuilder(graph_builder, kTfLiteActRelu6);
134 }
135 
CreateTanhOpBuilder(GraphBuilder * graph_builder)136 OpBuilder* CreateTanhOpBuilder(GraphBuilder* graph_builder) {
137   return new ActivationLayerBuilder(graph_builder, kTfLiteActTanh);
138 }
139 
140 }  // namespace coreml
141 }  // namespace delegates
142 }  // namespace tflite
143