xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/loopnest.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <string>
4 #include <unordered_map>
5 #include <unordered_set>
6 #include <vector>
7 
8 #include <torch/csrc/Export.h>
9 #include <torch/csrc/jit/tensorexpr/fwd_decls.h>
10 
11 namespace torch::jit::tensorexpr {
12 
13 class Expr;
14 class Var;
15 class Buf;
16 class Tensor;
17 class Function;
18 class Stmt;
19 class For;
20 class Block;
21 class Store;
22 class Dtype;
23 
24 class TORCH_API LoopNest {
25  public:
26   // A constructor for building a LoopNest from a list of Tensors
27   LoopNest(
28       const std::vector<Tensor>& output_tensors,
29       const std::vector<Tensor>& tensors_to_compute);
30 
31   // A convenience constructor for the case when all tensors are output tensors
32   LoopNest(const std::vector<Tensor>& output_tensors);
33 
34   // A constructor for building a LoopNest from an Stmt and a list of output
35   // buffers.
36   LoopNest(StmtPtr stmt, std::unordered_set<BufPtr> output_bufs);
37 
38   // A constructor for building a LoopNest from another loopnest. It clones the
39   // other loopnest's stmt.
40   LoopNest(const LoopNest& other);
41 
root_stmt()42   StmtPtr root_stmt() const {
43     return root_stmt_;
44   }
45 
46   std::vector<ForPtr> getLoopStmtsFor(const Tensor&) const;
47   std::vector<ForPtr> getLoopStmtsFor(const BufPtr&) const;
48   std::vector<ForPtr> getLoopStmtsFor(StmtPtr) const;
49   StmtPtr getLoopBodyFor(const Tensor&) const;
50   StmtPtr getLoopBodyFor(BufPtr) const;
51 
52   // Returns the For stmt indexed by 'indices' in the 'root' For stmt.
53   //'indices' indicates the path to the returned loop from 'root' in AST, e.g.,
54   //
55   // root: for(int i...){
56   // j_loop: for (int j...){
57   // k1_loop:  for (int k1...){
58   //            A[i, j, k1] = ....
59   //          }
60   //          B[i, j] = ...
61   // k2_loop:  for (int k2...){
62   //            A[i, j, k2] = ...
63   //          }
64   //        }
65   //      }
66   //
67   // the path from 'root' to 'j_loop' is [0]
68   // the path from 'root' to 'k1_loop' is [0, 0]
69   // the path from 'root' to 'k2_loop' is [0, 2]
70   ForPtr getLoopAt(ForPtr root, const std::vector<int>& indices) const;
71 
72   // Returns the For stmt that is immediately enclosing the given stmt.
73   static ForPtr getParentLoop(const StmtPtr& st);
74 
75   // Returns the list of For stmts corresponding to the loopnest that is
76   // enclosing the given stmt.
77   static std::vector<ForPtr> getEnclosingLoopNest(const StmtPtr& st);
78 
79   // Returns a list of all Stmts that write to the given buf.
80   std::vector<StmtPtr> getAllWritesToBuf(BufPtr) const;
81 
82   // The following methods return the For loops that contain writes to
83   // the given buf.
84   //
85   // For example, consider the following code:
86   //   for i1
87   //     for j1
88   //       a[i1,j1] =
89   //   for i2
90   //     for j2
91   //       for k2
92   //         a[i2,j2] =
93   //     for j3
94   //       a[i2,j3] =
95 
96   // Returns a list of For loops which directly contain a Stmt that writes
97   // to buf.
98   // For the above example:
99   //   getAllInnermostLoopsWritingToBuf(a) => {j1, k2, j3}
100   std::vector<ForPtr> getAllInnermostLoopsWritingToBuf(BufPtr) const;
101 
102   // Returns a list of For loopnests which contain a Stmt that writes to
103   // the given buf. Each loopnest here is a vector For loops.
104   // For the above example:
105   //   getAllLoopNestsWritingToBuf(a) => {{i1,j1}, {i2,j2,k2}, {i2,j3}}
106   std::vector<std::vector<ForPtr>> getAllLoopNestsWritingToBuf(BufPtr) const;
107 
108   StmtPtr simplify();
109 
110   // Sanitize variables and buffer names.
111   // The pass assigns predefined names for loop index variables
112   // (i,j,k,l,m,n,o,p,i1,j1,k1,...) and ensures these names are not conflicting
113   // anywhere. It also removes duplicates from other Buf nad Var names as well
114   // as replaces illegal characters in them with underscores.
115   //
116   // Note: since it's currently technically possible to use the same variable
117   // as index in two different loops, this transformation finds such cases and
118   // introduces new variables to avoid duplication.
119   static StmtPtr sanitizeNames(StmtPtr s);
120 
121   bool computeInline(const StmtPtr& s);
122   bool computeInline(const BufPtr& b);
123   void inlineIntermediateBufs(bool allow_duplicated_work);
124 
125   // Optimizes conditionals.
126   //
127   // Currently, only the following pattern of conditionals is optimized.
128   // This corresponds to the conditional format that is generated to handle
129   // `aten::cat` op.
130   //
131   //   for (int i = 0; i < 20; i++) {
132   //     A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
133   //   }
134   //
135   // Constraints that must be satisfied for this optimization:
136   //   * All conditions should be of the form "var < expr".
137   //   * All conditions should have the same variable, say v.
138   //   * The condition variable found should be the same as the inner-most
139   //     loop variable. TODO: Remove this constraint.
140   //   * If there are multiple stores that contain conditionals using the same
141   //     loop variable, only the first conditional will be optimized.
142   //     TODO: Remove this constraint.
143   bool optimizeConditionals();
144 
145   // Splits the given loop into 2 nested loops with the given factor as the
146   // inner loop bound. If the factor does not evenly divide the loop bound,
147   // then the remaining iterations are extracted into a tail loop that is
148   // added after the given loop.
149   //
150   // For example, consider the following code:
151   //   for (int i = 0; i < 100; ++i) {
152   //     A[i] =
153   //   }
154   //
155   // splitWithTail(i, 8, ...) will result in:
156   //   for (int i_outer = 0; i_outer < 12; ++i_outer) {
157   //     for (int i_inner = 0; i_inner < 8; ++i_inner) {
158   //       A[i_outer * 8 + i_inner] =
159   //     }
160   //   }
161   //   for (int i_tail = 0; i_tail < 4; ++i_tail) {
162   //     A[i_tail + 96] =
163   //   }
164   //
165   // The given loop will be transformed to the outer loop after splitting.
166   // So, the pointer to the input loop should be valid after splitting and
167   // will point to the outer loop. The `inner` and `tail` parameters will be
168   // set to point to the inner and tail loops that are generated.
169   static void splitWithTail(
170       const ForPtr& f,
171       int factor,
172       ForPtr* inner,
173       ForPtr* tail);
174   // A convenience wrapper when the caller does not need to access the
175   // split loops.
176   static void splitWithTail(const ForPtr& f, int factor);
177 
178   // Splits the given loop into 2 nested loops with the given factor as the
179   // inner loop bound. If the factor does not evenly divide the loop bound,
180   // then a conditional is inserted into the body to handle the remaining
181   // iterations appropriately.
182   //
183   // For example, consider the following code:
184   //   for (int i = 0; i < 100; ++i) {
185   //     A[i] =
186   //   }
187   //
188   // splitWithMask(i, 8, ...) will result in:
189   //   for (int i_outer = 0; i_outer < 13; ++i_outer) {
190   //     for (int i_inner = 0; i_inner < 8; ++i_inner) {
191   //       if (i_outer * 8 + i_inner < 100) {
192   //         A[i_outer * 8 + i_inner] =
193   //       }
194   //     }
195   //   }
196   //
197   // The given loop will be transformed to the outer loop after splitting.
198   // So, the pointer to the input loop should be valid after splitting and
199   // will point to the outer loop. The `inner` parameter will be set to point
200   // to the inner loop that is generated.
201   static void splitWithMask(const ForPtr& f, int factor, ForPtr* inner);
202   // A convenience wrapper when the caller does not need to access the
203   // split loops.
204   static void splitWithMask(const ForPtr& f, int factor);
205 
206   // The following methods support loop distribution.
207   // For example, consider the following code. This will be used to
208   // demonstrate the methods below.
209   //
210   // S0:  for m
211   // S1:    for i
212   // S2:      A[i] = 0
213   // S3:      for j
214   // S4:        A[i] = A[i] +
215   // S5:      B[i] = A[i]
216   // S6:      for k
217   // S7:        B[i] = B[i] +
218 
219   // This method distributes the given loop over its body by splitting
220   // after every given pivot stmt.
221   //
222   // NOTE: Pivot stmts that are not in the given loop's body will be ignored.
223   //
224   // For the above example:
225   //   distributeLoop(S1, {S3, S5})
226   // will result in:
227   // S0:  for m
228   // S1:    for i
229   // S2:      A[i] = 0
230   // S3:      for j
231   // S4:        A[i] = A[i] +
232   //   :    for i
233   // S5:      B[i] = A[i]
234   //   :    for i
235   // S6:      for k
236   // S7:        B[i] = B[i] +
237   static std::vector<ForPtr> distributeLoop(
238       const ForPtr& loop,
239       const std::unordered_set<StmtPtr>& pivots);
240 
241   // This method distributes the given loop over every stmt in its body.
242   //
243   // For the above example:
244   //   distributeLoop(S1)
245   // will result in:
246   // S0:  for m
247   // S1:    for i
248   // S2:      A[i] = 0
249   //   :    for i
250   // S3:      for j
251   // S4:        A[i] = A[i] +
252   //   :    for i
253   // S5:      B[i] = A[i]
254   //   :    for i
255   // S6:      for k
256   // S7:        B[i] = B[i] +
257   static std::vector<ForPtr> distributeLoop(const ForPtr& loop);
258   // Same as above, but also distribute parent loops.
259   // Returns the result of distributing the outermost loop.
260   //
261   // For the above example:
262   //   distributeLoopAndParents(S1) will result in:
263   // S0:  for m
264   // S1:    for i
265   // S2:      A[i] = 0
266   //   :  for m
267   //   :    for i
268   // S3:      for j
269   // S4:        A[i] = A[i] +
270   //   :  for m
271   //   :    for i
272   // S5:      B[i] = A[i]
273   //   :  for m
274   //   :    for i
275   // S6:      for k
276   // S7:        B[i] = B[i] +
277   static std::vector<ForPtr> distributeLoopAndParents(const ForPtr& loop);
278 
279   // This method distributes the given loop over its body by splitting
280   // after every For stmt in its body.
281   //
282   // For the above example:
283   //   distributeLoopOverInnerLoops(S1)
284   // will result in:
285   // S0:  for m
286   // S1:    for i
287   // S2:      A[i] = 0
288   // S3:      for j
289   // S4:        A[i] = A[i] +
290   //   :    for i
291   // S5:      B[i] = A[i]
292   // S6:      for k
293   // S7:        B[i] = B[i] +
294   static std::vector<ForPtr> distributeLoopOverInnerLoops(const ForPtr& loop);
295   // Same as above, but also distribute parent loops.
296   // Returns the result of distributing the outermost loop.
297   //
298   // For the above example:
299   //   distributeLoopAndParentsOverInnerLoops(S1)
300   // will result in:
301   // S0:  for m
302   // S1:    for i
303   // S2:      A[i] = 0
304   // S3:      for j
305   // S4:        A[i] = A[i] +
306   //   :  for m
307   //   :    for i
308   // S5:      B[i] = A[i]
309   // S6:      for k
310   // S7:        B[i] = B[i] +
311   static std::vector<ForPtr> distributeLoopAndParentsOverInnerLoops(
312       const ForPtr& loop);
313 
314   // This method performs loop fusion.
315   // For example, consider the following code.
316   //
317   // S1:  for m
318   // S2:    A[m] = 0
319   // S3:    for j
320   // S4:      A[m] = A[m] +
321   // S5:  for n
322   // S5:    B[n] = A[n]
323   // S6:    for k
324   // S7:      B[n] = B[n] +
325   //
326   // fuseLoops({S1, S5}), will return the following loop:
327   // S1:  for m
328   // S2:    A[m] = 0
329   // S3:    for j
330   // S4:      A[m] = A[m] +
331   // S5:    B[m] = A[m]
332   // S6:    for k
333   // S7:      B[m] = B[m] +
334   //
335   // This transformation is unsafe as it simply add all loops into the body of
336   // the first loop for fusion without correctness checks.
337   //
338   // Below are the two requirements to apply unsafeFuseLoops:
339   //  * All the loops have the same parent.
340   //  * There are no statements between these loops in their parent body.
341   static bool unsafeFuseLoops(const std::vector<ForPtr>& loops, ForPtr* fused);
342 
343   // Loop fusion is done only when all the conditions below are satisfied.
344   //  * All the loops have the same parent.
345   //  * There are no statements between these loops in their parent body.
346   //  * The start bounds are the same for all loops.
347   //  * The stop bounds are the same for all loops.
348   //  * Fusing the loops does not violate or add any dependencies.
349   static bool fuseLoops(const std::vector<ForPtr>& loops, ForPtr* fused);
350 
351   static void reorderAxis(const ForPtr& a, const ForPtr& b);
352 
353   // Reorder the given list of loops according to the permutation specified.
354   // Here `permutation[i]` represents the position of the loop in the input
355   // which will end up at position `i` after the reorder.
356   //
357   // For example, consider the following code:
358   //   for p
359   //     for q
360   //       for r
361   //         for s
362   //           A[p,q,r,s] =
363   //
364   // reorder({p, q, r, s}, {2, 3, 0, 1}) will return the list of loops in the
365   // following form:
366   //    for r
367   //      for s
368   //        for p
369   //          for q
370   //            A[p,q,r,s] =
371   static std::vector<ForPtr> reorder(
372       const std::vector<ForPtr>& loops,
373       const std::vector<size_t>& permutation);
374 
375   // Tile takes a 2d domain (x, y) and splits it into small rectangular blocks
376   // each with shape (x_factor, y_factor). The traversal over the domain turns
377   // into an outer iteration over the blocks and an inner traversal over all
378   // points in the block.
379   // Note that if x dim % x_factor or y dim % y_factor does not equal to 0, the
380   // loop body will generate corresponding tailing loops.
381   // The transformation is in-place and returns 'xtail'.
382   //
383   // For example, consider the following code:
384   //   for i: [0, 64)
385   //     for j: [0, 64)
386   //       for k: [0, 32)
387   //         A[i, j] = B[i, k] + C[j, k]
388   //
389   // tile(i, j, 4, 8) will transform "i" for-stmt into the following nested
390   // loop:
391   //   for i_outer: [0, 16)
392   //     for j_outer: [0, 8)
393   //       for i_inner: [0, 4)
394   //         for j_inner: [0, 8)
395   //           for k: [0, 32)
396   //             A[i_outer * 4 + i_inner, j_outer * 8 + j_inner] =
397   //             B[i_outer * 4 + i_inner, k] + C[j_outer * 8 + j_inner, k]
398   //
399   // tile(i, j, 4, 9) will transform "i" for-stmt into the following nested
400   // loop:
401   //   for i_outer: [0, 16)
402   //     for j_outer: [0, 7)
403   //       for i_inner: [0, 4)
404   //         for j_inner: [0, 9)
405   //           for k: (0, 32)
406   //             A[i_outer * 4 + i_inner, j_outer * 9 + j_inner] =
407   //             B[i_outer * 4 + i_inner, k] + C[j_outer * 9 + j_inner, k]
408   //     for j_tail: [0, 1)
409   //       for i_inner: [0, 4)
410   //         for k: (0, 32)
411   //           A[i_outer * 4 + i_inner, 7 * 9 + j_tail] =
412   //           B[i_outer * 4 + i_inner, k] + C[7 * 9 + j_tail, k]
413   ForPtr tile(const ForPtr& x, const ForPtr& y, int x_factor, int y_factor);
414 
415   // Returns true if the given loops are perfectly nested, i.e., every loop
416   // (except the innermost) should have exactly one statement in its body
417   // and that statement must be the next inner loop.
418   static bool areLoopsPerfectlyNested(const std::vector<ForPtr>& loops);
419 
420   // Returns true if the given loop has a loop-carried dependence.
421   static bool hasLoopCarriedDependence(const ForPtr& loop);
422 
423   // Unrolls all the iterations of the given loop.
424   // Requires that the loop bounds are constant.
425   static void fullUnroll(const ForPtr& f, StmtPtr* unrolled);
426   static void fullUnroll(const ForPtr& f);
427 
428   // Unrolls the given loop for the specified factor.
429   // This does not require constant bounds for the loop being unrolled.
430   static void unroll(const ForPtr& f, int factor, ForPtr* tail);
431   static void unroll(const ForPtr& f, int factor);
432 
433   static bool normalize(const ForPtr& f);
434   static bool isNormalized(const ForPtr& f);
435 
436   static bool flatten(const std::vector<ForPtr>& f, ForPtr* flattened);
437   static bool flatten(const std::vector<ForPtr>& f);
438 
439   // Compresses the given buffer based on its use in the given Stmts.
440   //
441   // NOTE: This API assumes that there are no accesses to the given buffer
442   // outside the given statement. So, this should be called with the entire
443   // kernel statement to avoid incorrect buffer compressions.
444   //
445   // For example, given the input:
446   //
447   // for (int i = 0; i < 100; ++i) {
448   //   for (int j = 0; j < 200; ++j) {
449   //     A[i,j] = sin(i*j)
450   //   }
451   //   for (int j = 0; j < 199; ++j) {
452   //     B[i,j] = A[i,j] + A[i, j+1]
453   //   }
454   // }
455   //
456   // compressBuffer(A, ...) will compress buffer A from
457   // [100, 200] to [1, 200] and modify the code as follows:
458   //
459   // for (int i = 0; i < 100; ++i) {
460   //   for (int j = 0; j < 200; ++j) {
461   //     A[0,j] = sin(i*j)
462   //   }
463   //   for (int j = 0; j < 199; ++j) {
464   //     B[i,j] = A[0,j] + A[0, j+1]
465   //   }
466   // }
467   static void compressBuffer(const BufPtr& buf, const StmtPtr& stmt);
468 
469   // Compresses all buffers in the given statement.
470   //
471   // NOTE: This API assumes that there are no accesses to buffers outside
472   // the given statement. So, this should be called with the entire
473   // kernel statement to avoid incorrect buffer compressions.
474   //
475   // TODO: Add an IR verifier check to detect invalidly compressed buffers.
476   static void compressAllBuffers(const StmtPtr& stmt);
477 
478   // Get 'num' loops from the loopnest starting at 'f'.
479   static std::vector<ForPtr> getLoopStmtsInLoopNest(
480       const ForPtr& f,
481       size_t num);
482 
483   // LoopOptions are propagated to tail.
484   static void sliceHead(
485       const ForPtr& f,
486       int factor,
487       ForPtr* head,
488       ForPtr* tail);
489   static void sliceHead(const ForPtr& f, int factor);
490   // LoopOptions are propagated to head.
491   static void sliceTail(
492       const ForPtr& f,
493       int factor,
494       ForPtr* head,
495       ForPtr* tail);
496   static void sliceTail(const ForPtr& f, int factor);
497 
498   using AccessResult = std::pair<BufPtr, StmtPtr>;
499   // Insert a cache for the consumer's usages of the buffer produced in
500   // consumer, and redirect reads and writes in the consumer to that cache.
501   // Returns a pair of the new cache buffer, and the new rewritten consumer.
502   static AccessResult cacheAccesses(
503       const BufPtr& producer,
504       const std::string& name,
505       const StmtPtr& consumer);
506 
507   // Insert a temporary computation of statement S in the scope of loop AT.
508   // S is assumed to be a Store or a Block containing a Store. Along with the
509   // computation itself, this transformation inserts Alloc/Free statements for
510   // the temporary buffer used in the computation.
511   static void computeAt(const StmtPtr& s, const ForPtr& at);
512 
513   // Rfactor a reduction axis into a normal axis.
514   //
515   // Requirements:
516   //  * S is the reduction store
517   //  * S is the only statement in the innermost loop
518   //  * There is at least two reduction arguments in S
519   //  * OUTER_REDUCTION_FOR loop corresponds to the outermost reduction variable
520   //  used in the store and all other reduction variables are index variables of
521   //  children loops of OUTER_REDUCTION_FOR
522   //  * OUTER_REDUCTION_FOR is a perfect loop nest, i.e. it has only loops
523   //  corresponding to the other reduction variables and the store, nested into
524   //  each other
525   //
526   // What it does:
527   //   * Introduce a new buffer with an extra dimension of a size equal to the
528   //   span of the loop OUTER_REDUCTION_FOR (the new buffer is returned via
529   //   RFAC_BUF_PTR)
530   //   * Insert an initialization store for the new buffer in
531   //   OUTER_REDUCTION_FOR before its nested loop
532   //   * Replace the reduction store to the original buffer with the reduction
533   //   store to the temp buffer, removing the index var of OUTER_REDUCTION_FOR
534   //   from reduction arguments
535   //   * Insert a final reduction store over the extra dimension of the new
536   //   buffer to the original buffer
537   //   * Returns TRUE if the transformation succeeded and FALSE otherwise
538   //
539   // Example:
540   // Original IR:
541   // S1: for i      # normal axis
542   // S2:   X[i] = 0
543   // S3:   for j    # reduction axis
544   // S4:     for k  # reduction axis
545   // S5:       X[i] = ReduceOp(X[i] + Y[i,j,k], reduce_axis={j,k})
546   //
547   // After RFACTOR(S5, S3)
548   // S1: for i               # normal axis
549   // S2:   X[i] = 0
550   // S3:   for j             # reduction axis for X, normal axis for X_rfac
551   //         X_rfac[i,j] = 0
552   // S4:     for k           # reduction axis
553   //           X_rfac[i,j] = ReduceOp(X_rfac[i,j] + Y[i,j,k], reduce_axis={k})
554   //         X[i] = ReduceOp(X[i] + X_rfac[i,j], reduce_axis={j})
555   static bool rfactor(const StmtPtr& s, const ForPtr& outer_reduction_for);
556   static bool rfactor(
557       const StmtPtr& s,
558       const ForPtr& outer_reduction_for,
559       BufPtr* rfac_buf_ptr);
560 
561   // Vectorize the given loop. This method requires that the given loop
562   // does not perform a reduction.
563   // It returns true if vectorization is successful and false otherwise.
564   static bool vectorize(const ForPtr&);
565 
566   // Find the inner-most loops and vectorize them. Currently, this only works
567   // for the LLVM backend, when no reductions are involved.
568   void vectorizeInnerLoops();
569 
570   void eliminateDeadStores();
571 
572   void prepareForCodegen();
573 
574   const std::unordered_set<BufPtr> getInputBufs() const;
getOutputBufs()575   const std::unordered_set<BufPtr> getOutputBufs() const {
576     return output_bufs_;
577   }
578   std::vector<BufPtr> getIntermediateBufs() const;
579 
580   // Finds which is the outer For between a and b for loops. If neither of the 2
581   // Fors is an ancestor of the other, it returns nullptr.
582   static ForPtr findOuterFor(ForPtr a, ForPtr b);
583 
584  private:
585   void initialize(
586       const std::vector<Tensor>& output_tensors,
587       const std::vector<Tensor>& tensors_to_compute);
588 
589   StmtPtr root_stmt_;
590 
591   std::unordered_set<BufPtr> output_bufs_;
592 };
593 
594 TORCH_API StmtPtr FlattenIndexes(const StmtPtr& s);
595 
596 // TODO: Revisit this once we decide on how dependencies analysis should look
597 // like. Maybe we would choose to use a different API and BufUse would be
598 // removed, or if we decide to keep it we need to properly document its API.
599 struct BufLoadOrStoreUse {
600   StmtPtr s;
601   bool isStore;
602 };
603 
604 /*
605  * Returns a map ( Buf -> uses of this Buf), uses are represented as vectors of
606  * BufUse elements, which are StmtPtr and a bool isStore flag. The order of uses
607  * in the vectors reflects the order in which the uses appear in the given
608  * statement.
609  */
610 std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> findLoadOrStoreUses(
611     const StmtPtr& s);
612 
613 // replaces all invalid characters with underscore
614 TORCH_API std::string sanitizeName(const std::string& input_name);
615 
616 } // namespace torch::jit::tensorexpr
617