1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OP_CONVERTER_H_
16 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OP_CONVERTER_H_
17
18 #if GOOGLE_CUDA && GOOGLE_TENSORRT
19
20 #include <memory>
21 #include <vector>
22
23 #include "absl/strings/str_format.h"
24 #include "tensorflow/compiler/tf2tensorrt/convert/trt_parameters.h"
25 #include "tensorflow/compiler/tf2tensorrt/convert/weights.h"
26
27 namespace tensorflow {
28 namespace tensorrt {
29 namespace convert {
30
31 class Converter;
32
33 // Specifies the expected type taken by a TRT_TensorOrWeights input during op
34 // conversion.
35 // kResource is only used for resource variable ops. For an operation like
36 // Add(tensor, ReadVariableOp(...)), the second operand of Add is the result of
37 // the ReadVariableOp, which is a kWeight.
38 enum class TrtInputArg { kTensor = 1, kWeight = 2, kBoth = 3, kResource = 4 };
39
40 // Parameters for each op converter.
41 struct OpConverterParams {
42 // Constructor used for validation only.
43 OpConverterParams(const NodeDef& node_def,
44 const std::vector<TRT_TensorOrWeights>& inputs,
45 std::vector<TRT_TensorOrWeights>* outputs,
46 TrtWeightStore* weight_store,
47 TrtPrecisionMode precision_mode, bool use_calibration,
48 bool use_implicit_batch, bool use_explicit_precision);
49
50 // Constructor used for conversion.
51 OpConverterParams(Converter* converter, const NodeDef& node_def,
52 const std::vector<TRT_TensorOrWeights>& inputs,
53 std::vector<TRT_TensorOrWeights>* outputs,
54 TrtWeightStore* weight_store);
55
56 Converter* converter = nullptr;
57 const NodeDef& node_def;
58 const std::vector<TRT_TensorOrWeights>& inputs;
59 std::vector<TRT_TensorOrWeights>* outputs;
60 const bool validation_only;
61 TrtWeightStore* weight_store;
62 const TrtPrecisionMode precision_mode;
63 const bool use_calibration;
64 const bool use_implicit_batch;
65 const bool use_explicit_precision;
66 };
67
68 // Operation converter function specification.
69 using OpConverter = std::function<Status(OpConverterParams*)>;
70
71 struct InputArgSpec {
72 absl::string_view name;
73 TrtInputArg allowed_roles;
74
CreateInputArgSpec75 static constexpr InputArgSpec Create(absl::string_view n, TrtInputArg role) {
76 return InputArgSpec{n, role};
77 }
78 };
79
80 // A Curiously recurring template pattern (CRTP) template class for operation
81 // converters.
82 template <typename Impl>
83 class OpConverterBase {
84 public:
OpConverterBase(OpConverterParams * params)85 explicit OpConverterBase(OpConverterParams* params)
86 : params_(params), node_def_attrs_(params->node_def) {}
87
88 // Default NodeDef attribute name to inspect in order to determine node data
89 // type. The Impl class can override this by implementing the same function.
NodeDefDataTypeAttributeName()90 static constexpr const char* NodeDefDataTypeAttributeName() { return "T"; }
91
92 // Default allowed data types for the NodeDef data type attribute. The Impl
93 // class can override this by implementing the same function.
AllowedDataTypes()94 static constexpr std::array<DataType, 2> AllowedDataTypes() {
95 return {DataType::DT_FLOAT, DataType::DT_HALF};
96 }
97
98 // Validate data type of the given NodeDef against allowed types.
ValidateNodeDefDataType()99 Status ValidateNodeDefDataType() {
100 // If the attribute name is empty, we should skip this check.
101 if (absl::string_view(Impl::NodeDefDataTypeAttributeName()).empty()) {
102 return Status::OK();
103 }
104
105 // Get the NodeDef data type.
106 auto dtype = GetAttrValue<DataType>(Impl::NodeDefDataTypeAttributeName());
107 if (!dtype.ok()) {
108 return errors::InvalidArgument("Attribute with name ",
109 Impl::NodeDefDataTypeAttributeName(),
110 " not found.");
111 }
112
113 // Check allowed data types.
114 const auto& node_def = params_->node_def;
115 const auto& allowed_dtypes = Impl::AllowedDataTypes();
116 if (std::find(allowed_dtypes.begin(), allowed_dtypes.end(), *dtype) ==
117 allowed_dtypes.end()) {
118 std::string allowed_types_string = absl::StrJoin(
119 allowed_dtypes, ", ", [](std::string* out, const DataType& type) {
120 absl::StrAppendFormat(out, "%s", DataTypeString(type));
121 });
122 return errors::Unimplemented("Data type ", DataTypeString(*dtype),
123 " is not supported for ", node_def.op(),
124 ", must be one of [", allowed_types_string,
125 "], at ", node_def.name());
126 }
127 return Status::OK();
128 }
129
HasFixNumberOfInputs()130 static constexpr bool HasFixNumberOfInputs() { return true; }
131
132 // Validates input argument roles and data types.
ValidateInputs()133 Status ValidateInputs() {
134 const NodeDef& node_def = params_->node_def;
135 const auto& inputs = params_->inputs;
136 if (Impl::HasFixNumberOfInputs()) {
137 TRT_ENSURE(inputs.size() == Impl::InputSpec().size());
138 } else {
139 TRT_ENSURE(inputs.size() <= Impl::InputSpec().size());
140 }
141 for (int i = 0; i < inputs.size(); i++) {
142 const InputArgSpec arg_spec = Impl::InputSpec()[i];
143 if (arg_spec.allowed_roles == TrtInputArg::kWeight &&
144 inputs.at(i).is_tensor()) {
145 return errors::Unimplemented("The input \"", arg_spec.name, "\" for ",
146 node_def.op(), " must be a constant, at ",
147 node_def.name());
148 }
149 if (arg_spec.allowed_roles == TrtInputArg::kTensor &&
150 inputs.at(i).is_weights()) {
151 return errors::Unimplemented("The input \"", arg_spec.name, "\" for ",
152 node_def.op(), " must be a tensor, at ",
153 node_def.name());
154 }
155 }
156 return Status::OK();
157 }
158
operator()159 Status operator()() {
160 // Validate data type and inputs.
161 TF_RETURN_IF_ERROR(this->ValidateNodeDefDataType());
162 TF_RETURN_IF_ERROR(this->ValidateInputs());
163
164 // Perform op-level validation.
165 TF_RETURN_IF_ERROR(reinterpret_cast<Impl*>(this)->Validate());
166 if (params_->validation_only) {
167 return Status::OK();
168 }
169
170 // Perform conversion.
171 return reinterpret_cast<Impl*>(this)->Convert();
172 }
173
174 protected:
AddOutput(const TRT_TensorOrWeights & out)175 void AddOutput(const TRT_TensorOrWeights& out) {
176 params_->outputs->push_back(out);
177 }
178
179 template <typename T>
GetAttrValue(absl::string_view key)180 StatusOr<T> GetAttrValue(absl::string_view key) const {
181 T result;
182 TF_RETURN_IF_ERROR(GetNodeAttr(node_def_attrs_, key, &result));
183 return result;
184 }
185
186 OpConverterParams* const params_;
187 AttrSlice node_def_attrs_;
188 };
189
190 // Constructs and returns a converter function for a given operation converter
191 // class T. This requires T to be a derived class of StructuredOpConverter.
192 template <typename T>
MakeConverterFunction()193 OpConverter MakeConverterFunction() {
194 return [](OpConverterParams* params) -> Status {
195 T converter(params);
196 return converter();
197 };
198 }
199
200 } // namespace convert
201 } // namespace tensorrt
202 } // namespace tensorflow
203
204 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
205 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OP_CONVERTER_H_
206