1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA && GOOGLE_TENSORRT
16 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
17 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
18 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter.h"
19 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
20 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "third_party/tensorrt/NvInfer.h"
24 #include "third_party/tensorrt/NvInferRuntimeCommon.h"
25 
26 namespace tensorflow {
27 namespace tensorrt {
28 namespace convert {
29 
get_spatial_dim_count(string format)30 int get_spatial_dim_count(string format) {
31   // Spatial dimensions are the dimensions besides NC, and here we assume NC
32   // always appear in the format string.
33   return format.size() - 2;
34 }
35 
36 class ConvertDataFormatVecPermute
37     : public OpConverterBase<ConvertDataFormatVecPermute> {
38  public:
ConvertDataFormatVecPermute(OpConverterParams * params)39   ConvertDataFormatVecPermute(OpConverterParams* params)
40       : OpConverterBase<ConvertDataFormatVecPermute>(params) {}
41 
42   struct DataFormatVecPermuteAttributes {
43     string dst_format;
44     string src_format;
45     int x_dim_count;
46   };
47 
InputSpec()48   static constexpr std::array<InputArgSpec, 1> InputSpec() {
49     return {InputArgSpec::Create("x", TrtInputArg::kBoth)};
50   }
51 
AllowedDataTypes()52   static constexpr std::array<DataType, 1> AllowedDataTypes() {
53     return {DataType::DT_INT32};
54   }
55 
Validate()56   Status Validate() {
57     const auto& inputs = params_->inputs;
58     const auto& node_def = params_->node_def;
59 
60     if (params_->use_implicit_batch) {
61       return errors::Unimplemented("Implicit batch mode not supported, at ",
62                                    node_def.name());
63     }
64 
65     x_input_ = inputs.at(0);
66 
67     // Check input rank.
68     const auto x_dims = x_input_.GetTrtDims();
69     int input_rank = x_dims.nbDims;
70     if (input_rank != 1 && input_rank != 2) {
71       return errors::InvalidArgument(
72           "Input must be a vector or matrix, but got rank ", input_rank,
73           ", at ", node_def.name());
74     }
75 
76     // Verify and consume node attributes.
77     StatusOr<string> dst_format = GetAttrValue<string>("dst_format");
78     StatusOr<string> src_format = GetAttrValue<string>("src_format");
79     TRT_ENSURE_OK(dst_format);
80     TRT_ENSURE_OK(src_format);
81 
82     // Check input dims.
83     const int full_dim_count = src_format->size();
84     const int spatial_dim_count = get_spatial_dim_count(*src_format);
85     if (input_rank == 1) {
86       if (x_dims.d[0] != spatial_dim_count && x_dims.d[0] != full_dim_count) {
87         return errors::InvalidArgument("1D input must be of size ",
88                                        spatial_dim_count, " or ",
89                                        full_dim_count, ", but got size ",
90                                        x_dims.d[0], ", at ", node_def.name());
91       }
92     } else if (input_rank == 2) {
93       if (x_dims.d[0] != spatial_dim_count && x_dims.d[0] != full_dim_count) {
94         return errors::InvalidArgument(
95             "First dimension of 2D input must be of size ", spatial_dim_count,
96             " or ", full_dim_count, ", but got shape (", x_dims.d[0], ", ",
97             x_dims.d[1], "), at ", node_def.name());
98       }
99       if (x_dims.d[1] != 2) {
100         return errors::InvalidArgument(
101             "Second dimension of 2D input must be of size 2, but got shape (",
102             x_dims.d[0], ", ", x_dims.d[1], "), at ", node_def.name());
103       }
104     }
105 
106     // Set custom attributes.
107     attrs_.x_dim_count = x_dims.d[0];
108     attrs_.dst_format = *dst_format;
109     attrs_.src_format = *src_format;
110 
111     return Status::OK();
112   }
113 
Convert()114   Status Convert() {
115     const auto& node_def = params_->node_def;
116 
117     // Copy format strings in case they need to be modified.
118     string dst_format = attrs_.dst_format;
119     string src_format = attrs_.src_format;
120     const int& spatial_dim_count = get_spatial_dim_count(src_format);
121 
122     // If the input is a vector of size spatial_dim_count, treat the elements
123     // as spatial dimensions.
124     if (attrs_.x_dim_count == spatial_dim_count) {
125       auto keep_only_spatial_dimensions =
126           [spatial_dim_count](string* format_str) -> void {
127         auto new_end = std::remove_if(format_str->begin(), format_str->end(),
128                                       [spatial_dim_count](const char dim) {
129                                         return dim == 'N' || dim == 'C';
130                                       });
131         format_str->erase(new_end, format_str->end());
132       };
133       keep_only_spatial_dimensions(&src_format);
134       keep_only_spatial_dimensions(&dst_format);
135     }
136 
137     // Create indices for the gather layer and make weights out of them.
138     std::vector<int32> dst_indices(attrs_.x_dim_count);
139     for (int i = 0; i < attrs_.x_dim_count; ++i) {
140       for (int j = 0; j < attrs_.x_dim_count; ++j) {
141         if (src_format[i] == dst_format[j]) {
142           dst_indices[j] = i;
143           break;
144         }
145       }
146     }
147     nvinfer1::Dims indices_dims = {1, {attrs_.x_dim_count}};
148     StatusOr<TRT_ShapedWeights> indices_weights =
149         params_->weight_store->GetTempWeights(nvinfer1::DataType::kINT32,
150                                               indices_dims);
151     TRT_ENSURE_OK(indices_weights);
152     int32* indices_ptr = indices_weights->GetPointer<int32>();
153     std::copy(dst_indices.data(), dst_indices.data() + attrs_.x_dim_count,
154               indices_ptr);
155     ITensorProxyPtr x_tensor =
156         x_input_.is_weights() ? params_->converter->CreateConstantLayer(
157                                     x_input_.weights(), x_input_.GetTrtDims())
158                               : x_input_.tensor();
159     ITensorProxyPtr indices_tensor =
160         params_->converter->CreateConstantLayer(*indices_weights, indices_dims);
161 
162     // Gather layer with 1D indices on axis 0, conserves shape.
163     nvinfer1::IGatherLayer* layer = params_->converter->network()->addGather(
164         *x_tensor->trt_tensor(), *indices_tensor->trt_tensor(), 0);
165     TRT_ENSURE(layer);
166     params_->converter->SetLayerName(layer, node_def);
167 
168     ITensorProxyPtr output_tensor = layer->getOutput(0);
169 
170     params_->outputs->push_back(TRT_TensorOrWeights(output_tensor));
171     return Status::OK();
172   }
173 
174  private:
175   TRT_TensorOrWeights x_input_;
176   DataFormatVecPermuteAttributes attrs_{};
177 };
178 REGISTER_DEFAULT_TRT_OP_CONVERTER(
179     MakeConverterFunction<ConvertDataFormatVecPermute>(),
180     {"DataFormatVecPermute"});
181 
182 }  // namespace convert
183 }  // namespace tensorrt
184 }  // namespace tensorflow
185 
186 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
187