1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
17
18 #include <algorithm>
19 #include <cctype>
20 #include <climits>
21 #include <cstdint>
22 #include <string>
23 #include <tuple>
24 #include <utility>
25
26 #include "absl/memory/memory.h"
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/Optional.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SetVector.h"
32 #include "llvm/ADT/StringRef.h"
33 #include "llvm/ADT/StringSwitch.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/Regex.h"
37 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project
38 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
39 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
40 #include "mlir/Dialect/Traits.h" // from @llvm-project
41 #include "mlir/IR/Attributes.h" // from @llvm-project
42 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
43 #include "mlir/IR/OpImplementation.h" // from @llvm-project
44 #include "mlir/IR/PatternMatch.h" // from @llvm-project
45 #include "mlir/Pass/Pass.h" // from @llvm-project
46 #include "mlir/Support/LLVM.h" // from @llvm-project
47 #include "mlir/Support/LogicalResult.h" // from @llvm-project
48 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
49 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
50 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
51 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
52 #include "tensorflow/core/util/matmul_bcast.h"
53
54 namespace mlir {
55 namespace TF {
56
57 namespace {
58
59 // Creates ConstOp for int32_t value.
createI32ConstOp(int32_t value,Location loc,PatternRewriter * rewriter)60 ConstOp createI32ConstOp(int32_t value, Location loc,
61 PatternRewriter* rewriter) {
62 auto int_attr = IntegerAttr::get(rewriter->getIntegerType(32), value);
63 return rewriter->create<ConstOp>(loc, int_attr);
64 }
65
66 // Creates ConstantOp for array of int32_t.
createI32ConstantOp(llvm::ArrayRef<int32_t> values,Location loc,PatternRewriter * rewriter)67 arith::ConstantOp createI32ConstantOp(llvm::ArrayRef<int32_t> values,
68 Location loc, PatternRewriter* rewriter) {
69 auto values_type = RankedTensorType::get(
70 {static_cast<int32_t>(values.size())}, rewriter->getIntegerType(32));
71 auto constant_attr = rewriter->getI32TensorAttr(values);
72 return rewriter->create<arith::ConstantOp>(loc, values_type, constant_attr);
73 }
74
75 // Creates ConstantOp for array of int64_t.
createI64ConstantOp(llvm::ArrayRef<int64_t> values,Location loc,PatternRewriter * rewriter)76 arith::ConstantOp createI64ConstantOp(llvm::ArrayRef<int64_t> values,
77 Location loc, PatternRewriter* rewriter) {
78 auto values_type = RankedTensorType::get(
79 {static_cast<int64_t>(values.size())}, rewriter->getIntegerType(64));
80 auto constant_attr = rewriter->getI64TensorAttr(values);
81 return rewriter->create<arith::ConstantOp>(loc, values_type, constant_attr);
82 }
83
createTransposeOp(Value value,Location loc,llvm::ArrayRef<int32_t> permutation,PatternRewriter * rewriter)84 TF::TransposeOp createTransposeOp(Value value, Location loc,
85 llvm::ArrayRef<int32_t> permutation,
86 PatternRewriter* rewriter) {
87 auto perm_op = createI32ConstantOp(permutation, loc, rewriter);
88 auto value_type = value.getType().cast<RankedTensorType>();
89 auto shape = value_type.getShape();
90 SmallVector<int64_t, 4> transposed_shape(shape.begin(), shape.end());
91 for (int i = 0, end = shape.size(); i < end; ++i) {
92 transposed_shape[i] = shape[permutation[i]];
93 }
94 auto transposed_type =
95 RankedTensorType::get(transposed_shape, value_type.getElementType());
96 return rewriter->create<TF::TransposeOp>(loc, transposed_type, value,
97 perm_op);
98 }
99
createReshapeOp(Value value,ArrayRef<int64_t> shape,Type element_type,Location loc,PatternRewriter * rewriter)100 TF::ReshapeOp createReshapeOp(Value value, ArrayRef<int64_t> shape,
101 Type element_type, Location loc,
102 PatternRewriter* rewriter) {
103 auto shape_tensor = createI64ConstantOp(shape, loc, rewriter);
104 Type resultType = RankedTensorType::get(shape, element_type);
105 return rewriter->create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
106 /*shape=*/shape_tensor);
107 }
108
109 // Creates ReshapeOp with runtime calcuation of required shape to support
110 // dynamic shapes. The shape is calculated by Shape and UnsortedSegmentProd op.
111 // `reshape_segids` and `num_reshape_segids` for UnsortedSegmentProd is
112 // calculated in `reshapeForBatchMatmul`.
createReshapeOpForDynamic(Value value,ArrayRef<int64_t> shape,ArrayRef<int32_t> reshape_segids,int32_t num_reshape_segids,Location loc,PatternRewriter * rewriter)113 TF::ReshapeOp createReshapeOpForDynamic(Value value, ArrayRef<int64_t> shape,
114 ArrayRef<int32_t> reshape_segids,
115 int32_t num_reshape_segids,
116 Location loc,
117 PatternRewriter* rewriter) {
118 // Build ShapeOp
119 auto input_shape =
120 rewriter->create<TF::ShapeOp>(loc, value, rewriter->getBoolAttr(true));
121
122 // Build UnsortedSegmentProdOp
123 Type segProdresultType =
124 RankedTensorType::get(num_reshape_segids, rewriter->getIntegerType(32));
125 auto segids_tensor = createI32ConstantOp(reshape_segids, loc, rewriter);
126 auto num_reshape_segids_tensor =
127 createI32ConstOp(num_reshape_segids, loc, rewriter);
128 auto segprod = rewriter->create<TF::UnsortedSegmentProdOp>(
129 loc, segProdresultType, input_shape->getResults()[0], segids_tensor,
130 num_reshape_segids_tensor);
131
132 // Build ReshapeOp with the result of UnsortedSegmentProdOp.
133 Type out_tensor_type =
134 RankedTensorType::get(shape, getElementTypeOrSelf(value.getType()));
135 return rewriter->create<TF::ReshapeOp>(loc, out_tensor_type,
136 /*tensor=*/value,
137 /*shape=*/segprod->getResults()[0]);
138 }
139
140 struct EinsumDimensionNumbers {
141 // Each field contains the list of dimensions appearing only in the specifed
142 // arguments of the einsum op with natural ordering. For example `rhs_out`
143 // contains the dimensions appearing in the RHS and the OUTPUT of the einsum
144 // but not in the LHS.
145 std::vector<int64_t> lhs;
146 std::vector<int64_t> rhs;
147 std::vector<std::tuple<int64_t, int64_t>> lhs_rhs;
148 std::vector<std::tuple<int64_t, int64_t>> lhs_out;
149 std::vector<std::tuple<int64_t, int64_t>> rhs_out;
150 std::vector<std::tuple<int64_t, int64_t, int64_t>> lhs_rhs_out;
151 };
152
createOutputReshapeOpForDynamic(Value value,ArrayRef<int64_t> shape,Value org_lhs,Value org_rhs,EinsumDimensionNumbers & dnums,Location loc,PatternRewriter * rewriter)153 TF::ReshapeOp createOutputReshapeOpForDynamic(
154 Value value, ArrayRef<int64_t> shape, Value org_lhs, Value org_rhs,
155 EinsumDimensionNumbers& dnums, Location loc, PatternRewriter* rewriter) {
156 BoolAttr true_attr = rewriter->getBoolAttr(true);
157 // Build ShapeOp
158 auto shape_lhs = rewriter->create<TF::ShapeOp>(loc, org_lhs, true_attr);
159 auto shape_rhs = rewriter->create<TF::ShapeOp>(loc, org_rhs, true_attr);
160
161 std::vector<int32_t> bl_index; // Indexes of B0,...,Bn and L0,...,Ln
162 bl_index.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size());
163 for (auto i : dnums.lhs_rhs_out) {
164 bl_index.push_back(std::get<0>(i));
165 }
166 for (auto i : dnums.lhs_out) {
167 bl_index.push_back(std::get<0>(i));
168 }
169 std::vector<int32_t> r_index; // Indexes of R0,...,Rn
170 r_index.reserve(dnums.rhs_out.size());
171 for (auto i : dnums.rhs_out) {
172 r_index.push_back(std::get<0>(i));
173 }
174
175 auto lhs_index_tensor = createI32ConstantOp(bl_index, loc, rewriter);
176 auto gather_lhs = rewriter->create<TF::GatherOp>(
177 loc,
178 RankedTensorType::get({static_cast<int>(bl_index.size())},
179 rewriter->getIntegerType(32)),
180 shape_lhs->getResults()[0], lhs_index_tensor->getResults()[0], true_attr);
181 auto rhs_index_tensor = createI32ConstantOp(r_index, loc, rewriter);
182 auto gather_rhs = rewriter->create<TF::GatherOp>(
183 loc,
184 RankedTensorType::get({static_cast<int>(r_index.size())},
185 rewriter->getIntegerType(32)),
186 shape_rhs->getResults()[0], rhs_index_tensor->getResults()[0], true_attr);
187 Value zero_value = createI32ConstOp(0, loc, rewriter);
188 auto concat_out_shape = rewriter->create<TF::ConcatOp>(
189 loc,
190 RankedTensorType::get({static_cast<int>(bl_index.size()) +
191 static_cast<int>(r_index.size())},
192 rewriter->getIntegerType(32)),
193 zero_value,
194 ArrayRef<Value>(
195 {gather_lhs->getResults()[0], gather_rhs->getResults()[0]}));
196
197 // Build ReshapeOp with the calculated output shape.
198 Type out_type =
199 RankedTensorType::get(shape, getElementTypeOrSelf(value.getType()));
200 return rewriter->create<TF::ReshapeOp>(
201 loc, out_type,
202 /*tensor=*/value,
203 /*shape=*/concat_out_shape->getResults()[0]);
204 }
205
EquationToMap(llvm::StringRef equation)206 llvm::Optional<llvm::SmallDenseMap<char, int64_t>> EquationToMap(
207 llvm::StringRef equation) {
208 llvm::SmallDenseMap<char, int64_t> map;
209 for (int64_t i = 0; i < equation.size(); ++i) {
210 if (!std::isalpha(equation[i])) {
211 // Unsupported character in the equation.
212 return llvm::None;
213 }
214 if (map.count(equation[i])) {
215 // Duplicate character in the equation.
216 return llvm::None;
217 }
218 map.try_emplace(equation[i], i);
219 }
220 return map;
221 }
222
GetAvailableLabels(llvm::StringRef lhs,llvm::StringRef rhs,int * lhs_named_label_count,int * rhs_named_label_count)223 llvm::Optional<llvm::SetVector<char>> GetAvailableLabels(
224 llvm::StringRef lhs, llvm::StringRef rhs, int* lhs_named_label_count,
225 int* rhs_named_label_count) {
226 llvm::SetVector<char> labels;
227 for (int i = 0; i < 26; ++i) {
228 labels.insert('a' + i);
229 labels.insert('A' + i);
230 }
231
232 auto is_start_of_ellipsis = [](StringRef equation, int start_index) {
233 if (equation.size() < (start_index + 3)) return false;
234
235 if (equation.substr(start_index, 3) != "...") return false;
236 return true;
237 };
238
239 int lhs_count = 0;
240 const int lhs_size = lhs.size();
241 for (int i = 0; i < lhs_size; ++i) {
242 const char label = lhs[i];
243 if (std::isalpha(label)) {
244 labels.remove(label);
245 ++lhs_count;
246 } else if (label == '.') {
247 if (!is_start_of_ellipsis(lhs, i)) return llvm::None;
248 i += 2;
249 } else {
250 // Unsupported character in the equation.
251 return llvm::None;
252 }
253 }
254 *lhs_named_label_count = lhs_count;
255
256 int rhs_count = 0;
257 const int rhs_size = rhs.size();
258 for (int i = 0; i < rhs_size; ++i) {
259 const char label = rhs[i];
260 if (std::isalpha(label)) {
261 labels.remove(label);
262 ++rhs_count;
263 } else if (label == '.') {
264 if (!is_start_of_ellipsis(rhs, i)) return llvm::None;
265 i += 2;
266 } else {
267 // Unsupported character in the equation.
268 return llvm::None;
269 }
270 }
271
272 *rhs_named_label_count = rhs_count;
273 return labels;
274 }
275
276 // Generate new unnamed labels for the expression.
277 // For example, if we have GenerateLabels(2, {'b', 'c', 'd'}) for "...xy"
278 // We will have "dcxy" for the ellipsis expression since it's rank 4,
279 // we will have dcbxy if it's rank 5.
GenerateLabels(int count,const llvm::SetVector<char> & available_labels)280 std::string GenerateLabels(int count,
281 const llvm::SetVector<char>& available_labels) {
282 std::string new_labels(count, 0);
283 for (int i = 0; i < count; ++i) {
284 new_labels[count - 1 - i] = available_labels[i];
285 }
286
287 return new_labels;
288 }
289
FlattenEllipsis(llvm::StringRef lhs,int lhs_named_label_count,llvm::StringRef rhs,int rhs_named_label_count,llvm::StringRef out,RankedTensorType lhs_ty,RankedTensorType rhs_ty,const llvm::SetVector<char> & available_labels)290 std::tuple<std::string, std::string, std::string> FlattenEllipsis(
291 llvm::StringRef lhs, int lhs_named_label_count, llvm::StringRef rhs,
292 int rhs_named_label_count, llvm::StringRef out, RankedTensorType lhs_ty,
293 RankedTensorType rhs_ty, const llvm::SetVector<char>& available_labels) {
294 std::string new_labels;
295 std::string new_lhs;
296 for (int i = 0; i < lhs.size(); ++i) {
297 const char label = lhs[i];
298 if (std::isalpha(label)) {
299 new_lhs.push_back(label);
300 } else {
301 // Encounter ellipsis: generate unnamed labels then insert to the new
302 // labels.
303 new_labels = GenerateLabels(lhs_ty.getRank() - lhs_named_label_count,
304 available_labels);
305 new_lhs.append(new_labels);
306 i += 2;
307 }
308 }
309
310 std::string new_rhs, new_rhs_labels;
311 for (int i = 0; i < rhs.size(); ++i) {
312 const char label = rhs[i];
313 if (std::isalpha(label)) {
314 new_rhs.push_back(label);
315 } else {
316 // Encounter ellipsis: generate unnamed labels then insert to the new
317 // labels.
318 new_rhs_labels = GenerateLabels(rhs_ty.getRank() - rhs_named_label_count,
319 available_labels);
320 new_rhs.append(new_rhs_labels);
321 i += 2;
322 if (new_rhs_labels.size() > new_labels.size()) {
323 new_labels = new_rhs_labels;
324 }
325 }
326 }
327
328 // Deal with the output next.
329 std::string new_output;
330 for (int i = 0; i < out.size(); ++i) {
331 const char label = out[i];
332 if (std::isalpha(label)) {
333 new_output.push_back(label);
334 } else {
335 // Encounter ellipsis: we will just insert the generated labels to the new
336 // output label.
337 new_output.append(new_labels);
338 i += 2;
339 }
340 }
341
342 return std::make_tuple(new_lhs, new_rhs, new_output);
343 }
344
GetEinsumDimensionNumbers(llvm::StringRef equation,RankedTensorType lhs_ty,RankedTensorType rhs_ty)345 llvm::Optional<EinsumDimensionNumbers> GetEinsumDimensionNumbers(
346 llvm::StringRef equation, RankedTensorType lhs_ty,
347 RankedTensorType rhs_ty) {
348 llvm::StringRef lhs_rhs;
349 llvm::StringRef out;
350 std::tie(lhs_rhs, out) = equation.split("->");
351 if (lhs_rhs.empty() || out.empty()) return llvm::None;
352
353 llvm::StringRef lhs;
354 llvm::StringRef rhs;
355 std::tie(lhs, rhs) = lhs_rhs.split(',');
356 if (lhs.empty() || rhs.empty()) return llvm::None;
357
358 // Try to flatten the "..." if possible.
359 int lhs_named_label, rhs_named_label;
360 auto avaiable_labels =
361 GetAvailableLabels(lhs, rhs, &lhs_named_label, &rhs_named_label);
362 if (!avaiable_labels.has_value()) return llvm::None;
363
364 auto flattended_labels =
365 FlattenEllipsis(lhs, lhs_named_label, rhs, rhs_named_label, out, lhs_ty,
366 rhs_ty, avaiable_labels.getValue());
367
368 lhs = std::get<0>(flattended_labels);
369 rhs = std::get<1>(flattended_labels);
370 out = std::get<2>(flattended_labels);
371
372 auto lhs_map_or = EquationToMap(lhs);
373 if (!lhs_map_or.has_value()) return llvm::None;
374 auto lhs_map = lhs_map_or.getValue();
375
376 auto rhs_map_or = EquationToMap(rhs);
377 if (!rhs_map_or.has_value()) return llvm::None;
378 auto rhs_map = rhs_map_or.getValue();
379
380 auto out_map_or = EquationToMap(out);
381 if (!out_map_or.has_value()) return llvm::None;
382 auto out_map = out_map_or.getValue();
383
384 EinsumDimensionNumbers dnums;
385 for (int64_t i = 0, e = lhs.size(); i < e; ++i) {
386 auto rhs_index = rhs_map.find(lhs[i]);
387 auto out_index = out_map.find(lhs[i]);
388 if (rhs_index == rhs_map.end() && out_index == out_map.end()) {
389 dnums.lhs.emplace_back(i);
390 } else if (rhs_index == rhs_map.end()) {
391 dnums.lhs_out.emplace_back(i, out_index->second);
392 } else if (out_index == out_map.end()) {
393 dnums.lhs_rhs.emplace_back(i, rhs_index->second);
394 } else {
395 dnums.lhs_rhs_out.emplace_back(i, rhs_index->second, out_index->second);
396 }
397 }
398 for (int64_t i = 0, e = rhs.size(); i < e; ++i) {
399 auto lhs_index = lhs_map.find(rhs[i]);
400 auto out_index = out_map.find(rhs[i]);
401 if (lhs_index == lhs_map.end()) {
402 if (out_index == out_map.end()) {
403 dnums.rhs.emplace_back(i);
404 } else {
405 dnums.rhs_out.emplace_back(i, out_index->second);
406 }
407 }
408 }
409 for (int64_t i = 0, e = out.size(); i < e; ++i) {
410 auto lhs_index = lhs_map.find(out[i]);
411 auto rhs_index = rhs_map.find(out[i]);
412 if (lhs_index == lhs_map.end() && rhs_index == rhs_map.end()) {
413 // out only isn't supported
414 return llvm::None;
415 }
416 }
417 return dnums;
418 }
419
inverseTransposeVector(llvm::ArrayRef<int64_t> input,llvm::ArrayRef<int32_t> permutation)420 std::vector<int64_t> inverseTransposeVector(
421 llvm::ArrayRef<int64_t> input, llvm::ArrayRef<int32_t> permutation) {
422 std::vector<int64_t> output(input.size());
423 for (int64_t i = 0; i < input.size(); ++i) {
424 output[permutation[i]] = input[i];
425 }
426 return output;
427 }
428
429 // Computes the transpositions required to convert dnums to one supported by
430 // tf.BatchMatmulV2 and returns the new set of dimension numbers with them.
431 // Transposed LHS shape will be B0,...,Bn,L0,...,Ln,C0,...,Cn and,
432 // transposed RHS shape will be B0,...,Bn,C0,...,Cn,R0,...,Rn respectively.
transposeForBatchMatmul(const Location & loc,EinsumDimensionNumbers & dnums,Value * lhs,Value * rhs,std::vector<int32_t> * out_inverse_transpose,PatternRewriter * rewriter)433 LogicalResult transposeForBatchMatmul(
434 const Location& loc, EinsumDimensionNumbers& dnums, Value* lhs, Value* rhs,
435 std::vector<int32_t>* out_inverse_transpose, PatternRewriter* rewriter) {
436 std::vector<int32_t> lhs_transpose;
437 std::vector<int32_t> rhs_transpose;
438 std::vector<int32_t> out_transpose;
439 lhs_transpose.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() +
440 dnums.lhs_rhs.size());
441 rhs_transpose.reserve(dnums.lhs_rhs_out.size() + dnums.rhs_out.size() +
442 dnums.lhs_rhs.size());
443 out_transpose.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() +
444 dnums.rhs_out.size());
445 // Generate transpose matrix for B0,...,Bn
446 for (int64_t i = 0, e = dnums.lhs_rhs_out.size(); i < e; ++i) {
447 lhs_transpose.push_back(std::get<0>(dnums.lhs_rhs_out[i]));
448 rhs_transpose.push_back(std::get<1>(dnums.lhs_rhs_out[i]));
449 out_transpose.push_back(std::get<2>(dnums.lhs_rhs_out[i]));
450 dnums.lhs_rhs_out[i] = std::make_tuple(i, i, i);
451 }
452
453 // Generate transpose matrix for L0,...,Ln
454 for (int64_t i = 0, e = dnums.lhs_out.size(); i < e; ++i) {
455 lhs_transpose.push_back(std::get<0>(dnums.lhs_out[i]));
456 out_transpose.push_back(std::get<1>(dnums.lhs_out[i]));
457 dnums.lhs_out[i] =
458 std::make_tuple(lhs_transpose.size() - 1, out_transpose.size() - 1);
459 }
460 // Generate transpose matrix for C0,...,Cn
461 for (int64_t i = 0, e = dnums.lhs_rhs.size(); i < e; ++i) {
462 lhs_transpose.push_back(std::get<0>(dnums.lhs_rhs[i]));
463 rhs_transpose.push_back(std::get<1>(dnums.lhs_rhs[i]));
464 dnums.lhs_rhs[i] =
465 std::make_tuple(lhs_transpose.size() - 1, rhs_transpose.size() - 1);
466 }
467 for (int64_t i = 0, e = dnums.rhs_out.size(); i < e; ++i) {
468 rhs_transpose.push_back(std::get<0>(dnums.rhs_out[i]));
469 out_transpose.push_back(std::get<1>(dnums.rhs_out[i]));
470 dnums.rhs_out[i] =
471 std::make_tuple(rhs_transpose.size() - 1, out_transpose.size() - 1);
472 }
473
474 out_inverse_transpose->resize(out_transpose.size());
475 for (int64_t i = 0, e = out_transpose.size(); i < e; ++i) {
476 out_inverse_transpose->at(out_transpose[i]) = i;
477 }
478
479 *lhs = createTransposeOp(*lhs, loc, lhs_transpose, rewriter);
480 *rhs = createTransposeOp(*rhs, loc, rhs_transpose, rewriter);
481 return success();
482 }
483
484 template <int I>
ProdShapeWithIndexInTuple(ArrayRef<int64_t> shape,const std::vector<std::tuple<int64_t,int64_t>> & index_tuples)485 inline int64_t ProdShapeWithIndexInTuple(
486 ArrayRef<int64_t> shape,
487 const std::vector<std::tuple<int64_t, int64_t>>& index_tuples) {
488 int64_t prod_shape = 1;
489 for (auto index_tuple : index_tuples) {
490 const int64_t shape_i = shape[std::get<I>(index_tuple)];
491 if (shape_i == -1) return -1;
492 prod_shape *= shape_i;
493 }
494 return prod_shape;
495 }
496
497 // Reshapes LHS and RHS to have B0,...,Bn,L,C and B0,...,Bn,C,R shape
498 // respectively while assuming that the initial shape for them is
499 // B0,...,Bn,L0,...,Ln,C0,...,Cn and B0,...,Bn,C0,...,Cn,R0,...,Rn respectively.
reshapeForBatchMatmul(const Location & loc,EinsumDimensionNumbers & dnums,Value * lhs,Value * rhs,SmallVectorImpl<int64_t> * out_shape,PatternRewriter * rewriter)500 LogicalResult reshapeForBatchMatmul(const Location& loc,
501 EinsumDimensionNumbers& dnums, Value* lhs,
502 Value* rhs,
503 SmallVectorImpl<int64_t>* out_shape,
504 PatternRewriter* rewriter) {
505 RankedTensorType lhs_type = lhs->getType().cast<RankedTensorType>();
506 RankedTensorType rhs_type = rhs->getType().cast<RankedTensorType>();
507
508 int32_t num_lhs_reshape_segids = 0;
509 int32_t num_rhs_reshape_segids = 0;
510 std::vector<int32_t> lhs_reshape_segids;
511 int lhs_rank =
512 dnums.lhs_rhs_out.size() + dnums.lhs_out.size() + dnums.lhs_rhs.size();
513 lhs_reshape_segids.resize(lhs_rank);
514 std::vector<int32_t> rhs_reshape_segids;
515 int rhs_rank =
516 dnums.lhs_rhs_out.size() + dnums.rhs_out.size() + dnums.lhs_rhs.size();
517 rhs_reshape_segids.resize(rhs_rank);
518
519 // Labels exist in all lhs, rhs and output are the batch labels B0,...,Bn.
520 std::vector<int64_t> lhs_shape;
521 std::vector<int64_t> rhs_shape;
522 lhs_shape.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() + 1);
523 rhs_shape.reserve(dnums.lhs_rhs_out.size() + 2);
524 for (auto i : dnums.lhs_rhs_out) {
525 const int64_t b1 = lhs_type.getShape()[std::get<0>(i)];
526 lhs_shape.push_back(b1);
527 const int64_t b2 = rhs_type.getShape()[std::get<1>(i)];
528 rhs_shape.push_back(b2);
529
530 lhs_reshape_segids.at(std::get<0>(i)) = num_lhs_reshape_segids++;
531 rhs_reshape_segids.at(std::get<1>(i)) = num_rhs_reshape_segids++;
532 }
533 if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape, *out_shape)) {
534 return failure();
535 }
536
537 // Calculates dimension for the label L from L0,...,Ln in lhs.
538 if (dnums.lhs_out.empty()) {
539 lhs_shape.push_back(1);
540 out_shape->push_back(1);
541 dnums.lhs_out.emplace_back(lhs_shape.size() - 1, out_shape->size() - 1);
542 ++num_lhs_reshape_segids;
543 } else if (dnums.lhs_rhs_out.empty()) {
544 // If there is not batch labels B0,...,Bn, it is safe to use L0,...,Ln as
545 // the batch labels in lhs, the rhs will be broadcasted.
546 for (auto i : dnums.lhs_out) {
547 const int64_t b = lhs_type.getShape()[std::get<0>(i)];
548 lhs_shape.push_back(b);
549 out_shape->push_back(b);
550
551 lhs_reshape_segids.at(std::get<0>(i)) = num_lhs_reshape_segids++;
552 }
553 } else {
554 const int64_t lhs_out_size =
555 ProdShapeWithIndexInTuple<0>(lhs_type.getShape(), dnums.lhs_out);
556 lhs_shape.push_back(lhs_out_size);
557 out_shape->push_back(lhs_out_size);
558
559 for (auto i : dnums.lhs_out) {
560 lhs_reshape_segids.at(std::get<0>(i)) = num_lhs_reshape_segids;
561 }
562 ++num_lhs_reshape_segids;
563 }
564
565 // Calculates dimension for the common label C from labels C0,...,Cn that
566 // exist in both lhs and rhs.
567 const int64_t lhs_size =
568 ProdShapeWithIndexInTuple<0>(lhs_type.getShape(), dnums.lhs_rhs);
569 const int64_t rhs_size =
570 ProdShapeWithIndexInTuple<1>(rhs_type.getShape(), dnums.lhs_rhs);
571 lhs_shape.push_back(lhs_size);
572 rhs_shape.push_back(rhs_size);
573
574 for (auto i : dnums.lhs_rhs) {
575 lhs_reshape_segids.at(std::get<0>(i)) = num_lhs_reshape_segids;
576 rhs_reshape_segids.at(std::get<1>(i)) = num_rhs_reshape_segids;
577 }
578 ++num_lhs_reshape_segids;
579 ++num_rhs_reshape_segids;
580
581 // Calculates dimension for the label R from R0,...,Rn in rhs.
582 const int64_t rhs_out_size =
583 ProdShapeWithIndexInTuple<0>(rhs_type.getShape(), dnums.rhs_out);
584 rhs_shape.push_back(rhs_out_size);
585 out_shape->push_back(rhs_out_size);
586
587 for (auto i : dnums.rhs_out) {
588 rhs_reshape_segids.at(std::get<0>(i)) = num_rhs_reshape_segids;
589 }
590 ++num_rhs_reshape_segids;
591
592 // If LHS requires reshapes.
593 if (lhs_rank != num_lhs_reshape_segids) {
594 if (succeeded(VerifyShapeOfReshapeOp(lhs_shape))) {
595 *lhs = createReshapeOp(*lhs, lhs_shape, lhs_type.getElementType(), loc,
596 rewriter);
597 } else {
598 // Check if lhs LHS shape can be calculated with SegmentProd. It requires
599 // to have at least 1 common index in lhs_out and lhs_rhs.
600 if (dnums.lhs_out.empty() || dnums.lhs_rhs.empty()) return failure();
601 *lhs = createReshapeOpForDynamic(*lhs, lhs_shape, lhs_reshape_segids,
602 num_lhs_reshape_segids, loc, rewriter);
603 }
604 }
605 // If RHS requires reshapes.
606 if (rhs_rank != num_rhs_reshape_segids) {
607 if (succeeded(VerifyShapeOfReshapeOp(rhs_shape))) {
608 *rhs = createReshapeOp(*rhs, rhs_shape, rhs_type.getElementType(), loc,
609 rewriter);
610 } else {
611 // Check if lhs RHS shape can be calculated with SegmentProd. It requires
612 // to have at least 1 common index in rhs_out and lhs_rhs.
613 if (dnums.rhs_out.empty() || dnums.lhs_rhs.empty()) return failure();
614 *rhs = createReshapeOpForDynamic(*rhs, rhs_shape, rhs_reshape_segids,
615 num_rhs_reshape_segids, loc, rewriter);
616 }
617 }
618
619 dnums.lhs_rhs.assign(
620 {std::make_tuple(dnums.lhs_rhs_out.size() + dnums.lhs_out.size(),
621 dnums.lhs_rhs_out.size())});
622 dnums.rhs_out.assign(
623 {std::make_tuple(dnums.lhs_rhs_out.size() + dnums.lhs_out.size(),
624 dnums.lhs_rhs_out.size() + dnums.lhs_out.size())});
625 return success();
626 }
627
rewriteToBatchMatmul(TF::EinsumOp op,EinsumDimensionNumbers dnums,PatternRewriter & rewriter)628 LogicalResult rewriteToBatchMatmul(TF::EinsumOp op,
629 EinsumDimensionNumbers dnums,
630 PatternRewriter& rewriter) {
631 if (!dnums.lhs.empty() || !dnums.rhs.empty()) return failure();
632
633 auto inputs = op.inputs();
634 if (inputs.size() != 2) return failure();
635 Value lhs = inputs.front();
636 Value rhs = inputs.back();
637 // Back original values for the later output shape calculation in
638 // `createOutputReshapeOpForDynamic`.
639 Value original_lhs = lhs;
640 Value original_rhs = rhs;
641 EinsumDimensionNumbers original_dnums = dnums;
642
643 RankedTensorType original_type =
644 op.getResult().getType().dyn_cast_or_null<RankedTensorType>();
645 if (!original_type) return failure();
646
647 std::vector<int32_t> out_transpose;
648 if (failed(transposeForBatchMatmul(op.getLoc(), dnums, &lhs, &rhs,
649 &out_transpose, &rewriter)))
650 return failure();
651
652 llvm::SmallVector<int64_t, 4> matmul_shape;
653 if (failed(reshapeForBatchMatmul(op.getLoc(), dnums, &lhs, &rhs,
654 &matmul_shape, &rewriter)))
655 return failure();
656
657 std::vector<int64_t> reshape_shape =
658 inverseTransposeVector(original_type.getShape(), out_transpose);
659
660 auto matmul_type =
661 RankedTensorType::get(matmul_shape, original_type.getElementType());
662 Value out = rewriter.create<TF::BatchMatMulV2Op>(
663 op.getLoc(), matmul_type, lhs, rhs, rewriter.getBoolAttr(false),
664 rewriter.getBoolAttr(false));
665
666 bool out_reshape_need = (reshape_shape.size() != matmul_shape.size() ||
667 original_type.getRank() != matmul_shape.size());
668 // Always add reshape for concrete output shapes.
669 if (succeeded(VerifyShapeOfReshapeOp(reshape_shape))) {
670 out = createReshapeOp(out, reshape_shape, original_type.getElementType(),
671 op.getLoc(), &rewriter);
672 } else if (out_reshape_need) {
673 out = createOutputReshapeOpForDynamic(out, reshape_shape, original_lhs,
674 original_rhs, original_dnums,
675 op.getLoc(), &rewriter);
676 }
677 out = createTransposeOp(out, op.getLoc(), out_transpose, &rewriter);
678
679 rewriter.replaceOp(op, out);
680 return success();
681 }
682
683 // Transform Einsum to other TF Ops for the supported variants.
684 struct TransformEinsumPass
685 : public TransformEinsumPassBase<TransformEinsumPass> {
686 void runOnOperation() override;
687 };
688
runOnOperation()689 void TransformEinsumPass::runOnOperation() {
690 RewritePatternSet patterns(&getContext());
691 auto func = getOperation();
692
693 patterns.add<ConvertTFEinsumOp>(&getContext());
694 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
695 }
696
697 } // namespace
698
matchAndRewrite(TF::EinsumOp op,PatternRewriter & rewriter) const699 LogicalResult ConvertTFEinsumOp::matchAndRewrite(
700 TF::EinsumOp op, PatternRewriter& rewriter) const {
701 RankedTensorType lhs =
702 op.getOperand(0).getType().dyn_cast_or_null<RankedTensorType>();
703 RankedTensorType rhs =
704 op.getOperand(1).getType().dyn_cast_or_null<RankedTensorType>();
705 if (!lhs || !rhs) {
706 return failure();
707 }
708
709 // TODO(b/162328998) Better support Einsum with dynamic input. Currently, one
710 // dynamic dimension is always supported. If there are two or more dynamic
711 // dimensions, it is supported if they only exist in a single component
712 // among: L0,...,Ln R0,...,Rn or C0,...,Cn.
713 if (const auto dnums_or = GetEinsumDimensionNumbers(op.equation(), lhs, rhs))
714 return rewriteToBatchMatmul(op, dnums_or.getValue(), rewriter);
715 return rewriter.notifyMatchFailure(op, "unsupported einsum lowering");
716 }
717
CreateTransformEinsumPass()718 std::unique_ptr<OperationPass<func::FuncOp>> CreateTransformEinsumPass() {
719 return std::make_unique<TransformEinsumPass>();
720 }
721
722 } // namespace TF
723 } // namespace mlir
724