1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA && GOOGLE_TENSORRT
16 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
17 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
18 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter.h"
19 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
20 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "third_party/tensorrt/NvInfer.h"
24 #include "third_party/tensorrt/NvInferRuntimeCommon.h"
25
26 namespace tensorflow {
27 namespace tensorrt {
28 namespace convert {
29
get_spatial_dim_count(string format)30 int get_spatial_dim_count(string format) {
31 // Spatial dimensions are the dimensions besides NC, and here we assume NC
32 // always appear in the format string.
33 return format.size() - 2;
34 }
35
36 class ConvertDataFormatVecPermute
37 : public OpConverterBase<ConvertDataFormatVecPermute> {
38 public:
ConvertDataFormatVecPermute(OpConverterParams * params)39 ConvertDataFormatVecPermute(OpConverterParams* params)
40 : OpConverterBase<ConvertDataFormatVecPermute>(params) {}
41
42 struct DataFormatVecPermuteAttributes {
43 string dst_format;
44 string src_format;
45 int x_dim_count;
46 };
47
InputSpec()48 static constexpr std::array<InputArgSpec, 1> InputSpec() {
49 return {InputArgSpec::Create("x", TrtInputArg::kBoth)};
50 }
51
AllowedDataTypes()52 static constexpr std::array<DataType, 1> AllowedDataTypes() {
53 return {DataType::DT_INT32};
54 }
55
Validate()56 Status Validate() {
57 const auto& inputs = params_->inputs;
58 const auto& node_def = params_->node_def;
59
60 if (params_->use_implicit_batch) {
61 return errors::Unimplemented("Implicit batch mode not supported, at ",
62 node_def.name());
63 }
64
65 x_input_ = inputs.at(0);
66
67 // Check input rank.
68 const auto x_dims = x_input_.GetTrtDims();
69 int input_rank = x_dims.nbDims;
70 if (input_rank != 1 && input_rank != 2) {
71 return errors::InvalidArgument(
72 "Input must be a vector or matrix, but got rank ", input_rank,
73 ", at ", node_def.name());
74 }
75
76 // Verify and consume node attributes.
77 StatusOr<string> dst_format = GetAttrValue<string>("dst_format");
78 StatusOr<string> src_format = GetAttrValue<string>("src_format");
79 TRT_ENSURE_OK(dst_format);
80 TRT_ENSURE_OK(src_format);
81
82 // Check input dims.
83 const int full_dim_count = src_format->size();
84 const int spatial_dim_count = get_spatial_dim_count(*src_format);
85 if (input_rank == 1) {
86 if (x_dims.d[0] != spatial_dim_count && x_dims.d[0] != full_dim_count) {
87 return errors::InvalidArgument("1D input must be of size ",
88 spatial_dim_count, " or ",
89 full_dim_count, ", but got size ",
90 x_dims.d[0], ", at ", node_def.name());
91 }
92 } else if (input_rank == 2) {
93 if (x_dims.d[0] != spatial_dim_count && x_dims.d[0] != full_dim_count) {
94 return errors::InvalidArgument(
95 "First dimension of 2D input must be of size ", spatial_dim_count,
96 " or ", full_dim_count, ", but got shape (", x_dims.d[0], ", ",
97 x_dims.d[1], "), at ", node_def.name());
98 }
99 if (x_dims.d[1] != 2) {
100 return errors::InvalidArgument(
101 "Second dimension of 2D input must be of size 2, but got shape (",
102 x_dims.d[0], ", ", x_dims.d[1], "), at ", node_def.name());
103 }
104 }
105
106 // Set custom attributes.
107 attrs_.x_dim_count = x_dims.d[0];
108 attrs_.dst_format = *dst_format;
109 attrs_.src_format = *src_format;
110
111 return Status::OK();
112 }
113
Convert()114 Status Convert() {
115 const auto& node_def = params_->node_def;
116
117 // Copy format strings in case they need to be modified.
118 string dst_format = attrs_.dst_format;
119 string src_format = attrs_.src_format;
120 const int& spatial_dim_count = get_spatial_dim_count(src_format);
121
122 // If the input is a vector of size spatial_dim_count, treat the elements
123 // as spatial dimensions.
124 if (attrs_.x_dim_count == spatial_dim_count) {
125 auto keep_only_spatial_dimensions =
126 [spatial_dim_count](string* format_str) -> void {
127 auto new_end = std::remove_if(format_str->begin(), format_str->end(),
128 [spatial_dim_count](const char dim) {
129 return dim == 'N' || dim == 'C';
130 });
131 format_str->erase(new_end, format_str->end());
132 };
133 keep_only_spatial_dimensions(&src_format);
134 keep_only_spatial_dimensions(&dst_format);
135 }
136
137 // Create indices for the gather layer and make weights out of them.
138 std::vector<int32> dst_indices(attrs_.x_dim_count);
139 for (int i = 0; i < attrs_.x_dim_count; ++i) {
140 for (int j = 0; j < attrs_.x_dim_count; ++j) {
141 if (src_format[i] == dst_format[j]) {
142 dst_indices[j] = i;
143 break;
144 }
145 }
146 }
147 nvinfer1::Dims indices_dims = {1, {attrs_.x_dim_count}};
148 StatusOr<TRT_ShapedWeights> indices_weights =
149 params_->weight_store->GetTempWeights(nvinfer1::DataType::kINT32,
150 indices_dims);
151 TRT_ENSURE_OK(indices_weights);
152 int32* indices_ptr = indices_weights->GetPointer<int32>();
153 std::copy(dst_indices.data(), dst_indices.data() + attrs_.x_dim_count,
154 indices_ptr);
155 ITensorProxyPtr x_tensor =
156 x_input_.is_weights() ? params_->converter->CreateConstantLayer(
157 x_input_.weights(), x_input_.GetTrtDims())
158 : x_input_.tensor();
159 ITensorProxyPtr indices_tensor =
160 params_->converter->CreateConstantLayer(*indices_weights, indices_dims);
161
162 // Gather layer with 1D indices on axis 0, conserves shape.
163 nvinfer1::IGatherLayer* layer = params_->converter->network()->addGather(
164 *x_tensor->trt_tensor(), *indices_tensor->trt_tensor(), 0);
165 TRT_ENSURE(layer);
166 params_->converter->SetLayerName(layer, node_def);
167
168 ITensorProxyPtr output_tensor = layer->getOutput(0);
169
170 params_->outputs->push_back(TRT_TensorOrWeights(output_tensor));
171 return Status::OK();
172 }
173
174 private:
175 TRT_TensorOrWeights x_input_;
176 DataFormatVecPermuteAttributes attrs_{};
177 };
178 REGISTER_DEFAULT_TRT_OP_CONVERTER(
179 MakeConverterFunction<ConvertDataFormatVecPermute>(),
180 {"DataFormatVecPermute"});
181
182 } // namespace convert
183 } // namespace tensorrt
184 } // namespace tensorflow
185
186 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
187