1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/Resize.h>
4 #include <ATen/native/xnnpack/Engine.h>
5 #include <ATen/WrapDimUtilsMulti.h>
6 #include <ATen/TensorOperators.h>
7 #include <c10/util/irange.h>
8 #include <c10/core/SymInt.h>
9 #include <c10/util/MaybeOwned.h>
10 #include <ATen/TensorSubclassLikeUtils.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/_trilinear.h>
17 #include <ATen/ops/_trilinear_native.h>
18 #include <ATen/ops/add.h>
19 #include <ATen/ops/addmm.h>
20 #include <ATen/ops/bilinear_native.h>
21 #include <ATen/ops/bmm.h>
22 #include <ATen/ops/einsum_native.h>
23 #include <ATen/ops/linear_native.h>
24 #include <ATen/ops/matmul.h>
25 #include <ATen/ops/mkldnn_linear.h>
26 #include <ATen/ops/mm.h>
27 #include <ATen/ops/mul.h>
28 #include <ATen/ops/tensordot_native.h>
29 #include <ATen/ops/zeros.h>
30 #endif
31
32 #include <cctype>
33 #include <deque>
34 #include <string>
35 #include <utility>
36 #include <vector>
37
38 namespace at::native {
39
40 // Parse environment variable "TORCH_LINEAR_FLATTEN_3D"
parseLinearFlatten3d()41 static inline bool parseLinearFlatten3d() {
42 // Uninitialized value
43 static int value = -1;
44 if (value == -1) {
45 const char* env_str = std::getenv("TORCH_LINEAR_FLATTEN_3D");
46 if (env_str != nullptr && strcmp(env_str, "1") == 0) {
47 value = 1;
48 } else {
49 value = 0;
50 }
51 }
52 return bool(value);
53 }
54
55 // `_flatten_nd_linear` flattens all but the last dimension of the input tensor
56 // before passing it to linear operation
_flatten_nd_linear(const Tensor & input,const Tensor & weight,const Tensor & bias)57 static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
58 const auto input_sizes = input.sym_sizes();
59 // can't use -1 in reshape because it errors when a dimension is 0
60 c10::SymInt flattened_dim = 1;
61 for (int64_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) {
62 flattened_dim = flattened_dim * input_sizes[i];
63 }
64 auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)});
65 const auto result = at::addmm(bias, inp_reshape, weight.t());
66 auto new_size = input_sizes.slice(0, input_sizes.size() - 1);
67 c10::SymDimVector sizes_vec(new_size.begin(), new_size.end());
68 sizes_vec.push_back(result.sym_size(1));
69 return result.view_symint(sizes_vec);
70 }
71
72
linear(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt)73 Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt) {
74 // _matmul_impl checks this again later, but _flatten_nd_linear does not work on scalars inputs,
75 // so let's try to catch this here already
76 const auto input_dim = input.dim();
77 const auto weight_dim = weight.dim();
78 TORCH_CHECK(input_dim != 0 && weight_dim != 0,
79 "both arguments to linear need to be at least 1D, but they are ",
80 input_dim, "D and ", weight_dim, "D");
81
82 // See [Note: hacky wrapper removal for optional tensor]
83 auto bias = bias_opt.has_value()
84 ? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
85 : c10::MaybeOwned<Tensor>::owned(std::in_place);
86 if (input.is_mkldnn()) {
87 return at::mkldnn_linear(input, weight, *bias);
88 }
89 #if defined(C10_MOBILE)
90 if (xnnpack::use_linear(input, weight, *bias)) {
91 return xnnpack::linear(input, weight, *bias);
92 }
93 #endif
94 if (input_dim == 2 && bias->defined()) {
95 // Fused op is marginally faster.
96 return at::addmm(*bias, input, weight.t());
97 }
98 if (bias->defined() && !input.is_xla()) {
99 // Also hit the fused path for contiguous 3D input, if not using xla
100 // backend. Reshaping/flattening has some performance implications on xla.
101 if (input.is_contiguous() && input_dim == 3) {
102 return _flatten_nd_linear(input, weight, *bias);
103 } else if (input.is_contiguous() && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) {
104 return _flatten_nd_linear(input, weight, *bias);
105 } else if (parseLinearFlatten3d() && input_dim == 3) {
106 // If user forces flattening via env var
107 const Tensor input_cont = input.contiguous();
108 return _flatten_nd_linear(input_cont, weight, *bias);
109 }
110 }
111 auto output = at::matmul(input, weight.t());
112 if (bias->defined()) {
113 // for composite compliance use out-of-place version of `add`
114 if (isTensorSubclassLike(*bias) ||
115 bias->_fw_grad(/*level*/ 0).defined()) {
116 output = at::add(output, *bias);
117 } else {
118 output.add_(*bias);
119 }
120 }
121 return output;
122 }
123
linear_out(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,Tensor & output)124 Tensor& linear_out(const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt, Tensor& output) {
125 TORCH_CHECK(!input.is_mkldnn(), "linear doesn't support out for MKLDNN tensors");
126 // See [Note: hacky wrapper removal for optional tensor]
127 auto bias = bias_opt.has_value()
128 ? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
129 : c10::MaybeOwned<Tensor>::owned(std::in_place);
130
131 if (input.dim() == 2 && bias->defined()) {
132 // Fused op is marginally faster.
133 return at::addmm_out(output, *bias, input, weight.t());
134 }
135 output = at::matmul_out(output, input, weight.t());
136 if (bias->defined()) {
137 output.add_(*bias);
138 }
139 return output;
140 }
141
142 // sumproduct_pair computes `(left*right).sum(sumdims)` by means of permutation and
143 // batch matrix multiplication
144 // its main purpose is to provide a pairwise reduction for einsum
sumproduct_pair(const Tensor & left_,const Tensor & right_,IntArrayRef sum_dims_,bool keepdim)145 static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArrayRef sum_dims_, bool keepdim) {
146 // assumes that tensors have been pre-unsqueezed (so that all dimensions match - after broadcasting)
147 // but makes no other assumptions on the order of dimensions
148 TORCH_CHECK(left_.dim()==right_.dim(), "number of dimensions must match");
149 if (sum_dims_.empty())
150 return at::mul(left_, right_);
151 int64_t dim = left_.dim();
152 auto sum_dims = at::dim_list_to_bitset(sum_dims_, dim);
153 // dimensions that will be part of the output (i.e. not summed over) in three vectors:
154 // dims in lro appear in left, right and output, similarly, lo: left and output, ro: right and output
155 // also the sizes are kept track of for reshaping
156 std::vector<int64_t> lro, lo, ro;
157 SymInt lro_size = 1, lo_size = 1, ro_size = 1, sum_size = 1;
158 Tensor left = left_;
159 Tensor right = right_;
160 for (const auto i : c10::irange(dim)) {
161 auto sl = left.sym_size(i)!=1;
162 auto sr = right.sym_size(i)!=1;
163 if (sum_dims[i]) { // first dimensions that will be summed over after multiplication
164 if (sl && sr) { // dimensions nontrivially in both left and right must be of the same size
165 TORCH_CHECK(left.sym_size(i)==right.sym_size(i), "non-broadcast dimensions must match");
166 sum_size *= left.sym_size(i);
167 } else if (sl) { // if it is only in one of left and right, we can sum right away
168 left = left.sum(i, true);
169 } else if (sr) {
170 right = right.sum(i, true);
171 }
172 } else if (sl && sr) { // now deal with dimensions that will be in the output
173 // dimensions nontrivially in both left and right must be of the same size
174 TORCH_CHECK(left.sym_size(i)==right.sym_size(i), "non-broadcast dimensions must match");
175 lro.push_back(i);
176 lro_size *= left.sym_size(i);
177 } else if (sl) { // keep track of dimensions appearing only once
178 lo.push_back(i);
179 lo_size *= left.sym_size(i);
180 } else {
181 ro.push_back(i);
182 ro_size *= right.sym_size(i);
183 }
184 }
185 // we now work with the following permutations / shapes.
186 // the pipeline is permute inputs -> reshape inputs -> batch matrix mul -> reshape(view) output -> permute output
187 // output: "lro, lo, 1-for-summed-dims, ro" with original shape dimensions
188 // left: "lro, lo, summed" permuted with lpermutation and the three flattened
189 // right: "lro, summed, ro" permuted with rpermutation and the three flattened
190 // then the permuted output is a view of bmm(left, right)
191 // finally, opermutation reverts the permutation to the original order of dimensions
192 auto out_num_dim = lro.size() + lo.size() + sum_dims_.size() + ro.size();
193 std::vector<SymInt> out_size;
194 out_size.reserve(out_num_dim);
195 for (auto& d : lro) out_size.push_back(left.sym_size(d));
196 for (auto& d : lo) out_size.push_back(left.sym_size(d));
197 for (auto& d : sum_dims_) { out_size.emplace_back(1); (void)(d); }; // avoid warning about not using d
198 for (auto& d : ro) out_size.push_back(right.sym_size(d));
199
200 std::vector<int64_t> lpermutation(lro);
201 lpermutation.insert(lpermutation.end(), lo.begin(), lo.end());
202 lpermutation.insert(lpermutation.end(), sum_dims_.begin(), sum_dims_.end());
203 lpermutation.insert(lpermutation.end(), ro.begin(), ro.end());
204
205 std::vector<int64_t> rpermutation(lro);
206 rpermutation.insert(rpermutation.end(), sum_dims_.begin(), sum_dims_.end());
207 rpermutation.insert(rpermutation.end(), ro.begin(), ro.end());
208 rpermutation.insert(rpermutation.end(), lo.begin(), lo.end());
209
210 std::vector<int64_t> opermutation(out_num_dim, -1);
211 {
212 int64_t i = 0;
213
214 for (auto it = lro.cbegin(); it != lro.cend(); i++, it++) {
215 opermutation[*it] = i;
216 }
217 for (auto it = lo.cbegin(); it != lo.cend(); i++, it++) {
218 opermutation[*it] = i;
219 }
220 for (auto it = sum_dims_.cbegin(); it != sum_dims_.cend(); i++, it++) {
221 opermutation[*it] = i;
222 }
223 for (auto it = ro.cbegin(); it != ro.cend(); i++, it++) {
224 opermutation[*it] = i;
225 }
226 }
227
228 // now we can execute the operations above
229 left = left.permute(lpermutation).reshape_symint({lro_size, std::move(lo_size), sum_size});
230 right = right.permute(rpermutation).reshape_symint({std::move(lro_size), std::move(sum_size), std::move(ro_size)});
231 Tensor result = at::bmm(left, right);
232 result = result.view_symint(out_size).permute(opermutation);
233
234 // finally squeeze summed dimensions if desired
235 if (! keepdim) {
236 auto sizes = result.sizes().vec();
237 for (auto i = dim-1; i>=0; i--) {
238 if (sum_dims[i]) {
239 sizes.erase(sizes.begin() + i);
240 }
241 }
242 result = result.view(sizes);
243 }
244 return result;
245 }
246
247 // There are roughly three parts to computing einsum:
248 // 1. Parse equation to extract the labels for each input operand and output
249 // 2. Unsqueeze missing dimensions from input operands and permute to align them
250 // 3. Compute result by multiplying input operands and summing contraction
251 // dimensions. We do the last part by reducing to bmm.
252 // If a path is specified, we reduce in the order specified by the path, else we
253 // default to going left => right. The path is a list of indices processed the same
254 // way as opt-einsum: https://optimized-einsum.readthedocs.io/en/stable/path_finding.html#format-of-the-path
einsum(c10::string_view equation,TensorList operands,at::OptionalIntArrayRef path)255 Tensor einsum(c10::string_view equation, TensorList operands, at::OptionalIntArrayRef path) {
256 TORCH_CHECK(!operands.empty(), "einsum(): must provide at least one operand");
257 const auto num_ops = operands.size();
258
259 if (path.has_value()) {
260 const auto path_size = num_ops == 1 ? 1 : (num_ops - 1) * 2;
261 TORCH_CHECK(
262 path->size() == path_size,
263 "einsum(): expected contraction path given in path parameter to have size ",
264 path_size,
265 " but got ",
266 path->size());
267 }
268
269 // Labels must be in range [A-Za-z]
270 constexpr uint8_t NUM_OF_LETTERS = 'z' - 'a' + 1;
271 constexpr uint8_t TOTAL_LABELS = NUM_OF_LETTERS * 2;
272
273 // Code used to identify ELLIPSIS ("...")
274 constexpr uint8_t ELLIPSIS = TOTAL_LABELS;
275
276 // Convert label in [A-Za-z] to subscript in [0, TOTAL_LABELS)
277 auto label_to_subscript = [=](unsigned char label) -> uint8_t {
278 return std::isupper(label) ? label - 'A' : label - 'a' + NUM_OF_LETTERS;
279 };
280
281 #ifndef STRIP_ERROR_MESSAGES
282 // Convert subscript in [0, TOTAL_LABELS) to label in [A-Za-z]
283 auto subscript_to_label = [=](uint8_t s) -> unsigned char {
284 return s < NUM_OF_LETTERS ? s + 'A' : s + 'a' - NUM_OF_LETTERS;
285 };
286 #endif
287
288 // Find arrow (->) to split equation into lhs and rhs
289 const auto arrow_pos = equation.find("->");
290 const auto lhs = equation.substr(0, arrow_pos);
291
292 // Convert labels for input operands into an index in [0, 52) and store
293 // them in op_labels for each operand along with ELLIPSIS if present.
294 std::vector<std::vector<uint8_t>> op_labels(num_ops);
295 bool ell_in_input = false;
296 std::size_t curr_op = 0;
297 for (std::size_t i = 0; i < lhs.length(); ++i) {
298 const unsigned char label = lhs[i];
299 switch (label) {
300 case ' ':
301 // Ignore spaces
302 break;
303
304 case '.':
305 TORCH_CHECK(
306 // Only one ellipsis per operand can be given
307 !ell_in_input,
308 "einsum(): found \'.\' for operand ",
309 curr_op,
310 " for which an ellipsis was already found");
311 TORCH_CHECK(
312 // Ensure it's a valid ellipsis
313 i + 2 < lhs.length() && lhs[++i] == '.' && lhs[++i] == '.',
314 "einsum(): found \'.\' for operand ",
315 curr_op,
316 " that is not part of any ellipsis");
317 op_labels[curr_op].push_back(ELLIPSIS);
318 ell_in_input = true;
319 break;
320
321 case ',':
322 // Move onto next operand
323 ++curr_op;
324 TORCH_CHECK(
325 curr_op < num_ops,
326 "einsum(): fewer operands were provided than specified in the equation");
327 ell_in_input = false;
328 break;
329
330 default:
331 // Parse label
332 TORCH_CHECK(
333 std::isalpha(label),
334 "einsum(): invalid subscript given at index ",
335 i,
336 " in the equation string, subscripts must be in [a-zA-Z]");
337 op_labels[curr_op].push_back(label_to_subscript(label));
338 }
339 }
340
341 TORCH_CHECK(
342 curr_op == num_ops - 1,
343 "einsum(): more operands were provided than specified in the equation");
344
345 std::vector<int64_t> label_count(TOTAL_LABELS, 0);
346
347 // The maximum number of dimensions covered by any ellipsis, needed when
348 // unsqueezing missing dimensions from operands to permute and broadcast
349 int64_t ell_num_dim = 0;
350
351 // Compute label frequency and number of dimensions covered by ellipsis
352 // We do this after parsing labels to make it more readable and simpler
353 // to compute the number of dimensions covered by ellipsis.
354 for(const auto i : c10::irange(num_ops)) {
355 const auto& operand = operands[i];
356 const auto labels = op_labels[i];
357 const auto ndims = operand.dim();
358 int64_t nlabels = static_cast<int64_t>(labels.size());
359 bool has_ellipsis = false;
360
361 for (const auto& label : labels) {
362 if (label == ELLIPSIS) {
363 --nlabels;
364 has_ellipsis = true;
365 ell_num_dim = std::max(ell_num_dim, ndims - nlabels);
366 } else {
367 ++label_count[label];
368 }
369 }
370
371 TORCH_CHECK(
372 has_ellipsis ? nlabels <= ndims : nlabels == ndims,
373 "einsum(): the number of subscripts in the equation (",
374 nlabels,
375 has_ellipsis ? ") is more than the number of dimensions ("
376 : ") does not match the number of dimensions (",
377 ndims,
378 ") for operand ",
379 i,
380 has_ellipsis ? "" : " and no ellipsis was given");
381 }
382
383 // We want to align the dimensions of every input tensor to have
384 // shape out_dims + sum_dims. For this, we create a mapping of label
385 // to index into the permuted shape.
386 std::vector<int64_t> label_perm_index(TOTAL_LABELS, -1);
387
388 // Current index in the permuted shape
389 int64_t perm_index = 0;
390
391 // Start index of ellipsis dimensions in the permuted shape
392 int64_t ell_index = 0;
393 bool ell_in_output = false;
394
395 if (arrow_pos == std::string::npos) {
396 // Implicit output is ellipsis (...) + labels seen only once
397 perm_index = ell_num_dim;
398 // ell_in_output is used to stop us from reducing ellipses dims later
399 ell_in_output = true;
400 for (const auto label : c10::irange(TOTAL_LABELS)) {
401 if (label_count[label] == 1) {
402 label_perm_index[label] = perm_index++;
403 }
404 }
405 } else {
406 // Parse explicit output
407 const auto rhs = equation.substr(arrow_pos + 2);
408 for (std::size_t i = 0; i < rhs.length(); ++i) {
409 const unsigned char label = rhs[i];
410 switch (label) {
411 case ' ':
412 // Ignore spaces
413 break;
414
415 case '.':
416 TORCH_CHECK(
417 // There can only be one ellipsis in the output
418 !ell_in_output,
419 "einsum(): found \'.\' for output but an ellipsis (...) was already found");
420 TORCH_CHECK(
421 // Ensure ellipsis is correct
422 i + 2 < rhs.length() && rhs[++i] == '.' && rhs[++i] == '.',
423 "einsum(): found \'.\' for output that is not part of any ellipsis (...)");
424 ell_index = perm_index;
425 perm_index += ell_num_dim;
426 ell_in_output = true;
427 break;
428
429 default:
430 TORCH_CHECK(
431 std::isalpha(label),
432 "einsum(): invalid subscript given at index ",
433 lhs.size() + 2 + i,
434 " in the equation string, subscripts must be in [a-zA-Z]");
435 const auto index = label_to_subscript(label);
436 TORCH_CHECK(
437 // Ensure label appeared at least once for some input operand and at
438 // most once for the output
439 label_count[index] > 0 && label_perm_index[index] == -1,
440 "einsum(): output subscript ",
441 label,
442 label_perm_index[index] > -1
443 ? " appears more than once in the output"
444 : " does not appear in the equation for any input operand");
445 label_perm_index[index] = perm_index++;
446 }
447 }
448 }
449
450 // Save number of dimensions in output before adding contraction dims (dims to sum out)
451 const int64_t out_num_dim = perm_index;
452
453 // If ellipsis is not part of the output, add to contraction dimensions
454 if (!ell_in_output) {
455 ell_index = perm_index;
456 perm_index += ell_num_dim;
457 }
458
459 // Add contraction labels (labels not present in output)
460 for (const auto label : c10::irange(TOTAL_LABELS)) {
461 if (label_count[label] > 0 && label_perm_index[label] == -1) {
462 label_perm_index[label] = perm_index++;
463 }
464 }
465
466 // Next: we check the sizes, take diagonals for repeated labels, unsqueeze
467 // missing dimensions so all operands have the same dimensions and permute
468 // the operands to align the dimensions following the indices computed above.
469 // We also count how many operands have dimension with size != 1 for each
470 // label used to identify which dimensions can be contracted.
471 std::vector<SymInt> label_size(TOTAL_LABELS, 1);
472 std::vector<SymInt> ell_sizes(ell_num_dim, 1);
473 std::vector<uint64_t> dim_counts(perm_index, 0);
474 std::deque<Tensor> ops;
475 for (const auto i : irange(num_ops)) {
476 auto op = operands[i];
477 std::vector<int64_t> permutation(perm_index, -1);
478 std::int64_t dim = 0;
479 for (const auto s : op_labels[i]) {
480 if (s == ELLIPSIS) {
481 // Iterate over each dimension covered by ellipsis
482 const auto ndim = operands[i].ndimension() - (static_cast<int64_t>(op_labels[i].size()) - 1);
483 for (auto j = ell_num_dim - ndim; j < ell_num_dim; ++j) {
484 if (op.sym_size(dim) != 1) {
485 // Update ellipsis size
486 TORCH_CHECK(
487 ell_sizes[j] == 1 || ell_sizes[j] == op.sym_size(dim),
488 "einsum(): dimension ",
489 dim,
490 " covered by ellipsis in operand ",
491 i,
492 "has size ",
493 op.size(dim),
494 " which does not broadcast with previously seen ellipsis with size ",
495 ell_sizes[j],
496 " for the respective dimension");
497 ell_sizes[j] = op.sym_size(dim);
498 ++dim_counts[ell_index + j];
499 }
500 permutation[ell_index + j] = dim++;
501 }
502 } else if (permutation[label_perm_index[s]] == -1) {
503 if (op.sym_size(dim) != 1) {
504 // Update subscript
505 TORCH_CHECK(
506 label_size[s] == 1 || label_size[s] == op.sym_size(dim),
507 "einsum(): subscript ",
508 subscript_to_label(s),
509 " has size ",
510 op.sym_size(dim),
511 " for operand ",
512 i,
513 " which does not broadcast with previously seen size ",
514 label_size[s]);
515 label_size[s] = op.sym_size(dim);
516 ++dim_counts[label_perm_index[s]];
517 }
518 permutation[label_perm_index[s]] = dim++;
519 } else {
520 // Repeated label, take diagonal
521 const auto prev_dim = permutation[label_perm_index[s]];
522 TORCH_CHECK(
523 op.sym_size(dim) == op.sym_size(prev_dim),
524 "einsum(): subscript ",
525 subscript_to_label(s),
526 " is repeated for operand ",
527 i,
528 " but the sizes don't match, ",
529 op.sym_size(dim),
530 " != ",
531 op.sym_size(prev_dim));
532 op = op.diagonal(0, prev_dim, dim).movedim(-1, prev_dim);
533 }
534 }
535
536 // Add dimensions for missing labels
537 for (auto& val : permutation) {
538 if (val == -1) {
539 op = op.unsqueeze(dim);
540 val = dim++;
541 }
542 }
543 ops.emplace_back(op.permute(permutation));
544 }
545
546 const auto contract_path = path.value_or(std::vector<int64_t>{});
547 auto it = contract_path.begin();
548
549 // Contract
550 while (ops.size() > 1) {
551 int64_t i = 0;
552 int64_t j = 1;
553
554 if (path.has_value()) {
555 i = *it++;
556 j = *it++;
557 if (j < i) {
558 std::swap(i, j);
559 }
560
561 TORCH_CHECK(
562 i != j && i >= 0 && j < static_cast<int64_t>(ops.size()),
563 "einsum(): invalid contraction (",
564 i,
565 ", ",
566 j,
567 i == j ? ") cannot contract an operand with itself"
568 : ") operand index is out of bounds");
569 }
570
571 auto a = ops[i];
572 auto b = ops[j];
573 ops.erase(ops.begin() + j);
574 ops.erase(ops.begin() + i);
575
576 // Collect dimensions that can be summed now
577 std::vector<int64_t> sum_dims;
578 SmallVector<int64_t, 5> a_dims_to_sum;
579 SmallVector<int64_t, 5> b_dims_to_sum;
580 for (auto dim = out_num_dim; dim < perm_index; ++dim) {
581 if (a.sym_size(dim) != 1 && b.sym_size(dim) != 1) {
582 if (--dim_counts[dim] == 1) {
583 sum_dims.push_back(dim);
584 dim_counts[dim] = 0;
585 }
586 } else if (dim_counts[dim] == 1) {
587 if (a.sym_size(dim) != 1) {
588 a_dims_to_sum.push_back(dim);
589 dim_counts[dim] = 0;
590 } else if (b.sym_size(dim) != 1) {
591 b_dims_to_sum.push_back(dim);
592 dim_counts[dim] = 0;
593 }
594 }
595 }
596
597 // Sum multiple dims at a time to minimize the number of kernel calls to sum
598 if (!a_dims_to_sum.empty()) {
599 a = a.sum(a_dims_to_sum, true);
600 }
601 if (!b_dims_to_sum.empty()) {
602 b = b.sum(b_dims_to_sum, true);
603 }
604
605 if (path.has_value()) {
606 ops.emplace_back(sumproduct_pair(a, b, sum_dims, true));
607 } else {
608 ops.emplace_front(sumproduct_pair(a, b, sum_dims, true));
609 }
610 }
611
612 // Sum out contraction dims
613 if (perm_index - out_num_dim > 0) {
614 // if there were ops to contract, we would have already done so
615 // in the previous loop and all the dims to sum are now 1
616 // NB: use view instead of squeeze (or sum) for faster (mps) performance
617 if (num_ops > 1) {
618 auto sizes = ops[0].sym_sizes().vec();
619 for (auto dim = perm_index - 1; dim >= out_num_dim; --dim) {
620 sizes.erase(sizes.begin() + dim);
621 }
622 return ops[0].view_symint(sizes);
623 } else {
624 std::vector<int64_t> sum_dims(perm_index - out_num_dim);
625 std::iota(sum_dims.begin(), sum_dims.end(), out_num_dim);
626 return ops[0].sum(sum_dims);
627 }
628 }
629
630 return ops[0];
631 }
632
633 // _trilinear computes a trilinear einstein sum with an unrolled dimension
634 // the result is `(i1.unsqueeze(expand1)*i2.unsqueeze(expand2)*i2.unsqueeze(expand3)).sum(sumdim)`
635 // the computation is unrolled in the unroll_dim dimension
636 // its main purpose is to unify the computations in bilinear and bilinear_backward
_trilinear(const Tensor & i1_,const Tensor & i2_,const Tensor & i3_,IntArrayRef expand1_,IntArrayRef expand2_,IntArrayRef expand3_,IntArrayRef sumdim_,int64_t unroll_dim)637 Tensor _trilinear(const Tensor& i1_, const Tensor& i2_, const Tensor& i3_,
638 IntArrayRef expand1_, IntArrayRef expand2_, IntArrayRef expand3_,
639 IntArrayRef sumdim_, int64_t unroll_dim) {
640 int64_t total_dim = i1_.dim()+expand1_.size();
641 TORCH_CHECK((unroll_dim >= 0) && (unroll_dim < total_dim), "unroll_dim must be in [0,", total_dim-1, "]");
642 auto expand1 = at::dim_list_to_bitset(expand1_, total_dim);
643 auto expand2 = at::dim_list_to_bitset(expand2_, total_dim);
644 auto expand3 = at::dim_list_to_bitset(expand3_, total_dim);
645 auto sumdim = at::dim_list_to_bitset(sumdim_, total_dim);
646 Tensor i1 = i1_;
647 Tensor i2 = i2_;
648 Tensor i3 = i3_;
649 std::vector<c10::SymInt> output_size;
650 std::vector<int64_t> sum_dims_12, sum_dims_23;
651 int64_t unroll_size = -1;
652 // asserts...
653 for (const auto i : c10::irange(total_dim)) {
654 c10::SymInt s = 0;
655 if (expand1[i]) {
656 i1 = i1.unsqueeze(i);
657 } else {
658 s = i1.sym_size(i);
659 }
660 if (expand2[i]) {
661 i2 = i2.unsqueeze(i);
662 } else {
663 s = i2.sym_size(i);
664 }
665 if (expand3[i]) {
666 i3 = i3.unsqueeze(i);
667 if (sumdim[i] && (i != unroll_dim))
668 sum_dims_12.push_back(i);
669 } else {
670 s = i3.sym_size(i);
671 if (sumdim[i] && (i != unroll_dim))
672 sum_dims_23.push_back(i);
673 }
674 output_size.push_back(sumdim[i] ? 1 : s);
675 if (i == unroll_dim)
676 unroll_size = s.guard_int(__FILE__, __LINE__);
677 }
678 int64_t slicemul1 = (expand1[unroll_dim] ? 0 : 1);
679 int64_t slicemul2 = (expand2[unroll_dim] ? 0 : 1);
680 int64_t slicemul3 = (expand3[unroll_dim] ? 0 : 1);
681
682 auto output = at::zeros_symint(output_size, i1.options());
683
684 // Three conditionals are necessary since this function is meant to work for both
685 // forward and backward, which changes the dimensions of the inputs.
686 // Note that if output has zero elems is because (at least) one of i1, i2, i3 has zero elems.
687 if (i1.sym_numel() != 0 && i2.sym_numel() != 0 && i3.sym_numel() != 0) {
688 if (! sumdim[unroll_dim]) {
689 for (const auto k : c10::irange(unroll_size)) {
690 Tensor buf = at::native::sumproduct_pair(i1.narrow(unroll_dim, k * slicemul1, 1),
691 i2.narrow(unroll_dim, k * slicemul2, 1),
692 sum_dims_12, true);
693 buf = at::native::sumproduct_pair(buf, i3.narrow(unroll_dim, k * slicemul3, 1), sum_dims_23, true);
694 output.narrow(unroll_dim, k, 1).add_(buf);
695 }
696 }
697 else {
698 for (const auto k : c10::irange(unroll_size)) {
699 Tensor buf = at::native::sumproduct_pair(i1.narrow(unroll_dim, k*slicemul1, 1),
700 i2.narrow(unroll_dim, k*slicemul2, 1), sum_dims_12, true);
701 buf = at::native::sumproduct_pair(buf, i3.narrow(unroll_dim, k*slicemul3, 1), sum_dims_23, true);
702 output.add_(buf);
703 }
704 }
705 }
706 for (int64_t i = output.dim()-1; i >= 0; i--)
707 if (sumdim[i])
708 output.squeeze_(i);
709 return output;
710 }
711
bilinear(const Tensor & input1,const Tensor & input2,const Tensor & weight,const std::optional<Tensor> & bias_opt)712 Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight, const std::optional<Tensor>& bias_opt) {
713 // See [Note: hacky wrapper removal for optional tensor]
714 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
715 const Tensor& bias = *bias_maybe_owned;
716 if (bias.defined()) {
717 TORCH_CHECK(
718 input1.dtype() == input2.dtype() && input1.dtype() == weight.dtype() &&
719 input1.dtype() == bias.dtype(),
720 "All tensors must have the same dtype, got input1: ",
721 input1.dtype(),
722 ", input2: ",
723 input2.dtype(),
724 ", weight: ",
725 weight.dtype(),
726 ", bias: ",
727 bias.dtype());
728 } else {
729 TORCH_CHECK(
730 input1.dtype() == input2.dtype() && input1.dtype() == weight.dtype(),
731 "All tensors must have the same dtype, got input1: ",
732 input1.dtype(),
733 ", input2: ",
734 input2.dtype(),
735 ", weight: ",
736 weight.dtype());
737 }
738
739 TORCH_CHECK(input1.dim() == input2.dim(), "bilinear(): input dimensions do not match: got ", input1.dim(), " and ", input2.dim());
740 for (const auto i : c10::irange(input1.dim() - 1)) {
741 TORCH_CHECK(input1.sym_size(i) == input2.sym_size(i),
742 "bilinear(): input batch dimensions do not match at dim ", i, ": got ", input1.sym_size(i), " and ", input2.sym_size(i));
743 }
744 TORCH_CHECK(input1.sym_size(input1.dim() - 1) == weight.sym_size(1),
745 "bilinear(): input1 size does not match weight size: got ",
746 input1.sym_size(input1.dim() - 1), " but expected ", weight.sym_size(1));
747 TORCH_CHECK(input2.sym_size(input2.dim() - 1) == weight.sym_size(2),
748 "bilinear(): input2 size does not match weight size: got ",
749 input2.sym_size(input2.dim() - 1), " but expected ", weight.sym_size(2));
750 TORCH_CHECK(!bias.defined() || bias.sym_size(0) == weight.sym_size(0),
751 "bilinear(): bias size does not match weight size: got ",
752 bias.sym_size(0), " but expected ", weight.sym_size(0));
753
754 std::vector<c10::SymInt> output_size;
755 auto size1 = input1.sym_sizes();
756 output_size.insert(output_size.end(), size1.begin(), size1.end() - 1);
757 output_size.push_back(weight.sym_size(0));
758 auto input1_flattened = input1.reshape_symint({-1, input1.sym_size(-1)});
759 auto input2_flattened = input2.reshape_symint({-1, input2.sym_size(-1)});
760 Tensor output = at::_trilinear(input1_flattened, weight, input2_flattened, {1,3}, {0}, {1,2}, {2,3}).reshape_symint(output_size);
761 if (bias.defined()) {
762 output = output + bias;
763 }
764 return output;
765 }
766
767 // implements tensordot, a matrix-multiplication-like contraction, but the dimensions given
768 // in the two dimension lists
tensordot(const Tensor & input1,const Tensor & input2,IntArrayRef dims1,IntArrayRef dims2)769 Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2) {
770 TORCH_CHECK(dims1.size() == dims2.size(), "both dimension lists should have same length");
771 TORCH_CHECK(input1.scalar_type() == input2.scalar_type(), "both inputs should have same dtype");
772 SymInt csize = 1; // total size of the contracted dimensions
773 Tensor t1 = input1;
774 Tensor t2 = input2;
775 for (const auto i : c10::irange(dims1.size())) {
776 SymInt s1 = input1.sym_size(dims1[i]);
777 SymInt s2 = input2.sym_size(dims2[i]);
778 if (s2 == 1) { // broadcasted dimensions can be summed right away
779 t1 = t1.sum(dims1[i], true, t1.scalar_type());
780 } else if (s1 == 1) {
781 t2 = t2.sum(dims2[i], true, t2.scalar_type());
782 } else {
783 TORCH_CHECK(s1 == s2, "contracted dimensions need to match, but first has size ", s1, " in dim ", dims1[i],
784 " and second has size ", s2, " in dim ", dims2[i]);
785 csize *= s1;
786 }
787 }
788
789 auto cdims1 = at::dim_list_to_bitset(dims1, input1.dim());
790 auto cdims2 = at::dim_list_to_bitset(dims2, input2.dim());
791 std::vector<int64_t> p1, p2; // p1, p2: input permutations
792 std::vector<SymInt> rsizes; // rsizes: sizes of the result
793 p1.reserve(input1.dim());
794 p2.reserve(input2.dim());
795 rsizes.reserve(input1.dim() + input2.dim() - (int64_t) dims1.size());
796 SymInt size1 = 1; // number of non-contracted elements in input1
797 SymInt size2 = 1; // number of non-contracted elements in input2
798
799 // fill the permutations and compute sizes
800 for (const auto i : c10::irange(input1.dim())) {
801 if (! cdims1[i]) {
802 p1.emplace_back(i);
803 size1 *= t1.sym_size(i);
804 rsizes.emplace_back(t1.sym_size(i));
805 }
806 }
807 for (const auto x : dims1) {
808 p1.emplace_back(x);
809 }
810 for (const auto x : dims2) {
811 p2.emplace_back(x);
812 }
813 for (const auto i : c10::irange(input2.dim())) {
814 if (! cdims2[i]) {
815 p2.emplace_back(i);
816 size2 *= t2.sym_size(i);
817 rsizes.emplace_back(t2.sym_size(i));
818 }
819 }
820 // permute and reshape for matrix multiplication
821 t1 = t1.permute(p1).reshape_symint({size1, csize});
822 t2 = t2.permute(p2).reshape_symint({csize, size2});
823 // multiply and reshape to target size
824 return at::mm(t1, t2).reshape_symint(rsizes);
825 }
826
tensordot_out(const Tensor & input1,const Tensor & input2,IntArrayRef dims1,IntArrayRef dims2,Tensor & result)827 Tensor &tensordot_out(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2, Tensor& result) {
828 Tensor result_tmp = at::native::tensordot(input1, input2, dims1, dims2);
829 auto result_dtype = result_tmp.scalar_type();
830 auto output_tensor_dtype = result.scalar_type();
831 auto output_device = result.device();
832 auto input1_device = input1.device();
833 auto input2_device = input2.device();
834 // check if the input & output tensors are on the same device.
835 TORCH_CHECK(
836 (output_device == input1_device) && (input1_device == input2_device),
837 "tensordot: Expected the output and input tensors to be on the "
838 "same device, but got the output tensor on ", output_device,
839 ", input tensor a on ", input1_device, ", and input tensor b on ", input2_device);
840 // check if the computed result has the same dtype as the out tensor
841 // (because tensordot does not support type promotion)
842 TORCH_CHECK(
843 result_dtype == output_tensor_dtype, "tensordot",
844 ": Expected the output tensor to have dtype ", result_dtype,
845 ", but got an output tensor with dtype ", output_tensor_dtype);
846 at::native::resize_output(result, result_tmp.sizes());
847 result.copy_(result_tmp);
848 return result;
849 }
850
851 } // namespace at::native
852