xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/op_converter.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OP_CONVERTER_H_
16 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OP_CONVERTER_H_
17 
18 #if GOOGLE_CUDA && GOOGLE_TENSORRT
19 
20 #include <memory>
21 #include <vector>
22 
23 #include "absl/strings/str_format.h"
24 #include "tensorflow/compiler/tf2tensorrt/convert/trt_parameters.h"
25 #include "tensorflow/compiler/tf2tensorrt/convert/weights.h"
26 
27 namespace tensorflow {
28 namespace tensorrt {
29 namespace convert {
30 
31 class Converter;
32 
33 // Specifies the expected type taken by a TRT_TensorOrWeights input during op
34 // conversion.
35 // kResource is only used for resource variable ops. For an operation like
36 // Add(tensor, ReadVariableOp(...)), the second operand of Add is the result of
37 // the ReadVariableOp, which is a kWeight.
38 enum class TrtInputArg { kTensor = 1, kWeight = 2, kBoth = 3, kResource = 4 };
39 
40 // Parameters for each op converter.
41 struct OpConverterParams {
42   // Constructor used for validation only.
43   OpConverterParams(const NodeDef& node_def,
44                     const std::vector<TRT_TensorOrWeights>& inputs,
45                     std::vector<TRT_TensorOrWeights>* outputs,
46                     TrtWeightStore* weight_store,
47                     TrtPrecisionMode precision_mode, bool use_calibration,
48                     bool use_implicit_batch, bool use_explicit_precision);
49 
50   // Constructor used for conversion.
51   OpConverterParams(Converter* converter, const NodeDef& node_def,
52                     const std::vector<TRT_TensorOrWeights>& inputs,
53                     std::vector<TRT_TensorOrWeights>* outputs,
54                     TrtWeightStore* weight_store);
55 
56   Converter* converter = nullptr;
57   const NodeDef& node_def;
58   const std::vector<TRT_TensorOrWeights>& inputs;
59   std::vector<TRT_TensorOrWeights>* outputs;
60   const bool validation_only;
61   TrtWeightStore* weight_store;
62   const TrtPrecisionMode precision_mode;
63   const bool use_calibration;
64   const bool use_implicit_batch;
65   const bool use_explicit_precision;
66 };
67 
68 // Operation converter function specification.
69 using OpConverter = std::function<Status(OpConverterParams*)>;
70 
71 struct InputArgSpec {
72   absl::string_view name;
73   TrtInputArg allowed_roles;
74 
CreateInputArgSpec75   static constexpr InputArgSpec Create(absl::string_view n, TrtInputArg role) {
76     return InputArgSpec{n, role};
77   }
78 };
79 
80 // A Curiously recurring template pattern (CRTP) template class for operation
81 // converters.
82 template <typename Impl>
83 class OpConverterBase {
84  public:
OpConverterBase(OpConverterParams * params)85   explicit OpConverterBase(OpConverterParams* params)
86       : params_(params), node_def_attrs_(params->node_def) {}
87 
88   // Default NodeDef attribute name to inspect in order to determine node data
89   // type. The Impl class can override this by implementing the same function.
NodeDefDataTypeAttributeName()90   static constexpr const char* NodeDefDataTypeAttributeName() { return "T"; }
91 
92   // Default allowed data types for the NodeDef data type attribute. The Impl
93   // class can override this by implementing the same function.
AllowedDataTypes()94   static constexpr std::array<DataType, 2> AllowedDataTypes() {
95     return {DataType::DT_FLOAT, DataType::DT_HALF};
96   }
97 
98   // Validate data type of the given NodeDef against allowed types.
ValidateNodeDefDataType()99   Status ValidateNodeDefDataType() {
100     // If the attribute name is empty, we should skip this check.
101     if (absl::string_view(Impl::NodeDefDataTypeAttributeName()).empty()) {
102       return Status::OK();
103     }
104 
105     // Get the NodeDef data type.
106     auto dtype = GetAttrValue<DataType>(Impl::NodeDefDataTypeAttributeName());
107     if (!dtype.ok()) {
108       return errors::InvalidArgument("Attribute with name ",
109                                      Impl::NodeDefDataTypeAttributeName(),
110                                      " not found.");
111     }
112 
113     // Check allowed data types.
114     const auto& node_def = params_->node_def;
115     const auto& allowed_dtypes = Impl::AllowedDataTypes();
116     if (std::find(allowed_dtypes.begin(), allowed_dtypes.end(), *dtype) ==
117         allowed_dtypes.end()) {
118       std::string allowed_types_string = absl::StrJoin(
119           allowed_dtypes, ", ", [](std::string* out, const DataType& type) {
120             absl::StrAppendFormat(out, "%s", DataTypeString(type));
121           });
122       return errors::Unimplemented("Data type ", DataTypeString(*dtype),
123                                    " is not supported for ", node_def.op(),
124                                    ", must be one of [", allowed_types_string,
125                                    "], at ", node_def.name());
126     }
127     return Status::OK();
128   }
129 
HasFixNumberOfInputs()130   static constexpr bool HasFixNumberOfInputs() { return true; }
131 
132   // Validates input argument roles and data types.
ValidateInputs()133   Status ValidateInputs() {
134     const NodeDef& node_def = params_->node_def;
135     const auto& inputs = params_->inputs;
136     if (Impl::HasFixNumberOfInputs()) {
137       TRT_ENSURE(inputs.size() == Impl::InputSpec().size());
138     } else {
139       TRT_ENSURE(inputs.size() <= Impl::InputSpec().size());
140     }
141     for (int i = 0; i < inputs.size(); i++) {
142       const InputArgSpec arg_spec = Impl::InputSpec()[i];
143       if (arg_spec.allowed_roles == TrtInputArg::kWeight &&
144           inputs.at(i).is_tensor()) {
145         return errors::Unimplemented("The input \"", arg_spec.name, "\" for ",
146                                      node_def.op(), " must be a constant, at ",
147                                      node_def.name());
148       }
149       if (arg_spec.allowed_roles == TrtInputArg::kTensor &&
150           inputs.at(i).is_weights()) {
151         return errors::Unimplemented("The input \"", arg_spec.name, "\" for ",
152                                      node_def.op(), " must be a tensor, at ",
153                                      node_def.name());
154       }
155     }
156     return Status::OK();
157   }
158 
operator()159   Status operator()() {
160     // Validate data type and inputs.
161     TF_RETURN_IF_ERROR(this->ValidateNodeDefDataType());
162     TF_RETURN_IF_ERROR(this->ValidateInputs());
163 
164     // Perform op-level validation.
165     TF_RETURN_IF_ERROR(reinterpret_cast<Impl*>(this)->Validate());
166     if (params_->validation_only) {
167       return Status::OK();
168     }
169 
170     // Perform conversion.
171     return reinterpret_cast<Impl*>(this)->Convert();
172   }
173 
174  protected:
AddOutput(const TRT_TensorOrWeights & out)175   void AddOutput(const TRT_TensorOrWeights& out) {
176     params_->outputs->push_back(out);
177   }
178 
179   template <typename T>
GetAttrValue(absl::string_view key)180   StatusOr<T> GetAttrValue(absl::string_view key) const {
181     T result;
182     TF_RETURN_IF_ERROR(GetNodeAttr(node_def_attrs_, key, &result));
183     return result;
184   }
185 
186   OpConverterParams* const params_;
187   AttrSlice node_def_attrs_;
188 };
189 
190 // Constructs and returns a converter function for a given operation converter
191 // class T. This requires T to be a derived class of StructuredOpConverter.
192 template <typename T>
MakeConverterFunction()193 OpConverter MakeConverterFunction() {
194   return [](OpConverterParams* params) -> Status {
195     T converter(params);
196     return converter();
197   };
198 }
199 
200 }  // namespace convert
201 }  // namespace tensorrt
202 }  // namespace tensorflow
203 
204 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
205 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OP_CONVERTER_H_
206