1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA && GOOGLE_TENSORRT 17 18 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" 19 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h" 20 #include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h" 21 22 namespace tensorflow { 23 namespace tensorrt { 24 namespace convert { 25 26 #if IS_TRT_VERSION_GE(8, 2, 0, 0) 27 28 template <int V> 29 class ConvertLikeOps : public OpConverterBase<ConvertLikeOps<V>> { 30 public: ConvertLikeOps(OpConverterParams * params)31 explicit ConvertLikeOps(OpConverterParams *params) 32 : OpConverterBase<ConvertLikeOps<V>>(params) {} 33 AllowedDataTypes()34 static constexpr std::array<DataType, 3> AllowedDataTypes() { 35 return {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}; 36 } 37 InputSpec()38 static constexpr std::array<InputArgSpec, 1> InputSpec() { 39 return std::array<InputArgSpec, 1>{ 40 InputArgSpec::Create("input", TrtInputArg::kBoth), 41 }; 42 } Validate()43 Status Validate() { 44 const auto ¶ms = *this->params_; 45 46 const std::string op_name = V == 0 ? "ZerosLike" : "OnesLike"; 47 if (params.use_implicit_batch) { 48 return errors::Unimplemented("Conversion for " + op_name + 49 " is not implemented in" 50 " implicit batch mode"); 51 } 52 const auto &inputs = params.inputs; 53 if (inputs.size() != 1) { 54 return errors::InvalidArgument(op_name, " expects 1 input, but received ", 55 inputs.size()); 56 } 57 return Status::OK(); 58 } 59 Convert()60 Status Convert() { 61 const auto ¶ms = *this->params_; 62 const auto &inputs = params.inputs; 63 auto *network = params.converter->network(); 64 65 const TRT_TensorOrWeights &input = inputs.at(0); 66 nvinfer1::Dims dims(input.GetTrtDims()); 67 68 const std::vector<int> value_input_dims_data = {1}; 69 const DimsAdapter value_input_dims(value_input_dims_data); 70 StatusOr<TRT_ShapedWeights> value_weights = 71 params.weight_store->GetTempWeights(input.TrtDType(), value_input_dims); 72 TF_RETURN_IF_ERROR(value_weights.status()); 73 TF_RETURN_IF_ERROR(value_weights->SetValues(V)); 74 TRT_TensorOrWeights value_input(value_weights.ValueOrDie()); 75 76 const auto is_dims_static = HasStaticShape(dims); 77 auto builder = TRTNetworkBuilder::Create(network, params.weight_store); 78 ITensorProxyPtr dims_input_tensor; 79 if (!is_dims_static) { 80 StatusOr<nvinfer1::IShapeLayer *> shape_layer = 81 builder->Shape(input.tensor()->trt_tensor()); 82 TF_RETURN_IF_ERROR(shape_layer.status()); 83 dims_input_tensor = (*shape_layer)->getOutput(0); 84 dims.nbDims = 0; 85 } 86 87 TRT_TensorOrWeights dims_input(dims_input_tensor); 88 StatusOr<nvinfer1::ILayer *> layer = 89 builder->AddFill(value_input, dims_input, true, is_dims_static, 90 input.GetTrtDims().nbDims, dims); 91 ITensorProxyPtr output_tensor = (*layer)->getOutput(0); 92 this->AddOutput(TRT_TensorOrWeights(output_tensor)); 93 return Status::OK(); 94 } 95 }; 96 97 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertLikeOps<0>>(), 98 "zeros_like"); 99 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertLikeOps<1>>(), 100 "ones_like"); 101 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertLikeOps<0>>(), 102 "ZerosLike"); 103 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertLikeOps<1>>(), 104 "OnesLike"); 105 106 #endif // IS_TRT_VERSION_GE(8, 2, 0, 0) 107 108 } // namespace convert 109 } // namespace tensorrt 110 } // namespace tensorflow 111 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT 112