xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/ops/einsum.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA && GOOGLE_TENSORRT
16 #include <iterator>
17 #include <limits>
18 #include <memory>
19 
20 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
21 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
22 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter.h"
23 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
24 #include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h"
25 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/util/einsum_op_util.h"
29 #include "third_party/tensorrt/NvInfer.h"
30 
31 #if IS_TRT_VERSION_GE(7, 1, 3, 0)
32 
33 namespace tensorflow {
34 namespace tensorrt {
35 namespace convert {
36 
37 namespace {
38 
39 #if !IS_TRT_VERSION_GE(8, 2, 0, 0)
40 
41 // Finds the indices of elements in [begin, end) in array
42 // [array_begin, array_end), and appends the indices to permute. This is used to
43 // construct the permutation sequence for the operand with input labels
44 // [array_begin, array_end) to the desired permuted labels [begin, end).
45 template <typename T>
FindIndicesoOfAllValuesInSrc(absl::Span<const T> values,absl::Span<const T> src,std::vector<int> * indices)46 Status FindIndicesoOfAllValuesInSrc(absl::Span<const T> values,
47                                     absl::Span<const T> src,
48                                     std::vector<int>* indices) {
49   if (src.size() < values.size()) {
50     return errors::Internal(
51         "Span 'src' cannot contain all elements of 'values'");
52   }
53   for (auto i = 0; i < values.size(); i++) {
54     auto iter = absl::c_find(src, values[i]);
55     if (iter == src.end()) {
56       return errors::Internal("Label ", values[i], " not found");
57     }
58     int idx = std::distance(src.begin(), iter);
59     indices->push_back(idx);
60   }
61   return Status::OK();
62 }
63 
64 // Layout of the einsum dimensions: Batch, Free and Contraction indices.
65 // Example: adbc,adce -> adbe. The first tensor has layout BFC, the second BCF.
66 enum class EinsumLayout { BFC, BCF, MIX };
67 
68 using DimType = EinsumDimensionType;
69 constexpr auto kBatch = DimType::kBatch;
70 constexpr auto kFree = DimType::kFree;
71 constexpr auto kContract = DimType::kContract;
72 
73 // Describes an operand: input shape, number of batch, free and contract
74 // dimensions, and the permutation that is needed to bring it to a matmul
75 // compatible form.
76 class EinsumDescriptor {
77  private:
78   // Checks whether input_labels[offset:offset+m] matches labels from other.
OrderMatches(const Labels & input_labels,int offset,int m,EinsumDimensionType dim_type,const std::unique_ptr<EinsumDescriptor> & other)79   static bool OrderMatches(const Labels& input_labels, int offset, int m,
80                            EinsumDimensionType dim_type,
81                            const std::unique_ptr<EinsumDescriptor>& other) {
82     if (other == nullptr) {
83       return true;
84     }
85     int offset_other = 0;
86     if (dim_type == kFree) {
87       offset = other->offset_f;
88     } else if (dim_type == kContract) {
89       offset = other->offset_c;
90     }
91     return std::equal(input_labels.begin() + offset,
92                       input_labels.begin() + offset + m,
93                       other->permuted_labels.begin() + offset_other);
94   }
95 
96   using label_t_iterator = std::vector<EinsumDimensionType>::const_iterator;
CountLabels(label_t_iterator begin,label_t_iterator end,EinsumDimensionType val)97   static int32_t CountLabels(label_t_iterator begin, label_t_iterator end,
98                              EinsumDimensionType val) {
99     return static_cast<int32_t>(std::count_if(
100         begin, end, [val](EinsumDimensionType t) { return t == val; }));
101   }
102 
103   // Appends indices to the "permute" vector where types maches value.
AppendMatchingIndicesToPermute(const std::vector<EinsumDimensionType> & types,EinsumDimensionType val)104   void AppendMatchingIndicesToPermute(
105       const std::vector<EinsumDimensionType>& types, EinsumDimensionType val) {
106     for (int i = 0; i < types.size(); i++) {
107       if (types[i] == val) {
108         permute.push_back(i);
109       }
110     }
111   }
112 
DetermineLayout(const Labels & input_labels,const std::vector<EinsumDimensionType> & types,const std::unique_ptr<EinsumDescriptor> & other)113   Status DetermineLayout(const Labels& input_labels,
114                          const std::vector<EinsumDimensionType>& types,
115                          const std::unique_ptr<EinsumDescriptor>& other) {
116     // Check if the current layout is BFC or BCF. In that case we could avoid
117     // transpose.
118     layout = EinsumLayout::MIX;
119     if (CountLabels(types.begin(), types.begin() + b, kBatch) == b &&
120         OrderMatches(input_labels, 0, b, kBatch, other)) {
121       // Batch dims are the leading dims. They have the same order as other.
122       if (CountLabels(types.begin() + b, types.begin() + b + f, kFree) == f) {
123         // All the free dims are placed consecutively after the batch dims.
124         // Their order is arbitrary. The final transpose will ensure that the
125         // output has correct order. We still have to check that the contract
126         // indices have correct order.
127         if (OrderMatches(input_labels, b + f, c, kContract, other)) {
128           layout = EinsumLayout::BFC;
129         }
130       } else if (CountLabels(types.begin() + b, types.begin() + b + c,
131                              kContract) == c) {
132         // All the contract dims are placed consecutively after the batch
133         // dims. Check whether the contract dims have the same order as the
134         // contract dims in other.
135         if (OrderMatches(input_labels, b, c, kContract, other)) {
136           layout = EinsumLayout::BCF;
137         }
138       }
139     }
140     return Status::OK();
141   }
142 
CalculateMixedLayoutPermutation(const EinsumLayout preferred_layout,const Labels & input_labels,const std::vector<EinsumDimensionType> & types,const std::unique_ptr<EinsumDescriptor> & other)143   Status CalculateMixedLayoutPermutation(
144       const EinsumLayout preferred_layout, const Labels& input_labels,
145       const std::vector<EinsumDimensionType>& types,
146       const std::unique_ptr<EinsumDescriptor>& other) {
147     // Input label types are mixed. Calculate a permutation that maps them
148     // to the preferred layout (BCF or BFC).
149     layout = preferred_layout;
150     if (other == nullptr) {
151       AppendMatchingIndicesToPermute(types, kBatch);
152     } else {
153       TF_RETURN_IF_ERROR(
154           FindIndicesoOfAllValuesInSrc(/*values=*/
155                                        absl::MakeConstSpan(
156                                            other->permuted_labels.begin(),
157                                            other->b),
158                                        /*src=*/
159                                        absl::MakeConstSpan(input_labels.begin(),
160                                                            input_labels.size()),
161                                        /*indices=*/&permute));
162     }
163     if (layout == EinsumLayout::BFC) {
164       AppendMatchingIndicesToPermute(types, kFree);
165       if (other == nullptr) {
166         AppendMatchingIndicesToPermute(types, kContract);
167       } else {
168         TF_RETURN_IF_ERROR(FindIndicesoOfAllValuesInSrc(
169             /*values=*/absl::MakeConstSpan(
170                 other->permuted_labels.begin() + other->offset_c, other->c),
171             /*src=*/
172             absl::MakeConstSpan(input_labels.begin(), input_labels.size()),
173             /*indices=*/&permute));
174       }
175       return Status::OK();
176     }
177     if (other == nullptr) {
178       AppendMatchingIndicesToPermute(types, kContract);
179     } else {
180       TF_RETURN_IF_ERROR(FindIndicesoOfAllValuesInSrc(
181           /*values=*/absl::MakeConstSpan(
182               other->permuted_labels.begin() + other->offset_c, other->c),
183           /*src=*/absl::MakeConstSpan(input_labels.begin(), input_labels.end()),
184           /*indices=*/&permute));
185     }
186     AppendMatchingIndicesToPermute(types, kFree);
187     return Status::OK();
188   }
189 
Initialize(const TRT_TensorOrWeights & operand,Labels input_labels,std::vector<EinsumDimensionType> & label_types,EinsumLayout preferred_layout,const std::unique_ptr<EinsumDescriptor> & other=nullptr)190   Status Initialize(const TRT_TensorOrWeights& operand, Labels input_labels,
191                     std::vector<EinsumDimensionType>& label_types,
192                     EinsumLayout preferred_layout,
193                     const std::unique_ptr<EinsumDescriptor>& other = nullptr) {
194     if (preferred_layout == EinsumLayout::MIX) {
195       return errors::Internal("Preferred einsum layout cannot be MIX");
196     }
197     // Map label indices to label types.
198     std::vector<EinsumDimensionType> types;  // Input label types.
199     std::transform(input_labels.begin(), input_labels.end(),
200                    std::back_inserter(types),
201                    [&label_types](int i) { return label_types.at(i); });
202 
203     b = CountLabels(types.begin(), types.end(), kBatch);
204     f = CountLabels(types.begin(), types.end(), kFree);
205     c = CountLabels(types.begin(), types.end(), kContract);
206 
207     if (c == 0 || f == 0) {
208       VLOG(2) << "Einsum equation needs to have at least one free and one "
209                  "contract dimension";
210       return errors::Unimplemented("No conversion for einsum equation.");
211     }
212 
213     TF_RETURN_IF_ERROR(DetermineLayout(input_labels, types, other));
214     if (layout == EinsumLayout::MIX) {
215       TF_RETURN_IF_ERROR(CalculateMixedLayoutPermutation(
216           preferred_layout, input_labels, types, other));
217     }
218 
219     if (layout == EinsumLayout::BFC) {
220       offset_f = b;
221       offset_c = f + b;
222     } else {
223       offset_f = b + c;
224       offset_c = b;
225     }
226 
227     dims = operand.GetTrtDims();
228     for (int i = 0; i < b; i++) {
229       // Set unknown batch dims to zero. These dims will be used in reshape op,
230       // where zero is a special value for retaining the original dim size.
231       if (dims.d[i] == -1) {
232         dims.d[i] = 0;
233       }
234     }
235     permuted_labels = input_labels;
236     if (!permute.empty()) {
237       // Apply the permutation on the dimension array.
238       nvinfer1::Dims orig_dims = dims;
239       for (int i = 0; i < permute.size(); i++) {
240         dims.d[i] = orig_dims.d[permute[i]];
241         permuted_labels[i] = input_labels[permute[i]];
242       }
243     }
244     size_tensors.resize(dims.nbDims, nullptr);
245     return Status::OK();
246   }
247 
248  public:
EinsumDescriptor()249   EinsumDescriptor() : b(0), f(0), c(0) {}
250 
251   // Deduces the number of batch, free, contract dimensions from the input
252   // labels, decides what layout to use, and determines permutation indices for
253   // that layout.
Create(const TRT_TensorOrWeights & operand,Labels input_labels,std::vector<EinsumDimensionType> & label_types,EinsumLayout preferred_layout,const std::unique_ptr<EinsumDescriptor> & other=nullptr)254   static StatusOr<std::unique_ptr<EinsumDescriptor>> Create(
255       const TRT_TensorOrWeights& operand, Labels input_labels,
256       std::vector<EinsumDimensionType>& label_types,
257       EinsumLayout preferred_layout,
258       const std::unique_ptr<EinsumDescriptor>& other = nullptr) {
259     auto desc = std::make_unique<EinsumDescriptor>();
260     TF_RETURN_IF_ERROR(desc->Initialize(operand, input_labels, label_types,
261                                         preferred_layout, other));
262     VLOG(2) << desc->DebugString();
263     return desc;
264   }
265 
NumBatchDims() const266   int NumBatchDims() const { return b; }
NumContractDims() const267   int NumContractDims() const { return c; }
NumFreeDims() const268   int NumFreeDims() const { return f; }
ContractDimOffset() const269   int ContractDimOffset() const { return offset_c; }
PermutedLabels() const270   const Labels& PermutedLabels() const { return permuted_labels; }
271 
DebugString() const272   std::string DebugString() const {
273     return absl::StrCat("Descriptor with ",
274                         (layout == EinsumLayout::BFC ? "BFC" : "BCF"),
275                         " layout, b=", b, ", f=", f, ", c=", c);
276   }
277 
278   // Returns whether the free and contract dimension have static shape.
HasStaticShape() const279   bool HasStaticShape() const {
280     return !std::any_of(dims.d + b, dims.d + dims.nbDims,
281                         [](int k) { return k == -1; });
282   }
283 
GetPermutation() const284   nvinfer1::Permutation GetPermutation() const {
285     nvinfer1::Permutation p;
286     std::copy(permute.begin(), permute.end(), p.order);
287     return p;
288   }
289 
PermuteVector() const290   std::vector<int> PermuteVector() const { return permute; }
291 
292   // Sets the "size_tensors" vector to be filled with scalar constant tensors
293   // representing the shape of the operand.
SetDynamicSize(TRTNetworkBuilder * builder,const TRT_TensorOrWeights & operand)294   Status SetDynamicSize(TRTNetworkBuilder* builder,
295                         const TRT_TensorOrWeights& operand) {
296     TRT_ENSURE(operand.GetTrtDims().nbDims == dims.nbDims);
297     if (operand.is_weights()) {
298       // Generate constants for each dimension of the constant weight tensor's
299       // shape.
300       for (int i = 0; i < operand.GetTrtDims().nbDims; i++) {
301         StatusOr<nvinfer1::IConstantLayer*> size_tensor =
302             builder->Constant<int32_t>(dims.d[i], 1);
303         TRT_ENSURE_PTR_OK(size_tensor);
304         size_tensors[i] = (*size_tensor)->getOutput(0);
305       }
306       return Status::OK();
307     }
308 
309     // If the operand is a dynamic tensor, compute the shape value dynamically.
310     StatusOr<nvinfer1::IShapeLayer*> shape_layer =
311         builder->Shape(operand.tensor()->trt_tensor());
312     TRT_ENSURE_PTR_OK(shape_layer);
313     nvinfer1::ITensor* shape = (*shape_layer)->getOutput(0);
314     for (int i = 0; i < operand.GetTrtDims().nbDims; i++) {
315       int idx = permute.empty() ? i : permute.at(i);
316       StatusOr<nvinfer1::ISliceLayer*> slice_layer =
317           builder->Slice(shape, {1, {idx}}, {1, {1}}, {1, {1}});
318       TRT_ENSURE_PTR_OK(slice_layer);
319       size_tensors[i] = (*slice_layer)->getOutput(0);
320     }
321     return Status::OK();
322   }
323 
324   EinsumLayout layout;
325   int b;  // number of batch dims
326   int f;  // number of free dims
327   int c;  // number of conraction dims
328   int offset_f;
329   int offset_c;
330   nvinfer1::Dims dims;
331   std::vector<int> permute;
332   std::vector<ITensorProxyPtr> size_tensors;
333   Labels permuted_labels;
334 };
335 
336 // Reshapes operand so that the free dimensions are combined into a single dim,
337 // and the contract dimensions are combined into another single dim.
GetEinsumNewDynamicShape(TRTNetworkBuilder * builder,const EinsumDescriptor & desc,ITensorProxyPtr * new_shape)338 Status GetEinsumNewDynamicShape(TRTNetworkBuilder* builder,
339                                 const EinsumDescriptor& desc,
340                                 ITensorProxyPtr* new_shape) {
341   std::vector<nvinfer1::ITensor*> size;
342   size.reserve(desc.b + 2);
343   absl::c_transform(absl::MakeSpan(desc.size_tensors).subspan(0, desc.b + 2),
344                     std::back_inserter(size),
345                     [](const ITensorProxyPtr x) { return x->trt_tensor(); });
346 
347   int idx_f = desc.layout == EinsumLayout::BFC ? desc.b : desc.b + 1;
348   int idx_c = desc.layout == EinsumLayout::BFC ? desc.b + 1 : desc.b;
349 
350   std::vector<nvinfer1::ITensor*> size_tensors;
351   size_tensors.reserve(desc.size_tensors.size());
352   absl::c_transform(desc.size_tensors, std::back_inserter(size_tensors),
353                     [](const ITensorProxyPtr x) -> nvinfer1::ITensor* {
354                       return x->trt_tensor();
355                     });
356 
357   StatusOr<nvinfer1::ILayer*> shape_vol = builder->CumulativeProd(
358       absl::MakeSpan(size_tensors).subspan(desc.offset_f, desc.f));
359   TRT_ENSURE_PTR_OK(shape_vol);
360   size[idx_f] = (*shape_vol)->getOutput(0);
361 
362   shape_vol = builder->CumulativeProd(
363       absl::MakeSpan(size_tensors).subspan(desc.offset_c, desc.c));
364   TRT_ENSURE_PTR_OK(shape_vol);
365   size[idx_c] = (*shape_vol)->getOutput(0);
366   StatusOr<nvinfer1::IConcatenationLayer*> layer =
367       builder->Concat(size, /*axis=*/0);
368   TRT_ENSURE_PTR_OK(layer);
369   *new_shape = (*layer)->getOutput(0);
370   return Status::OK();
371 }
372 
373 // Reshapes operand so that the free dimensions are combined into a single dim,
374 // and the contract dimensions are combined into another single dim.
GetEinsumNewStaticShape(const EinsumDescriptor & desc,nvinfer1::Dims * new_dims)375 Status GetEinsumNewStaticShape(const EinsumDescriptor& desc,
376                                nvinfer1::Dims* new_dims) {
377   // Copy the batch dims and append two additional dimensions.
378   DimsAdapter adap(
379       absl::MakeSpan(static_cast<const int32_t*>(desc.dims.d), desc.b));
380   adap.Append(1).Append(1);
381 
382   // Combine free dims and contract dims.
383   int idx_f = desc.layout == EinsumLayout::BFC ? desc.b : desc.b + 1;
384   int idx_c = desc.layout == EinsumLayout::BFC ? desc.b + 1 : desc.b;
385 
386   // Find the volume of the free dimensions.
387   int64_t vol_f =
388       DimsAdapter(
389           absl::MakeSpan(
390               static_cast<const int32_t*>(desc.dims.d) + desc.offset_f, desc.f))
391           .Volume();
392 
393   // Find the volume of the contracted dimensions.
394   int64_t vol_c =
395       DimsAdapter(
396           absl::MakeSpan(
397               static_cast<const int32_t*>(desc.dims.d) + desc.offset_c, desc.c))
398           .Volume();
399 
400   adap.dim(idx_f) = vol_f;
401   adap.dim(idx_c) = vol_c;
402   *new_dims = adap.AsTrtDims();
403   return Status::OK();
404 }
405 
ConditionEinsumWeights(TRTNetworkBuilder * builder,const TRT_TensorOrWeights & operand,const EinsumDescriptor & desc,const bool need_transpose)406 StatusOr<TRT_TensorOrWeights> ConditionEinsumWeights(
407     TRTNetworkBuilder* builder, const TRT_TensorOrWeights& operand,
408     const EinsumDescriptor& desc, const bool need_transpose) {
409   TRT_ENSURE(operand.is_weights());
410   if (!need_transpose) {
411     // If we don't need to transpose, then the operand remains as a weights
412     // constant. In this case we also don't need a reshape.
413     TRT_ShapedWeights weights(operand.weights());
414     nvinfer1::Dims new_dims;
415     TF_RETURN_IF_ERROR(GetEinsumNewStaticShape(desc, &new_dims));
416     TF_RETURN_IF_ERROR(weights.SetShape(new_dims));
417     return TRT_TensorOrWeights(weights);
418   }
419 
420   // Let TensorRT handle constant folding where possible.
421   StatusOr<nvinfer1::IConstantLayer*> tensor = builder->WeightsToConstant(
422       operand.weights().GetTrtWeights(), operand.GetTrtDims());
423   TRT_ENSURE_PTR_OK(tensor);
424   return TRT_TensorOrWeights((*tensor)->getOutput(0));
425 }
426 
427 // Builds a TRT shuffle operation for the given operand. Replaces operand with a
428 // pointer to the shuffle output.
ConditionEinsumTensor(TRTNetworkBuilder * builder,std::unique_ptr<TRT_TensorOrWeights> * operand,const EinsumDescriptor & desc,const bool need_transpose,const bool need_reshape)429 Status ConditionEinsumTensor(TRTNetworkBuilder* builder,
430                              std::unique_ptr<TRT_TensorOrWeights>* operand,
431                              const EinsumDescriptor& desc,
432                              const bool need_transpose,
433                              const bool need_reshape) {
434   StatusOr<ShuffleBuilder> shuffle =
435       ShuffleBuilder::Create(builder, (*operand)->tensor()->trt_tensor());
436   TRT_ENSURE_OK(shuffle);
437 
438   // Set new shape.
439   if (need_reshape) {
440     if (desc.HasStaticShape()) {
441       nvinfer1::Dims new_dims;
442       TF_RETURN_IF_ERROR(GetEinsumNewStaticShape(desc, &new_dims));
443       shuffle->SetReshape(new_dims);
444     } else {
445       ITensorProxyPtr new_shape;
446       TF_RETURN_IF_ERROR(GetEinsumNewDynamicShape(&*builder, desc, &new_shape));
447       shuffle->SetReshape(new_shape->trt_tensor());
448     }
449   }
450 
451   if (need_transpose) {
452     shuffle->SetFirstTranspose(desc.GetPermutation());
453   }
454 
455   StatusOr<nvinfer1::ITensor*> shuffle_out = shuffle->Output();
456   TRT_ENSURE_PTR_OK(shuffle_out);
457   *operand = std::make_unique<TRT_TensorOrWeights>(*shuffle_out);
458   return Status::OK();
459 }
460 
461 // Handles einsum operand conditioning for both constant and non-constant
462 // inputs. This is supported using the ShuffleEinsumWeights and
463 // ShuffleEinsumTensor routines.
ConditionEinsumOperand(TRTNetworkBuilder * builder,std::unique_ptr<TRT_TensorOrWeights> * operand,const EinsumDescriptor & desc)464 Status ConditionEinsumOperand(TRTNetworkBuilder* builder,
465                               std::unique_ptr<TRT_TensorOrWeights>* operand,
466                               const EinsumDescriptor& desc) {
467   bool need_reshape = (desc.f != 1 || desc.c != 1);
468   bool need_transpose = !desc.permute.empty();
469 
470   VLOG(2) << "Condition operand. Need reshape: " << need_reshape
471           << ". Need transpose: " << need_transpose;
472 
473   if ((*operand)->is_weights()) {
474     StatusOr<TRT_TensorOrWeights> result =
475         ConditionEinsumWeights(builder, **operand, desc, need_transpose);
476     TRT_ENSURE_OK(result);
477     *operand = std::make_unique<TRT_TensorOrWeights>(std::move(result).value());
478   }
479 
480   // If we didn't convert the operand to a tensor, we can return here.
481   if ((*operand)->is_weights()) {
482     return Status::OK();
483   }
484 
485   TF_RETURN_IF_ERROR(ConditionEinsumTensor(builder, operand, desc,
486                                            need_transpose, need_reshape));
487 
488   return Status::OK();
489 }
490 
491 // Combines output dims/labels by copying batch and free dims/labels from input
492 // A, and concatenating free values from input B.
493 template <typename InputIterator, typename OutputIterator>
AssembleOutput(InputIterator begin_a,InputIterator begin_b,const EinsumDescriptor & desc_a,const EinsumDescriptor & desc_b,OutputIterator out)494 void AssembleOutput(InputIterator begin_a, InputIterator begin_b,
495                     const EinsumDescriptor& desc_a,
496                     const EinsumDescriptor& desc_b, OutputIterator out) {
497   std::copy(begin_a, begin_a + desc_a.b, out);
498   begin_a += desc_a.offset_f;
499   std::copy(begin_a, begin_a + desc_a.f, out + desc_a.b);
500   begin_b += desc_b.offset_f;
501   std::copy(begin_b, begin_b + desc_b.f, out + desc_a.b + desc_a.f);
502 }
503 
504 // Restores free dimensions and sets final index order. Consider C = A * B,
505 // batched MatMul op, where A.shape = [B, x, k] and B.shape = [B, k, y]. Then
506 // C.shape = [B, x, y]. Here B can denote multiple batch indices while x, y, k
507 // are single indices. The original inputs to Einsum can have multiple free
508 // indices. These were combined into a singe free dimension x and y, for example
509 // x = f_a1 * f_a2 * f_a3, y = f_b1 * f_b2. This routine creates a shuffle layer
510 // to expand x into and y the original free dims, e.g. C is reshaped to
511 // [B, f_a1, f_a2, f_a3, f_b1, f_b2]. Finally, a permutation is applied to
512 // transform the shape to the shape of the original Einsum output.
ShuffleEinsumOutput(OpConverterParams * params,EinsumDescriptor desc_a,EinsumDescriptor desc_b,const std::vector<int> & permutation,ITensorProxyPtr * output)513 Status ShuffleEinsumOutput(OpConverterParams* params, EinsumDescriptor desc_a,
514                            EinsumDescriptor desc_b,
515                            const std::vector<int>& permutation,
516                            ITensorProxyPtr* output) {
517   if (permutation.empty() && (desc_a.f == 1 && desc_b.f == 1)) {
518     return Status::OK();
519   }
520 
521   nvinfer1::IShuffleLayer* layer =
522       params->converter->network()->addShuffle(*(*output)->trt_tensor());
523   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, params->node_def.name());
524   params->converter->SetLayerName(layer, params->node_def, "shuffle",
525                                   /*sub_op_instance=*/2);
526 
527   int output_rank = desc_a.b + desc_a.f + desc_b.f;
528   if (desc_a.f != 1 || desc_b.f != 1) {
529     if (desc_a.HasStaticShape() && desc_b.HasStaticShape()) {
530       nvinfer1::Dims dims_out = {output_rank, {}};
531       AssembleOutput(desc_a.dims.d, desc_b.dims.d, desc_a, desc_b, dims_out.d);
532       layer->setReshapeDimensions(dims_out);
533     } else {
534       std::vector<ITensorProxyPtr> size_tensors(output_rank);
535       AssembleOutput(desc_a.size_tensors.begin(), desc_b.size_tensors.begin(),
536                      desc_a, desc_b, size_tensors.begin());
537       ITensorProxyPtr new_shape;
538       auto builder = TRTNetworkBuilder::Create(params->converter->network(),
539                                                params->weight_store);
540       TRT_ENSURE_OK(builder);
541       std::vector<nvinfer1::ITensor*> size_itensors;
542       absl::c_transform(size_tensors, std::back_inserter(size_itensors),
543                         [](auto x) { return x->trt_tensor(); });
544       StatusOr<nvinfer1::IConcatenationLayer*> concat =
545           builder->Concat(size_itensors, /*axis=*/0);
546       TRT_ENSURE_PTR_OK(concat);
547       new_shape = (*concat)->getOutput(0);
548       layer->setInput(1, *new_shape->trt_tensor());
549     }
550   }
551 
552   if (!permutation.empty()) {
553     nvinfer1::Permutation p;
554     std::copy(permutation.begin(), permutation.end(), p.order);
555     layer->setSecondTranspose(p);
556   }
557   *output = layer->getOutput(0);
558   return Status::OK();
559 }
560 
561 // Updates "final_transpose" according to the given descriptors and output
562 // labels.
GetOutputTranspose(const EinsumDescriptor & descriptor_a,const EinsumDescriptor & descriptor_b,Labels output_labels)563 StatusOr<std::vector<int>> GetOutputTranspose(
564     const EinsumDescriptor& descriptor_a, const EinsumDescriptor& descriptor_b,
565     Labels output_labels) {
566   // Get final transpose.
567   std::vector<int> final_transpose;
568   final_transpose.reserve(descriptor_a.b + descriptor_a.f + descriptor_b.f);
569   Labels matmul_output_labels(descriptor_a.b + descriptor_a.f + descriptor_b.f);
570   AssembleOutput(descriptor_a.permuted_labels.begin(),
571                  descriptor_b.permuted_labels.begin(), descriptor_a,
572                  descriptor_b, matmul_output_labels.begin());
573   TF_RETURN_IF_ERROR(
574       FindIndicesoOfAllValuesInSrc(/*values=*/
575                                    absl::MakeConstSpan(output_labels.begin(),
576                                                        output_labels.end()),
577                                    /*src=*/
578                                    absl::MakeConstSpan(
579                                        matmul_output_labels.begin(),
580                                        matmul_output_labels.end()),
581                                    /*indices=*/&final_transpose));
582   // Clear identity transpose.
583   bool identity_transpose = true;
584   for (int i = 0; i < final_transpose.size() && identity_transpose; i++) {
585     identity_transpose &= final_transpose.at(i) == i;
586   }
587   if (identity_transpose) {
588     final_transpose.clear();
589   }
590   return final_transpose;
591 }
592 
593 // Prepares EinsumDescriptors after parsing the equation and determines the
594 // final transpose.
ParseEquation(const std::string & equation,std::unique_ptr<TRT_TensorOrWeights> * input_a,std::unique_ptr<TRT_TensorOrWeights> * input_b,std::unique_ptr<EinsumDescriptor> * descriptor_a,std::unique_ptr<EinsumDescriptor> * descriptor_b,std::vector<int> * final_transpose)595 Status ParseEquation(const std::string& equation,
596                      std::unique_ptr<TRT_TensorOrWeights>* input_a,
597                      std::unique_ptr<TRT_TensorOrWeights>* input_b,
598                      std::unique_ptr<EinsumDescriptor>* descriptor_a,
599                      std::unique_ptr<EinsumDescriptor>* descriptor_b,
600                      std::vector<int>* final_transpose) {
601   VLOG(2) << "Einsum equation " << equation;
602   OperandLabels input_labels;
603   Labels output_labels;
604   std::vector<EinsumDimensionType> label_types;
605   OperandLabelCounts input_label_counts;
606   LabelCounts output_label_counts;
607   absl::InlinedVector<bool, 2> input_has_ellipsis;
608   bool output_has_ellipsis;
609   TF_RETURN_IF_ERROR(
610       ParseEinsumEquation(equation, &input_labels, &output_labels, &label_types,
611                           &input_label_counts, &output_label_counts,
612                           &input_has_ellipsis, &output_has_ellipsis));
613 
614   if (input_has_ellipsis[0] || input_has_ellipsis[1] || output_has_ellipsis) {
615     // TODO(tfeher): Handle ellipsis like EinsumHelper::ProcessDimensions.
616     // Note: ProcessDimensions would introduce kBroadcasting labels, which we
617     // need to replace with kBatch before we call InitDescriptor.
618     VLOG(2) << "Ellipsis not yet supported";
619     return errors::Unimplemented("No conversion for einsum equation.");
620   }
621 
622   if (absl::c_any_of(label_types, [](auto l) {
623         return l == EinsumDimensionType::kReduce ||
624                l == EinsumDimensionType::kBroadcasting;
625       })) {
626     VLOG(2) << "Einsum reductions not implemented";
627     return errors::Unimplemented("No conversion for einsum equation.");
628   }
629 
630   auto no_duplicated_labels = [](const LabelCounts& label_counts) {
631     return absl::c_any_of(label_counts, [](int i) { return i > 1; });
632   };
633   if (no_duplicated_labels(input_label_counts[0]) ||
634       no_duplicated_labels(input_label_counts[1]) ||
635       no_duplicated_labels(output_label_counts)) {
636     VLOG(2) << "Einsum invalid label count";
637     return errors::Unimplemented("No conversion for einsum equation.");
638   }
639 
640   if ((*input_a)->is_weights() && (*input_b)->is_tensor()) {
641     // We prefer to use FC layer, needs A as tensor and B as weight.
642     std::swap(*input_a, *input_b);
643     std::swap(input_labels[0], input_labels[1]);
644     std::swap(input_label_counts[0], input_label_counts[1]);
645   }
646 
647   auto desc = EinsumDescriptor::Create(**input_a, input_labels[0], label_types,
648                                        EinsumLayout::BFC);
649   TF_RETURN_IF_ERROR(desc.status());
650   *descriptor_a = std::move(desc).value();
651 
652   desc = EinsumDescriptor::Create(**input_b, input_labels[1], label_types,
653                                   EinsumLayout::BCF, *descriptor_a);
654   TF_RETURN_IF_ERROR(desc.status());
655   *descriptor_b = std::move(desc).value();
656 
657   auto out_transpose =
658       GetOutputTranspose(**descriptor_a, **descriptor_b, output_labels);
659 
660   TRT_ENSURE_OK(out_transpose)
661   *final_transpose = std::move(out_transpose).value();
662   return Status::OK();
663 }
664 
665 class ConvertEinsum : public OpConverterBase<ConvertEinsum> {
666  public:
ConvertEinsum(OpConverterParams * params)667   explicit ConvertEinsum(OpConverterParams* params)
668       : OpConverterBase<ConvertEinsum>(params) {}
669 
AllowedDataTypes()670   static constexpr std::array<DataType, 3> AllowedDataTypes() {
671     return {DataType::DT_FLOAT, DataType::DT_HALF};
672   }
673 
InputSpec()674   static constexpr std::array<InputArgSpec, 2> InputSpec() {
675     return {InputArgSpec::Create("input_a", TrtInputArg::kBoth),
676             InputArgSpec::Create("input_b", TrtInputArg::kBoth)};
677   }
678 
Validate()679   Status Validate() {
680     const auto& inputs = params_->inputs;
681     if (params_->use_implicit_batch) {
682       return errors::Unimplemented(
683           "Einsum converter requires dynamic shape mode");
684     }
685 
686     input_a = std::make_unique<TRT_TensorOrWeights>(inputs.at(0));
687     input_b = std::make_unique<TRT_TensorOrWeights>(inputs.at(1));
688 
689     StatusOr<std::string> eq = GetAttrValue<std::string>("equation");
690     TRT_ENSURE_OK(eq);
691     TF_RETURN_IF_ERROR(ParseEquation(*eq, &input_a, &input_b, &descriptor_a,
692                                      &descriptor_b, &final_transpose));
693 
694     return Status::OK();
695   }
696 
Convert()697   Status Convert() {
698     auto builder = TRTNetworkBuilder::Create(params_->converter->network(),
699                                              params_->weight_store);
700     TRT_ENSURE_OK(builder);
701     TRT_ENSURE(input_a && input_b);
702     TRT_ENSURE(descriptor_a && descriptor_b);
703 
704     // Populate the size_tensor vector in the descriptor.
705     TF_RETURN_IF_ERROR(descriptor_a->SetDynamicSize(&*builder, *input_a));
706     TF_RETURN_IF_ERROR(descriptor_b->SetDynamicSize(&*builder, *input_b));
707 
708     // Condition the operands for lowering to matmul.
709     TF_RETURN_IF_ERROR(
710         ConditionEinsumOperand(&*builder, &input_a, *descriptor_a));
711     TF_RETURN_IF_ERROR(
712         ConditionEinsumOperand(&*builder, &input_b, *descriptor_b));
713 
714     // Build the matmul implementation.
715     StatusOr<ITensorProxyPtr> result = ConvertMatMulImpl(
716         params_, *input_a, *input_b, descriptor_a->layout == EinsumLayout::BCF,
717         descriptor_b->layout == EinsumLayout::BFC);
718     TF_RETURN_IF_ERROR(result.status());
719     ITensorProxyPtr output = result.ValueOrDie();
720 
721     // Reshape and permute the output.
722     TF_RETURN_IF_ERROR(ShuffleEinsumOutput(
723         params_, *descriptor_a, *descriptor_b, final_transpose, &output));
724     this->AddOutput(output);
725     return Status::OK();
726   }
727 
728  private:
729   std::unique_ptr<TRT_TensorOrWeights> input_a{nullptr};
730   std::unique_ptr<TRT_TensorOrWeights> input_b{nullptr};
731   std::vector<int> final_transpose;
732   std::unique_ptr<EinsumDescriptor> descriptor_a{nullptr};
733   std::unique_ptr<EinsumDescriptor> descriptor_b{nullptr};
734 };
735 #else
736 
737 // Helper class to reindex equations to contain only lowercase characters. We
738 // simply define a mapping from the old character set to a new set.
739 // - The input is assumed to be a valid TF equation.
740 // - The input is TRT compatible, therefore it has max 8 dims. (Thus we have
741 //   enough lowercase English characters to represent the equation.)
742 // How do we reindex/map equations:
743 // - Only uppercase letters are changed, if possible we just lowercase them.
744 // - If the equation contains both upper and lowercase variant of a letter, say
745 //   X and x, then we map X to the first unused lowercase letter.
746 class ReIndexer {
747  public:
748   // Initializes the index map with existing lowercase labels.
749   ReIndexer(std::string eq) {
750     for (char c : eq) {
751       if (islower(c)) {
752         idx_map_[c] = c;
753       }
754     }
755   }
756   // Finds new character for uppercase character c.
757   char operator()(char c) {
758     if (!std::isupper(c)) return c;
759     if (idx_map_.count(c) > 0) return idx_map_[c];
760     char new_idx = std::tolower(c);
761 
762     // If lower(c) is not used in the equation, use it to replace c.
763     if (idx_map_.count(new_idx) == 0) {
764       idx_map_[c] = new_idx;
765       idx_map_[new_idx] = new_idx;  // mark that new_idx is taken
766       return new_idx;
767     }
768 
769     // Otherwise, find the first available lower case to replace c.
770     for (char k = 'a'; k <= 'z'; k++) {
771       if (idx_map_.count(k) == 0) {
772         new_idx = k;
773         idx_map_[c] = new_idx;
774         idx_map_[new_idx] = new_idx;  // mark that new_idx is taken
775         break;
776       }
777     }
778     return new_idx;
779   }
780 
781  private:
782   // Each key is an index used in the original or in the reindexed equation.
783   // The values are the corresponding new lowercase indices.
784   std::map<char, char> idx_map_;
785 };
786 
787 class ConvertEinsum : public OpConverterBase<ConvertEinsum> {
788  public:
789   explicit ConvertEinsum(OpConverterParams* params)
790       : OpConverterBase<ConvertEinsum>(params) {}
791 
792   static constexpr std::array<DataType, 3> AllowedDataTypes() {
793     return {DataType::DT_FLOAT, DataType::DT_HALF};
794   }
795 
796   Status ValidateInputs() {
797     TRT_ENSURE(params_->inputs.size() <= 2);
798     return Status::OK();
799   }
800   static constexpr bool HasFixNumberOfInputs() { return false; }
801 
802   static constexpr std::array<InputArgSpec, 2> InputSpec() {
803     return {InputArgSpec::Create("input_a", TrtInputArg::kBoth),
804             InputArgSpec::Create("input_b", TrtInputArg::kBoth)};
805   }
806 
807   std::string MakeLowerCase(const std::string& eq) {
808     std::string res = eq;
809     ReIndexer reindexer(eq);
810     std::transform(eq.begin(), eq.end(), res.begin(), reindexer);
811     return res;
812   }
813 
814   // Checks if the equation is supported by TRT.
815   Status ValidateEinsumEquation(const std::string& eq) {
816     const auto& inputs = params_->inputs;
817     OperandLabels input_labels;
818     Labels output_labels;
819     std::vector<EinsumDimensionType> label_types;
820     OperandLabelCounts input_label_counts;
821     LabelCounts output_label_counts;
822     absl::InlinedVector<bool, 2> input_has_ellipsis;
823     bool output_has_ellipsis;
824     VLOG(2) << "Parsing equation " << eq;
825     TF_RETURN_IF_ERROR(ParseEinsumEquation(
826         eq, &input_labels, &output_labels, &label_types, &input_label_counts,
827         &output_label_counts, &input_has_ellipsis, &output_has_ellipsis));
828 
829     Status unimplemented =
830         errors::Unimplemented("No conversion for einsum equation.");
831     if (input_has_ellipsis[0] || (inputs.size() > 1 && input_has_ellipsis[1]) ||
832         output_has_ellipsis) {
833       VLOG(2) << "Ellipsis not yet supported";
834       return unimplemented;
835     }
836     for (int i = 0; i < input_label_counts.size(); i++) {
837       for (int k = 0; k < input_label_counts[i].size(); k++) {
838         if (input_label_counts[i][k] > 1) {
839           VLOG(2) << "Diagonal operation or reduction not yet supported";
840           return unimplemented;
841         }
842       }
843     }
844     bool has_out_idx =
845         std::reduce(output_label_counts.begin(), output_label_counts.end(),
846                     false, std::logical_or<int>());
847     if (!has_out_idx) {
848       VLOG(2) << "Scalar output not allowed in dynamic shape mode";
849       return unimplemented;
850     }
851     // Check for outer product
852     if (input_label_counts.size() == 2 && output_label_counts.size() == 2 &&
853         output_label_counts[0] == 1 && output_label_counts[1] == 1) {
854       VLOG(2) << "Outer product not supported";
855       return unimplemented;
856     }
857     return Status::OK();
858   }
859 
860   Status Validate() {
861     VLOG(2) << "Running validation using the new einsum "
862                "converter";
863     if (params_->use_implicit_batch) {
864       return errors::Unimplemented(
865           "Einsum converter requires dynamic shape mode");
866     }
867 
868     StatusOr<std::string> eq = GetAttrValue<std::string>("equation");
869     TRT_ENSURE_OK(eq);
870 
871     TF_RETURN_IF_ERROR(ValidateEinsumEquation(*eq));
872 
873     // While TF has case sensitive equations, TensorRT expects lowercase eq (as
874     // of version 8.4). See
875     // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#einsum-layer
876     equation = MakeLowerCase(*eq);
877 
878     return Status::OK();
879   }
880 
881   Status Convert() {
882     auto builder = TRTNetworkBuilder::Create(params_->converter->network(),
883                                              params_->weight_store);
884     TRT_ENSURE_OK(builder);
885 
886     std::vector<nvinfer1::ITensor*> trt_input;
887     for (const TRT_TensorOrWeights& input_arg : params_->inputs) {
888       ITensorProxyPtr ptr = nullptr;
889       if (input_arg.is_tensor()) {
890         ptr = input_arg.tensor();
891       } else {
892         StatusOr<nvinfer1::IConstantLayer*> const_layer =
893             builder->WeightsToConstant(input_arg.weights().GetTrtWeights(),
894                                        input_arg.GetTrtDims());
895         TRT_ENSURE_PTR_OK(const_layer);
896         ptr = (*const_layer)->getOutput(0);
897       }
898       trt_input.push_back(ptr->trt_tensor());
899     }
900     nvinfer1::IEinsumLayer* layer = params_->converter->network()->addEinsum(
901         trt_input.data(), trt_input.size(), equation.c_str());
902     TRT_ENSURE(layer);
903 
904     ITensorProxyPtr output = layer->getOutput(0);
905     this->AddOutput(output);
906     return Status::OK();
907   }
908 
909  private:
910   std::string equation;
911 };
912 
913 #endif
914 
915 }  // namespace
916 
917 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertEinsum>(),
918                                   "Einsum");
919 #endif
920 
921 }  // namespace convert
922 }  // namespace tensorrt
923 }  // namespace tensorflow
924 
925 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
926