1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA && GOOGLE_TENSORRT
16 #include <iterator>
17 #include <limits>
18 #include <memory>
19
20 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
21 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
22 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter.h"
23 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
24 #include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h"
25 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/util/einsum_op_util.h"
29 #include "third_party/tensorrt/NvInfer.h"
30
31 #if IS_TRT_VERSION_GE(7, 1, 3, 0)
32
33 namespace tensorflow {
34 namespace tensorrt {
35 namespace convert {
36
37 namespace {
38
39 #if !IS_TRT_VERSION_GE(8, 2, 0, 0)
40
41 // Finds the indices of elements in [begin, end) in array
42 // [array_begin, array_end), and appends the indices to permute. This is used to
43 // construct the permutation sequence for the operand with input labels
44 // [array_begin, array_end) to the desired permuted labels [begin, end).
45 template <typename T>
FindIndicesoOfAllValuesInSrc(absl::Span<const T> values,absl::Span<const T> src,std::vector<int> * indices)46 Status FindIndicesoOfAllValuesInSrc(absl::Span<const T> values,
47 absl::Span<const T> src,
48 std::vector<int>* indices) {
49 if (src.size() < values.size()) {
50 return errors::Internal(
51 "Span 'src' cannot contain all elements of 'values'");
52 }
53 for (auto i = 0; i < values.size(); i++) {
54 auto iter = absl::c_find(src, values[i]);
55 if (iter == src.end()) {
56 return errors::Internal("Label ", values[i], " not found");
57 }
58 int idx = std::distance(src.begin(), iter);
59 indices->push_back(idx);
60 }
61 return Status::OK();
62 }
63
64 // Layout of the einsum dimensions: Batch, Free and Contraction indices.
65 // Example: adbc,adce -> adbe. The first tensor has layout BFC, the second BCF.
66 enum class EinsumLayout { BFC, BCF, MIX };
67
68 using DimType = EinsumDimensionType;
69 constexpr auto kBatch = DimType::kBatch;
70 constexpr auto kFree = DimType::kFree;
71 constexpr auto kContract = DimType::kContract;
72
73 // Describes an operand: input shape, number of batch, free and contract
74 // dimensions, and the permutation that is needed to bring it to a matmul
75 // compatible form.
76 class EinsumDescriptor {
77 private:
78 // Checks whether input_labels[offset:offset+m] matches labels from other.
OrderMatches(const Labels & input_labels,int offset,int m,EinsumDimensionType dim_type,const std::unique_ptr<EinsumDescriptor> & other)79 static bool OrderMatches(const Labels& input_labels, int offset, int m,
80 EinsumDimensionType dim_type,
81 const std::unique_ptr<EinsumDescriptor>& other) {
82 if (other == nullptr) {
83 return true;
84 }
85 int offset_other = 0;
86 if (dim_type == kFree) {
87 offset = other->offset_f;
88 } else if (dim_type == kContract) {
89 offset = other->offset_c;
90 }
91 return std::equal(input_labels.begin() + offset,
92 input_labels.begin() + offset + m,
93 other->permuted_labels.begin() + offset_other);
94 }
95
96 using label_t_iterator = std::vector<EinsumDimensionType>::const_iterator;
CountLabels(label_t_iterator begin,label_t_iterator end,EinsumDimensionType val)97 static int32_t CountLabels(label_t_iterator begin, label_t_iterator end,
98 EinsumDimensionType val) {
99 return static_cast<int32_t>(std::count_if(
100 begin, end, [val](EinsumDimensionType t) { return t == val; }));
101 }
102
103 // Appends indices to the "permute" vector where types maches value.
AppendMatchingIndicesToPermute(const std::vector<EinsumDimensionType> & types,EinsumDimensionType val)104 void AppendMatchingIndicesToPermute(
105 const std::vector<EinsumDimensionType>& types, EinsumDimensionType val) {
106 for (int i = 0; i < types.size(); i++) {
107 if (types[i] == val) {
108 permute.push_back(i);
109 }
110 }
111 }
112
DetermineLayout(const Labels & input_labels,const std::vector<EinsumDimensionType> & types,const std::unique_ptr<EinsumDescriptor> & other)113 Status DetermineLayout(const Labels& input_labels,
114 const std::vector<EinsumDimensionType>& types,
115 const std::unique_ptr<EinsumDescriptor>& other) {
116 // Check if the current layout is BFC or BCF. In that case we could avoid
117 // transpose.
118 layout = EinsumLayout::MIX;
119 if (CountLabels(types.begin(), types.begin() + b, kBatch) == b &&
120 OrderMatches(input_labels, 0, b, kBatch, other)) {
121 // Batch dims are the leading dims. They have the same order as other.
122 if (CountLabels(types.begin() + b, types.begin() + b + f, kFree) == f) {
123 // All the free dims are placed consecutively after the batch dims.
124 // Their order is arbitrary. The final transpose will ensure that the
125 // output has correct order. We still have to check that the contract
126 // indices have correct order.
127 if (OrderMatches(input_labels, b + f, c, kContract, other)) {
128 layout = EinsumLayout::BFC;
129 }
130 } else if (CountLabels(types.begin() + b, types.begin() + b + c,
131 kContract) == c) {
132 // All the contract dims are placed consecutively after the batch
133 // dims. Check whether the contract dims have the same order as the
134 // contract dims in other.
135 if (OrderMatches(input_labels, b, c, kContract, other)) {
136 layout = EinsumLayout::BCF;
137 }
138 }
139 }
140 return Status::OK();
141 }
142
CalculateMixedLayoutPermutation(const EinsumLayout preferred_layout,const Labels & input_labels,const std::vector<EinsumDimensionType> & types,const std::unique_ptr<EinsumDescriptor> & other)143 Status CalculateMixedLayoutPermutation(
144 const EinsumLayout preferred_layout, const Labels& input_labels,
145 const std::vector<EinsumDimensionType>& types,
146 const std::unique_ptr<EinsumDescriptor>& other) {
147 // Input label types are mixed. Calculate a permutation that maps them
148 // to the preferred layout (BCF or BFC).
149 layout = preferred_layout;
150 if (other == nullptr) {
151 AppendMatchingIndicesToPermute(types, kBatch);
152 } else {
153 TF_RETURN_IF_ERROR(
154 FindIndicesoOfAllValuesInSrc(/*values=*/
155 absl::MakeConstSpan(
156 other->permuted_labels.begin(),
157 other->b),
158 /*src=*/
159 absl::MakeConstSpan(input_labels.begin(),
160 input_labels.size()),
161 /*indices=*/&permute));
162 }
163 if (layout == EinsumLayout::BFC) {
164 AppendMatchingIndicesToPermute(types, kFree);
165 if (other == nullptr) {
166 AppendMatchingIndicesToPermute(types, kContract);
167 } else {
168 TF_RETURN_IF_ERROR(FindIndicesoOfAllValuesInSrc(
169 /*values=*/absl::MakeConstSpan(
170 other->permuted_labels.begin() + other->offset_c, other->c),
171 /*src=*/
172 absl::MakeConstSpan(input_labels.begin(), input_labels.size()),
173 /*indices=*/&permute));
174 }
175 return Status::OK();
176 }
177 if (other == nullptr) {
178 AppendMatchingIndicesToPermute(types, kContract);
179 } else {
180 TF_RETURN_IF_ERROR(FindIndicesoOfAllValuesInSrc(
181 /*values=*/absl::MakeConstSpan(
182 other->permuted_labels.begin() + other->offset_c, other->c),
183 /*src=*/absl::MakeConstSpan(input_labels.begin(), input_labels.end()),
184 /*indices=*/&permute));
185 }
186 AppendMatchingIndicesToPermute(types, kFree);
187 return Status::OK();
188 }
189
Initialize(const TRT_TensorOrWeights & operand,Labels input_labels,std::vector<EinsumDimensionType> & label_types,EinsumLayout preferred_layout,const std::unique_ptr<EinsumDescriptor> & other=nullptr)190 Status Initialize(const TRT_TensorOrWeights& operand, Labels input_labels,
191 std::vector<EinsumDimensionType>& label_types,
192 EinsumLayout preferred_layout,
193 const std::unique_ptr<EinsumDescriptor>& other = nullptr) {
194 if (preferred_layout == EinsumLayout::MIX) {
195 return errors::Internal("Preferred einsum layout cannot be MIX");
196 }
197 // Map label indices to label types.
198 std::vector<EinsumDimensionType> types; // Input label types.
199 std::transform(input_labels.begin(), input_labels.end(),
200 std::back_inserter(types),
201 [&label_types](int i) { return label_types.at(i); });
202
203 b = CountLabels(types.begin(), types.end(), kBatch);
204 f = CountLabels(types.begin(), types.end(), kFree);
205 c = CountLabels(types.begin(), types.end(), kContract);
206
207 if (c == 0 || f == 0) {
208 VLOG(2) << "Einsum equation needs to have at least one free and one "
209 "contract dimension";
210 return errors::Unimplemented("No conversion for einsum equation.");
211 }
212
213 TF_RETURN_IF_ERROR(DetermineLayout(input_labels, types, other));
214 if (layout == EinsumLayout::MIX) {
215 TF_RETURN_IF_ERROR(CalculateMixedLayoutPermutation(
216 preferred_layout, input_labels, types, other));
217 }
218
219 if (layout == EinsumLayout::BFC) {
220 offset_f = b;
221 offset_c = f + b;
222 } else {
223 offset_f = b + c;
224 offset_c = b;
225 }
226
227 dims = operand.GetTrtDims();
228 for (int i = 0; i < b; i++) {
229 // Set unknown batch dims to zero. These dims will be used in reshape op,
230 // where zero is a special value for retaining the original dim size.
231 if (dims.d[i] == -1) {
232 dims.d[i] = 0;
233 }
234 }
235 permuted_labels = input_labels;
236 if (!permute.empty()) {
237 // Apply the permutation on the dimension array.
238 nvinfer1::Dims orig_dims = dims;
239 for (int i = 0; i < permute.size(); i++) {
240 dims.d[i] = orig_dims.d[permute[i]];
241 permuted_labels[i] = input_labels[permute[i]];
242 }
243 }
244 size_tensors.resize(dims.nbDims, nullptr);
245 return Status::OK();
246 }
247
248 public:
EinsumDescriptor()249 EinsumDescriptor() : b(0), f(0), c(0) {}
250
251 // Deduces the number of batch, free, contract dimensions from the input
252 // labels, decides what layout to use, and determines permutation indices for
253 // that layout.
Create(const TRT_TensorOrWeights & operand,Labels input_labels,std::vector<EinsumDimensionType> & label_types,EinsumLayout preferred_layout,const std::unique_ptr<EinsumDescriptor> & other=nullptr)254 static StatusOr<std::unique_ptr<EinsumDescriptor>> Create(
255 const TRT_TensorOrWeights& operand, Labels input_labels,
256 std::vector<EinsumDimensionType>& label_types,
257 EinsumLayout preferred_layout,
258 const std::unique_ptr<EinsumDescriptor>& other = nullptr) {
259 auto desc = std::make_unique<EinsumDescriptor>();
260 TF_RETURN_IF_ERROR(desc->Initialize(operand, input_labels, label_types,
261 preferred_layout, other));
262 VLOG(2) << desc->DebugString();
263 return desc;
264 }
265
NumBatchDims() const266 int NumBatchDims() const { return b; }
NumContractDims() const267 int NumContractDims() const { return c; }
NumFreeDims() const268 int NumFreeDims() const { return f; }
ContractDimOffset() const269 int ContractDimOffset() const { return offset_c; }
PermutedLabels() const270 const Labels& PermutedLabels() const { return permuted_labels; }
271
DebugString() const272 std::string DebugString() const {
273 return absl::StrCat("Descriptor with ",
274 (layout == EinsumLayout::BFC ? "BFC" : "BCF"),
275 " layout, b=", b, ", f=", f, ", c=", c);
276 }
277
278 // Returns whether the free and contract dimension have static shape.
HasStaticShape() const279 bool HasStaticShape() const {
280 return !std::any_of(dims.d + b, dims.d + dims.nbDims,
281 [](int k) { return k == -1; });
282 }
283
GetPermutation() const284 nvinfer1::Permutation GetPermutation() const {
285 nvinfer1::Permutation p;
286 std::copy(permute.begin(), permute.end(), p.order);
287 return p;
288 }
289
PermuteVector() const290 std::vector<int> PermuteVector() const { return permute; }
291
292 // Sets the "size_tensors" vector to be filled with scalar constant tensors
293 // representing the shape of the operand.
SetDynamicSize(TRTNetworkBuilder * builder,const TRT_TensorOrWeights & operand)294 Status SetDynamicSize(TRTNetworkBuilder* builder,
295 const TRT_TensorOrWeights& operand) {
296 TRT_ENSURE(operand.GetTrtDims().nbDims == dims.nbDims);
297 if (operand.is_weights()) {
298 // Generate constants for each dimension of the constant weight tensor's
299 // shape.
300 for (int i = 0; i < operand.GetTrtDims().nbDims; i++) {
301 StatusOr<nvinfer1::IConstantLayer*> size_tensor =
302 builder->Constant<int32_t>(dims.d[i], 1);
303 TRT_ENSURE_PTR_OK(size_tensor);
304 size_tensors[i] = (*size_tensor)->getOutput(0);
305 }
306 return Status::OK();
307 }
308
309 // If the operand is a dynamic tensor, compute the shape value dynamically.
310 StatusOr<nvinfer1::IShapeLayer*> shape_layer =
311 builder->Shape(operand.tensor()->trt_tensor());
312 TRT_ENSURE_PTR_OK(shape_layer);
313 nvinfer1::ITensor* shape = (*shape_layer)->getOutput(0);
314 for (int i = 0; i < operand.GetTrtDims().nbDims; i++) {
315 int idx = permute.empty() ? i : permute.at(i);
316 StatusOr<nvinfer1::ISliceLayer*> slice_layer =
317 builder->Slice(shape, {1, {idx}}, {1, {1}}, {1, {1}});
318 TRT_ENSURE_PTR_OK(slice_layer);
319 size_tensors[i] = (*slice_layer)->getOutput(0);
320 }
321 return Status::OK();
322 }
323
324 EinsumLayout layout;
325 int b; // number of batch dims
326 int f; // number of free dims
327 int c; // number of conraction dims
328 int offset_f;
329 int offset_c;
330 nvinfer1::Dims dims;
331 std::vector<int> permute;
332 std::vector<ITensorProxyPtr> size_tensors;
333 Labels permuted_labels;
334 };
335
336 // Reshapes operand so that the free dimensions are combined into a single dim,
337 // and the contract dimensions are combined into another single dim.
GetEinsumNewDynamicShape(TRTNetworkBuilder * builder,const EinsumDescriptor & desc,ITensorProxyPtr * new_shape)338 Status GetEinsumNewDynamicShape(TRTNetworkBuilder* builder,
339 const EinsumDescriptor& desc,
340 ITensorProxyPtr* new_shape) {
341 std::vector<nvinfer1::ITensor*> size;
342 size.reserve(desc.b + 2);
343 absl::c_transform(absl::MakeSpan(desc.size_tensors).subspan(0, desc.b + 2),
344 std::back_inserter(size),
345 [](const ITensorProxyPtr x) { return x->trt_tensor(); });
346
347 int idx_f = desc.layout == EinsumLayout::BFC ? desc.b : desc.b + 1;
348 int idx_c = desc.layout == EinsumLayout::BFC ? desc.b + 1 : desc.b;
349
350 std::vector<nvinfer1::ITensor*> size_tensors;
351 size_tensors.reserve(desc.size_tensors.size());
352 absl::c_transform(desc.size_tensors, std::back_inserter(size_tensors),
353 [](const ITensorProxyPtr x) -> nvinfer1::ITensor* {
354 return x->trt_tensor();
355 });
356
357 StatusOr<nvinfer1::ILayer*> shape_vol = builder->CumulativeProd(
358 absl::MakeSpan(size_tensors).subspan(desc.offset_f, desc.f));
359 TRT_ENSURE_PTR_OK(shape_vol);
360 size[idx_f] = (*shape_vol)->getOutput(0);
361
362 shape_vol = builder->CumulativeProd(
363 absl::MakeSpan(size_tensors).subspan(desc.offset_c, desc.c));
364 TRT_ENSURE_PTR_OK(shape_vol);
365 size[idx_c] = (*shape_vol)->getOutput(0);
366 StatusOr<nvinfer1::IConcatenationLayer*> layer =
367 builder->Concat(size, /*axis=*/0);
368 TRT_ENSURE_PTR_OK(layer);
369 *new_shape = (*layer)->getOutput(0);
370 return Status::OK();
371 }
372
373 // Reshapes operand so that the free dimensions are combined into a single dim,
374 // and the contract dimensions are combined into another single dim.
GetEinsumNewStaticShape(const EinsumDescriptor & desc,nvinfer1::Dims * new_dims)375 Status GetEinsumNewStaticShape(const EinsumDescriptor& desc,
376 nvinfer1::Dims* new_dims) {
377 // Copy the batch dims and append two additional dimensions.
378 DimsAdapter adap(
379 absl::MakeSpan(static_cast<const int32_t*>(desc.dims.d), desc.b));
380 adap.Append(1).Append(1);
381
382 // Combine free dims and contract dims.
383 int idx_f = desc.layout == EinsumLayout::BFC ? desc.b : desc.b + 1;
384 int idx_c = desc.layout == EinsumLayout::BFC ? desc.b + 1 : desc.b;
385
386 // Find the volume of the free dimensions.
387 int64_t vol_f =
388 DimsAdapter(
389 absl::MakeSpan(
390 static_cast<const int32_t*>(desc.dims.d) + desc.offset_f, desc.f))
391 .Volume();
392
393 // Find the volume of the contracted dimensions.
394 int64_t vol_c =
395 DimsAdapter(
396 absl::MakeSpan(
397 static_cast<const int32_t*>(desc.dims.d) + desc.offset_c, desc.c))
398 .Volume();
399
400 adap.dim(idx_f) = vol_f;
401 adap.dim(idx_c) = vol_c;
402 *new_dims = adap.AsTrtDims();
403 return Status::OK();
404 }
405
ConditionEinsumWeights(TRTNetworkBuilder * builder,const TRT_TensorOrWeights & operand,const EinsumDescriptor & desc,const bool need_transpose)406 StatusOr<TRT_TensorOrWeights> ConditionEinsumWeights(
407 TRTNetworkBuilder* builder, const TRT_TensorOrWeights& operand,
408 const EinsumDescriptor& desc, const bool need_transpose) {
409 TRT_ENSURE(operand.is_weights());
410 if (!need_transpose) {
411 // If we don't need to transpose, then the operand remains as a weights
412 // constant. In this case we also don't need a reshape.
413 TRT_ShapedWeights weights(operand.weights());
414 nvinfer1::Dims new_dims;
415 TF_RETURN_IF_ERROR(GetEinsumNewStaticShape(desc, &new_dims));
416 TF_RETURN_IF_ERROR(weights.SetShape(new_dims));
417 return TRT_TensorOrWeights(weights);
418 }
419
420 // Let TensorRT handle constant folding where possible.
421 StatusOr<nvinfer1::IConstantLayer*> tensor = builder->WeightsToConstant(
422 operand.weights().GetTrtWeights(), operand.GetTrtDims());
423 TRT_ENSURE_PTR_OK(tensor);
424 return TRT_TensorOrWeights((*tensor)->getOutput(0));
425 }
426
427 // Builds a TRT shuffle operation for the given operand. Replaces operand with a
428 // pointer to the shuffle output.
ConditionEinsumTensor(TRTNetworkBuilder * builder,std::unique_ptr<TRT_TensorOrWeights> * operand,const EinsumDescriptor & desc,const bool need_transpose,const bool need_reshape)429 Status ConditionEinsumTensor(TRTNetworkBuilder* builder,
430 std::unique_ptr<TRT_TensorOrWeights>* operand,
431 const EinsumDescriptor& desc,
432 const bool need_transpose,
433 const bool need_reshape) {
434 StatusOr<ShuffleBuilder> shuffle =
435 ShuffleBuilder::Create(builder, (*operand)->tensor()->trt_tensor());
436 TRT_ENSURE_OK(shuffle);
437
438 // Set new shape.
439 if (need_reshape) {
440 if (desc.HasStaticShape()) {
441 nvinfer1::Dims new_dims;
442 TF_RETURN_IF_ERROR(GetEinsumNewStaticShape(desc, &new_dims));
443 shuffle->SetReshape(new_dims);
444 } else {
445 ITensorProxyPtr new_shape;
446 TF_RETURN_IF_ERROR(GetEinsumNewDynamicShape(&*builder, desc, &new_shape));
447 shuffle->SetReshape(new_shape->trt_tensor());
448 }
449 }
450
451 if (need_transpose) {
452 shuffle->SetFirstTranspose(desc.GetPermutation());
453 }
454
455 StatusOr<nvinfer1::ITensor*> shuffle_out = shuffle->Output();
456 TRT_ENSURE_PTR_OK(shuffle_out);
457 *operand = std::make_unique<TRT_TensorOrWeights>(*shuffle_out);
458 return Status::OK();
459 }
460
461 // Handles einsum operand conditioning for both constant and non-constant
462 // inputs. This is supported using the ShuffleEinsumWeights and
463 // ShuffleEinsumTensor routines.
ConditionEinsumOperand(TRTNetworkBuilder * builder,std::unique_ptr<TRT_TensorOrWeights> * operand,const EinsumDescriptor & desc)464 Status ConditionEinsumOperand(TRTNetworkBuilder* builder,
465 std::unique_ptr<TRT_TensorOrWeights>* operand,
466 const EinsumDescriptor& desc) {
467 bool need_reshape = (desc.f != 1 || desc.c != 1);
468 bool need_transpose = !desc.permute.empty();
469
470 VLOG(2) << "Condition operand. Need reshape: " << need_reshape
471 << ". Need transpose: " << need_transpose;
472
473 if ((*operand)->is_weights()) {
474 StatusOr<TRT_TensorOrWeights> result =
475 ConditionEinsumWeights(builder, **operand, desc, need_transpose);
476 TRT_ENSURE_OK(result);
477 *operand = std::make_unique<TRT_TensorOrWeights>(std::move(result).value());
478 }
479
480 // If we didn't convert the operand to a tensor, we can return here.
481 if ((*operand)->is_weights()) {
482 return Status::OK();
483 }
484
485 TF_RETURN_IF_ERROR(ConditionEinsumTensor(builder, operand, desc,
486 need_transpose, need_reshape));
487
488 return Status::OK();
489 }
490
491 // Combines output dims/labels by copying batch and free dims/labels from input
492 // A, and concatenating free values from input B.
493 template <typename InputIterator, typename OutputIterator>
AssembleOutput(InputIterator begin_a,InputIterator begin_b,const EinsumDescriptor & desc_a,const EinsumDescriptor & desc_b,OutputIterator out)494 void AssembleOutput(InputIterator begin_a, InputIterator begin_b,
495 const EinsumDescriptor& desc_a,
496 const EinsumDescriptor& desc_b, OutputIterator out) {
497 std::copy(begin_a, begin_a + desc_a.b, out);
498 begin_a += desc_a.offset_f;
499 std::copy(begin_a, begin_a + desc_a.f, out + desc_a.b);
500 begin_b += desc_b.offset_f;
501 std::copy(begin_b, begin_b + desc_b.f, out + desc_a.b + desc_a.f);
502 }
503
504 // Restores free dimensions and sets final index order. Consider C = A * B,
505 // batched MatMul op, where A.shape = [B, x, k] and B.shape = [B, k, y]. Then
506 // C.shape = [B, x, y]. Here B can denote multiple batch indices while x, y, k
507 // are single indices. The original inputs to Einsum can have multiple free
508 // indices. These were combined into a singe free dimension x and y, for example
509 // x = f_a1 * f_a2 * f_a3, y = f_b1 * f_b2. This routine creates a shuffle layer
510 // to expand x into and y the original free dims, e.g. C is reshaped to
511 // [B, f_a1, f_a2, f_a3, f_b1, f_b2]. Finally, a permutation is applied to
512 // transform the shape to the shape of the original Einsum output.
ShuffleEinsumOutput(OpConverterParams * params,EinsumDescriptor desc_a,EinsumDescriptor desc_b,const std::vector<int> & permutation,ITensorProxyPtr * output)513 Status ShuffleEinsumOutput(OpConverterParams* params, EinsumDescriptor desc_a,
514 EinsumDescriptor desc_b,
515 const std::vector<int>& permutation,
516 ITensorProxyPtr* output) {
517 if (permutation.empty() && (desc_a.f == 1 && desc_b.f == 1)) {
518 return Status::OK();
519 }
520
521 nvinfer1::IShuffleLayer* layer =
522 params->converter->network()->addShuffle(*(*output)->trt_tensor());
523 TFTRT_RETURN_ERROR_IF_NULLPTR(layer, params->node_def.name());
524 params->converter->SetLayerName(layer, params->node_def, "shuffle",
525 /*sub_op_instance=*/2);
526
527 int output_rank = desc_a.b + desc_a.f + desc_b.f;
528 if (desc_a.f != 1 || desc_b.f != 1) {
529 if (desc_a.HasStaticShape() && desc_b.HasStaticShape()) {
530 nvinfer1::Dims dims_out = {output_rank, {}};
531 AssembleOutput(desc_a.dims.d, desc_b.dims.d, desc_a, desc_b, dims_out.d);
532 layer->setReshapeDimensions(dims_out);
533 } else {
534 std::vector<ITensorProxyPtr> size_tensors(output_rank);
535 AssembleOutput(desc_a.size_tensors.begin(), desc_b.size_tensors.begin(),
536 desc_a, desc_b, size_tensors.begin());
537 ITensorProxyPtr new_shape;
538 auto builder = TRTNetworkBuilder::Create(params->converter->network(),
539 params->weight_store);
540 TRT_ENSURE_OK(builder);
541 std::vector<nvinfer1::ITensor*> size_itensors;
542 absl::c_transform(size_tensors, std::back_inserter(size_itensors),
543 [](auto x) { return x->trt_tensor(); });
544 StatusOr<nvinfer1::IConcatenationLayer*> concat =
545 builder->Concat(size_itensors, /*axis=*/0);
546 TRT_ENSURE_PTR_OK(concat);
547 new_shape = (*concat)->getOutput(0);
548 layer->setInput(1, *new_shape->trt_tensor());
549 }
550 }
551
552 if (!permutation.empty()) {
553 nvinfer1::Permutation p;
554 std::copy(permutation.begin(), permutation.end(), p.order);
555 layer->setSecondTranspose(p);
556 }
557 *output = layer->getOutput(0);
558 return Status::OK();
559 }
560
561 // Updates "final_transpose" according to the given descriptors and output
562 // labels.
GetOutputTranspose(const EinsumDescriptor & descriptor_a,const EinsumDescriptor & descriptor_b,Labels output_labels)563 StatusOr<std::vector<int>> GetOutputTranspose(
564 const EinsumDescriptor& descriptor_a, const EinsumDescriptor& descriptor_b,
565 Labels output_labels) {
566 // Get final transpose.
567 std::vector<int> final_transpose;
568 final_transpose.reserve(descriptor_a.b + descriptor_a.f + descriptor_b.f);
569 Labels matmul_output_labels(descriptor_a.b + descriptor_a.f + descriptor_b.f);
570 AssembleOutput(descriptor_a.permuted_labels.begin(),
571 descriptor_b.permuted_labels.begin(), descriptor_a,
572 descriptor_b, matmul_output_labels.begin());
573 TF_RETURN_IF_ERROR(
574 FindIndicesoOfAllValuesInSrc(/*values=*/
575 absl::MakeConstSpan(output_labels.begin(),
576 output_labels.end()),
577 /*src=*/
578 absl::MakeConstSpan(
579 matmul_output_labels.begin(),
580 matmul_output_labels.end()),
581 /*indices=*/&final_transpose));
582 // Clear identity transpose.
583 bool identity_transpose = true;
584 for (int i = 0; i < final_transpose.size() && identity_transpose; i++) {
585 identity_transpose &= final_transpose.at(i) == i;
586 }
587 if (identity_transpose) {
588 final_transpose.clear();
589 }
590 return final_transpose;
591 }
592
593 // Prepares EinsumDescriptors after parsing the equation and determines the
594 // final transpose.
ParseEquation(const std::string & equation,std::unique_ptr<TRT_TensorOrWeights> * input_a,std::unique_ptr<TRT_TensorOrWeights> * input_b,std::unique_ptr<EinsumDescriptor> * descriptor_a,std::unique_ptr<EinsumDescriptor> * descriptor_b,std::vector<int> * final_transpose)595 Status ParseEquation(const std::string& equation,
596 std::unique_ptr<TRT_TensorOrWeights>* input_a,
597 std::unique_ptr<TRT_TensorOrWeights>* input_b,
598 std::unique_ptr<EinsumDescriptor>* descriptor_a,
599 std::unique_ptr<EinsumDescriptor>* descriptor_b,
600 std::vector<int>* final_transpose) {
601 VLOG(2) << "Einsum equation " << equation;
602 OperandLabels input_labels;
603 Labels output_labels;
604 std::vector<EinsumDimensionType> label_types;
605 OperandLabelCounts input_label_counts;
606 LabelCounts output_label_counts;
607 absl::InlinedVector<bool, 2> input_has_ellipsis;
608 bool output_has_ellipsis;
609 TF_RETURN_IF_ERROR(
610 ParseEinsumEquation(equation, &input_labels, &output_labels, &label_types,
611 &input_label_counts, &output_label_counts,
612 &input_has_ellipsis, &output_has_ellipsis));
613
614 if (input_has_ellipsis[0] || input_has_ellipsis[1] || output_has_ellipsis) {
615 // TODO(tfeher): Handle ellipsis like EinsumHelper::ProcessDimensions.
616 // Note: ProcessDimensions would introduce kBroadcasting labels, which we
617 // need to replace with kBatch before we call InitDescriptor.
618 VLOG(2) << "Ellipsis not yet supported";
619 return errors::Unimplemented("No conversion for einsum equation.");
620 }
621
622 if (absl::c_any_of(label_types, [](auto l) {
623 return l == EinsumDimensionType::kReduce ||
624 l == EinsumDimensionType::kBroadcasting;
625 })) {
626 VLOG(2) << "Einsum reductions not implemented";
627 return errors::Unimplemented("No conversion for einsum equation.");
628 }
629
630 auto no_duplicated_labels = [](const LabelCounts& label_counts) {
631 return absl::c_any_of(label_counts, [](int i) { return i > 1; });
632 };
633 if (no_duplicated_labels(input_label_counts[0]) ||
634 no_duplicated_labels(input_label_counts[1]) ||
635 no_duplicated_labels(output_label_counts)) {
636 VLOG(2) << "Einsum invalid label count";
637 return errors::Unimplemented("No conversion for einsum equation.");
638 }
639
640 if ((*input_a)->is_weights() && (*input_b)->is_tensor()) {
641 // We prefer to use FC layer, needs A as tensor and B as weight.
642 std::swap(*input_a, *input_b);
643 std::swap(input_labels[0], input_labels[1]);
644 std::swap(input_label_counts[0], input_label_counts[1]);
645 }
646
647 auto desc = EinsumDescriptor::Create(**input_a, input_labels[0], label_types,
648 EinsumLayout::BFC);
649 TF_RETURN_IF_ERROR(desc.status());
650 *descriptor_a = std::move(desc).value();
651
652 desc = EinsumDescriptor::Create(**input_b, input_labels[1], label_types,
653 EinsumLayout::BCF, *descriptor_a);
654 TF_RETURN_IF_ERROR(desc.status());
655 *descriptor_b = std::move(desc).value();
656
657 auto out_transpose =
658 GetOutputTranspose(**descriptor_a, **descriptor_b, output_labels);
659
660 TRT_ENSURE_OK(out_transpose)
661 *final_transpose = std::move(out_transpose).value();
662 return Status::OK();
663 }
664
665 class ConvertEinsum : public OpConverterBase<ConvertEinsum> {
666 public:
ConvertEinsum(OpConverterParams * params)667 explicit ConvertEinsum(OpConverterParams* params)
668 : OpConverterBase<ConvertEinsum>(params) {}
669
AllowedDataTypes()670 static constexpr std::array<DataType, 3> AllowedDataTypes() {
671 return {DataType::DT_FLOAT, DataType::DT_HALF};
672 }
673
InputSpec()674 static constexpr std::array<InputArgSpec, 2> InputSpec() {
675 return {InputArgSpec::Create("input_a", TrtInputArg::kBoth),
676 InputArgSpec::Create("input_b", TrtInputArg::kBoth)};
677 }
678
Validate()679 Status Validate() {
680 const auto& inputs = params_->inputs;
681 if (params_->use_implicit_batch) {
682 return errors::Unimplemented(
683 "Einsum converter requires dynamic shape mode");
684 }
685
686 input_a = std::make_unique<TRT_TensorOrWeights>(inputs.at(0));
687 input_b = std::make_unique<TRT_TensorOrWeights>(inputs.at(1));
688
689 StatusOr<std::string> eq = GetAttrValue<std::string>("equation");
690 TRT_ENSURE_OK(eq);
691 TF_RETURN_IF_ERROR(ParseEquation(*eq, &input_a, &input_b, &descriptor_a,
692 &descriptor_b, &final_transpose));
693
694 return Status::OK();
695 }
696
Convert()697 Status Convert() {
698 auto builder = TRTNetworkBuilder::Create(params_->converter->network(),
699 params_->weight_store);
700 TRT_ENSURE_OK(builder);
701 TRT_ENSURE(input_a && input_b);
702 TRT_ENSURE(descriptor_a && descriptor_b);
703
704 // Populate the size_tensor vector in the descriptor.
705 TF_RETURN_IF_ERROR(descriptor_a->SetDynamicSize(&*builder, *input_a));
706 TF_RETURN_IF_ERROR(descriptor_b->SetDynamicSize(&*builder, *input_b));
707
708 // Condition the operands for lowering to matmul.
709 TF_RETURN_IF_ERROR(
710 ConditionEinsumOperand(&*builder, &input_a, *descriptor_a));
711 TF_RETURN_IF_ERROR(
712 ConditionEinsumOperand(&*builder, &input_b, *descriptor_b));
713
714 // Build the matmul implementation.
715 StatusOr<ITensorProxyPtr> result = ConvertMatMulImpl(
716 params_, *input_a, *input_b, descriptor_a->layout == EinsumLayout::BCF,
717 descriptor_b->layout == EinsumLayout::BFC);
718 TF_RETURN_IF_ERROR(result.status());
719 ITensorProxyPtr output = result.ValueOrDie();
720
721 // Reshape and permute the output.
722 TF_RETURN_IF_ERROR(ShuffleEinsumOutput(
723 params_, *descriptor_a, *descriptor_b, final_transpose, &output));
724 this->AddOutput(output);
725 return Status::OK();
726 }
727
728 private:
729 std::unique_ptr<TRT_TensorOrWeights> input_a{nullptr};
730 std::unique_ptr<TRT_TensorOrWeights> input_b{nullptr};
731 std::vector<int> final_transpose;
732 std::unique_ptr<EinsumDescriptor> descriptor_a{nullptr};
733 std::unique_ptr<EinsumDescriptor> descriptor_b{nullptr};
734 };
735 #else
736
737 // Helper class to reindex equations to contain only lowercase characters. We
738 // simply define a mapping from the old character set to a new set.
739 // - The input is assumed to be a valid TF equation.
740 // - The input is TRT compatible, therefore it has max 8 dims. (Thus we have
741 // enough lowercase English characters to represent the equation.)
742 // How do we reindex/map equations:
743 // - Only uppercase letters are changed, if possible we just lowercase them.
744 // - If the equation contains both upper and lowercase variant of a letter, say
745 // X and x, then we map X to the first unused lowercase letter.
746 class ReIndexer {
747 public:
748 // Initializes the index map with existing lowercase labels.
749 ReIndexer(std::string eq) {
750 for (char c : eq) {
751 if (islower(c)) {
752 idx_map_[c] = c;
753 }
754 }
755 }
756 // Finds new character for uppercase character c.
757 char operator()(char c) {
758 if (!std::isupper(c)) return c;
759 if (idx_map_.count(c) > 0) return idx_map_[c];
760 char new_idx = std::tolower(c);
761
762 // If lower(c) is not used in the equation, use it to replace c.
763 if (idx_map_.count(new_idx) == 0) {
764 idx_map_[c] = new_idx;
765 idx_map_[new_idx] = new_idx; // mark that new_idx is taken
766 return new_idx;
767 }
768
769 // Otherwise, find the first available lower case to replace c.
770 for (char k = 'a'; k <= 'z'; k++) {
771 if (idx_map_.count(k) == 0) {
772 new_idx = k;
773 idx_map_[c] = new_idx;
774 idx_map_[new_idx] = new_idx; // mark that new_idx is taken
775 break;
776 }
777 }
778 return new_idx;
779 }
780
781 private:
782 // Each key is an index used in the original or in the reindexed equation.
783 // The values are the corresponding new lowercase indices.
784 std::map<char, char> idx_map_;
785 };
786
787 class ConvertEinsum : public OpConverterBase<ConvertEinsum> {
788 public:
789 explicit ConvertEinsum(OpConverterParams* params)
790 : OpConverterBase<ConvertEinsum>(params) {}
791
792 static constexpr std::array<DataType, 3> AllowedDataTypes() {
793 return {DataType::DT_FLOAT, DataType::DT_HALF};
794 }
795
796 Status ValidateInputs() {
797 TRT_ENSURE(params_->inputs.size() <= 2);
798 return Status::OK();
799 }
800 static constexpr bool HasFixNumberOfInputs() { return false; }
801
802 static constexpr std::array<InputArgSpec, 2> InputSpec() {
803 return {InputArgSpec::Create("input_a", TrtInputArg::kBoth),
804 InputArgSpec::Create("input_b", TrtInputArg::kBoth)};
805 }
806
807 std::string MakeLowerCase(const std::string& eq) {
808 std::string res = eq;
809 ReIndexer reindexer(eq);
810 std::transform(eq.begin(), eq.end(), res.begin(), reindexer);
811 return res;
812 }
813
814 // Checks if the equation is supported by TRT.
815 Status ValidateEinsumEquation(const std::string& eq) {
816 const auto& inputs = params_->inputs;
817 OperandLabels input_labels;
818 Labels output_labels;
819 std::vector<EinsumDimensionType> label_types;
820 OperandLabelCounts input_label_counts;
821 LabelCounts output_label_counts;
822 absl::InlinedVector<bool, 2> input_has_ellipsis;
823 bool output_has_ellipsis;
824 VLOG(2) << "Parsing equation " << eq;
825 TF_RETURN_IF_ERROR(ParseEinsumEquation(
826 eq, &input_labels, &output_labels, &label_types, &input_label_counts,
827 &output_label_counts, &input_has_ellipsis, &output_has_ellipsis));
828
829 Status unimplemented =
830 errors::Unimplemented("No conversion for einsum equation.");
831 if (input_has_ellipsis[0] || (inputs.size() > 1 && input_has_ellipsis[1]) ||
832 output_has_ellipsis) {
833 VLOG(2) << "Ellipsis not yet supported";
834 return unimplemented;
835 }
836 for (int i = 0; i < input_label_counts.size(); i++) {
837 for (int k = 0; k < input_label_counts[i].size(); k++) {
838 if (input_label_counts[i][k] > 1) {
839 VLOG(2) << "Diagonal operation or reduction not yet supported";
840 return unimplemented;
841 }
842 }
843 }
844 bool has_out_idx =
845 std::reduce(output_label_counts.begin(), output_label_counts.end(),
846 false, std::logical_or<int>());
847 if (!has_out_idx) {
848 VLOG(2) << "Scalar output not allowed in dynamic shape mode";
849 return unimplemented;
850 }
851 // Check for outer product
852 if (input_label_counts.size() == 2 && output_label_counts.size() == 2 &&
853 output_label_counts[0] == 1 && output_label_counts[1] == 1) {
854 VLOG(2) << "Outer product not supported";
855 return unimplemented;
856 }
857 return Status::OK();
858 }
859
860 Status Validate() {
861 VLOG(2) << "Running validation using the new einsum "
862 "converter";
863 if (params_->use_implicit_batch) {
864 return errors::Unimplemented(
865 "Einsum converter requires dynamic shape mode");
866 }
867
868 StatusOr<std::string> eq = GetAttrValue<std::string>("equation");
869 TRT_ENSURE_OK(eq);
870
871 TF_RETURN_IF_ERROR(ValidateEinsumEquation(*eq));
872
873 // While TF has case sensitive equations, TensorRT expects lowercase eq (as
874 // of version 8.4). See
875 // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#einsum-layer
876 equation = MakeLowerCase(*eq);
877
878 return Status::OK();
879 }
880
881 Status Convert() {
882 auto builder = TRTNetworkBuilder::Create(params_->converter->network(),
883 params_->weight_store);
884 TRT_ENSURE_OK(builder);
885
886 std::vector<nvinfer1::ITensor*> trt_input;
887 for (const TRT_TensorOrWeights& input_arg : params_->inputs) {
888 ITensorProxyPtr ptr = nullptr;
889 if (input_arg.is_tensor()) {
890 ptr = input_arg.tensor();
891 } else {
892 StatusOr<nvinfer1::IConstantLayer*> const_layer =
893 builder->WeightsToConstant(input_arg.weights().GetTrtWeights(),
894 input_arg.GetTrtDims());
895 TRT_ENSURE_PTR_OK(const_layer);
896 ptr = (*const_layer)->getOutput(0);
897 }
898 trt_input.push_back(ptr->trt_tensor());
899 }
900 nvinfer1::IEinsumLayer* layer = params_->converter->network()->addEinsum(
901 trt_input.data(), trt_input.size(), equation.c_str());
902 TRT_ENSURE(layer);
903
904 ITensorProxyPtr output = layer->getOutput(0);
905 this->AddOutput(output);
906 return Status::OK();
907 }
908
909 private:
910 std::string equation;
911 };
912
913 #endif
914
915 } // namespace
916
917 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertEinsum>(),
918 "Einsum");
919 #endif
920
921 } // namespace convert
922 } // namespace tensorrt
923 } // namespace tensorflow
924
925 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
926