xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
17 
18 #include <algorithm>
19 #include <cctype>
20 #include <climits>
21 #include <cstdint>
22 #include <string>
23 #include <tuple>
24 #include <utility>
25 
26 #include "absl/memory/memory.h"
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/Optional.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SetVector.h"
32 #include "llvm/ADT/StringRef.h"
33 #include "llvm/ADT/StringSwitch.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/Regex.h"
37 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"  // from @llvm-project
38 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
39 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
40 #include "mlir/Dialect/Traits.h"  // from @llvm-project
41 #include "mlir/IR/Attributes.h"  // from @llvm-project
42 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
43 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
44 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
45 #include "mlir/Pass/Pass.h"  // from @llvm-project
46 #include "mlir/Support/LLVM.h"  // from @llvm-project
47 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
48 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
49 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
50 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
51 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
52 #include "tensorflow/core/util/matmul_bcast.h"
53 
54 namespace mlir {
55 namespace TF {
56 
57 namespace {
58 
59 // Creates ConstOp for int32_t value.
createI32ConstOp(int32_t value,Location loc,PatternRewriter * rewriter)60 ConstOp createI32ConstOp(int32_t value, Location loc,
61                          PatternRewriter* rewriter) {
62   auto int_attr = IntegerAttr::get(rewriter->getIntegerType(32), value);
63   return rewriter->create<ConstOp>(loc, int_attr);
64 }
65 
66 // Creates ConstantOp for array of int32_t.
createI32ConstantOp(llvm::ArrayRef<int32_t> values,Location loc,PatternRewriter * rewriter)67 arith::ConstantOp createI32ConstantOp(llvm::ArrayRef<int32_t> values,
68                                       Location loc, PatternRewriter* rewriter) {
69   auto values_type = RankedTensorType::get(
70       {static_cast<int32_t>(values.size())}, rewriter->getIntegerType(32));
71   auto constant_attr = rewriter->getI32TensorAttr(values);
72   return rewriter->create<arith::ConstantOp>(loc, values_type, constant_attr);
73 }
74 
75 // Creates ConstantOp for array of int64_t.
createI64ConstantOp(llvm::ArrayRef<int64_t> values,Location loc,PatternRewriter * rewriter)76 arith::ConstantOp createI64ConstantOp(llvm::ArrayRef<int64_t> values,
77                                       Location loc, PatternRewriter* rewriter) {
78   auto values_type = RankedTensorType::get(
79       {static_cast<int64_t>(values.size())}, rewriter->getIntegerType(64));
80   auto constant_attr = rewriter->getI64TensorAttr(values);
81   return rewriter->create<arith::ConstantOp>(loc, values_type, constant_attr);
82 }
83 
createTransposeOp(Value value,Location loc,llvm::ArrayRef<int32_t> permutation,PatternRewriter * rewriter)84 TF::TransposeOp createTransposeOp(Value value, Location loc,
85                                   llvm::ArrayRef<int32_t> permutation,
86                                   PatternRewriter* rewriter) {
87   auto perm_op = createI32ConstantOp(permutation, loc, rewriter);
88   auto value_type = value.getType().cast<RankedTensorType>();
89   auto shape = value_type.getShape();
90   SmallVector<int64_t, 4> transposed_shape(shape.begin(), shape.end());
91   for (int i = 0, end = shape.size(); i < end; ++i) {
92     transposed_shape[i] = shape[permutation[i]];
93   }
94   auto transposed_type =
95       RankedTensorType::get(transposed_shape, value_type.getElementType());
96   return rewriter->create<TF::TransposeOp>(loc, transposed_type, value,
97                                            perm_op);
98 }
99 
createReshapeOp(Value value,ArrayRef<int64_t> shape,Type element_type,Location loc,PatternRewriter * rewriter)100 TF::ReshapeOp createReshapeOp(Value value, ArrayRef<int64_t> shape,
101                               Type element_type, Location loc,
102                               PatternRewriter* rewriter) {
103   auto shape_tensor = createI64ConstantOp(shape, loc, rewriter);
104   Type resultType = RankedTensorType::get(shape, element_type);
105   return rewriter->create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
106                                          /*shape=*/shape_tensor);
107 }
108 
109 // Creates ReshapeOp with runtime calcuation of required shape to support
110 // dynamic shapes. The shape is calculated by Shape and UnsortedSegmentProd op.
111 // `reshape_segids` and `num_reshape_segids` for UnsortedSegmentProd is
112 // calculated in `reshapeForBatchMatmul`.
createReshapeOpForDynamic(Value value,ArrayRef<int64_t> shape,ArrayRef<int32_t> reshape_segids,int32_t num_reshape_segids,Location loc,PatternRewriter * rewriter)113 TF::ReshapeOp createReshapeOpForDynamic(Value value, ArrayRef<int64_t> shape,
114                                         ArrayRef<int32_t> reshape_segids,
115                                         int32_t num_reshape_segids,
116                                         Location loc,
117                                         PatternRewriter* rewriter) {
118   // Build ShapeOp
119   auto input_shape =
120       rewriter->create<TF::ShapeOp>(loc, value, rewriter->getBoolAttr(true));
121 
122   // Build UnsortedSegmentProdOp
123   Type segProdresultType =
124       RankedTensorType::get(num_reshape_segids, rewriter->getIntegerType(32));
125   auto segids_tensor = createI32ConstantOp(reshape_segids, loc, rewriter);
126   auto num_reshape_segids_tensor =
127       createI32ConstOp(num_reshape_segids, loc, rewriter);
128   auto segprod = rewriter->create<TF::UnsortedSegmentProdOp>(
129       loc, segProdresultType, input_shape->getResults()[0], segids_tensor,
130       num_reshape_segids_tensor);
131 
132   // Build ReshapeOp with the result of UnsortedSegmentProdOp.
133   Type out_tensor_type =
134       RankedTensorType::get(shape, getElementTypeOrSelf(value.getType()));
135   return rewriter->create<TF::ReshapeOp>(loc, out_tensor_type,
136                                          /*tensor=*/value,
137                                          /*shape=*/segprod->getResults()[0]);
138 }
139 
140 struct EinsumDimensionNumbers {
141   // Each field contains the list of dimensions appearing only in the specifed
142   // arguments of the einsum op with natural ordering. For example `rhs_out`
143   // contains the dimensions appearing in the RHS and the OUTPUT of the einsum
144   // but not in the LHS.
145   std::vector<int64_t> lhs;
146   std::vector<int64_t> rhs;
147   std::vector<std::tuple<int64_t, int64_t>> lhs_rhs;
148   std::vector<std::tuple<int64_t, int64_t>> lhs_out;
149   std::vector<std::tuple<int64_t, int64_t>> rhs_out;
150   std::vector<std::tuple<int64_t, int64_t, int64_t>> lhs_rhs_out;
151 };
152 
createOutputReshapeOpForDynamic(Value value,ArrayRef<int64_t> shape,Value org_lhs,Value org_rhs,EinsumDimensionNumbers & dnums,Location loc,PatternRewriter * rewriter)153 TF::ReshapeOp createOutputReshapeOpForDynamic(
154     Value value, ArrayRef<int64_t> shape, Value org_lhs, Value org_rhs,
155     EinsumDimensionNumbers& dnums, Location loc, PatternRewriter* rewriter) {
156   BoolAttr true_attr = rewriter->getBoolAttr(true);
157   // Build ShapeOp
158   auto shape_lhs = rewriter->create<TF::ShapeOp>(loc, org_lhs, true_attr);
159   auto shape_rhs = rewriter->create<TF::ShapeOp>(loc, org_rhs, true_attr);
160 
161   std::vector<int32_t> bl_index;  // Indexes of B0,...,Bn and L0,...,Ln
162   bl_index.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size());
163   for (auto i : dnums.lhs_rhs_out) {
164     bl_index.push_back(std::get<0>(i));
165   }
166   for (auto i : dnums.lhs_out) {
167     bl_index.push_back(std::get<0>(i));
168   }
169   std::vector<int32_t> r_index;  // Indexes of R0,...,Rn
170   r_index.reserve(dnums.rhs_out.size());
171   for (auto i : dnums.rhs_out) {
172     r_index.push_back(std::get<0>(i));
173   }
174 
175   auto lhs_index_tensor = createI32ConstantOp(bl_index, loc, rewriter);
176   auto gather_lhs = rewriter->create<TF::GatherOp>(
177       loc,
178       RankedTensorType::get({static_cast<int>(bl_index.size())},
179                             rewriter->getIntegerType(32)),
180       shape_lhs->getResults()[0], lhs_index_tensor->getResults()[0], true_attr);
181   auto rhs_index_tensor = createI32ConstantOp(r_index, loc, rewriter);
182   auto gather_rhs = rewriter->create<TF::GatherOp>(
183       loc,
184       RankedTensorType::get({static_cast<int>(r_index.size())},
185                             rewriter->getIntegerType(32)),
186       shape_rhs->getResults()[0], rhs_index_tensor->getResults()[0], true_attr);
187   Value zero_value = createI32ConstOp(0, loc, rewriter);
188   auto concat_out_shape = rewriter->create<TF::ConcatOp>(
189       loc,
190       RankedTensorType::get({static_cast<int>(bl_index.size()) +
191                              static_cast<int>(r_index.size())},
192                             rewriter->getIntegerType(32)),
193       zero_value,
194       ArrayRef<Value>(
195           {gather_lhs->getResults()[0], gather_rhs->getResults()[0]}));
196 
197   // Build ReshapeOp with the calculated output shape.
198   Type out_type =
199       RankedTensorType::get(shape, getElementTypeOrSelf(value.getType()));
200   return rewriter->create<TF::ReshapeOp>(
201       loc, out_type,
202       /*tensor=*/value,
203       /*shape=*/concat_out_shape->getResults()[0]);
204 }
205 
EquationToMap(llvm::StringRef equation)206 llvm::Optional<llvm::SmallDenseMap<char, int64_t>> EquationToMap(
207     llvm::StringRef equation) {
208   llvm::SmallDenseMap<char, int64_t> map;
209   for (int64_t i = 0; i < equation.size(); ++i) {
210     if (!std::isalpha(equation[i])) {
211       // Unsupported character in the equation.
212       return llvm::None;
213     }
214     if (map.count(equation[i])) {
215       // Duplicate character in the equation.
216       return llvm::None;
217     }
218     map.try_emplace(equation[i], i);
219   }
220   return map;
221 }
222 
GetAvailableLabels(llvm::StringRef lhs,llvm::StringRef rhs,int * lhs_named_label_count,int * rhs_named_label_count)223 llvm::Optional<llvm::SetVector<char>> GetAvailableLabels(
224     llvm::StringRef lhs, llvm::StringRef rhs, int* lhs_named_label_count,
225     int* rhs_named_label_count) {
226   llvm::SetVector<char> labels;
227   for (int i = 0; i < 26; ++i) {
228     labels.insert('a' + i);
229     labels.insert('A' + i);
230   }
231 
232   auto is_start_of_ellipsis = [](StringRef equation, int start_index) {
233     if (equation.size() < (start_index + 3)) return false;
234 
235     if (equation.substr(start_index, 3) != "...") return false;
236     return true;
237   };
238 
239   int lhs_count = 0;
240   const int lhs_size = lhs.size();
241   for (int i = 0; i < lhs_size; ++i) {
242     const char label = lhs[i];
243     if (std::isalpha(label)) {
244       labels.remove(label);
245       ++lhs_count;
246     } else if (label == '.') {
247       if (!is_start_of_ellipsis(lhs, i)) return llvm::None;
248       i += 2;
249     } else {
250       // Unsupported character in the equation.
251       return llvm::None;
252     }
253   }
254   *lhs_named_label_count = lhs_count;
255 
256   int rhs_count = 0;
257   const int rhs_size = rhs.size();
258   for (int i = 0; i < rhs_size; ++i) {
259     const char label = rhs[i];
260     if (std::isalpha(label)) {
261       labels.remove(label);
262       ++rhs_count;
263     } else if (label == '.') {
264       if (!is_start_of_ellipsis(rhs, i)) return llvm::None;
265       i += 2;
266     } else {
267       // Unsupported character in the equation.
268       return llvm::None;
269     }
270   }
271 
272   *rhs_named_label_count = rhs_count;
273   return labels;
274 }
275 
276 // Generate new unnamed labels for the expression.
277 // For example, if we have GenerateLabels(2, {'b', 'c', 'd'}) for "...xy"
278 // We will have "dcxy" for the ellipsis expression since it's rank 4,
279 // we will have dcbxy if it's rank 5.
GenerateLabels(int count,const llvm::SetVector<char> & available_labels)280 std::string GenerateLabels(int count,
281                            const llvm::SetVector<char>& available_labels) {
282   std::string new_labels(count, 0);
283   for (int i = 0; i < count; ++i) {
284     new_labels[count - 1 - i] = available_labels[i];
285   }
286 
287   return new_labels;
288 }
289 
FlattenEllipsis(llvm::StringRef lhs,int lhs_named_label_count,llvm::StringRef rhs,int rhs_named_label_count,llvm::StringRef out,RankedTensorType lhs_ty,RankedTensorType rhs_ty,const llvm::SetVector<char> & available_labels)290 std::tuple<std::string, std::string, std::string> FlattenEllipsis(
291     llvm::StringRef lhs, int lhs_named_label_count, llvm::StringRef rhs,
292     int rhs_named_label_count, llvm::StringRef out, RankedTensorType lhs_ty,
293     RankedTensorType rhs_ty, const llvm::SetVector<char>& available_labels) {
294   std::string new_labels;
295   std::string new_lhs;
296   for (int i = 0; i < lhs.size(); ++i) {
297     const char label = lhs[i];
298     if (std::isalpha(label)) {
299       new_lhs.push_back(label);
300     } else {
301       // Encounter ellipsis: generate unnamed labels then insert to the new
302       // labels.
303       new_labels = GenerateLabels(lhs_ty.getRank() - lhs_named_label_count,
304                                   available_labels);
305       new_lhs.append(new_labels);
306       i += 2;
307     }
308   }
309 
310   std::string new_rhs, new_rhs_labels;
311   for (int i = 0; i < rhs.size(); ++i) {
312     const char label = rhs[i];
313     if (std::isalpha(label)) {
314       new_rhs.push_back(label);
315     } else {
316       // Encounter ellipsis: generate unnamed labels then insert to the new
317       // labels.
318       new_rhs_labels = GenerateLabels(rhs_ty.getRank() - rhs_named_label_count,
319                                       available_labels);
320       new_rhs.append(new_rhs_labels);
321       i += 2;
322       if (new_rhs_labels.size() > new_labels.size()) {
323         new_labels = new_rhs_labels;
324       }
325     }
326   }
327 
328   // Deal with the output next.
329   std::string new_output;
330   for (int i = 0; i < out.size(); ++i) {
331     const char label = out[i];
332     if (std::isalpha(label)) {
333       new_output.push_back(label);
334     } else {
335       // Encounter ellipsis: we will just insert the generated labels to the new
336       // output label.
337       new_output.append(new_labels);
338       i += 2;
339     }
340   }
341 
342   return std::make_tuple(new_lhs, new_rhs, new_output);
343 }
344 
GetEinsumDimensionNumbers(llvm::StringRef equation,RankedTensorType lhs_ty,RankedTensorType rhs_ty)345 llvm::Optional<EinsumDimensionNumbers> GetEinsumDimensionNumbers(
346     llvm::StringRef equation, RankedTensorType lhs_ty,
347     RankedTensorType rhs_ty) {
348   llvm::StringRef lhs_rhs;
349   llvm::StringRef out;
350   std::tie(lhs_rhs, out) = equation.split("->");
351   if (lhs_rhs.empty() || out.empty()) return llvm::None;
352 
353   llvm::StringRef lhs;
354   llvm::StringRef rhs;
355   std::tie(lhs, rhs) = lhs_rhs.split(',');
356   if (lhs.empty() || rhs.empty()) return llvm::None;
357 
358   // Try to flatten the "..." if possible.
359   int lhs_named_label, rhs_named_label;
360   auto avaiable_labels =
361       GetAvailableLabels(lhs, rhs, &lhs_named_label, &rhs_named_label);
362   if (!avaiable_labels.has_value()) return llvm::None;
363 
364   auto flattended_labels =
365       FlattenEllipsis(lhs, lhs_named_label, rhs, rhs_named_label, out, lhs_ty,
366                       rhs_ty, avaiable_labels.getValue());
367 
368   lhs = std::get<0>(flattended_labels);
369   rhs = std::get<1>(flattended_labels);
370   out = std::get<2>(flattended_labels);
371 
372   auto lhs_map_or = EquationToMap(lhs);
373   if (!lhs_map_or.has_value()) return llvm::None;
374   auto lhs_map = lhs_map_or.getValue();
375 
376   auto rhs_map_or = EquationToMap(rhs);
377   if (!rhs_map_or.has_value()) return llvm::None;
378   auto rhs_map = rhs_map_or.getValue();
379 
380   auto out_map_or = EquationToMap(out);
381   if (!out_map_or.has_value()) return llvm::None;
382   auto out_map = out_map_or.getValue();
383 
384   EinsumDimensionNumbers dnums;
385   for (int64_t i = 0, e = lhs.size(); i < e; ++i) {
386     auto rhs_index = rhs_map.find(lhs[i]);
387     auto out_index = out_map.find(lhs[i]);
388     if (rhs_index == rhs_map.end() && out_index == out_map.end()) {
389       dnums.lhs.emplace_back(i);
390     } else if (rhs_index == rhs_map.end()) {
391       dnums.lhs_out.emplace_back(i, out_index->second);
392     } else if (out_index == out_map.end()) {
393       dnums.lhs_rhs.emplace_back(i, rhs_index->second);
394     } else {
395       dnums.lhs_rhs_out.emplace_back(i, rhs_index->second, out_index->second);
396     }
397   }
398   for (int64_t i = 0, e = rhs.size(); i < e; ++i) {
399     auto lhs_index = lhs_map.find(rhs[i]);
400     auto out_index = out_map.find(rhs[i]);
401     if (lhs_index == lhs_map.end()) {
402       if (out_index == out_map.end()) {
403         dnums.rhs.emplace_back(i);
404       } else {
405         dnums.rhs_out.emplace_back(i, out_index->second);
406       }
407     }
408   }
409   for (int64_t i = 0, e = out.size(); i < e; ++i) {
410     auto lhs_index = lhs_map.find(out[i]);
411     auto rhs_index = rhs_map.find(out[i]);
412     if (lhs_index == lhs_map.end() && rhs_index == rhs_map.end()) {
413       // out only isn't supported
414       return llvm::None;
415     }
416   }
417   return dnums;
418 }
419 
inverseTransposeVector(llvm::ArrayRef<int64_t> input,llvm::ArrayRef<int32_t> permutation)420 std::vector<int64_t> inverseTransposeVector(
421     llvm::ArrayRef<int64_t> input, llvm::ArrayRef<int32_t> permutation) {
422   std::vector<int64_t> output(input.size());
423   for (int64_t i = 0; i < input.size(); ++i) {
424     output[permutation[i]] = input[i];
425   }
426   return output;
427 }
428 
429 // Computes the transpositions required to convert dnums to one supported by
430 // tf.BatchMatmulV2 and returns the new set of dimension numbers with them.
431 // Transposed LHS shape will be B0,...,Bn,L0,...,Ln,C0,...,Cn and,
432 // transposed RHS shape will be B0,...,Bn,C0,...,Cn,R0,...,Rn respectively.
transposeForBatchMatmul(const Location & loc,EinsumDimensionNumbers & dnums,Value * lhs,Value * rhs,std::vector<int32_t> * out_inverse_transpose,PatternRewriter * rewriter)433 LogicalResult transposeForBatchMatmul(
434     const Location& loc, EinsumDimensionNumbers& dnums, Value* lhs, Value* rhs,
435     std::vector<int32_t>* out_inverse_transpose, PatternRewriter* rewriter) {
436   std::vector<int32_t> lhs_transpose;
437   std::vector<int32_t> rhs_transpose;
438   std::vector<int32_t> out_transpose;
439   lhs_transpose.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() +
440                         dnums.lhs_rhs.size());
441   rhs_transpose.reserve(dnums.lhs_rhs_out.size() + dnums.rhs_out.size() +
442                         dnums.lhs_rhs.size());
443   out_transpose.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() +
444                         dnums.rhs_out.size());
445   // Generate transpose matrix for B0,...,Bn
446   for (int64_t i = 0, e = dnums.lhs_rhs_out.size(); i < e; ++i) {
447     lhs_transpose.push_back(std::get<0>(dnums.lhs_rhs_out[i]));
448     rhs_transpose.push_back(std::get<1>(dnums.lhs_rhs_out[i]));
449     out_transpose.push_back(std::get<2>(dnums.lhs_rhs_out[i]));
450     dnums.lhs_rhs_out[i] = std::make_tuple(i, i, i);
451   }
452 
453   // Generate transpose matrix for L0,...,Ln
454   for (int64_t i = 0, e = dnums.lhs_out.size(); i < e; ++i) {
455     lhs_transpose.push_back(std::get<0>(dnums.lhs_out[i]));
456     out_transpose.push_back(std::get<1>(dnums.lhs_out[i]));
457     dnums.lhs_out[i] =
458         std::make_tuple(lhs_transpose.size() - 1, out_transpose.size() - 1);
459   }
460   // Generate transpose matrix for C0,...,Cn
461   for (int64_t i = 0, e = dnums.lhs_rhs.size(); i < e; ++i) {
462     lhs_transpose.push_back(std::get<0>(dnums.lhs_rhs[i]));
463     rhs_transpose.push_back(std::get<1>(dnums.lhs_rhs[i]));
464     dnums.lhs_rhs[i] =
465         std::make_tuple(lhs_transpose.size() - 1, rhs_transpose.size() - 1);
466   }
467   for (int64_t i = 0, e = dnums.rhs_out.size(); i < e; ++i) {
468     rhs_transpose.push_back(std::get<0>(dnums.rhs_out[i]));
469     out_transpose.push_back(std::get<1>(dnums.rhs_out[i]));
470     dnums.rhs_out[i] =
471         std::make_tuple(rhs_transpose.size() - 1, out_transpose.size() - 1);
472   }
473 
474   out_inverse_transpose->resize(out_transpose.size());
475   for (int64_t i = 0, e = out_transpose.size(); i < e; ++i) {
476     out_inverse_transpose->at(out_transpose[i]) = i;
477   }
478 
479   *lhs = createTransposeOp(*lhs, loc, lhs_transpose, rewriter);
480   *rhs = createTransposeOp(*rhs, loc, rhs_transpose, rewriter);
481   return success();
482 }
483 
484 template <int I>
ProdShapeWithIndexInTuple(ArrayRef<int64_t> shape,const std::vector<std::tuple<int64_t,int64_t>> & index_tuples)485 inline int64_t ProdShapeWithIndexInTuple(
486     ArrayRef<int64_t> shape,
487     const std::vector<std::tuple<int64_t, int64_t>>& index_tuples) {
488   int64_t prod_shape = 1;
489   for (auto index_tuple : index_tuples) {
490     const int64_t shape_i = shape[std::get<I>(index_tuple)];
491     if (shape_i == -1) return -1;
492     prod_shape *= shape_i;
493   }
494   return prod_shape;
495 }
496 
497 // Reshapes LHS and RHS to have B0,...,Bn,L,C and B0,...,Bn,C,R shape
498 // respectively while assuming that the initial shape for them is
499 // B0,...,Bn,L0,...,Ln,C0,...,Cn and B0,...,Bn,C0,...,Cn,R0,...,Rn respectively.
reshapeForBatchMatmul(const Location & loc,EinsumDimensionNumbers & dnums,Value * lhs,Value * rhs,SmallVectorImpl<int64_t> * out_shape,PatternRewriter * rewriter)500 LogicalResult reshapeForBatchMatmul(const Location& loc,
501                                     EinsumDimensionNumbers& dnums, Value* lhs,
502                                     Value* rhs,
503                                     SmallVectorImpl<int64_t>* out_shape,
504                                     PatternRewriter* rewriter) {
505   RankedTensorType lhs_type = lhs->getType().cast<RankedTensorType>();
506   RankedTensorType rhs_type = rhs->getType().cast<RankedTensorType>();
507 
508   int32_t num_lhs_reshape_segids = 0;
509   int32_t num_rhs_reshape_segids = 0;
510   std::vector<int32_t> lhs_reshape_segids;
511   int lhs_rank =
512       dnums.lhs_rhs_out.size() + dnums.lhs_out.size() + dnums.lhs_rhs.size();
513   lhs_reshape_segids.resize(lhs_rank);
514   std::vector<int32_t> rhs_reshape_segids;
515   int rhs_rank =
516       dnums.lhs_rhs_out.size() + dnums.rhs_out.size() + dnums.lhs_rhs.size();
517   rhs_reshape_segids.resize(rhs_rank);
518 
519   // Labels exist in all lhs, rhs and output are the batch labels B0,...,Bn.
520   std::vector<int64_t> lhs_shape;
521   std::vector<int64_t> rhs_shape;
522   lhs_shape.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() + 1);
523   rhs_shape.reserve(dnums.lhs_rhs_out.size() + 2);
524   for (auto i : dnums.lhs_rhs_out) {
525     const int64_t b1 = lhs_type.getShape()[std::get<0>(i)];
526     lhs_shape.push_back(b1);
527     const int64_t b2 = rhs_type.getShape()[std::get<1>(i)];
528     rhs_shape.push_back(b2);
529 
530     lhs_reshape_segids.at(std::get<0>(i)) = num_lhs_reshape_segids++;
531     rhs_reshape_segids.at(std::get<1>(i)) = num_rhs_reshape_segids++;
532   }
533   if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape, *out_shape)) {
534     return failure();
535   }
536 
537   // Calculates dimension for the label L from L0,...,Ln in lhs.
538   if (dnums.lhs_out.empty()) {
539     lhs_shape.push_back(1);
540     out_shape->push_back(1);
541     dnums.lhs_out.emplace_back(lhs_shape.size() - 1, out_shape->size() - 1);
542     ++num_lhs_reshape_segids;
543   } else if (dnums.lhs_rhs_out.empty()) {
544     // If there is not batch labels B0,...,Bn, it is safe to use L0,...,Ln as
545     // the batch labels in lhs, the rhs will be broadcasted.
546     for (auto i : dnums.lhs_out) {
547       const int64_t b = lhs_type.getShape()[std::get<0>(i)];
548       lhs_shape.push_back(b);
549       out_shape->push_back(b);
550 
551       lhs_reshape_segids.at(std::get<0>(i)) = num_lhs_reshape_segids++;
552     }
553   } else {
554     const int64_t lhs_out_size =
555         ProdShapeWithIndexInTuple<0>(lhs_type.getShape(), dnums.lhs_out);
556     lhs_shape.push_back(lhs_out_size);
557     out_shape->push_back(lhs_out_size);
558 
559     for (auto i : dnums.lhs_out) {
560       lhs_reshape_segids.at(std::get<0>(i)) = num_lhs_reshape_segids;
561     }
562     ++num_lhs_reshape_segids;
563   }
564 
565   // Calculates dimension for the common label C from labels C0,...,Cn that
566   // exist in both lhs and rhs.
567   const int64_t lhs_size =
568       ProdShapeWithIndexInTuple<0>(lhs_type.getShape(), dnums.lhs_rhs);
569   const int64_t rhs_size =
570       ProdShapeWithIndexInTuple<1>(rhs_type.getShape(), dnums.lhs_rhs);
571   lhs_shape.push_back(lhs_size);
572   rhs_shape.push_back(rhs_size);
573 
574   for (auto i : dnums.lhs_rhs) {
575     lhs_reshape_segids.at(std::get<0>(i)) = num_lhs_reshape_segids;
576     rhs_reshape_segids.at(std::get<1>(i)) = num_rhs_reshape_segids;
577   }
578   ++num_lhs_reshape_segids;
579   ++num_rhs_reshape_segids;
580 
581   // Calculates dimension for the label R from R0,...,Rn in rhs.
582   const int64_t rhs_out_size =
583       ProdShapeWithIndexInTuple<0>(rhs_type.getShape(), dnums.rhs_out);
584   rhs_shape.push_back(rhs_out_size);
585   out_shape->push_back(rhs_out_size);
586 
587   for (auto i : dnums.rhs_out) {
588     rhs_reshape_segids.at(std::get<0>(i)) = num_rhs_reshape_segids;
589   }
590   ++num_rhs_reshape_segids;
591 
592   // If LHS requires reshapes.
593   if (lhs_rank != num_lhs_reshape_segids) {
594     if (succeeded(VerifyShapeOfReshapeOp(lhs_shape))) {
595       *lhs = createReshapeOp(*lhs, lhs_shape, lhs_type.getElementType(), loc,
596                              rewriter);
597     } else {
598       // Check if lhs LHS shape can be calculated with SegmentProd. It requires
599       // to have at least 1 common index in lhs_out and lhs_rhs.
600       if (dnums.lhs_out.empty() || dnums.lhs_rhs.empty()) return failure();
601       *lhs = createReshapeOpForDynamic(*lhs, lhs_shape, lhs_reshape_segids,
602                                        num_lhs_reshape_segids, loc, rewriter);
603     }
604   }
605   // If RHS requires reshapes.
606   if (rhs_rank != num_rhs_reshape_segids) {
607     if (succeeded(VerifyShapeOfReshapeOp(rhs_shape))) {
608       *rhs = createReshapeOp(*rhs, rhs_shape, rhs_type.getElementType(), loc,
609                              rewriter);
610     } else {
611       // Check if lhs RHS shape can be calculated with SegmentProd. It requires
612       // to have at least 1 common index in rhs_out and lhs_rhs.
613       if (dnums.rhs_out.empty() || dnums.lhs_rhs.empty()) return failure();
614       *rhs = createReshapeOpForDynamic(*rhs, rhs_shape, rhs_reshape_segids,
615                                        num_rhs_reshape_segids, loc, rewriter);
616     }
617   }
618 
619   dnums.lhs_rhs.assign(
620       {std::make_tuple(dnums.lhs_rhs_out.size() + dnums.lhs_out.size(),
621                        dnums.lhs_rhs_out.size())});
622   dnums.rhs_out.assign(
623       {std::make_tuple(dnums.lhs_rhs_out.size() + dnums.lhs_out.size(),
624                        dnums.lhs_rhs_out.size() + dnums.lhs_out.size())});
625   return success();
626 }
627 
rewriteToBatchMatmul(TF::EinsumOp op,EinsumDimensionNumbers dnums,PatternRewriter & rewriter)628 LogicalResult rewriteToBatchMatmul(TF::EinsumOp op,
629                                    EinsumDimensionNumbers dnums,
630                                    PatternRewriter& rewriter) {
631   if (!dnums.lhs.empty() || !dnums.rhs.empty()) return failure();
632 
633   auto inputs = op.inputs();
634   if (inputs.size() != 2) return failure();
635   Value lhs = inputs.front();
636   Value rhs = inputs.back();
637   // Back original values for the later output shape calculation in
638   // `createOutputReshapeOpForDynamic`.
639   Value original_lhs = lhs;
640   Value original_rhs = rhs;
641   EinsumDimensionNumbers original_dnums = dnums;
642 
643   RankedTensorType original_type =
644       op.getResult().getType().dyn_cast_or_null<RankedTensorType>();
645   if (!original_type) return failure();
646 
647   std::vector<int32_t> out_transpose;
648   if (failed(transposeForBatchMatmul(op.getLoc(), dnums, &lhs, &rhs,
649                                      &out_transpose, &rewriter)))
650     return failure();
651 
652   llvm::SmallVector<int64_t, 4> matmul_shape;
653   if (failed(reshapeForBatchMatmul(op.getLoc(), dnums, &lhs, &rhs,
654                                    &matmul_shape, &rewriter)))
655     return failure();
656 
657   std::vector<int64_t> reshape_shape =
658       inverseTransposeVector(original_type.getShape(), out_transpose);
659 
660   auto matmul_type =
661       RankedTensorType::get(matmul_shape, original_type.getElementType());
662   Value out = rewriter.create<TF::BatchMatMulV2Op>(
663       op.getLoc(), matmul_type, lhs, rhs, rewriter.getBoolAttr(false),
664       rewriter.getBoolAttr(false));
665 
666   bool out_reshape_need = (reshape_shape.size() != matmul_shape.size() ||
667                            original_type.getRank() != matmul_shape.size());
668   // Always add reshape for concrete output shapes.
669   if (succeeded(VerifyShapeOfReshapeOp(reshape_shape))) {
670     out = createReshapeOp(out, reshape_shape, original_type.getElementType(),
671                           op.getLoc(), &rewriter);
672   } else if (out_reshape_need) {
673     out = createOutputReshapeOpForDynamic(out, reshape_shape, original_lhs,
674                                           original_rhs, original_dnums,
675                                           op.getLoc(), &rewriter);
676   }
677   out = createTransposeOp(out, op.getLoc(), out_transpose, &rewriter);
678 
679   rewriter.replaceOp(op, out);
680   return success();
681 }
682 
683 // Transform Einsum to other TF Ops for the supported variants.
684 struct TransformEinsumPass
685     : public TransformEinsumPassBase<TransformEinsumPass> {
686   void runOnOperation() override;
687 };
688 
runOnOperation()689 void TransformEinsumPass::runOnOperation() {
690   RewritePatternSet patterns(&getContext());
691   auto func = getOperation();
692 
693   patterns.add<ConvertTFEinsumOp>(&getContext());
694   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
695 }
696 
697 }  // namespace
698 
matchAndRewrite(TF::EinsumOp op,PatternRewriter & rewriter) const699 LogicalResult ConvertTFEinsumOp::matchAndRewrite(
700     TF::EinsumOp op, PatternRewriter& rewriter) const {
701   RankedTensorType lhs =
702       op.getOperand(0).getType().dyn_cast_or_null<RankedTensorType>();
703   RankedTensorType rhs =
704       op.getOperand(1).getType().dyn_cast_or_null<RankedTensorType>();
705   if (!lhs || !rhs) {
706     return failure();
707   }
708 
709   // TODO(b/162328998) Better support Einsum with dynamic input. Currently, one
710   // dynamic dimension is always supported. If there are two or more dynamic
711   // dimensions, it is supported if they only exist in a single component
712   // among: L0,...,Ln R0,...,Rn or C0,...,Cn.
713   if (const auto dnums_or = GetEinsumDimensionNumbers(op.equation(), lhs, rhs))
714     return rewriteToBatchMatmul(op, dnums_or.getValue(), rewriter);
715   return rewriter.notifyMatchFailure(op, "unsupported einsum lowering");
716 }
717 
CreateTransformEinsumPass()718 std::unique_ptr<OperationPass<func::FuncOp>> CreateTransformEinsumPass() {
719   return std::make_unique<TransformEinsumPass>();
720 }
721 
722 }  // namespace TF
723 }  // namespace mlir
724