xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/ops/slice_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/ops/slice_ops.h"
17 
18 #if GOOGLE_CUDA && GOOGLE_TENSORRT
19 
20 #include <bitset>
21 #include <vector>
22 
23 #include "absl/container/inlined_vector.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/optional.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
29 #include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h"
30 #include "tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/util/strided_slice_op.h"
36 #include "third_party/tensorrt/NvInfer.h"
37 
38 namespace tensorflow {
39 namespace tensorrt {
40 namespace convert {
41 
42 // Adds a set of operations to the network which set the parameters for the
43 // given "slice_layer" in order to handle dynamic input shape.
44 Status HandleDynamicStridedSliceInput(
45     TRTNetworkBuilder* builder, nvinfer1::ISliceLayer* slice_layer,
46     const StridedSliceShapeSpec& strided_slice_spec,
47     const absl::InlinedVector<int64, 4>& dynamic_input_size_indices,
48     nvinfer1::Dims begin_dims, nvinfer1::Dims stride_dims,
49     nvinfer1::Dims end_dims);
50 
ConvertStridedSliceHelper(OpConverterParams * params,const TRT_TensorOrWeights & input,const PartialTensorShape & input_dims,const SliceDims & begin,const SliceDims & stride,const SliceDims & end,std::optional<nvinfer1::Dims> final_shape,std::optional<int> op_instance,std::optional<StridedSliceShapeSpec> strided_slice_spec)51 Status ConvertStridedSliceHelper(
52     OpConverterParams* params, const TRT_TensorOrWeights& input,
53     const PartialTensorShape& input_dims, const SliceDims& begin,
54     const SliceDims& stride, const SliceDims& end,
55     std::optional<nvinfer1::Dims> final_shape, std::optional<int> op_instance,
56     std::optional<StridedSliceShapeSpec> strided_slice_spec) {
57   const auto& node_def = params->node_def;
58 
59   auto begin_dims = DimsAdapter::Create(begin, params->use_implicit_batch);
60   auto stride_dims = DimsAdapter::Create(stride, params->use_implicit_batch);
61   auto end_dims = DimsAdapter::Create(end, params->use_implicit_batch);
62   TRT_ENSURE_OK(begin_dims);
63   TRT_ENSURE_OK(stride_dims);
64   TRT_ENSURE_OK(end_dims);
65 
66   // For each dimension, gather information about static vs dynamic dimension
67   // and slice size.
68   nvinfer1::Dims size_dims = begin_dims->AsTrtDims();
69   absl::InlinedVector<int64, 4> static_input_size_indices;
70   absl::InlinedVector<int64, 4> dynamic_input_size_indices;
71   for (int i = 0; i < begin_dims->NumDims(); i++) {
72     size_dims.d[i] = (std::abs(end_dims->dim(i) - begin_dims->dim(i)) +
73                       std::abs(stride_dims->dim(i)) - 1) /
74                      std::abs(stride_dims->dim(i));
75 
76     if (input_dims.dim_size(i) < 0) {
77       // end_dims and begin_dims do not have valid information yet.
78       dynamic_input_size_indices.push_back(i);
79     } else {
80       static_input_size_indices.push_back(i);
81       if (end_dims->dim(i) < begin_dims->dim(i) && stride_dims->dim(i) > 0) {
82         return errors::InvalidArgument(
83             "\"size\" cannot be negative for StridedSlice");
84       }
85     }
86   }
87 
88   if (!dynamic_input_size_indices.empty()) {
89     if (strided_slice_spec == std::nullopt) {
90       return errors::InvalidArgument(
91           "The argument `strided_slice_spec` is "
92           "`std::nullopt` with `dynamic_input_size_indices` non empty.");
93     }
94     if (params->use_implicit_batch) {
95       return errors::InvalidArgument(
96           "In implicit batch mode, dynamic input size is not supported.");
97     }
98   }
99 
100   if (params->validation_only) return Status::OK();
101 
102   StatusOr<TRTNetworkBuilder> builder = TRTNetworkBuilder::Create(
103       params->converter->network(), params->weight_store);
104   TRT_ENSURE_OK(builder);
105 
106   // VLOG(2) << "strided slice helper:"
107   //         << " begin:" << DebugString(begin_dims)
108   //         << "\n stride: " << DebugString(stride_dims)
109   //         << "\n end: " << DebugString(end_dims)
110   //         << "\n size: " << DebugString(size_dims)
111   //         << "\n Dynamic indices: " <<
112   //         DebugString(dynamic_input_size_indices)
113   //         << "\n Static indices: " << DebugString(static_input_size_indices);
114   // Create the slice operation. For dynamic dims, the inputs of the operations
115   // may be reassigned later.
116   StatusOr<nvinfer1::ISliceLayer*> slice =
117       builder->Slice(input.tensor()->trt_tensor(), begin_dims->AsTrtDims(),
118                      size_dims, stride_dims->AsTrtDims());
119   TRT_ENSURE_PTR_OK(slice);
120 
121   // Handle dynamic input shapes.
122   if (!dynamic_input_size_indices.empty()) {
123     TF_RETURN_IF_ERROR(HandleDynamicStridedSliceInput(
124         &*builder, *slice, *strided_slice_spec, dynamic_input_size_indices,
125         begin_dims->AsTrtDims(), stride_dims->AsTrtDims(),
126         end_dims->AsTrtDims()));
127   }
128 
129   params->converter->SetLayerName(*slice, params->node_def, "slice",
130                                   op_instance);
131   ITensorProxyPtr tensor = (*slice)->getOutput(0);
132 
133   // Reshape for shrink_axis.
134   if (final_shape) {
135     TF_RETURN_IF_ERROR(PrepareTensorForShape(
136         params->converter, TRT_TensorOrWeights(tensor), *final_shape,
137         /*validation_only=*/false, &tensor, node_def, op_instance));
138   }
139   params->outputs->push_back(TRT_TensorOrWeights(tensor));
140   return Status::OK();
141 }
142 
HandleDynamicStridedSliceInput(TRTNetworkBuilder * builder,nvinfer1::ISliceLayer * slice_layer,const StridedSliceShapeSpec & strided_slice_spec,const absl::InlinedVector<int64,4> & dynamic_input_size_indices,nvinfer1::Dims begin_dims,nvinfer1::Dims stride_dims,nvinfer1::Dims end_dims)143 Status HandleDynamicStridedSliceInput(
144     TRTNetworkBuilder* builder, nvinfer1::ISliceLayer* slice_layer,
145     const StridedSliceShapeSpec& strided_slice_spec,
146     const absl::InlinedVector<int64, 4>& dynamic_input_size_indices,
147     nvinfer1::Dims begin_dims, nvinfer1::Dims stride_dims,
148     nvinfer1::Dims end_dims) {
149   TRT_ENSURE(builder);
150   TRT_ENSURE(slice_layer);
151 
152   nvinfer1::ITensor* input_tensor = slice_layer->getInput(0);
153   TRT_ENSURE(input_tensor);
154 
155   // For each dynamic input dimension of the input, do some preprocessing based
156   // on whether this dimension is set in "begin_mask" or "end_mask" and the sign
157   // of the dimension's stride value.
158   // When stride is negative:
159   //   - If "begin_mask[dynamic_idx]" is set, then we need to adjust the slice
160   //     start of dimension[i] to the dynamic size.
161   //   - If "end_mask[dynamic_idx]" is set, it suffices to set
162   //     end_dims[dynamic_idx] to -1.
163   // When stride is positive:
164   //   - If "begin_mask[dynamic_idx]" is set, it suffices to set
165   //     begin_dims[dynamic_idx] to zero.
166   //   - If "end_mask[dynamic_idx]" is set, we need to adjust slice end to the
167   //     dynamic size of dimension "dynamic_idx".
168   absl::InlinedVector<int64, 4> dynamic_begin_indices;
169   absl::InlinedVector<int64, 4> dynamic_end_indices;
170   const auto begin_mask = std::bitset<32>(strided_slice_spec.begin_dense_mask);
171   const auto end_mask = std::bitset<32>(strided_slice_spec.end_dense_mask);
172   for (int i = 0; i < dynamic_input_size_indices.size(); i++) {
173     auto dynamic_idx = dynamic_input_size_indices[i];
174     if (begin_mask[dynamic_idx]) {
175       begin_dims.d[dynamic_idx] = 0;
176       if (stride_dims.d[dynamic_idx] < 0) {
177         dynamic_begin_indices.push_back(dynamic_idx);
178       }
179     }
180     if (end_mask[dynamic_idx]) {
181       end_dims.d[dynamic_idx] = stride_dims.d[dynamic_idx] > 0 ? 0 : -1;
182       if (stride_dims.d[dynamic_idx] > 0) {
183         dynamic_end_indices.push_back(dynamic_idx);
184       }
185     }
186   }
187 
188   // VLOG(2) << " Dynamic begin indices: " << DebugString(dynamic_begin_indices)
189   //         << " Dynamic end indices: " << DebugString(dynamic_end_indices);
190 
191   // Create ITensors for each of the begin/stride/end constants.
192   StatusOr<nvinfer1::IConstantLayer*> begin_const = builder->Constant(
193       std::vector<int>(begin_dims.d, begin_dims.d + begin_dims.nbDims));
194   TRT_ENSURE_PTR_OK(begin_const);
195   nvinfer1::ITensor* begin_tensor = (*begin_const)->getOutput(0);
196   StatusOr<nvinfer1::IConstantLayer*> stride_const = builder->Constant(
197       std::vector<int>(stride_dims.d, stride_dims.d + stride_dims.nbDims));
198   TRT_ENSURE_PTR_OK(stride_const);
199   StatusOr<nvinfer1::IConstantLayer*> end_const = builder->Constant(
200       std::vector<int>(end_dims.d, end_dims.d + end_dims.nbDims));
201   TRT_ENSURE_PTR_OK(end_const);
202   nvinfer1::ITensor* end_tensor = (*end_const)->getOutput(0);
203 
204   // Make corrections based on the begin_mask/end_mask values.
205   if (dynamic_end_indices.size() > 0) {
206     StatusOr<nvinfer1::IGatherLayer*> dynamic_end_masked_tensor =
207         builder->GetPartialShapeOf(input_tensor, dynamic_end_indices,
208                                    /*sub_one=*/false);
209     TRT_ENSURE_PTR_OK(dynamic_end_masked_tensor);
210     StatusOr<nvinfer1::IElementWiseLayer*> end_corrected =
211         builder->Add((*dynamic_end_masked_tensor)->getOutput(0), end_tensor);
212     TRT_ENSURE_PTR_OK(end_corrected);
213     end_tensor = (*end_corrected)->getOutput(0);
214   }
215   if (dynamic_begin_indices.size() > 0) {
216     StatusOr<nvinfer1::IGatherLayer*> dynamic_begin_masked_tensor =
217         builder->GetPartialShapeOf(input_tensor, dynamic_begin_indices,
218                                    /*sub_one=*/true);
219     TRT_ENSURE_PTR_OK(dynamic_begin_masked_tensor);
220 
221     // Add back the original "begin" values for static dimensions.
222     StatusOr<nvinfer1::IElementWiseLayer*> begin_corrected = builder->Add(
223         (*dynamic_begin_masked_tensor)->getOutput(0), begin_tensor);
224     TRT_ENSURE_PTR_OK(begin_corrected);
225     begin_tensor = (*begin_corrected)->getOutput(0);
226   }
227 
228   // Calculate the final size of the slice dynamicaly.
229   nvinfer1::ITensor* size_tensor;
230   {
231     StatusOr<nvinfer1::IElementWiseLayer*> num =
232         builder->Sub(end_tensor, begin_tensor);
233     TRT_ENSURE_PTR_OK(num);
234     StatusOr<nvinfer1::IElementWiseLayer*> ceil_div = builder->AbsCeilDivInt(
235         (*num)->getOutput(0), (*stride_const)->getOutput(0));
236     TRT_ENSURE_PTR_OK(ceil_div);
237     size_tensor = (*ceil_div)->getOutput(0);
238   }
239 
240   slice_layer->setInput(1, *begin_tensor);
241   slice_layer->setInput(2, *size_tensor);
242   slice_layer->setInput(3, *(*stride_const)->getOutput(0));
243 
244   return Status::OK();
245 }
246 
247 }  // namespace convert
248 }  // namespace tensorrt
249 }  // namespace tensorflow
250 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
251