xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/ops/binary_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA && GOOGLE_TENSORRT
17 
18 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
19 #include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h"
20 
21 namespace tensorflow {
22 namespace tensorrt {
23 namespace convert {
24 
BinaryOperationMap()25 const BinaryOperationMapType* BinaryOperationMap() {
26   static const auto* map = new BinaryOperationMapType({
27     {"Add", nvinfer1::ElementWiseOperation::kSUM},
28         {"AddV2", nvinfer1::ElementWiseOperation::kSUM},
29         {"Mul", nvinfer1::ElementWiseOperation::kPROD},
30         {"Sub", nvinfer1::ElementWiseOperation::kSUB},
31         {"Div", nvinfer1::ElementWiseOperation::kDIV},
32         {"FloorDiv", nvinfer1::ElementWiseOperation::kFLOOR_DIV},
33         {"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
34         {"Minimum", nvinfer1::ElementWiseOperation::kMIN},
35         {"Maximum", nvinfer1::ElementWiseOperation::kMAX},
36         {"Pow", nvinfer1::ElementWiseOperation::kPOW},
37 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
38         {"Greater", nvinfer1::ElementWiseOperation::kGREATER},
39         {"Less", nvinfer1::ElementWiseOperation::kLESS},
40         {"Equal", nvinfer1::ElementWiseOperation::kEQUAL},
41         // Operators are implemented as NOT Less and NOT Greater, respectively.
42         {"GreaterEqual", nvinfer1::ElementWiseOperation::kLESS},
43         {"LessEqual", nvinfer1::ElementWiseOperation::kGREATER},
44 #endif
45   });
46   return map;
47 }
48 
BinaryBooleanOperationMap()49 const BinaryOperationMapType* BinaryBooleanOperationMap() {
50   static const auto* map = new BinaryOperationMapType({
51       {"LogicalOr", nvinfer1::ElementWiseOperation::kOR},
52       {"LogicalAnd", nvinfer1::ElementWiseOperation::kAND},
53   });
54   return map;
55 }
56 
57 namespace {
58 class ConvertBinaryImpl {
59  protected:
ConvertBinaryImpl(const BinaryOperationMapType * pOperMap)60   ConvertBinaryImpl(const BinaryOperationMapType* pOperMap)
61       : pOperMap_(pOperMap) {}
62 
ValidateImpl(const OpConverterParams & params,const std::vector<string> & implicit_batch_not_supported_ops={},bool both_tensors=false)63   Status ValidateImpl(
64       const OpConverterParams& params,
65       const std::vector<string>& implicit_batch_not_supported_ops = {},
66       bool both_tensors = false) {
67     const auto& node_def = params.node_def;
68     const auto op = node_def.op();
69     const auto op_pair = pOperMap_->find(op);
70     if (op_pair == pOperMap_->end()) {
71       return errors::Unimplemented("Binary op: ", op, " not supported");
72     }
73 
74     // Constant folding should have been done by TensorFlow.
75     const auto& inputs = params.inputs;
76     if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
77       return errors::Unimplemented(
78           "Constant folding is falled back to TensorFlow, binary op '", op,
79           "' received both input as constant");
80     }
81 
82     if ((convertToBool_ = find_name(op, implicit_batch_not_supported_ops))) {
83       if (params.use_implicit_batch) {
84         return errors::Unimplemented(
85             "Binary op: '", op, "' is not supported in implicit batch mode");
86       }
87     }
88 
89     if (both_tensors) {
90       if (inputs.at(0).is_weights() || inputs.at(1).is_weights()) {
91         return errors::InvalidArgument("Both inputs  of '", op,
92                                        "' are expected to be tensors");
93       }
94       // No need to convert the output of "LogicalOr" and "LogicalAnd"
95       convertToBool_ = false;
96     }
97 
98     nvinfer1::Dims broadcasted_dims[2];
99     TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
100         inputs.at(0), inputs.at(1), true, params.use_implicit_batch,
101         broadcasted_dims, broadcasted_dims + 1));
102 
103     for (int i = 0; i < tensor_.size(); i++) {
104       // This will also convert constants to tensors.
105       TF_RETURN_IF_ERROR(PrepareTensorForShape(
106           params.converter, inputs.at(i), broadcasted_dims[i],
107           params.validation_only, &tensor_[i], node_def, i));
108     }
109     operation_ = op_pair->second;
110     return Status::OK();
111   }
112 
ConvertImpl(const OpConverterParams & params,const std::vector<string> & revert_bool_ops={})113   Status ConvertImpl(const OpConverterParams& params,
114                      const std::vector<string>& revert_bool_ops = {}) {
115     const auto& node_def = params.node_def;
116     // Add ElementWise layer.
117     auto* network = params.converter->network();
118     nvinfer1::ILayer* layer = network->addElementWise(
119         *tensor_[0]->trt_tensor(), *tensor_[1]->trt_tensor(), operation_);
120     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
121 
122     if (params.use_explicit_precision) {
123       layer->setPrecision(nvinfer1::DataType::kFLOAT);
124     }
125 
126     params.converter->SetLayerName(layer, node_def);
127     const auto& output = layer->getOutput(0);
128     if (convertToBool_) {
129       output->setType(nvinfer1::DataType::kBOOL);
130       if (find_name(node_def.op(), revert_bool_ops)) {
131         nvinfer1::IUnaryLayer* unary_layer =
132             network->addUnary(*output, nvinfer1::UnaryOperation::kNOT);
133         TFTRT_RETURN_ERROR_IF_NULLPTR(unary_layer, node_def.name());
134         params.outputs->push_back(
135             TRT_TensorOrWeights(unary_layer->getOutput(0)));
136         return Status::OK();
137       }
138     }
139 
140     params.outputs->push_back(TRT_TensorOrWeights(output));
141     return Status::OK();
142   }
143 
InputSpec()144   static constexpr std::array<InputArgSpec, 2> InputSpec() {
145     return std::array<InputArgSpec, 2>{
146         InputArgSpec::Create("x", TrtInputArg::kBoth),
147         InputArgSpec::Create("y", TrtInputArg::kBoth)};
148   }
149 
150  private:
151   const BinaryOperationMapType* pOperMap_;
152   std::array<ITensorProxyPtr, 2> tensor_{nullptr, nullptr};
153   nvinfer1::ElementWiseOperation operation_;
154   bool convertToBool_;
155 };
156 
157 class ConvertBinary : public OpConverterBase<ConvertBinary>,
158                       protected ConvertBinaryImpl {
159  public:
ConvertBinary(OpConverterParams * params)160   explicit ConvertBinary(OpConverterParams* params)
161       : OpConverterBase<ConvertBinary>(params),
162         ConvertBinaryImpl(BinaryOperationMap()) {}
163 
AllowedDataTypes()164   static constexpr std::array<DataType, 3> AllowedDataTypes() {
165     return {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32};
166   }
167 
InputSpec()168   static constexpr std::array<InputArgSpec, 2> InputSpec() {
169     return ConvertBinaryImpl::InputSpec();
170   }
171 
Validate()172   Status Validate() {
173     const std::vector<string> implicit_batch_not_supported_ops {
174 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
175       "Greater", "Less", "Equal", "GreaterEqual", "LessEqual"
176 #endif
177     };
178     return ValidateImpl(*params_, implicit_batch_not_supported_ops);
179   }
Convert()180   Status Convert() {
181     const std::vector<string> implemented_with_reverted_ops {
182 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
183       "GreaterEqual", "LessEqual"
184 #endif
185     };
186     return ConvertImpl(*params_, implemented_with_reverted_ops);
187   }
188 };
189 
190 class ConvertBooleanBinary : public OpConverterBase<ConvertBooleanBinary>,
191                              public ConvertBinaryImpl {
192  public:
ConvertBooleanBinary(OpConverterParams * params)193   explicit ConvertBooleanBinary(OpConverterParams* params)
194       : OpConverterBase<ConvertBooleanBinary>(params),
195         ConvertBinaryImpl(BinaryBooleanOperationMap()) {}
196 
AllowedDataTypes()197   static constexpr std::array<DataType, 1> AllowedDataTypes() {
198     return {DataType::DT_BOOL};
199   }
200 
InputSpec()201   static constexpr std::array<InputArgSpec, 2> InputSpec() {
202     return ConvertBinaryImpl::InputSpec();
203   }
204 
NodeDefDataTypeAttributeName()205   static constexpr const char* NodeDefDataTypeAttributeName() { return ""; }
Validate()206   Status Validate() {
207 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
208     return ValidateImpl(*params_, {"LogicalOr", "LogicalAnd"}, true);
209 #else
210     return errors::Unimplemented("Boolean op: ", params_->node_def.op(),
211                                  " is not supported in TRT version < 8.2");
212 #endif
213   }
Convert()214   Status Convert() { return ConvertImpl(*params_); }
215 };
216 }  // namespace
217 
218 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertBinary>(),
219                                   GetOperationNames(*BinaryOperationMap()));
220 REGISTER_DEFAULT_TRT_OP_CONVERTER(
221     MakeConverterFunction<ConvertBooleanBinary>(),
222     GetOperationNames(*BinaryBooleanOperationMap()));
223 
224 }  // namespace convert
225 }  // namespace tensorrt
226 }  // namespace tensorflow
227 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
228