xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/ops/unary_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA && GOOGLE_TENSORRT
17 
18 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
19 #include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h"
20 
21 namespace tensorflow {
22 namespace tensorrt {
23 namespace convert {
24 
UnaryOperationMap()25 const UnaryOperationMapType* UnaryOperationMap() {
26   static auto* const m =
27       new std::unordered_map<string, nvinfer1::UnaryOperation>({
28         {"Neg", nvinfer1::UnaryOperation::kNEG},
29             {"Exp", nvinfer1::UnaryOperation::kEXP},
30             {"Log", nvinfer1::UnaryOperation::kLOG},
31             {"Rsqrt", nvinfer1::UnaryOperation::kSQRT},
32             {"Sqrt", nvinfer1::UnaryOperation::kSQRT},
33             {"Abs", nvinfer1::UnaryOperation::kABS},
34             {"Reciprocal", nvinfer1::UnaryOperation::kRECIP},
35             {"Sin", nvinfer1::UnaryOperation::kSIN},
36             {"Cos", nvinfer1::UnaryOperation::kCOS},
37             {"Tan", nvinfer1::UnaryOperation::kTAN},
38             {"Sinh", nvinfer1::UnaryOperation::kSINH},
39             {"Cosh", nvinfer1::UnaryOperation::kCOSH},
40             {"Asin", nvinfer1::UnaryOperation::kASIN},
41             {"Acos", nvinfer1::UnaryOperation::kACOS},
42             {"Atan", nvinfer1::UnaryOperation::kATAN},
43             {"Asinh", nvinfer1::UnaryOperation::kASINH},
44             {"Acosh", nvinfer1::UnaryOperation::kACOSH},
45             {"Atanh", nvinfer1::UnaryOperation::kATANH},
46             {"Ceil", nvinfer1::UnaryOperation::kCEIL},
47             {"Floor", nvinfer1::UnaryOperation::kFLOOR},
48             {"Erf", nvinfer1::UnaryOperation::kERF},
49 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
50             {"Round", nvinfer1::UnaryOperation::kROUND},
51             {"Sign", nvinfer1::UnaryOperation::kSIGN},
52 #endif
53       });
54   return m;
55 }
56 
UnaryBooleanOperationMap()57 const UnaryOperationMapType* UnaryBooleanOperationMap() {
58   static auto* const m = new UnaryOperationMapType({
59       {"LogicalNot", nvinfer1::UnaryOperation::kNOT},
60   });
61   return m;
62 }
63 
ActivationTypeMap()64 const OperationMap<nvinfer1::ActivationType>* ActivationTypeMap() {
65   static auto* const m =
66       new std::unordered_map<string, nvinfer1::ActivationType>({
67           {"LeakyRelu", nvinfer1::ActivationType::kLEAKY_RELU},
68           {"Relu", nvinfer1::ActivationType::kRELU},
69           {"Relu6", nvinfer1::ActivationType::kCLIP},
70           {"Sigmoid", nvinfer1::ActivationType::kSIGMOID},
71           {"Tanh", nvinfer1::ActivationType::kTANH},
72           {"Elu", nvinfer1::ActivationType::kELU},
73           {"Selu", nvinfer1::ActivationType::kSELU},
74           {"Softsign", nvinfer1::ActivationType::kSOFTSIGN},
75           {"Softplus", nvinfer1::ActivationType::kSOFTPLUS},
76       });
77   return m;
78 }
79 
80 template <typename T>
81 class ConvertUnaryImpl {
82  protected:
ConvertUnaryImpl(const OperationMap<T> * pOperMap)83   ConvertUnaryImpl(const OperationMap<T>* pOperMap) : pOperMap_(pOperMap) {}
84 
ValidateImpl(const OpConverterParams & params,const std::vector<string> & not_supported_ops={})85   Status ValidateImpl(const OpConverterParams& params,
86                       const std::vector<string>& not_supported_ops = {}) {
87     const auto& op = params.node_def.op();
88     if (pOperMap_->find(op) == pOperMap_->end()) {
89       return errors::Unimplemented("Unary op: ", op, " not supported");
90     }
91     DimsAdapter input_dims(params.inputs.at(0).GetTrtDims());
92     if (!input_dims.NumDims()) {
93       return errors::InvalidArgument(
94           "At least 1 dimension is required for UNARY operation '", op, "'");
95     }
96 
97     if (!not_supported_ops.empty() && params.use_implicit_batch) {
98       const auto& end = not_supported_ops.end();
99       if (std::find(not_supported_ops.begin(), end, op) != end) {
100         return errors::Unimplemented(
101             "Unary op: '", op, "' is not supported in implicit batch mode");
102       }
103     }
104 
105     return Status::OK();
106   }
107 
ConvertImpl(const OpConverterParams & params)108   Status ConvertImpl(const OpConverterParams& params) {
109     const auto& node_def = params.node_def;
110     auto* converter = params.converter;
111     const auto op_pair = pOperMap_->find(node_def.op());
112     ITensorProxyPtr tensor = params.inputs.at(0).tensor();
113     nvinfer1::IUnaryLayer* layer =
114         converter->network()->addUnary(*tensor->trt_tensor(), op_pair->second);
115     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
116     converter->SetLayerName(layer, node_def);
117     if (node_def.op() == "Rsqrt") {
118       layer = converter->network()->addUnary(*layer->getOutput(0),
119                                              nvinfer1::UnaryOperation::kRECIP);
120       TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
121       converter->SetLayerName(layer, node_def, "recip");
122     }
123     params.outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
124     return Status::OK();
125   }
InputSpec()126   static constexpr std::array<InputArgSpec, 1> InputSpec() {
127     return std::array<InputArgSpec, 1>{
128         InputArgSpec::Create("x", TrtInputArg::kTensor)};
129   }
130 
131  protected:
132   const OperationMap<T>* pOperMap_;
133 };
134 
135 class ConvertUnary : public OpConverterBase<ConvertUnary>,
136                      protected ConvertUnaryImpl<nvinfer1::UnaryOperation> {
137  public:
ConvertUnary(OpConverterParams * params)138   explicit ConvertUnary(OpConverterParams* params)
139       : OpConverterBase<ConvertUnary>(params),
140         ConvertUnaryImpl(UnaryOperationMap()) {}
141 
AllowedDataTypes()142   static constexpr std::array<DataType, 2> AllowedDataTypes() {
143     return {DataType::DT_FLOAT, DataType::DT_HALF};
144   }
145 
InputSpec()146   static constexpr std::array<InputArgSpec, 1> InputSpec() {
147     return ConvertUnaryImpl::InputSpec();
148   }
149 
NodeDefDataTypeAttributeName()150   static constexpr const char* NodeDefDataTypeAttributeName() { return ""; }
Validate()151   Status Validate() { return ValidateImpl(*params_, {"Sign", "Round"}); }
Convert()152   Status Convert() { return ConvertImpl(*params_); }
153 };
154 
155 class ConvertBooleanUnary : public OpConverterBase<ConvertBooleanUnary>,
156                             public ConvertUnaryImpl<nvinfer1::UnaryOperation> {
157  public:
ConvertBooleanUnary(OpConverterParams * params)158   explicit ConvertBooleanUnary(OpConverterParams* params)
159       : OpConverterBase<ConvertBooleanUnary>(params),
160         ConvertUnaryImpl(UnaryBooleanOperationMap()) {}
161 
AllowedDataTypes()162   static constexpr std::array<DataType, 1> AllowedDataTypes() {
163     return {DataType::DT_BOOL};
164   }
165 
InputSpec()166   static constexpr std::array<InputArgSpec, 1> InputSpec() {
167     return ConvertUnaryImpl::InputSpec();
168   }
169 
NodeDefDataTypeAttributeName()170   static constexpr const char* NodeDefDataTypeAttributeName() { return ""; }
Validate()171   Status Validate() {
172 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
173     return ValidateImpl(*params_, {"LogicalNot"});
174 #else
175     return errors::Unimplemented("Boolean op: ", params_->node_def.op(),
176                                  " is not supported in TRT version < 8.2");
177 #endif
178   }
Convert()179   Status Convert() { return ConvertImpl(*params_); }
180 };
181 
182 class ConvertActivation : public OpConverterBase<ConvertActivation>,
183                           protected ConvertUnaryImpl<nvinfer1::ActivationType> {
184  public:
ConvertActivation(OpConverterParams * params)185   explicit ConvertActivation(OpConverterParams* params)
186       : OpConverterBase<ConvertActivation>(params),
187         ConvertUnaryImpl(ActivationTypeMap()) {}
188 
AllowedDataTypes()189   static constexpr std::array<DataType, 2> AllowedDataTypes() {
190     return {DataType::DT_FLOAT, DataType::DT_HALF};
191   }
192 
InputSpec()193   static constexpr std::array<InputArgSpec, 1> InputSpec() {
194     return std::array<InputArgSpec, 1>{
195         InputArgSpec::Create("input", TrtInputArg::kTensor)};
196   }
197 
NodeDefDataTypeAttributeName()198   static constexpr const char* NodeDefDataTypeAttributeName() { return ""; }
Validate()199   Status Validate() {
200     TF_RETURN_IF_ERROR(ValidateImpl(*params_));
201     const auto& node_def = params_->node_def;
202     if (node_def.op() == "LeakyRelu") {
203       return GetNodeAttr(AttrSlice(node_def), "alpha", &alpha_);
204     }
205     alpha_ = 1.0f;
206     return Status::OK();
207   }
Convert()208   Status Convert() {
209     auto* converter = params_->converter;
210     const auto& inputs = params_->inputs;
211     const auto& node_def = params_->node_def;
212     const auto& op = node_def.op();
213     const auto op_pair = pOperMap_->find(op);
214     nvinfer1::IActivationLayer* layer = converter->network()->addActivation(
215         *inputs.at(0).tensor()->trt_tensor(), op_pair->second);
216     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
217     converter->SetLayerName(layer, node_def, "activation");
218     ITensorProxyPtr output_tensor = layer->getOutput(0);
219     // Set parameters.
220     if (op == "Selu") {
221       // From tensorflow/core/kernels/relu_op_functor.h
222       alpha_ = 1.7580993408473768599402175208123f;
223       layer->setBeta(1.0507009873554804934193349852946f);
224     } else if (op == "Softplus") {
225       layer->setBeta(1.0f);
226     } else if (op == "Relu6") {
227       layer->setBeta(6.0f);
228       converter->ProvideQuantizationRange(&output_tensor, alpha_ = 0.0f, 6.0f);
229     }
230     layer->setAlpha(alpha_);
231     params_->outputs->push_back(TRT_TensorOrWeights(output_tensor));
232     return Status::OK();
233   }
234 
235  private:
236   float alpha_ = 0.f;
237 };
238 
239 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertUnary>(),
240                                   GetOperationNames(*UnaryOperationMap()));
241 REGISTER_DEFAULT_TRT_OP_CONVERTER(
242     MakeConverterFunction<ConvertBooleanUnary>(),
243     GetOperationNames(*UnaryBooleanOperationMap()));
244 
245 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertActivation>(),
246                                   GetOperationNames(*ActivationTypeMap()));
247 }  // namespace convert
248 }  // namespace tensorrt
249 }  // namespace tensorflow
250 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
251