xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Linear.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/Resize.h>
4 #include <ATen/native/xnnpack/Engine.h>
5 #include <ATen/WrapDimUtilsMulti.h>
6 #include <ATen/TensorOperators.h>
7 #include <c10/util/irange.h>
8 #include <c10/core/SymInt.h>
9 #include <c10/util/MaybeOwned.h>
10 #include <ATen/TensorSubclassLikeUtils.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/_trilinear.h>
17 #include <ATen/ops/_trilinear_native.h>
18 #include <ATen/ops/add.h>
19 #include <ATen/ops/addmm.h>
20 #include <ATen/ops/bilinear_native.h>
21 #include <ATen/ops/bmm.h>
22 #include <ATen/ops/einsum_native.h>
23 #include <ATen/ops/linear_native.h>
24 #include <ATen/ops/matmul.h>
25 #include <ATen/ops/mkldnn_linear.h>
26 #include <ATen/ops/mm.h>
27 #include <ATen/ops/mul.h>
28 #include <ATen/ops/tensordot_native.h>
29 #include <ATen/ops/zeros.h>
30 #endif
31 
32 #include <cctype>
33 #include <deque>
34 #include <string>
35 #include <utility>
36 #include <vector>
37 
38 namespace at::native {
39 
40 // Parse environment variable "TORCH_LINEAR_FLATTEN_3D"
parseLinearFlatten3d()41 static inline bool parseLinearFlatten3d() {
42   // Uninitialized value
43   static int value = -1;
44   if (value == -1) {
45     const char* env_str = std::getenv("TORCH_LINEAR_FLATTEN_3D");
46     if (env_str != nullptr && strcmp(env_str, "1") == 0) {
47       value = 1;
48     } else {
49       value = 0;
50     }
51   }
52   return bool(value);
53 }
54 
55 // `_flatten_nd_linear` flattens all but the last dimension of the input tensor
56 // before passing it to linear operation
_flatten_nd_linear(const Tensor & input,const Tensor & weight,const Tensor & bias)57 static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
58     const auto input_sizes = input.sym_sizes();
59     // can't use -1 in reshape because it errors when a dimension is 0
60     c10::SymInt flattened_dim = 1;
61     for (int64_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) {
62       flattened_dim = flattened_dim * input_sizes[i];
63     }
64     auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)});
65     const auto result = at::addmm(bias, inp_reshape, weight.t());
66     auto new_size = input_sizes.slice(0, input_sizes.size() - 1);
67     c10::SymDimVector sizes_vec(new_size.begin(), new_size.end());
68     sizes_vec.push_back(result.sym_size(1));
69     return result.view_symint(sizes_vec);
70 }
71 
72 
linear(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt)73 Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt) {
74   // _matmul_impl checks this again later, but _flatten_nd_linear does not work on scalars inputs,
75   // so let's try to catch this here already
76   const auto input_dim = input.dim();
77   const auto weight_dim = weight.dim();
78   TORCH_CHECK(input_dim != 0 && weight_dim != 0,
79               "both arguments to linear need to be at least 1D, but they are ",
80               input_dim, "D and ", weight_dim, "D");
81 
82   // See [Note: hacky wrapper removal for optional tensor]
83   auto bias = bias_opt.has_value()
84     ? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
85     : c10::MaybeOwned<Tensor>::owned(std::in_place);
86   if (input.is_mkldnn()) {
87     return at::mkldnn_linear(input, weight, *bias);
88   }
89 #if defined(C10_MOBILE)
90   if (xnnpack::use_linear(input, weight, *bias)) {
91     return xnnpack::linear(input, weight, *bias);
92   }
93 #endif
94   if (input_dim == 2 && bias->defined()) {
95     // Fused op is marginally faster.
96     return at::addmm(*bias, input, weight.t());
97   }
98   if (bias->defined() && !input.is_xla()) {
99     // Also hit the fused path for contiguous 3D input, if not using xla
100     // backend. Reshaping/flattening has some performance implications on xla.
101     if (input.is_contiguous() && input_dim == 3) {
102       return _flatten_nd_linear(input, weight, *bias);
103     } else if (input.is_contiguous() && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) {
104       return _flatten_nd_linear(input, weight, *bias);
105     } else if (parseLinearFlatten3d() && input_dim == 3) {
106       // If user forces flattening via env var
107       const Tensor input_cont = input.contiguous();
108       return _flatten_nd_linear(input_cont, weight, *bias);
109     }
110   }
111   auto output = at::matmul(input, weight.t());
112   if (bias->defined()) {
113     // for composite compliance use out-of-place version of `add`
114     if (isTensorSubclassLike(*bias) ||
115         bias->_fw_grad(/*level*/ 0).defined()) {
116       output = at::add(output, *bias);
117     } else {
118       output.add_(*bias);
119     }
120   }
121   return output;
122 }
123 
linear_out(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,Tensor & output)124 Tensor& linear_out(const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt, Tensor& output) {
125   TORCH_CHECK(!input.is_mkldnn(), "linear doesn't support out for MKLDNN tensors");
126   // See [Note: hacky wrapper removal for optional tensor]
127   auto bias = bias_opt.has_value()
128               ? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
129               : c10::MaybeOwned<Tensor>::owned(std::in_place);
130 
131   if (input.dim() == 2 && bias->defined()) {
132     // Fused op is marginally faster.
133     return at::addmm_out(output, *bias, input, weight.t());
134   }
135   output = at::matmul_out(output, input, weight.t());
136   if (bias->defined()) {
137     output.add_(*bias);
138   }
139   return output;
140 }
141 
142 // sumproduct_pair computes `(left*right).sum(sumdims)` by means of permutation and
143 // batch matrix multiplication
144 // its main purpose is to provide a pairwise reduction for einsum
sumproduct_pair(const Tensor & left_,const Tensor & right_,IntArrayRef sum_dims_,bool keepdim)145 static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArrayRef sum_dims_, bool keepdim) {
146   // assumes that tensors have been pre-unsqueezed (so that all dimensions match - after broadcasting)
147   // but makes no other assumptions on the order of dimensions
148   TORCH_CHECK(left_.dim()==right_.dim(), "number of dimensions must match");
149   if (sum_dims_.empty())
150     return at::mul(left_, right_);
151   int64_t dim = left_.dim();
152   auto sum_dims = at::dim_list_to_bitset(sum_dims_, dim);
153   // dimensions that will be part of the output (i.e. not summed over) in three vectors:
154   // dims in lro appear in left, right and output, similarly, lo: left and output, ro: right and output
155   // also the sizes are kept track of for reshaping
156   std::vector<int64_t> lro, lo, ro;
157   SymInt lro_size = 1, lo_size = 1, ro_size = 1, sum_size = 1;
158   Tensor left = left_;
159   Tensor right = right_;
160   for (const auto i : c10::irange(dim)) {
161     auto sl = left.sym_size(i)!=1;
162     auto sr = right.sym_size(i)!=1;
163     if (sum_dims[i]) { // first dimensions that will be summed over after multiplication
164       if (sl && sr) {  // dimensions nontrivially in both left and right must be of the same size
165         TORCH_CHECK(left.sym_size(i)==right.sym_size(i), "non-broadcast dimensions must match");
166         sum_size *= left.sym_size(i);
167       } else if (sl) { // if it is only in one of left and right, we can sum right away
168         left = left.sum(i, true);
169       } else if (sr) {
170         right = right.sum(i, true);
171       }
172     } else if (sl && sr) { // now deal with dimensions that will be in the output
173       // dimensions nontrivially in both left and right must be of the same size
174       TORCH_CHECK(left.sym_size(i)==right.sym_size(i), "non-broadcast dimensions must match");
175       lro.push_back(i);
176       lro_size *= left.sym_size(i);
177     } else if (sl) { // keep track of dimensions appearing only once
178       lo.push_back(i);
179       lo_size *= left.sym_size(i);
180     } else {
181       ro.push_back(i);
182       ro_size *= right.sym_size(i);
183     }
184   }
185   // we now work with the following permutations / shapes.
186   // the pipeline is permute inputs -> reshape inputs -> batch matrix mul -> reshape(view) output -> permute output
187   // output: "lro, lo, 1-for-summed-dims, ro" with original shape dimensions
188   // left:   "lro, lo, summed" permuted with lpermutation and the three flattened
189   // right:  "lro, summed, ro" permuted with rpermutation and the three flattened
190   // then the permuted output is a view of bmm(left, right)
191   // finally, opermutation reverts the permutation to the original order of dimensions
192   auto out_num_dim = lro.size() + lo.size() + sum_dims_.size() + ro.size();
193   std::vector<SymInt> out_size;
194   out_size.reserve(out_num_dim);
195   for (auto& d : lro) out_size.push_back(left.sym_size(d));
196   for (auto& d : lo) out_size.push_back(left.sym_size(d));
197   for (auto& d : sum_dims_) { out_size.emplace_back(1); (void)(d); }; // avoid warning about not using d
198   for (auto& d : ro) out_size.push_back(right.sym_size(d));
199 
200   std::vector<int64_t> lpermutation(lro);
201   lpermutation.insert(lpermutation.end(), lo.begin(), lo.end());
202   lpermutation.insert(lpermutation.end(), sum_dims_.begin(), sum_dims_.end());
203   lpermutation.insert(lpermutation.end(), ro.begin(), ro.end());
204 
205   std::vector<int64_t> rpermutation(lro);
206   rpermutation.insert(rpermutation.end(), sum_dims_.begin(), sum_dims_.end());
207   rpermutation.insert(rpermutation.end(), ro.begin(), ro.end());
208   rpermutation.insert(rpermutation.end(), lo.begin(), lo.end());
209 
210   std::vector<int64_t> opermutation(out_num_dim, -1);
211   {
212     int64_t i = 0;
213 
214     for (auto it = lro.cbegin(); it != lro.cend(); i++, it++) {
215       opermutation[*it] = i;
216     }
217     for (auto it = lo.cbegin(); it != lo.cend(); i++, it++) {
218       opermutation[*it] = i;
219     }
220     for (auto it = sum_dims_.cbegin(); it != sum_dims_.cend(); i++, it++) {
221       opermutation[*it] = i;
222     }
223     for (auto it = ro.cbegin(); it != ro.cend(); i++, it++) {
224       opermutation[*it] = i;
225     }
226   }
227 
228   // now we can execute the operations above
229   left = left.permute(lpermutation).reshape_symint({lro_size, std::move(lo_size), sum_size});
230   right = right.permute(rpermutation).reshape_symint({std::move(lro_size), std::move(sum_size), std::move(ro_size)});
231   Tensor result = at::bmm(left, right);
232   result = result.view_symint(out_size).permute(opermutation);
233 
234   // finally squeeze summed dimensions if desired
235   if (! keepdim) {
236     auto sizes = result.sizes().vec();
237     for (auto i = dim-1; i>=0; i--) {
238       if (sum_dims[i]) {
239         sizes.erase(sizes.begin() + i);
240       }
241     }
242     result = result.view(sizes);
243   }
244   return result;
245 }
246 
247 // There are roughly three parts to computing einsum:
248 // 1. Parse equation to extract the labels for each input operand and output
249 // 2. Unsqueeze missing dimensions from input operands and permute to align them
250 // 3. Compute result by multiplying input operands and summing contraction
251 //    dimensions. We do the last part by reducing to bmm.
252 // If a path is specified, we reduce in the order specified by the path, else we
253 // default to going left => right. The path is a list of indices processed the same
254 // way as opt-einsum: https://optimized-einsum.readthedocs.io/en/stable/path_finding.html#format-of-the-path
einsum(c10::string_view equation,TensorList operands,at::OptionalIntArrayRef path)255 Tensor einsum(c10::string_view equation, TensorList operands, at::OptionalIntArrayRef path) {
256   TORCH_CHECK(!operands.empty(), "einsum(): must provide at least one operand");
257   const auto num_ops = operands.size();
258 
259   if (path.has_value()) {
260     const auto path_size = num_ops == 1 ? 1 : (num_ops - 1) * 2;
261     TORCH_CHECK(
262         path->size() == path_size,
263         "einsum(): expected contraction path given in path parameter to have size ",
264         path_size,
265         " but got ",
266         path->size());
267   }
268 
269   // Labels must be in range [A-Za-z]
270   constexpr uint8_t NUM_OF_LETTERS = 'z' - 'a' + 1;
271   constexpr uint8_t TOTAL_LABELS = NUM_OF_LETTERS * 2;
272 
273   // Code used to identify ELLIPSIS ("...")
274   constexpr uint8_t ELLIPSIS = TOTAL_LABELS;
275 
276   // Convert label in [A-Za-z] to subscript in [0, TOTAL_LABELS)
277   auto label_to_subscript = [=](unsigned char label) -> uint8_t {
278     return std::isupper(label) ? label - 'A' : label - 'a' + NUM_OF_LETTERS;
279   };
280 
281 #ifndef STRIP_ERROR_MESSAGES
282   // Convert subscript in [0, TOTAL_LABELS) to label in [A-Za-z]
283   auto subscript_to_label = [=](uint8_t s) -> unsigned char {
284     return s < NUM_OF_LETTERS ? s + 'A' : s + 'a' - NUM_OF_LETTERS;
285   };
286 #endif
287 
288   // Find arrow (->) to split equation into lhs and rhs
289   const auto arrow_pos = equation.find("->");
290   const auto lhs = equation.substr(0, arrow_pos);
291 
292   // Convert labels for input operands into an index in [0, 52) and store
293   // them in op_labels for each operand along with ELLIPSIS if present.
294   std::vector<std::vector<uint8_t>> op_labels(num_ops);
295   bool ell_in_input = false;
296   std::size_t curr_op = 0;
297   for (std::size_t i = 0; i < lhs.length(); ++i) {
298     const unsigned char label = lhs[i];
299     switch (label) {
300       case ' ':
301         // Ignore spaces
302         break;
303 
304       case '.':
305         TORCH_CHECK(
306             // Only one ellipsis per operand can be given
307             !ell_in_input,
308             "einsum(): found \'.\' for operand ",
309             curr_op,
310             " for which an ellipsis was already found");
311         TORCH_CHECK(
312             // Ensure it's a valid ellipsis
313             i + 2 < lhs.length() && lhs[++i] == '.' && lhs[++i] == '.',
314             "einsum(): found \'.\' for operand ",
315             curr_op,
316             " that is not part of any ellipsis");
317         op_labels[curr_op].push_back(ELLIPSIS);
318         ell_in_input = true;
319         break;
320 
321       case ',':
322         // Move onto next operand
323         ++curr_op;
324         TORCH_CHECK(
325             curr_op < num_ops,
326             "einsum(): fewer operands were provided than specified in the equation");
327         ell_in_input = false;
328         break;
329 
330       default:
331         // Parse label
332         TORCH_CHECK(
333             std::isalpha(label),
334             "einsum(): invalid subscript given at index ",
335             i,
336             " in the equation string, subscripts must be in [a-zA-Z]");
337         op_labels[curr_op].push_back(label_to_subscript(label));
338     }
339   }
340 
341   TORCH_CHECK(
342       curr_op == num_ops - 1,
343       "einsum(): more operands were provided than specified in the equation");
344 
345   std::vector<int64_t> label_count(TOTAL_LABELS, 0);
346 
347   // The maximum number of dimensions covered by any ellipsis, needed when
348   // unsqueezing missing dimensions from operands to permute and broadcast
349   int64_t ell_num_dim = 0;
350 
351   // Compute label frequency and number of dimensions covered by ellipsis
352   // We do this after parsing labels to make it more readable and simpler
353   // to compute the number of dimensions covered by ellipsis.
354   for(const auto i : c10::irange(num_ops)) {
355     const auto& operand = operands[i];
356     const auto labels = op_labels[i];
357     const auto ndims = operand.dim();
358     int64_t nlabels = static_cast<int64_t>(labels.size());
359     bool has_ellipsis = false;
360 
361     for (const auto& label : labels) {
362       if (label == ELLIPSIS) {
363         --nlabels;
364         has_ellipsis = true;
365         ell_num_dim = std::max(ell_num_dim, ndims - nlabels);
366       } else {
367         ++label_count[label];
368       }
369     }
370 
371     TORCH_CHECK(
372         has_ellipsis ? nlabels <= ndims : nlabels == ndims,
373         "einsum(): the number of subscripts in the equation (",
374         nlabels,
375         has_ellipsis ? ") is more than the number of dimensions ("
376                      : ") does not match the number of dimensions (",
377         ndims,
378         ") for operand ",
379         i,
380         has_ellipsis ? "" : " and no ellipsis was given");
381   }
382 
383   // We want to align the dimensions of every input tensor to have
384   // shape out_dims + sum_dims. For this, we create a mapping of label
385   // to index into the permuted shape.
386   std::vector<int64_t> label_perm_index(TOTAL_LABELS, -1);
387 
388   // Current index in the permuted shape
389   int64_t perm_index = 0;
390 
391   // Start index of ellipsis dimensions in the permuted shape
392   int64_t ell_index = 0;
393   bool ell_in_output = false;
394 
395   if (arrow_pos == std::string::npos) {
396     // Implicit output is ellipsis (...) + labels seen only once
397     perm_index = ell_num_dim;
398     // ell_in_output is used to stop us from reducing ellipses dims later
399     ell_in_output = true;
400     for (const auto label : c10::irange(TOTAL_LABELS)) {
401       if (label_count[label] == 1) {
402         label_perm_index[label] = perm_index++;
403       }
404     }
405   } else {
406     // Parse explicit output
407     const auto rhs = equation.substr(arrow_pos + 2);
408     for (std::size_t i = 0; i < rhs.length(); ++i) {
409       const unsigned char label = rhs[i];
410       switch (label) {
411         case ' ':
412           // Ignore spaces
413           break;
414 
415         case '.':
416           TORCH_CHECK(
417               // There can only be one ellipsis in the output
418               !ell_in_output,
419               "einsum(): found \'.\' for output but an ellipsis (...) was already found");
420           TORCH_CHECK(
421               // Ensure ellipsis is correct
422               i + 2 < rhs.length() && rhs[++i] == '.' && rhs[++i] == '.',
423               "einsum(): found \'.\' for output that is not part of any ellipsis (...)");
424           ell_index = perm_index;
425           perm_index += ell_num_dim;
426           ell_in_output = true;
427           break;
428 
429         default:
430           TORCH_CHECK(
431               std::isalpha(label),
432               "einsum(): invalid subscript given at index ",
433               lhs.size() + 2 + i,
434               " in the equation string, subscripts must be in [a-zA-Z]");
435           const auto index = label_to_subscript(label);
436           TORCH_CHECK(
437               // Ensure label appeared at least once for some input operand and at
438               // most once for the output
439               label_count[index] > 0 && label_perm_index[index] == -1,
440               "einsum(): output subscript ",
441               label,
442               label_perm_index[index] > -1
443                   ? " appears more than once in the output"
444                   : " does not appear in the equation for any input operand");
445           label_perm_index[index] = perm_index++;
446       }
447     }
448   }
449 
450   // Save number of dimensions in output before adding contraction dims (dims to sum out)
451   const int64_t out_num_dim = perm_index;
452 
453   // If ellipsis is not part of the output, add to contraction dimensions
454   if (!ell_in_output) {
455     ell_index = perm_index;
456     perm_index += ell_num_dim;
457   }
458 
459   // Add contraction labels (labels not present in output)
460   for (const auto label : c10::irange(TOTAL_LABELS)) {
461     if (label_count[label] > 0 && label_perm_index[label] == -1) {
462       label_perm_index[label] = perm_index++;
463     }
464   }
465 
466   // Next: we check the sizes, take diagonals for repeated labels, unsqueeze
467   // missing dimensions so all operands have the same dimensions and permute
468   // the operands to align the dimensions following the indices computed above.
469   // We also count how many operands have dimension with size != 1 for each
470   // label used to identify which dimensions can be contracted.
471   std::vector<SymInt> label_size(TOTAL_LABELS, 1);
472   std::vector<SymInt> ell_sizes(ell_num_dim, 1);
473   std::vector<uint64_t> dim_counts(perm_index, 0);
474   std::deque<Tensor> ops;
475   for (const auto i : irange(num_ops)) {
476     auto op = operands[i];
477     std::vector<int64_t> permutation(perm_index, -1);
478     std::int64_t dim = 0;
479     for (const auto s : op_labels[i]) {
480       if (s == ELLIPSIS) {
481         // Iterate over each dimension covered by ellipsis
482         const auto ndim = operands[i].ndimension() - (static_cast<int64_t>(op_labels[i].size()) - 1);
483         for (auto j = ell_num_dim - ndim; j < ell_num_dim; ++j) {
484           if (op.sym_size(dim) != 1) {
485             // Update ellipsis size
486             TORCH_CHECK(
487                 ell_sizes[j] == 1 || ell_sizes[j] == op.sym_size(dim),
488                 "einsum(): dimension ",
489                 dim,
490                 " covered by ellipsis in operand ",
491                 i,
492                 "has size ",
493                 op.size(dim),
494                 " which does not broadcast with previously seen ellipsis with size ",
495                 ell_sizes[j],
496                 " for the respective dimension");
497             ell_sizes[j] = op.sym_size(dim);
498             ++dim_counts[ell_index + j];
499           }
500           permutation[ell_index + j] = dim++;
501         }
502       } else if (permutation[label_perm_index[s]] == -1) {
503         if (op.sym_size(dim) != 1) {
504           // Update subscript
505           TORCH_CHECK(
506               label_size[s] == 1 || label_size[s] == op.sym_size(dim),
507               "einsum(): subscript ",
508               subscript_to_label(s),
509               " has size ",
510               op.sym_size(dim),
511               " for operand ",
512               i,
513               " which does not broadcast with previously seen size ",
514               label_size[s]);
515           label_size[s] = op.sym_size(dim);
516           ++dim_counts[label_perm_index[s]];
517         }
518         permutation[label_perm_index[s]] = dim++;
519       } else {
520         // Repeated label, take diagonal
521         const auto prev_dim = permutation[label_perm_index[s]];
522         TORCH_CHECK(
523           op.sym_size(dim) == op.sym_size(prev_dim),
524             "einsum(): subscript ",
525             subscript_to_label(s),
526             " is repeated for operand ",
527             i,
528             " but the sizes don't match, ",
529             op.sym_size(dim),
530             " != ",
531             op.sym_size(prev_dim));
532         op = op.diagonal(0, prev_dim, dim).movedim(-1, prev_dim);
533       }
534     }
535 
536     // Add dimensions for missing labels
537     for (auto& val : permutation) {
538       if (val == -1) {
539         op = op.unsqueeze(dim);
540         val = dim++;
541       }
542     }
543     ops.emplace_back(op.permute(permutation));
544   }
545 
546   const auto contract_path = path.value_or(std::vector<int64_t>{});
547   auto it = contract_path.begin();
548 
549   // Contract
550   while (ops.size() > 1) {
551     int64_t i = 0;
552     int64_t j = 1;
553 
554     if (path.has_value()) {
555       i = *it++;
556       j = *it++;
557       if (j < i) {
558         std::swap(i, j);
559       }
560 
561       TORCH_CHECK(
562           i != j && i >= 0 && j < static_cast<int64_t>(ops.size()),
563           "einsum(): invalid contraction (",
564           i,
565           ", ",
566           j,
567           i == j ? ") cannot contract an operand with itself"
568                  : ") operand index is out of bounds");
569     }
570 
571     auto a = ops[i];
572     auto b = ops[j];
573     ops.erase(ops.begin() + j);
574     ops.erase(ops.begin() + i);
575 
576     // Collect dimensions that can be summed now
577     std::vector<int64_t> sum_dims;
578     SmallVector<int64_t, 5> a_dims_to_sum;
579     SmallVector<int64_t, 5> b_dims_to_sum;
580     for (auto dim = out_num_dim; dim < perm_index; ++dim) {
581       if (a.sym_size(dim) != 1 && b.sym_size(dim) != 1) {
582         if (--dim_counts[dim] == 1) {
583           sum_dims.push_back(dim);
584           dim_counts[dim] = 0;
585         }
586       } else if (dim_counts[dim] == 1) {
587         if (a.sym_size(dim) != 1) {
588           a_dims_to_sum.push_back(dim);
589           dim_counts[dim] = 0;
590         } else if (b.sym_size(dim) != 1) {
591           b_dims_to_sum.push_back(dim);
592           dim_counts[dim] = 0;
593         }
594       }
595     }
596 
597     // Sum multiple dims at a time to minimize the number of kernel calls to sum
598     if (!a_dims_to_sum.empty()) {
599       a = a.sum(a_dims_to_sum, true);
600     }
601     if (!b_dims_to_sum.empty()) {
602       b = b.sum(b_dims_to_sum, true);
603     }
604 
605     if (path.has_value()) {
606       ops.emplace_back(sumproduct_pair(a, b, sum_dims, true));
607     } else {
608       ops.emplace_front(sumproduct_pair(a, b, sum_dims, true));
609     }
610   }
611 
612   // Sum out contraction dims
613   if (perm_index - out_num_dim > 0) {
614     // if there were ops to contract, we would have already done so
615     // in the previous loop and all the dims to sum are now 1
616     // NB: use view instead of squeeze (or sum) for faster (mps) performance
617     if (num_ops > 1) {
618       auto sizes = ops[0].sym_sizes().vec();
619       for (auto dim = perm_index - 1; dim >= out_num_dim; --dim) {
620         sizes.erase(sizes.begin() + dim);
621       }
622       return ops[0].view_symint(sizes);
623     } else {
624       std::vector<int64_t> sum_dims(perm_index - out_num_dim);
625       std::iota(sum_dims.begin(), sum_dims.end(), out_num_dim);
626       return ops[0].sum(sum_dims);
627     }
628   }
629 
630   return ops[0];
631 }
632 
633 // _trilinear computes a trilinear einstein sum with an unrolled dimension
634 // the result is `(i1.unsqueeze(expand1)*i2.unsqueeze(expand2)*i2.unsqueeze(expand3)).sum(sumdim)`
635 // the computation is unrolled in the unroll_dim dimension
636 // its main purpose is to unify the computations in bilinear and bilinear_backward
_trilinear(const Tensor & i1_,const Tensor & i2_,const Tensor & i3_,IntArrayRef expand1_,IntArrayRef expand2_,IntArrayRef expand3_,IntArrayRef sumdim_,int64_t unroll_dim)637 Tensor _trilinear(const Tensor& i1_, const Tensor& i2_, const Tensor& i3_,
638                   IntArrayRef expand1_, IntArrayRef expand2_, IntArrayRef expand3_,
639                   IntArrayRef sumdim_, int64_t unroll_dim) {
640   int64_t total_dim = i1_.dim()+expand1_.size();
641   TORCH_CHECK((unroll_dim >= 0) && (unroll_dim < total_dim), "unroll_dim must be in [0,", total_dim-1, "]");
642   auto expand1 = at::dim_list_to_bitset(expand1_, total_dim);
643   auto expand2 = at::dim_list_to_bitset(expand2_, total_dim);
644   auto expand3 = at::dim_list_to_bitset(expand3_, total_dim);
645   auto sumdim  = at::dim_list_to_bitset(sumdim_,  total_dim);
646   Tensor i1 = i1_;
647   Tensor i2 = i2_;
648   Tensor i3 = i3_;
649   std::vector<c10::SymInt> output_size;
650   std::vector<int64_t> sum_dims_12, sum_dims_23;
651   int64_t unroll_size = -1;
652   // asserts...
653   for (const auto i : c10::irange(total_dim)) {
654     c10::SymInt s = 0;
655     if (expand1[i]) {
656       i1 = i1.unsqueeze(i);
657     } else  {
658       s = i1.sym_size(i);
659     }
660     if (expand2[i]) {
661       i2 = i2.unsqueeze(i);
662     } else  {
663       s = i2.sym_size(i);
664     }
665     if (expand3[i]) {
666       i3 = i3.unsqueeze(i);
667       if (sumdim[i] && (i != unroll_dim))
668         sum_dims_12.push_back(i);
669     } else  {
670       s = i3.sym_size(i);
671       if (sumdim[i] && (i != unroll_dim))
672         sum_dims_23.push_back(i);
673     }
674     output_size.push_back(sumdim[i] ? 1 : s);
675     if (i == unroll_dim)
676       unroll_size = s.guard_int(__FILE__, __LINE__);
677   }
678   int64_t slicemul1 = (expand1[unroll_dim] ? 0 : 1);
679   int64_t slicemul2 = (expand2[unroll_dim] ? 0 : 1);
680   int64_t slicemul3 = (expand3[unroll_dim] ? 0 : 1);
681 
682   auto output = at::zeros_symint(output_size, i1.options());
683 
684   // Three conditionals are necessary since this function is meant to work for both
685   // forward and backward, which changes the dimensions of the inputs.
686   // Note that if output has zero elems is because (at least) one of i1, i2, i3 has zero elems.
687   if (i1.sym_numel() != 0 && i2.sym_numel() != 0 && i3.sym_numel() != 0) {
688     if (! sumdim[unroll_dim]) {
689       for (const auto k : c10::irange(unroll_size)) {
690         Tensor buf = at::native::sumproduct_pair(i1.narrow(unroll_dim, k * slicemul1, 1),
691                                                  i2.narrow(unroll_dim, k * slicemul2, 1),
692                                                  sum_dims_12, true);
693         buf = at::native::sumproduct_pair(buf, i3.narrow(unroll_dim, k * slicemul3, 1), sum_dims_23, true);
694         output.narrow(unroll_dim, k, 1).add_(buf);
695       }
696     }
697     else {
698       for (const auto k : c10::irange(unroll_size)) {
699         Tensor buf = at::native::sumproduct_pair(i1.narrow(unroll_dim, k*slicemul1, 1),
700                                                  i2.narrow(unroll_dim, k*slicemul2, 1), sum_dims_12, true);
701         buf = at::native::sumproduct_pair(buf, i3.narrow(unroll_dim, k*slicemul3, 1), sum_dims_23, true);
702         output.add_(buf);
703       }
704     }
705   }
706   for (int64_t i = output.dim()-1; i >= 0; i--)
707     if (sumdim[i])
708       output.squeeze_(i);
709   return output;
710 }
711 
bilinear(const Tensor & input1,const Tensor & input2,const Tensor & weight,const std::optional<Tensor> & bias_opt)712 Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight, const std::optional<Tensor>& bias_opt) {
713   // See [Note: hacky wrapper removal for optional tensor]
714   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
715   const Tensor& bias = *bias_maybe_owned;
716   if (bias.defined()) {
717     TORCH_CHECK(
718         input1.dtype() == input2.dtype() && input1.dtype() == weight.dtype() &&
719             input1.dtype() == bias.dtype(),
720         "All tensors must have the same dtype, got input1: ",
721         input1.dtype(),
722         ", input2: ",
723         input2.dtype(),
724         ", weight: ",
725         weight.dtype(),
726         ", bias: ",
727         bias.dtype());
728   } else {
729     TORCH_CHECK(
730         input1.dtype() == input2.dtype() && input1.dtype() == weight.dtype(),
731         "All tensors must have the same dtype, got input1: ",
732         input1.dtype(),
733         ", input2: ",
734         input2.dtype(),
735         ", weight: ",
736         weight.dtype());
737   }
738 
739   TORCH_CHECK(input1.dim() == input2.dim(), "bilinear(): input dimensions do not match: got ", input1.dim(), " and ", input2.dim());
740   for (const auto i : c10::irange(input1.dim() - 1)) {
741     TORCH_CHECK(input1.sym_size(i) == input2.sym_size(i),
742               "bilinear(): input batch dimensions do not match at dim ", i, ": got ", input1.sym_size(i), " and ", input2.sym_size(i));
743   }
744   TORCH_CHECK(input1.sym_size(input1.dim() - 1) == weight.sym_size(1),
745             "bilinear(): input1 size does not match weight size: got ",
746             input1.sym_size(input1.dim() - 1), " but expected ", weight.sym_size(1));
747   TORCH_CHECK(input2.sym_size(input2.dim() - 1) == weight.sym_size(2),
748             "bilinear(): input2 size does not match weight size: got ",
749             input2.sym_size(input2.dim() - 1), " but expected ", weight.sym_size(2));
750   TORCH_CHECK(!bias.defined() || bias.sym_size(0) == weight.sym_size(0),
751             "bilinear(): bias size does not match weight size: got ",
752             bias.sym_size(0), " but expected ", weight.sym_size(0));
753 
754   std::vector<c10::SymInt> output_size;
755   auto size1 = input1.sym_sizes();
756   output_size.insert(output_size.end(), size1.begin(), size1.end() - 1);
757   output_size.push_back(weight.sym_size(0));
758   auto input1_flattened = input1.reshape_symint({-1, input1.sym_size(-1)});
759   auto input2_flattened = input2.reshape_symint({-1, input2.sym_size(-1)});
760   Tensor output = at::_trilinear(input1_flattened, weight, input2_flattened, {1,3}, {0}, {1,2}, {2,3}).reshape_symint(output_size);
761   if (bias.defined()) {
762     output = output + bias;
763   }
764   return output;
765 }
766 
767 // implements tensordot, a matrix-multiplication-like contraction, but the dimensions given
768 // in the two dimension lists
tensordot(const Tensor & input1,const Tensor & input2,IntArrayRef dims1,IntArrayRef dims2)769 Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2) {
770   TORCH_CHECK(dims1.size() == dims2.size(), "both dimension lists should have same length");
771   TORCH_CHECK(input1.scalar_type() == input2.scalar_type(), "both inputs should have same dtype");
772   SymInt csize = 1;  // total size of the contracted dimensions
773   Tensor t1 = input1;
774   Tensor t2 = input2;
775   for (const auto i : c10::irange(dims1.size())) {
776     SymInt s1 = input1.sym_size(dims1[i]);
777     SymInt s2 = input2.sym_size(dims2[i]);
778     if (s2 == 1) { // broadcasted dimensions can be summed right away
779       t1 = t1.sum(dims1[i], true, t1.scalar_type());
780     } else if (s1 == 1) {
781       t2 = t2.sum(dims2[i], true, t2.scalar_type());
782     } else {
783       TORCH_CHECK(s1 == s2, "contracted dimensions need to match, but first has size ", s1, " in dim ", dims1[i],
784                " and second has size ", s2, " in dim ", dims2[i]);
785       csize *= s1;
786     }
787   }
788 
789   auto cdims1 = at::dim_list_to_bitset(dims1, input1.dim());
790   auto cdims2 = at::dim_list_to_bitset(dims2, input2.dim());
791   std::vector<int64_t> p1, p2;  // p1, p2: input permutations
792   std::vector<SymInt> rsizes;  // rsizes: sizes of the result
793   p1.reserve(input1.dim());
794   p2.reserve(input2.dim());
795   rsizes.reserve(input1.dim() + input2.dim() - (int64_t) dims1.size());
796   SymInt size1 = 1; // number of non-contracted elements in input1
797   SymInt size2 = 1; // number of non-contracted elements in input2
798 
799   // fill the permutations and compute sizes
800   for (const auto i : c10::irange(input1.dim())) {
801     if (! cdims1[i]) {
802       p1.emplace_back(i);
803       size1 *= t1.sym_size(i);
804       rsizes.emplace_back(t1.sym_size(i));
805     }
806   }
807   for (const auto x : dims1) {
808     p1.emplace_back(x);
809   }
810   for (const auto x : dims2) {
811     p2.emplace_back(x);
812   }
813   for (const auto i : c10::irange(input2.dim())) {
814     if (! cdims2[i]) {
815       p2.emplace_back(i);
816       size2 *= t2.sym_size(i);
817       rsizes.emplace_back(t2.sym_size(i));
818     }
819   }
820   // permute and reshape for matrix multiplication
821   t1 = t1.permute(p1).reshape_symint({size1, csize});
822   t2 = t2.permute(p2).reshape_symint({csize, size2});
823   // multiply and reshape to target size
824   return at::mm(t1, t2).reshape_symint(rsizes);
825 }
826 
tensordot_out(const Tensor & input1,const Tensor & input2,IntArrayRef dims1,IntArrayRef dims2,Tensor & result)827 Tensor &tensordot_out(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2, Tensor& result) {
828   Tensor result_tmp = at::native::tensordot(input1, input2, dims1, dims2);
829   auto result_dtype = result_tmp.scalar_type();
830   auto output_tensor_dtype = result.scalar_type();
831   auto output_device = result.device();
832   auto input1_device = input1.device();
833   auto input2_device = input2.device();
834   // check if the input & output tensors are on the same device.
835   TORCH_CHECK(
836     (output_device == input1_device) && (input1_device == input2_device),
837     "tensordot: Expected the output and input tensors to be on the "
838     "same device, but got the output tensor on ", output_device,
839     ", input tensor a on ", input1_device, ", and input tensor b on ", input2_device);
840   // check if the computed result has the same dtype as the out tensor
841   // (because tensordot does not support type promotion)
842   TORCH_CHECK(
843     result_dtype == output_tensor_dtype, "tensordot",
844     ": Expected the output tensor to have dtype ", result_dtype,
845     ", but got an output tensor with dtype ", output_tensor_dtype);
846   at::native::resize_output(result, result_tmp.sizes());
847   result.copy_(result_tmp);
848   return result;
849 }
850 
851 }  // namespace at::native
852