1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA && GOOGLE_TENSORRT
17
18 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
19 #include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h"
20
21 namespace tensorflow {
22 namespace tensorrt {
23 namespace convert {
24
BinaryOperationMap()25 const BinaryOperationMapType* BinaryOperationMap() {
26 static const auto* map = new BinaryOperationMapType({
27 {"Add", nvinfer1::ElementWiseOperation::kSUM},
28 {"AddV2", nvinfer1::ElementWiseOperation::kSUM},
29 {"Mul", nvinfer1::ElementWiseOperation::kPROD},
30 {"Sub", nvinfer1::ElementWiseOperation::kSUB},
31 {"Div", nvinfer1::ElementWiseOperation::kDIV},
32 {"FloorDiv", nvinfer1::ElementWiseOperation::kFLOOR_DIV},
33 {"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
34 {"Minimum", nvinfer1::ElementWiseOperation::kMIN},
35 {"Maximum", nvinfer1::ElementWiseOperation::kMAX},
36 {"Pow", nvinfer1::ElementWiseOperation::kPOW},
37 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
38 {"Greater", nvinfer1::ElementWiseOperation::kGREATER},
39 {"Less", nvinfer1::ElementWiseOperation::kLESS},
40 {"Equal", nvinfer1::ElementWiseOperation::kEQUAL},
41 // Operators are implemented as NOT Less and NOT Greater, respectively.
42 {"GreaterEqual", nvinfer1::ElementWiseOperation::kLESS},
43 {"LessEqual", nvinfer1::ElementWiseOperation::kGREATER},
44 #endif
45 });
46 return map;
47 }
48
BinaryBooleanOperationMap()49 const BinaryOperationMapType* BinaryBooleanOperationMap() {
50 static const auto* map = new BinaryOperationMapType({
51 {"LogicalOr", nvinfer1::ElementWiseOperation::kOR},
52 {"LogicalAnd", nvinfer1::ElementWiseOperation::kAND},
53 });
54 return map;
55 }
56
57 namespace {
58 class ConvertBinaryImpl {
59 protected:
ConvertBinaryImpl(const BinaryOperationMapType * pOperMap)60 ConvertBinaryImpl(const BinaryOperationMapType* pOperMap)
61 : pOperMap_(pOperMap) {}
62
ValidateImpl(const OpConverterParams & params,const std::vector<string> & implicit_batch_not_supported_ops={},bool both_tensors=false)63 Status ValidateImpl(
64 const OpConverterParams& params,
65 const std::vector<string>& implicit_batch_not_supported_ops = {},
66 bool both_tensors = false) {
67 const auto& node_def = params.node_def;
68 const auto op = node_def.op();
69 const auto op_pair = pOperMap_->find(op);
70 if (op_pair == pOperMap_->end()) {
71 return errors::Unimplemented("Binary op: ", op, " not supported");
72 }
73
74 // Constant folding should have been done by TensorFlow.
75 const auto& inputs = params.inputs;
76 if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
77 return errors::Unimplemented(
78 "Constant folding is falled back to TensorFlow, binary op '", op,
79 "' received both input as constant");
80 }
81
82 if ((convertToBool_ = find_name(op, implicit_batch_not_supported_ops))) {
83 if (params.use_implicit_batch) {
84 return errors::Unimplemented(
85 "Binary op: '", op, "' is not supported in implicit batch mode");
86 }
87 }
88
89 if (both_tensors) {
90 if (inputs.at(0).is_weights() || inputs.at(1).is_weights()) {
91 return errors::InvalidArgument("Both inputs of '", op,
92 "' are expected to be tensors");
93 }
94 // No need to convert the output of "LogicalOr" and "LogicalAnd"
95 convertToBool_ = false;
96 }
97
98 nvinfer1::Dims broadcasted_dims[2];
99 TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
100 inputs.at(0), inputs.at(1), true, params.use_implicit_batch,
101 broadcasted_dims, broadcasted_dims + 1));
102
103 for (int i = 0; i < tensor_.size(); i++) {
104 // This will also convert constants to tensors.
105 TF_RETURN_IF_ERROR(PrepareTensorForShape(
106 params.converter, inputs.at(i), broadcasted_dims[i],
107 params.validation_only, &tensor_[i], node_def, i));
108 }
109 operation_ = op_pair->second;
110 return Status::OK();
111 }
112
ConvertImpl(const OpConverterParams & params,const std::vector<string> & revert_bool_ops={})113 Status ConvertImpl(const OpConverterParams& params,
114 const std::vector<string>& revert_bool_ops = {}) {
115 const auto& node_def = params.node_def;
116 // Add ElementWise layer.
117 auto* network = params.converter->network();
118 nvinfer1::ILayer* layer = network->addElementWise(
119 *tensor_[0]->trt_tensor(), *tensor_[1]->trt_tensor(), operation_);
120 TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
121
122 if (params.use_explicit_precision) {
123 layer->setPrecision(nvinfer1::DataType::kFLOAT);
124 }
125
126 params.converter->SetLayerName(layer, node_def);
127 const auto& output = layer->getOutput(0);
128 if (convertToBool_) {
129 output->setType(nvinfer1::DataType::kBOOL);
130 if (find_name(node_def.op(), revert_bool_ops)) {
131 nvinfer1::IUnaryLayer* unary_layer =
132 network->addUnary(*output, nvinfer1::UnaryOperation::kNOT);
133 TFTRT_RETURN_ERROR_IF_NULLPTR(unary_layer, node_def.name());
134 params.outputs->push_back(
135 TRT_TensorOrWeights(unary_layer->getOutput(0)));
136 return Status::OK();
137 }
138 }
139
140 params.outputs->push_back(TRT_TensorOrWeights(output));
141 return Status::OK();
142 }
143
InputSpec()144 static constexpr std::array<InputArgSpec, 2> InputSpec() {
145 return std::array<InputArgSpec, 2>{
146 InputArgSpec::Create("x", TrtInputArg::kBoth),
147 InputArgSpec::Create("y", TrtInputArg::kBoth)};
148 }
149
150 private:
151 const BinaryOperationMapType* pOperMap_;
152 std::array<ITensorProxyPtr, 2> tensor_{nullptr, nullptr};
153 nvinfer1::ElementWiseOperation operation_;
154 bool convertToBool_;
155 };
156
157 class ConvertBinary : public OpConverterBase<ConvertBinary>,
158 protected ConvertBinaryImpl {
159 public:
ConvertBinary(OpConverterParams * params)160 explicit ConvertBinary(OpConverterParams* params)
161 : OpConverterBase<ConvertBinary>(params),
162 ConvertBinaryImpl(BinaryOperationMap()) {}
163
AllowedDataTypes()164 static constexpr std::array<DataType, 3> AllowedDataTypes() {
165 return {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32};
166 }
167
InputSpec()168 static constexpr std::array<InputArgSpec, 2> InputSpec() {
169 return ConvertBinaryImpl::InputSpec();
170 }
171
Validate()172 Status Validate() {
173 const std::vector<string> implicit_batch_not_supported_ops {
174 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
175 "Greater", "Less", "Equal", "GreaterEqual", "LessEqual"
176 #endif
177 };
178 return ValidateImpl(*params_, implicit_batch_not_supported_ops);
179 }
Convert()180 Status Convert() {
181 const std::vector<string> implemented_with_reverted_ops {
182 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
183 "GreaterEqual", "LessEqual"
184 #endif
185 };
186 return ConvertImpl(*params_, implemented_with_reverted_ops);
187 }
188 };
189
190 class ConvertBooleanBinary : public OpConverterBase<ConvertBooleanBinary>,
191 public ConvertBinaryImpl {
192 public:
ConvertBooleanBinary(OpConverterParams * params)193 explicit ConvertBooleanBinary(OpConverterParams* params)
194 : OpConverterBase<ConvertBooleanBinary>(params),
195 ConvertBinaryImpl(BinaryBooleanOperationMap()) {}
196
AllowedDataTypes()197 static constexpr std::array<DataType, 1> AllowedDataTypes() {
198 return {DataType::DT_BOOL};
199 }
200
InputSpec()201 static constexpr std::array<InputArgSpec, 2> InputSpec() {
202 return ConvertBinaryImpl::InputSpec();
203 }
204
NodeDefDataTypeAttributeName()205 static constexpr const char* NodeDefDataTypeAttributeName() { return ""; }
Validate()206 Status Validate() {
207 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
208 return ValidateImpl(*params_, {"LogicalOr", "LogicalAnd"}, true);
209 #else
210 return errors::Unimplemented("Boolean op: ", params_->node_def.op(),
211 " is not supported in TRT version < 8.2");
212 #endif
213 }
Convert()214 Status Convert() { return ConvertImpl(*params_); }
215 };
216 } // namespace
217
218 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertBinary>(),
219 GetOperationNames(*BinaryOperationMap()));
220 REGISTER_DEFAULT_TRT_OP_CONVERTER(
221 MakeConverterFunction<ConvertBooleanBinary>(),
222 GetOperationNames(*BinaryBooleanOperationMap()));
223
224 } // namespace convert
225 } // namespace tensorrt
226 } // namespace tensorflow
227 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
228