1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA && GOOGLE_TENSORRT
17
18 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
19 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
20 #include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h"
21
22 namespace tensorflow {
23 namespace tensorrt {
24 namespace convert {
25
26 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
27
28 template <typename Impl>
29 class ConvertFillBase : public OpConverterBase<Impl> {
30 public:
ConvertFillBase(OpConverterParams * params)31 explicit ConvertFillBase(OpConverterParams* params)
32 : OpConverterBase<Impl>(params) {}
33
AllowedDataTypes()34 static constexpr std::array<DataType, 3> AllowedDataTypes() {
35 return {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32};
36 }
37
ValidateFillBase(const OpConverterParams & params)38 Status ValidateFillBase(const OpConverterParams& params) {
39 if (params.use_implicit_batch) {
40 return errors::Unimplemented("Conversion for ", params.node_def.op(),
41 " is not implemented in"
42 " implicit batch mode");
43 }
44 return Status::OK();
45 }
46 };
47
48 class ConvertFill : public ConvertFillBase<ConvertFill> {
49 public:
ConvertFill(OpConverterParams * params)50 explicit ConvertFill(OpConverterParams* params)
51 : ConvertFillBase<ConvertFill>(params) {}
52
InputSpec()53 static constexpr std::array<InputArgSpec, 2> InputSpec() {
54 return std::array<InputArgSpec, 2>{
55 InputArgSpec::Create("dims", TrtInputArg::kBoth),
56 InputArgSpec::Create("value", TrtInputArg::kBoth)};
57 }
58
Validate()59 Status Validate() {
60 const auto& params = *this->params_;
61 TF_RETURN_IF_ERROR(this->ValidateFillBase(params));
62
63 const auto& inputs = params.inputs;
64 const auto& node_def = params.node_def;
65 const TRT_TensorOrWeights& dims_input = inputs.at(0);
66
67 const auto dims_type = dims_input.TrtDType();
68 if (dims_type != nvinfer1::DataType::kINT32) {
69 return errors::InvalidArgument("The dims parameter of ", node_def.op(),
70 " operation in ", node_def.name(),
71 " is expected to be of type ",
72 DebugString(nvinfer1::DataType::kINT32),
73 " type, got ", DebugString(dims_type));
74 }
75
76 const auto nbDims = dims_input.GetTrtDims().nbDims;
77 if (nbDims < 0) {
78 return errors::InvalidArgument("The shape of parameter ", node_def.op(),
79 " operation in ", node_def.name(),
80 " cannot be partial.");
81 }
82 return Status::OK();
83 }
84
Convert()85 Status Convert() {
86 const auto& params = *this->params_;
87 auto* network = params.converter->network();
88 const auto& inputs = params.inputs;
89
90 const bool is_dims_static = inputs[0].is_weights();
91 const bool is_value_static = inputs[1].is_weights();
92
93 const TRT_TensorOrWeights& dims_input = inputs.at(0);
94 const TRT_TensorOrWeights& value_input = inputs.at(1);
95
96 int nbDims = dims_input.GetTrtDims().d[0];
97
98 nvinfer1::Dims trt_dims{0};
99 if (is_dims_static) {
100 const auto dims_weights = dims_input.weights();
101 DimsAdapter dims_adapter(dims_weights.GetSpan<int32>());
102 dims_adapter.TrtDims(&trt_dims);
103 }
104
105 auto builder = TRTNetworkBuilder::Create(network, params.weight_store);
106 StatusOr<nvinfer1::ILayer*> layer =
107 builder->AddFill(value_input, dims_input, is_value_static,
108 is_dims_static, nbDims, trt_dims);
109 ITensorProxyPtr output_tensor = (*layer)->getOutput(0);
110 this->AddOutput(TRT_TensorOrWeights(output_tensor));
111 return Status::OK();
112 }
113 };
114
115 class ConvertRange : public ConvertFillBase<ConvertRange> {
116 public:
ConvertRange(OpConverterParams * params)117 explicit ConvertRange(OpConverterParams* params)
118 : ConvertFillBase<ConvertRange>(params) {}
119
InputSpec()120 static constexpr std::array<InputArgSpec, 3> InputSpec() {
121 return std::array<InputArgSpec, 3>{
122 InputArgSpec::Create("start", TrtInputArg::kBoth),
123 InputArgSpec::Create("limit", TrtInputArg::kBoth),
124 InputArgSpec::Create("delta", TrtInputArg::kBoth)};
125 }
126
NodeDefDataTypeAttributeName()127 static constexpr const char* NodeDefDataTypeAttributeName() { return ""; }
Validate()128 Status Validate() {
129 const auto& params = *this->params_;
130 TF_RETURN_IF_ERROR(this->ValidateFillBase(params));
131
132 const auto& inputs = params.inputs;
133 const auto& node_def = params.node_def;
134
135 float param[3];
136 all_weights_ = all_integers_ = true;
137 for (int i = 0; i < 3; i++) {
138 const auto& input = inputs.at(i);
139 all_integers_ &= input.TrtDType() == nvinfer1::DataType::kINT32;
140 if (input.is_weights()) {
141 switch (input.TrtDType()) {
142 case nvinfer1::DataType::kFLOAT:
143 param[i] = get_input_param<float>(input);
144 break;
145 case nvinfer1::DataType::kHALF:
146 param[i] = get_input_param<Eigen::half>(input);
147 break;
148 case nvinfer1::DataType::kINT32:
149 param[i] = get_input_param<int>(input);
150 break;
151 default:
152 return errors::InvalidArgument(
153 "Unsupported data type ", DebugString(input.TrtDType()),
154 " used for '", InputSpec()[i].name, "'");
155 }
156 } else {
157 all_weights_ = false;
158 }
159 }
160
161 if (!(all_weights_ || all_integers_)) {
162 // As of 06/03/2022, when at least one of the (start, limit, delta)
163 // is passed as a tensor, they must all be of type kINT32
164 return errors::Unimplemented(convert_range_expected_msg(node_def));
165 }
166
167 if (inputs.at(2).is_weights()) {
168 if ((delta_ = param[2]) == 0) {
169 return errors::InvalidArgument("The delta parameter of ", node_def.op(),
170 " operation cannot be equal to 0");
171 }
172
173 if (!all_weights_ && delta_ < 0) {
174 return errors::InvalidArgument(
175 "The delta parameter of Range operation "
176 "cannot be negative, when one of (start, limit) is passed as "
177 "a tensor, but got ",
178 delta_);
179 }
180 }
181
182 for (int i = 0; i < 3; i++) {
183 const auto& input = inputs.at(i);
184 const auto& dims = input.GetTrtDims();
185 if (dims.nbDims != 1 || dims.d[0] != 1) {
186 return errors::InvalidArgument("Dimension for '", InputSpec()[i].name,
187 "' of ", node_def.op(), " operator ",
188 "should be equal to 1");
189 }
190 }
191
192 if (all_weights_) {
193 const auto num_intervals_float =
194 (param[1] - (start_ = param[0])) / delta_;
195 if (num_intervals_float < 0) {
196 const auto error = convert_range_error_msg(start_, param[1], delta_);
197 return errors::InvalidArgument(error);
198 }
199
200 num_values_ = static_cast<int>(num_intervals_float);
201 if (start_ + delta_ * num_values_ != param[1]) {
202 num_values_++;
203 }
204 }
205
206 return Status::OK();
207 }
208
Convert()209 Status Convert() {
210 const auto& params = *this->params_;
211 const auto& inputs = params.inputs;
212 const TRT_TensorOrWeights& input = inputs.at(0);
213 TRT_TensorOrWeights value_input;
214 nvinfer1::Dims trt_dims{1};
215 auto builder = TRTNetworkBuilder::Create(params.converter->network(),
216 params.weight_store);
217 TRT_ENSURE_OK(builder);
218 ITensorProxyPtr dims_input_tensor = nullptr;
219 ITensorProxyPtr beta_tensor = nullptr;
220 ITensorProxyPtr scalar_tensor = nullptr;
221 if (!all_weights_) {
222 ITensorProxyPtr tensors[3];
223 for (int i = 0; i < 3; i++) {
224 TF_RETURN_IF_ERROR(
225 builder->get_tensor4TensorOrWeights(inputs.at(i), tensors + i));
226 }
227
228 StatusOr<nvinfer1::IElementWiseLayer*> num =
229 builder->Sub(/*limit*/ tensors[1]->trt_tensor(),
230 /*start*/ tensors[0]->trt_tensor());
231
232 TRT_ENSURE_PTR_OK(num);
233 StatusOr<nvinfer1::IElementWiseLayer*> ceil_div = builder->FloorDiv(
234 (*num)->getOutput(0), (beta_tensor = tensors[2])->trt_tensor());
235 TRT_ENSURE_PTR_OK(ceil_div);
236 dims_input_tensor = (*ceil_div)->getOutput(0);
237 dims_input_tensor->setType(nvinfer1::DataType::kINT32);
238
239 nvinfer1::Dims scalar_dims{0};
240 TF_RETURN_IF_ERROR(PrepareTensorForShape(
241 params.converter, params.inputs.at(0), scalar_dims, false,
242 &scalar_tensor, params.node_def));
243 } else {
244 DimsAdapter value_input_dims(std::vector<int>{1});
245 StatusOr<TRT_ShapedWeights> value_weights =
246 params.weight_store->GetTempWeights(input.TrtDType(),
247 value_input_dims);
248
249 TF_RETURN_IF_ERROR(value_weights.status());
250 TF_RETURN_IF_ERROR(value_weights->SetValues(start_));
251 value_input = TRT_TensorOrWeights(value_weights.ValueOrDie());
252
253 trt_dims.d[0] = num_values_;
254 StatusOr<nvinfer1::IConstantLayer*> const_layer =
255 builder->ConstantShape(value_input_dims);
256 TRT_ENSURE_PTR_OK(const_layer);
257 dims_input_tensor = (*const_layer)->getOutput(0);
258 }
259
260 TRT_TensorOrWeights dims_input(dims_input_tensor);
261
262 StatusOr<nvinfer1::ILayer*> layer =
263 builder->AddFill(value_input, dims_input, all_weights_, all_weights_, 1,
264 trt_dims, scalar_tensor, beta_tensor, delta_);
265
266 ITensorProxyPtr output_tensor = (*layer)->getOutput(0);
267 if (all_integers_) {
268 output_tensor->setType(nvinfer1::DataType::kINT32);
269 }
270
271 this->AddOutput(TRT_TensorOrWeights(output_tensor));
272 return Status::OK();
273 }
274
275 private:
276 template <typename T>
get_input_param(const TRT_TensorOrWeights & input)277 float get_input_param(const TRT_TensorOrWeights& input) {
278 return static_cast<float>(*input.weights().GetPointer<T>());
279 }
280
281 float start_;
282 float delta_;
283 int num_values_;
284 bool all_weights_;
285 bool all_integers_;
286 };
287
convert_range_error_msg(float start,float limit,float delta)288 std::string convert_range_error_msg(float start, float limit, float delta) {
289 const char* format_string =
290 "For parameters (start, limit) = (%.2f, %.2f) "
291 "of the Range operation delta cannot be %s, got %.2f";
292 return absl::StrFormat(format_string, start, limit,
293 start < limit ? "negative" : "positive", delta);
294 }
295
convert_range_expected_msg(const NodeDef & node_def)296 std::string convert_range_expected_msg(const NodeDef& node_def) {
297 return "When at least one of parameters (start, limit, delta) of " +
298 node_def.op() + " operation in " + node_def.name() +
299 " is passed as a tensor, they must all be of type kINT32";
300 }
301
302 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertFill>(), "Fill");
303 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertRange>(),
304 "Range");
305
306 #endif // IS_TRT_VERSION_GE(8, 2, 0, 0)
307
308 } // namespace convert
309 } // namespace tensorrt
310 } // namespace tensorflow
311 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
312