xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/ops/like_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA && GOOGLE_TENSORRT
17 
18 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
19 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
20 #include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h"
21 
22 namespace tensorflow {
23 namespace tensorrt {
24 namespace convert {
25 
26 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
27 
28 template <int V>
29 class ConvertLikeOps : public OpConverterBase<ConvertLikeOps<V>> {
30  public:
ConvertLikeOps(OpConverterParams * params)31   explicit ConvertLikeOps(OpConverterParams *params)
32       : OpConverterBase<ConvertLikeOps<V>>(params) {}
33 
AllowedDataTypes()34   static constexpr std::array<DataType, 3> AllowedDataTypes() {
35     return {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32};
36   }
37 
InputSpec()38   static constexpr std::array<InputArgSpec, 1> InputSpec() {
39     return std::array<InputArgSpec, 1>{
40         InputArgSpec::Create("input", TrtInputArg::kBoth),
41     };
42   }
Validate()43   Status Validate() {
44     const auto &params = *this->params_;
45 
46     const std::string op_name = V == 0 ? "ZerosLike" : "OnesLike";
47     if (params.use_implicit_batch) {
48       return errors::Unimplemented("Conversion for " + op_name +
49                                    " is not implemented in"
50                                    " implicit batch mode");
51     }
52     const auto &inputs = params.inputs;
53     if (inputs.size() != 1) {
54       return errors::InvalidArgument(op_name, " expects 1 input, but received ",
55                                      inputs.size());
56     }
57     return Status::OK();
58   }
59 
Convert()60   Status Convert() {
61     const auto &params = *this->params_;
62     const auto &inputs = params.inputs;
63     auto *network = params.converter->network();
64 
65     const TRT_TensorOrWeights &input = inputs.at(0);
66     nvinfer1::Dims dims(input.GetTrtDims());
67 
68     const std::vector<int> value_input_dims_data = {1};
69     const DimsAdapter value_input_dims(value_input_dims_data);
70     StatusOr<TRT_ShapedWeights> value_weights =
71         params.weight_store->GetTempWeights(input.TrtDType(), value_input_dims);
72     TF_RETURN_IF_ERROR(value_weights.status());
73     TF_RETURN_IF_ERROR(value_weights->SetValues(V));
74     TRT_TensorOrWeights value_input(value_weights.ValueOrDie());
75 
76     const auto is_dims_static = HasStaticShape(dims);
77     auto builder = TRTNetworkBuilder::Create(network, params.weight_store);
78     ITensorProxyPtr dims_input_tensor;
79     if (!is_dims_static) {
80       StatusOr<nvinfer1::IShapeLayer *> shape_layer =
81           builder->Shape(input.tensor()->trt_tensor());
82       TF_RETURN_IF_ERROR(shape_layer.status());
83       dims_input_tensor = (*shape_layer)->getOutput(0);
84       dims.nbDims = 0;
85     }
86 
87     TRT_TensorOrWeights dims_input(dims_input_tensor);
88     StatusOr<nvinfer1::ILayer *> layer =
89         builder->AddFill(value_input, dims_input, true, is_dims_static,
90                          input.GetTrtDims().nbDims, dims);
91     ITensorProxyPtr output_tensor = (*layer)->getOutput(0);
92     this->AddOutput(TRT_TensorOrWeights(output_tensor));
93     return Status::OK();
94   }
95 };
96 
97 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertLikeOps<0>>(),
98                                   "zeros_like");
99 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertLikeOps<1>>(),
100                                   "ones_like");
101 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertLikeOps<0>>(),
102                                   "ZerosLike");
103 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertLikeOps<1>>(),
104                                   "OnesLike");
105 
106 #endif  // IS_TRT_VERSION_GE(8, 2, 0, 0)
107 
108 }  // namespace convert
109 }  // namespace tensorrt
110 }  // namespace tensorflow
111 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
112