xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/ops/fill_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA && GOOGLE_TENSORRT
17 
18 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
19 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
20 #include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h"
21 
22 namespace tensorflow {
23 namespace tensorrt {
24 namespace convert {
25 
26 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
27 
28 template <typename Impl>
29 class ConvertFillBase : public OpConverterBase<Impl> {
30  public:
ConvertFillBase(OpConverterParams * params)31   explicit ConvertFillBase(OpConverterParams* params)
32       : OpConverterBase<Impl>(params) {}
33 
AllowedDataTypes()34   static constexpr std::array<DataType, 3> AllowedDataTypes() {
35     return {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32};
36   }
37 
ValidateFillBase(const OpConverterParams & params)38   Status ValidateFillBase(const OpConverterParams& params) {
39     if (params.use_implicit_batch) {
40       return errors::Unimplemented("Conversion for ", params.node_def.op(),
41                                    " is not implemented in"
42                                    " implicit batch mode");
43     }
44     return Status::OK();
45   }
46 };
47 
48 class ConvertFill : public ConvertFillBase<ConvertFill> {
49  public:
ConvertFill(OpConverterParams * params)50   explicit ConvertFill(OpConverterParams* params)
51       : ConvertFillBase<ConvertFill>(params) {}
52 
InputSpec()53   static constexpr std::array<InputArgSpec, 2> InputSpec() {
54     return std::array<InputArgSpec, 2>{
55         InputArgSpec::Create("dims", TrtInputArg::kBoth),
56         InputArgSpec::Create("value", TrtInputArg::kBoth)};
57   }
58 
Validate()59   Status Validate() {
60     const auto& params = *this->params_;
61     TF_RETURN_IF_ERROR(this->ValidateFillBase(params));
62 
63     const auto& inputs = params.inputs;
64     const auto& node_def = params.node_def;
65     const TRT_TensorOrWeights& dims_input = inputs.at(0);
66 
67     const auto dims_type = dims_input.TrtDType();
68     if (dims_type != nvinfer1::DataType::kINT32) {
69       return errors::InvalidArgument("The dims parameter of ", node_def.op(),
70                                      " operation in ", node_def.name(),
71                                      " is expected to be of type ",
72                                      DebugString(nvinfer1::DataType::kINT32),
73                                      " type, got ", DebugString(dims_type));
74     }
75 
76     const auto nbDims = dims_input.GetTrtDims().nbDims;
77     if (nbDims < 0) {
78       return errors::InvalidArgument("The shape of parameter ", node_def.op(),
79                                      " operation in ", node_def.name(),
80                                      " cannot be partial.");
81     }
82     return Status::OK();
83   }
84 
Convert()85   Status Convert() {
86     const auto& params = *this->params_;
87     auto* network = params.converter->network();
88     const auto& inputs = params.inputs;
89 
90     const bool is_dims_static = inputs[0].is_weights();
91     const bool is_value_static = inputs[1].is_weights();
92 
93     const TRT_TensorOrWeights& dims_input = inputs.at(0);
94     const TRT_TensorOrWeights& value_input = inputs.at(1);
95 
96     int nbDims = dims_input.GetTrtDims().d[0];
97 
98     nvinfer1::Dims trt_dims{0};
99     if (is_dims_static) {
100       const auto dims_weights = dims_input.weights();
101       DimsAdapter dims_adapter(dims_weights.GetSpan<int32>());
102       dims_adapter.TrtDims(&trt_dims);
103     }
104 
105     auto builder = TRTNetworkBuilder::Create(network, params.weight_store);
106     StatusOr<nvinfer1::ILayer*> layer =
107         builder->AddFill(value_input, dims_input, is_value_static,
108                          is_dims_static, nbDims, trt_dims);
109     ITensorProxyPtr output_tensor = (*layer)->getOutput(0);
110     this->AddOutput(TRT_TensorOrWeights(output_tensor));
111     return Status::OK();
112   }
113 };
114 
115 class ConvertRange : public ConvertFillBase<ConvertRange> {
116  public:
ConvertRange(OpConverterParams * params)117   explicit ConvertRange(OpConverterParams* params)
118       : ConvertFillBase<ConvertRange>(params) {}
119 
InputSpec()120   static constexpr std::array<InputArgSpec, 3> InputSpec() {
121     return std::array<InputArgSpec, 3>{
122         InputArgSpec::Create("start", TrtInputArg::kBoth),
123         InputArgSpec::Create("limit", TrtInputArg::kBoth),
124         InputArgSpec::Create("delta", TrtInputArg::kBoth)};
125   }
126 
NodeDefDataTypeAttributeName()127   static constexpr const char* NodeDefDataTypeAttributeName() { return ""; }
Validate()128   Status Validate() {
129     const auto& params = *this->params_;
130     TF_RETURN_IF_ERROR(this->ValidateFillBase(params));
131 
132     const auto& inputs = params.inputs;
133     const auto& node_def = params.node_def;
134 
135     float param[3];
136     all_weights_ = all_integers_ = true;
137     for (int i = 0; i < 3; i++) {
138       const auto& input = inputs.at(i);
139       all_integers_ &= input.TrtDType() == nvinfer1::DataType::kINT32;
140       if (input.is_weights()) {
141         switch (input.TrtDType()) {
142           case nvinfer1::DataType::kFLOAT:
143             param[i] = get_input_param<float>(input);
144             break;
145           case nvinfer1::DataType::kHALF:
146             param[i] = get_input_param<Eigen::half>(input);
147             break;
148           case nvinfer1::DataType::kINT32:
149             param[i] = get_input_param<int>(input);
150             break;
151           default:
152             return errors::InvalidArgument(
153                 "Unsupported data type ", DebugString(input.TrtDType()),
154                 " used for '", InputSpec()[i].name, "'");
155         }
156       } else {
157         all_weights_ = false;
158       }
159     }
160 
161     if (!(all_weights_ || all_integers_)) {
162       // As of 06/03/2022, when at least one of the (start, limit, delta)
163       // is passed as a tensor, they must all be of type kINT32
164       return errors::Unimplemented(convert_range_expected_msg(node_def));
165     }
166 
167     if (inputs.at(2).is_weights()) {
168       if ((delta_ = param[2]) == 0) {
169         return errors::InvalidArgument("The delta parameter of ", node_def.op(),
170                                        " operation cannot be equal to 0");
171       }
172 
173       if (!all_weights_ && delta_ < 0) {
174         return errors::InvalidArgument(
175             "The delta parameter of Range operation "
176             "cannot be negative, when one of (start, limit) is passed as "
177             "a tensor, but got ",
178             delta_);
179       }
180     }
181 
182     for (int i = 0; i < 3; i++) {
183       const auto& input = inputs.at(i);
184       const auto& dims = input.GetTrtDims();
185       if (dims.nbDims != 1 || dims.d[0] != 1) {
186         return errors::InvalidArgument("Dimension for '", InputSpec()[i].name,
187                                        "' of ", node_def.op(), " operator ",
188                                        "should be equal to 1");
189       }
190     }
191 
192     if (all_weights_) {
193       const auto num_intervals_float =
194           (param[1] - (start_ = param[0])) / delta_;
195       if (num_intervals_float < 0) {
196         const auto error = convert_range_error_msg(start_, param[1], delta_);
197         return errors::InvalidArgument(error);
198       }
199 
200       num_values_ = static_cast<int>(num_intervals_float);
201       if (start_ + delta_ * num_values_ != param[1]) {
202         num_values_++;
203       }
204     }
205 
206     return Status::OK();
207   }
208 
Convert()209   Status Convert() {
210     const auto& params = *this->params_;
211     const auto& inputs = params.inputs;
212     const TRT_TensorOrWeights& input = inputs.at(0);
213     TRT_TensorOrWeights value_input;
214     nvinfer1::Dims trt_dims{1};
215     auto builder = TRTNetworkBuilder::Create(params.converter->network(),
216                                              params.weight_store);
217     TRT_ENSURE_OK(builder);
218     ITensorProxyPtr dims_input_tensor = nullptr;
219     ITensorProxyPtr beta_tensor = nullptr;
220     ITensorProxyPtr scalar_tensor = nullptr;
221     if (!all_weights_) {
222       ITensorProxyPtr tensors[3];
223       for (int i = 0; i < 3; i++) {
224         TF_RETURN_IF_ERROR(
225             builder->get_tensor4TensorOrWeights(inputs.at(i), tensors + i));
226       }
227 
228       StatusOr<nvinfer1::IElementWiseLayer*> num =
229           builder->Sub(/*limit*/ tensors[1]->trt_tensor(),
230                        /*start*/ tensors[0]->trt_tensor());
231 
232       TRT_ENSURE_PTR_OK(num);
233       StatusOr<nvinfer1::IElementWiseLayer*> ceil_div = builder->FloorDiv(
234           (*num)->getOutput(0), (beta_tensor = tensors[2])->trt_tensor());
235       TRT_ENSURE_PTR_OK(ceil_div);
236       dims_input_tensor = (*ceil_div)->getOutput(0);
237       dims_input_tensor->setType(nvinfer1::DataType::kINT32);
238 
239       nvinfer1::Dims scalar_dims{0};
240       TF_RETURN_IF_ERROR(PrepareTensorForShape(
241           params.converter, params.inputs.at(0), scalar_dims, false,
242           &scalar_tensor, params.node_def));
243     } else {
244       DimsAdapter value_input_dims(std::vector<int>{1});
245       StatusOr<TRT_ShapedWeights> value_weights =
246           params.weight_store->GetTempWeights(input.TrtDType(),
247                                               value_input_dims);
248 
249       TF_RETURN_IF_ERROR(value_weights.status());
250       TF_RETURN_IF_ERROR(value_weights->SetValues(start_));
251       value_input = TRT_TensorOrWeights(value_weights.ValueOrDie());
252 
253       trt_dims.d[0] = num_values_;
254       StatusOr<nvinfer1::IConstantLayer*> const_layer =
255           builder->ConstantShape(value_input_dims);
256       TRT_ENSURE_PTR_OK(const_layer);
257       dims_input_tensor = (*const_layer)->getOutput(0);
258     }
259 
260     TRT_TensorOrWeights dims_input(dims_input_tensor);
261 
262     StatusOr<nvinfer1::ILayer*> layer =
263         builder->AddFill(value_input, dims_input, all_weights_, all_weights_, 1,
264                          trt_dims, scalar_tensor, beta_tensor, delta_);
265 
266     ITensorProxyPtr output_tensor = (*layer)->getOutput(0);
267     if (all_integers_) {
268       output_tensor->setType(nvinfer1::DataType::kINT32);
269     }
270 
271     this->AddOutput(TRT_TensorOrWeights(output_tensor));
272     return Status::OK();
273   }
274 
275  private:
276   template <typename T>
get_input_param(const TRT_TensorOrWeights & input)277   float get_input_param(const TRT_TensorOrWeights& input) {
278     return static_cast<float>(*input.weights().GetPointer<T>());
279   }
280 
281   float start_;
282   float delta_;
283   int num_values_;
284   bool all_weights_;
285   bool all_integers_;
286 };
287 
convert_range_error_msg(float start,float limit,float delta)288 std::string convert_range_error_msg(float start, float limit, float delta) {
289   const char* format_string =
290       "For parameters (start, limit) = (%.2f, %.2f) "
291       "of the Range operation delta cannot be %s, got %.2f";
292   return absl::StrFormat(format_string, start, limit,
293                          start < limit ? "negative" : "positive", delta);
294 }
295 
convert_range_expected_msg(const NodeDef & node_def)296 std::string convert_range_expected_msg(const NodeDef& node_def) {
297   return "When at least one of parameters (start, limit, delta) of " +
298          node_def.op() + " operation in " + node_def.name() +
299          " is passed as a tensor, they must all be of type kINT32";
300 }
301 
302 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertFill>(), "Fill");
303 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertRange>(),
304                                   "Range");
305 
306 #endif  // IS_TRT_VERSION_GE(8, 2, 0, 0)
307 
308 }  // namespace convert
309 }  // namespace tensorrt
310 }  // namespace tensorflow
311 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
312