1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA && GOOGLE_TENSORRT
17
18 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
19 #include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h"
20
21 namespace tensorflow {
22 namespace tensorrt {
23 namespace convert {
24
UnaryOperationMap()25 const UnaryOperationMapType* UnaryOperationMap() {
26 static auto* const m =
27 new std::unordered_map<string, nvinfer1::UnaryOperation>({
28 {"Neg", nvinfer1::UnaryOperation::kNEG},
29 {"Exp", nvinfer1::UnaryOperation::kEXP},
30 {"Log", nvinfer1::UnaryOperation::kLOG},
31 {"Rsqrt", nvinfer1::UnaryOperation::kSQRT},
32 {"Sqrt", nvinfer1::UnaryOperation::kSQRT},
33 {"Abs", nvinfer1::UnaryOperation::kABS},
34 {"Reciprocal", nvinfer1::UnaryOperation::kRECIP},
35 {"Sin", nvinfer1::UnaryOperation::kSIN},
36 {"Cos", nvinfer1::UnaryOperation::kCOS},
37 {"Tan", nvinfer1::UnaryOperation::kTAN},
38 {"Sinh", nvinfer1::UnaryOperation::kSINH},
39 {"Cosh", nvinfer1::UnaryOperation::kCOSH},
40 {"Asin", nvinfer1::UnaryOperation::kASIN},
41 {"Acos", nvinfer1::UnaryOperation::kACOS},
42 {"Atan", nvinfer1::UnaryOperation::kATAN},
43 {"Asinh", nvinfer1::UnaryOperation::kASINH},
44 {"Acosh", nvinfer1::UnaryOperation::kACOSH},
45 {"Atanh", nvinfer1::UnaryOperation::kATANH},
46 {"Ceil", nvinfer1::UnaryOperation::kCEIL},
47 {"Floor", nvinfer1::UnaryOperation::kFLOOR},
48 {"Erf", nvinfer1::UnaryOperation::kERF},
49 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
50 {"Round", nvinfer1::UnaryOperation::kROUND},
51 {"Sign", nvinfer1::UnaryOperation::kSIGN},
52 #endif
53 });
54 return m;
55 }
56
UnaryBooleanOperationMap()57 const UnaryOperationMapType* UnaryBooleanOperationMap() {
58 static auto* const m = new UnaryOperationMapType({
59 {"LogicalNot", nvinfer1::UnaryOperation::kNOT},
60 });
61 return m;
62 }
63
ActivationTypeMap()64 const OperationMap<nvinfer1::ActivationType>* ActivationTypeMap() {
65 static auto* const m =
66 new std::unordered_map<string, nvinfer1::ActivationType>({
67 {"LeakyRelu", nvinfer1::ActivationType::kLEAKY_RELU},
68 {"Relu", nvinfer1::ActivationType::kRELU},
69 {"Relu6", nvinfer1::ActivationType::kCLIP},
70 {"Sigmoid", nvinfer1::ActivationType::kSIGMOID},
71 {"Tanh", nvinfer1::ActivationType::kTANH},
72 {"Elu", nvinfer1::ActivationType::kELU},
73 {"Selu", nvinfer1::ActivationType::kSELU},
74 {"Softsign", nvinfer1::ActivationType::kSOFTSIGN},
75 {"Softplus", nvinfer1::ActivationType::kSOFTPLUS},
76 });
77 return m;
78 }
79
80 template <typename T>
81 class ConvertUnaryImpl {
82 protected:
ConvertUnaryImpl(const OperationMap<T> * pOperMap)83 ConvertUnaryImpl(const OperationMap<T>* pOperMap) : pOperMap_(pOperMap) {}
84
ValidateImpl(const OpConverterParams & params,const std::vector<string> & not_supported_ops={})85 Status ValidateImpl(const OpConverterParams& params,
86 const std::vector<string>& not_supported_ops = {}) {
87 const auto& op = params.node_def.op();
88 if (pOperMap_->find(op) == pOperMap_->end()) {
89 return errors::Unimplemented("Unary op: ", op, " not supported");
90 }
91 DimsAdapter input_dims(params.inputs.at(0).GetTrtDims());
92 if (!input_dims.NumDims()) {
93 return errors::InvalidArgument(
94 "At least 1 dimension is required for UNARY operation '", op, "'");
95 }
96
97 if (!not_supported_ops.empty() && params.use_implicit_batch) {
98 const auto& end = not_supported_ops.end();
99 if (std::find(not_supported_ops.begin(), end, op) != end) {
100 return errors::Unimplemented(
101 "Unary op: '", op, "' is not supported in implicit batch mode");
102 }
103 }
104
105 return Status::OK();
106 }
107
ConvertImpl(const OpConverterParams & params)108 Status ConvertImpl(const OpConverterParams& params) {
109 const auto& node_def = params.node_def;
110 auto* converter = params.converter;
111 const auto op_pair = pOperMap_->find(node_def.op());
112 ITensorProxyPtr tensor = params.inputs.at(0).tensor();
113 nvinfer1::IUnaryLayer* layer =
114 converter->network()->addUnary(*tensor->trt_tensor(), op_pair->second);
115 TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
116 converter->SetLayerName(layer, node_def);
117 if (node_def.op() == "Rsqrt") {
118 layer = converter->network()->addUnary(*layer->getOutput(0),
119 nvinfer1::UnaryOperation::kRECIP);
120 TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
121 converter->SetLayerName(layer, node_def, "recip");
122 }
123 params.outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
124 return Status::OK();
125 }
InputSpec()126 static constexpr std::array<InputArgSpec, 1> InputSpec() {
127 return std::array<InputArgSpec, 1>{
128 InputArgSpec::Create("x", TrtInputArg::kTensor)};
129 }
130
131 protected:
132 const OperationMap<T>* pOperMap_;
133 };
134
135 class ConvertUnary : public OpConverterBase<ConvertUnary>,
136 protected ConvertUnaryImpl<nvinfer1::UnaryOperation> {
137 public:
ConvertUnary(OpConverterParams * params)138 explicit ConvertUnary(OpConverterParams* params)
139 : OpConverterBase<ConvertUnary>(params),
140 ConvertUnaryImpl(UnaryOperationMap()) {}
141
AllowedDataTypes()142 static constexpr std::array<DataType, 2> AllowedDataTypes() {
143 return {DataType::DT_FLOAT, DataType::DT_HALF};
144 }
145
InputSpec()146 static constexpr std::array<InputArgSpec, 1> InputSpec() {
147 return ConvertUnaryImpl::InputSpec();
148 }
149
NodeDefDataTypeAttributeName()150 static constexpr const char* NodeDefDataTypeAttributeName() { return ""; }
Validate()151 Status Validate() { return ValidateImpl(*params_, {"Sign", "Round"}); }
Convert()152 Status Convert() { return ConvertImpl(*params_); }
153 };
154
155 class ConvertBooleanUnary : public OpConverterBase<ConvertBooleanUnary>,
156 public ConvertUnaryImpl<nvinfer1::UnaryOperation> {
157 public:
ConvertBooleanUnary(OpConverterParams * params)158 explicit ConvertBooleanUnary(OpConverterParams* params)
159 : OpConverterBase<ConvertBooleanUnary>(params),
160 ConvertUnaryImpl(UnaryBooleanOperationMap()) {}
161
AllowedDataTypes()162 static constexpr std::array<DataType, 1> AllowedDataTypes() {
163 return {DataType::DT_BOOL};
164 }
165
InputSpec()166 static constexpr std::array<InputArgSpec, 1> InputSpec() {
167 return ConvertUnaryImpl::InputSpec();
168 }
169
NodeDefDataTypeAttributeName()170 static constexpr const char* NodeDefDataTypeAttributeName() { return ""; }
Validate()171 Status Validate() {
172 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
173 return ValidateImpl(*params_, {"LogicalNot"});
174 #else
175 return errors::Unimplemented("Boolean op: ", params_->node_def.op(),
176 " is not supported in TRT version < 8.2");
177 #endif
178 }
Convert()179 Status Convert() { return ConvertImpl(*params_); }
180 };
181
182 class ConvertActivation : public OpConverterBase<ConvertActivation>,
183 protected ConvertUnaryImpl<nvinfer1::ActivationType> {
184 public:
ConvertActivation(OpConverterParams * params)185 explicit ConvertActivation(OpConverterParams* params)
186 : OpConverterBase<ConvertActivation>(params),
187 ConvertUnaryImpl(ActivationTypeMap()) {}
188
AllowedDataTypes()189 static constexpr std::array<DataType, 2> AllowedDataTypes() {
190 return {DataType::DT_FLOAT, DataType::DT_HALF};
191 }
192
InputSpec()193 static constexpr std::array<InputArgSpec, 1> InputSpec() {
194 return std::array<InputArgSpec, 1>{
195 InputArgSpec::Create("input", TrtInputArg::kTensor)};
196 }
197
NodeDefDataTypeAttributeName()198 static constexpr const char* NodeDefDataTypeAttributeName() { return ""; }
Validate()199 Status Validate() {
200 TF_RETURN_IF_ERROR(ValidateImpl(*params_));
201 const auto& node_def = params_->node_def;
202 if (node_def.op() == "LeakyRelu") {
203 return GetNodeAttr(AttrSlice(node_def), "alpha", &alpha_);
204 }
205 alpha_ = 1.0f;
206 return Status::OK();
207 }
Convert()208 Status Convert() {
209 auto* converter = params_->converter;
210 const auto& inputs = params_->inputs;
211 const auto& node_def = params_->node_def;
212 const auto& op = node_def.op();
213 const auto op_pair = pOperMap_->find(op);
214 nvinfer1::IActivationLayer* layer = converter->network()->addActivation(
215 *inputs.at(0).tensor()->trt_tensor(), op_pair->second);
216 TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
217 converter->SetLayerName(layer, node_def, "activation");
218 ITensorProxyPtr output_tensor = layer->getOutput(0);
219 // Set parameters.
220 if (op == "Selu") {
221 // From tensorflow/core/kernels/relu_op_functor.h
222 alpha_ = 1.7580993408473768599402175208123f;
223 layer->setBeta(1.0507009873554804934193349852946f);
224 } else if (op == "Softplus") {
225 layer->setBeta(1.0f);
226 } else if (op == "Relu6") {
227 layer->setBeta(6.0f);
228 converter->ProvideQuantizationRange(&output_tensor, alpha_ = 0.0f, 6.0f);
229 }
230 layer->setAlpha(alpha_);
231 params_->outputs->push_back(TRT_TensorOrWeights(output_tensor));
232 return Status::OK();
233 }
234
235 private:
236 float alpha_ = 0.f;
237 };
238
239 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertUnary>(),
240 GetOperationNames(*UnaryOperationMap()));
241 REGISTER_DEFAULT_TRT_OP_CONVERTER(
242 MakeConverterFunction<ConvertBooleanUnary>(),
243 GetOperationNames(*UnaryBooleanOperationMap()));
244
245 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertActivation>(),
246 GetOperationNames(*ActivationTypeMap()));
247 } // namespace convert
248 } // namespace tensorrt
249 } // namespace tensorflow
250 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
251