1 #pragma once 2 3 #include <string> 4 #include <unordered_map> 5 #include <unordered_set> 6 #include <vector> 7 8 #include <torch/csrc/Export.h> 9 #include <torch/csrc/jit/tensorexpr/fwd_decls.h> 10 11 namespace torch::jit::tensorexpr { 12 13 class Expr; 14 class Var; 15 class Buf; 16 class Tensor; 17 class Function; 18 class Stmt; 19 class For; 20 class Block; 21 class Store; 22 class Dtype; 23 24 class TORCH_API LoopNest { 25 public: 26 // A constructor for building a LoopNest from a list of Tensors 27 LoopNest( 28 const std::vector<Tensor>& output_tensors, 29 const std::vector<Tensor>& tensors_to_compute); 30 31 // A convenience constructor for the case when all tensors are output tensors 32 LoopNest(const std::vector<Tensor>& output_tensors); 33 34 // A constructor for building a LoopNest from an Stmt and a list of output 35 // buffers. 36 LoopNest(StmtPtr stmt, std::unordered_set<BufPtr> output_bufs); 37 38 // A constructor for building a LoopNest from another loopnest. It clones the 39 // other loopnest's stmt. 40 LoopNest(const LoopNest& other); 41 root_stmt()42 StmtPtr root_stmt() const { 43 return root_stmt_; 44 } 45 46 std::vector<ForPtr> getLoopStmtsFor(const Tensor&) const; 47 std::vector<ForPtr> getLoopStmtsFor(const BufPtr&) const; 48 std::vector<ForPtr> getLoopStmtsFor(StmtPtr) const; 49 StmtPtr getLoopBodyFor(const Tensor&) const; 50 StmtPtr getLoopBodyFor(BufPtr) const; 51 52 // Returns the For stmt indexed by 'indices' in the 'root' For stmt. 53 //'indices' indicates the path to the returned loop from 'root' in AST, e.g., 54 // 55 // root: for(int i...){ 56 // j_loop: for (int j...){ 57 // k1_loop: for (int k1...){ 58 // A[i, j, k1] = .... 59 // } 60 // B[i, j] = ... 61 // k2_loop: for (int k2...){ 62 // A[i, j, k2] = ... 63 // } 64 // } 65 // } 66 // 67 // the path from 'root' to 'j_loop' is [0] 68 // the path from 'root' to 'k1_loop' is [0, 0] 69 // the path from 'root' to 'k2_loop' is [0, 2] 70 ForPtr getLoopAt(ForPtr root, const std::vector<int>& indices) const; 71 72 // Returns the For stmt that is immediately enclosing the given stmt. 73 static ForPtr getParentLoop(const StmtPtr& st); 74 75 // Returns the list of For stmts corresponding to the loopnest that is 76 // enclosing the given stmt. 77 static std::vector<ForPtr> getEnclosingLoopNest(const StmtPtr& st); 78 79 // Returns a list of all Stmts that write to the given buf. 80 std::vector<StmtPtr> getAllWritesToBuf(BufPtr) const; 81 82 // The following methods return the For loops that contain writes to 83 // the given buf. 84 // 85 // For example, consider the following code: 86 // for i1 87 // for j1 88 // a[i1,j1] = 89 // for i2 90 // for j2 91 // for k2 92 // a[i2,j2] = 93 // for j3 94 // a[i2,j3] = 95 96 // Returns a list of For loops which directly contain a Stmt that writes 97 // to buf. 98 // For the above example: 99 // getAllInnermostLoopsWritingToBuf(a) => {j1, k2, j3} 100 std::vector<ForPtr> getAllInnermostLoopsWritingToBuf(BufPtr) const; 101 102 // Returns a list of For loopnests which contain a Stmt that writes to 103 // the given buf. Each loopnest here is a vector For loops. 104 // For the above example: 105 // getAllLoopNestsWritingToBuf(a) => {{i1,j1}, {i2,j2,k2}, {i2,j3}} 106 std::vector<std::vector<ForPtr>> getAllLoopNestsWritingToBuf(BufPtr) const; 107 108 StmtPtr simplify(); 109 110 // Sanitize variables and buffer names. 111 // The pass assigns predefined names for loop index variables 112 // (i,j,k,l,m,n,o,p,i1,j1,k1,...) and ensures these names are not conflicting 113 // anywhere. It also removes duplicates from other Buf nad Var names as well 114 // as replaces illegal characters in them with underscores. 115 // 116 // Note: since it's currently technically possible to use the same variable 117 // as index in two different loops, this transformation finds such cases and 118 // introduces new variables to avoid duplication. 119 static StmtPtr sanitizeNames(StmtPtr s); 120 121 bool computeInline(const StmtPtr& s); 122 bool computeInline(const BufPtr& b); 123 void inlineIntermediateBufs(bool allow_duplicated_work); 124 125 // Optimizes conditionals. 126 // 127 // Currently, only the following pattern of conditionals is optimized. 128 // This corresponds to the conditional format that is generated to handle 129 // `aten::cat` op. 130 // 131 // for (int i = 0; i < 20; i++) { 132 // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) 133 // } 134 // 135 // Constraints that must be satisfied for this optimization: 136 // * All conditions should be of the form "var < expr". 137 // * All conditions should have the same variable, say v. 138 // * The condition variable found should be the same as the inner-most 139 // loop variable. TODO: Remove this constraint. 140 // * If there are multiple stores that contain conditionals using the same 141 // loop variable, only the first conditional will be optimized. 142 // TODO: Remove this constraint. 143 bool optimizeConditionals(); 144 145 // Splits the given loop into 2 nested loops with the given factor as the 146 // inner loop bound. If the factor does not evenly divide the loop bound, 147 // then the remaining iterations are extracted into a tail loop that is 148 // added after the given loop. 149 // 150 // For example, consider the following code: 151 // for (int i = 0; i < 100; ++i) { 152 // A[i] = 153 // } 154 // 155 // splitWithTail(i, 8, ...) will result in: 156 // for (int i_outer = 0; i_outer < 12; ++i_outer) { 157 // for (int i_inner = 0; i_inner < 8; ++i_inner) { 158 // A[i_outer * 8 + i_inner] = 159 // } 160 // } 161 // for (int i_tail = 0; i_tail < 4; ++i_tail) { 162 // A[i_tail + 96] = 163 // } 164 // 165 // The given loop will be transformed to the outer loop after splitting. 166 // So, the pointer to the input loop should be valid after splitting and 167 // will point to the outer loop. The `inner` and `tail` parameters will be 168 // set to point to the inner and tail loops that are generated. 169 static void splitWithTail( 170 const ForPtr& f, 171 int factor, 172 ForPtr* inner, 173 ForPtr* tail); 174 // A convenience wrapper when the caller does not need to access the 175 // split loops. 176 static void splitWithTail(const ForPtr& f, int factor); 177 178 // Splits the given loop into 2 nested loops with the given factor as the 179 // inner loop bound. If the factor does not evenly divide the loop bound, 180 // then a conditional is inserted into the body to handle the remaining 181 // iterations appropriately. 182 // 183 // For example, consider the following code: 184 // for (int i = 0; i < 100; ++i) { 185 // A[i] = 186 // } 187 // 188 // splitWithMask(i, 8, ...) will result in: 189 // for (int i_outer = 0; i_outer < 13; ++i_outer) { 190 // for (int i_inner = 0; i_inner < 8; ++i_inner) { 191 // if (i_outer * 8 + i_inner < 100) { 192 // A[i_outer * 8 + i_inner] = 193 // } 194 // } 195 // } 196 // 197 // The given loop will be transformed to the outer loop after splitting. 198 // So, the pointer to the input loop should be valid after splitting and 199 // will point to the outer loop. The `inner` parameter will be set to point 200 // to the inner loop that is generated. 201 static void splitWithMask(const ForPtr& f, int factor, ForPtr* inner); 202 // A convenience wrapper when the caller does not need to access the 203 // split loops. 204 static void splitWithMask(const ForPtr& f, int factor); 205 206 // The following methods support loop distribution. 207 // For example, consider the following code. This will be used to 208 // demonstrate the methods below. 209 // 210 // S0: for m 211 // S1: for i 212 // S2: A[i] = 0 213 // S3: for j 214 // S4: A[i] = A[i] + 215 // S5: B[i] = A[i] 216 // S6: for k 217 // S7: B[i] = B[i] + 218 219 // This method distributes the given loop over its body by splitting 220 // after every given pivot stmt. 221 // 222 // NOTE: Pivot stmts that are not in the given loop's body will be ignored. 223 // 224 // For the above example: 225 // distributeLoop(S1, {S3, S5}) 226 // will result in: 227 // S0: for m 228 // S1: for i 229 // S2: A[i] = 0 230 // S3: for j 231 // S4: A[i] = A[i] + 232 // : for i 233 // S5: B[i] = A[i] 234 // : for i 235 // S6: for k 236 // S7: B[i] = B[i] + 237 static std::vector<ForPtr> distributeLoop( 238 const ForPtr& loop, 239 const std::unordered_set<StmtPtr>& pivots); 240 241 // This method distributes the given loop over every stmt in its body. 242 // 243 // For the above example: 244 // distributeLoop(S1) 245 // will result in: 246 // S0: for m 247 // S1: for i 248 // S2: A[i] = 0 249 // : for i 250 // S3: for j 251 // S4: A[i] = A[i] + 252 // : for i 253 // S5: B[i] = A[i] 254 // : for i 255 // S6: for k 256 // S7: B[i] = B[i] + 257 static std::vector<ForPtr> distributeLoop(const ForPtr& loop); 258 // Same as above, but also distribute parent loops. 259 // Returns the result of distributing the outermost loop. 260 // 261 // For the above example: 262 // distributeLoopAndParents(S1) will result in: 263 // S0: for m 264 // S1: for i 265 // S2: A[i] = 0 266 // : for m 267 // : for i 268 // S3: for j 269 // S4: A[i] = A[i] + 270 // : for m 271 // : for i 272 // S5: B[i] = A[i] 273 // : for m 274 // : for i 275 // S6: for k 276 // S7: B[i] = B[i] + 277 static std::vector<ForPtr> distributeLoopAndParents(const ForPtr& loop); 278 279 // This method distributes the given loop over its body by splitting 280 // after every For stmt in its body. 281 // 282 // For the above example: 283 // distributeLoopOverInnerLoops(S1) 284 // will result in: 285 // S0: for m 286 // S1: for i 287 // S2: A[i] = 0 288 // S3: for j 289 // S4: A[i] = A[i] + 290 // : for i 291 // S5: B[i] = A[i] 292 // S6: for k 293 // S7: B[i] = B[i] + 294 static std::vector<ForPtr> distributeLoopOverInnerLoops(const ForPtr& loop); 295 // Same as above, but also distribute parent loops. 296 // Returns the result of distributing the outermost loop. 297 // 298 // For the above example: 299 // distributeLoopAndParentsOverInnerLoops(S1) 300 // will result in: 301 // S0: for m 302 // S1: for i 303 // S2: A[i] = 0 304 // S3: for j 305 // S4: A[i] = A[i] + 306 // : for m 307 // : for i 308 // S5: B[i] = A[i] 309 // S6: for k 310 // S7: B[i] = B[i] + 311 static std::vector<ForPtr> distributeLoopAndParentsOverInnerLoops( 312 const ForPtr& loop); 313 314 // This method performs loop fusion. 315 // For example, consider the following code. 316 // 317 // S1: for m 318 // S2: A[m] = 0 319 // S3: for j 320 // S4: A[m] = A[m] + 321 // S5: for n 322 // S5: B[n] = A[n] 323 // S6: for k 324 // S7: B[n] = B[n] + 325 // 326 // fuseLoops({S1, S5}), will return the following loop: 327 // S1: for m 328 // S2: A[m] = 0 329 // S3: for j 330 // S4: A[m] = A[m] + 331 // S5: B[m] = A[m] 332 // S6: for k 333 // S7: B[m] = B[m] + 334 // 335 // This transformation is unsafe as it simply add all loops into the body of 336 // the first loop for fusion without correctness checks. 337 // 338 // Below are the two requirements to apply unsafeFuseLoops: 339 // * All the loops have the same parent. 340 // * There are no statements between these loops in their parent body. 341 static bool unsafeFuseLoops(const std::vector<ForPtr>& loops, ForPtr* fused); 342 343 // Loop fusion is done only when all the conditions below are satisfied. 344 // * All the loops have the same parent. 345 // * There are no statements between these loops in their parent body. 346 // * The start bounds are the same for all loops. 347 // * The stop bounds are the same for all loops. 348 // * Fusing the loops does not violate or add any dependencies. 349 static bool fuseLoops(const std::vector<ForPtr>& loops, ForPtr* fused); 350 351 static void reorderAxis(const ForPtr& a, const ForPtr& b); 352 353 // Reorder the given list of loops according to the permutation specified. 354 // Here `permutation[i]` represents the position of the loop in the input 355 // which will end up at position `i` after the reorder. 356 // 357 // For example, consider the following code: 358 // for p 359 // for q 360 // for r 361 // for s 362 // A[p,q,r,s] = 363 // 364 // reorder({p, q, r, s}, {2, 3, 0, 1}) will return the list of loops in the 365 // following form: 366 // for r 367 // for s 368 // for p 369 // for q 370 // A[p,q,r,s] = 371 static std::vector<ForPtr> reorder( 372 const std::vector<ForPtr>& loops, 373 const std::vector<size_t>& permutation); 374 375 // Tile takes a 2d domain (x, y) and splits it into small rectangular blocks 376 // each with shape (x_factor, y_factor). The traversal over the domain turns 377 // into an outer iteration over the blocks and an inner traversal over all 378 // points in the block. 379 // Note that if x dim % x_factor or y dim % y_factor does not equal to 0, the 380 // loop body will generate corresponding tailing loops. 381 // The transformation is in-place and returns 'xtail'. 382 // 383 // For example, consider the following code: 384 // for i: [0, 64) 385 // for j: [0, 64) 386 // for k: [0, 32) 387 // A[i, j] = B[i, k] + C[j, k] 388 // 389 // tile(i, j, 4, 8) will transform "i" for-stmt into the following nested 390 // loop: 391 // for i_outer: [0, 16) 392 // for j_outer: [0, 8) 393 // for i_inner: [0, 4) 394 // for j_inner: [0, 8) 395 // for k: [0, 32) 396 // A[i_outer * 4 + i_inner, j_outer * 8 + j_inner] = 397 // B[i_outer * 4 + i_inner, k] + C[j_outer * 8 + j_inner, k] 398 // 399 // tile(i, j, 4, 9) will transform "i" for-stmt into the following nested 400 // loop: 401 // for i_outer: [0, 16) 402 // for j_outer: [0, 7) 403 // for i_inner: [0, 4) 404 // for j_inner: [0, 9) 405 // for k: (0, 32) 406 // A[i_outer * 4 + i_inner, j_outer * 9 + j_inner] = 407 // B[i_outer * 4 + i_inner, k] + C[j_outer * 9 + j_inner, k] 408 // for j_tail: [0, 1) 409 // for i_inner: [0, 4) 410 // for k: (0, 32) 411 // A[i_outer * 4 + i_inner, 7 * 9 + j_tail] = 412 // B[i_outer * 4 + i_inner, k] + C[7 * 9 + j_tail, k] 413 ForPtr tile(const ForPtr& x, const ForPtr& y, int x_factor, int y_factor); 414 415 // Returns true if the given loops are perfectly nested, i.e., every loop 416 // (except the innermost) should have exactly one statement in its body 417 // and that statement must be the next inner loop. 418 static bool areLoopsPerfectlyNested(const std::vector<ForPtr>& loops); 419 420 // Returns true if the given loop has a loop-carried dependence. 421 static bool hasLoopCarriedDependence(const ForPtr& loop); 422 423 // Unrolls all the iterations of the given loop. 424 // Requires that the loop bounds are constant. 425 static void fullUnroll(const ForPtr& f, StmtPtr* unrolled); 426 static void fullUnroll(const ForPtr& f); 427 428 // Unrolls the given loop for the specified factor. 429 // This does not require constant bounds for the loop being unrolled. 430 static void unroll(const ForPtr& f, int factor, ForPtr* tail); 431 static void unroll(const ForPtr& f, int factor); 432 433 static bool normalize(const ForPtr& f); 434 static bool isNormalized(const ForPtr& f); 435 436 static bool flatten(const std::vector<ForPtr>& f, ForPtr* flattened); 437 static bool flatten(const std::vector<ForPtr>& f); 438 439 // Compresses the given buffer based on its use in the given Stmts. 440 // 441 // NOTE: This API assumes that there are no accesses to the given buffer 442 // outside the given statement. So, this should be called with the entire 443 // kernel statement to avoid incorrect buffer compressions. 444 // 445 // For example, given the input: 446 // 447 // for (int i = 0; i < 100; ++i) { 448 // for (int j = 0; j < 200; ++j) { 449 // A[i,j] = sin(i*j) 450 // } 451 // for (int j = 0; j < 199; ++j) { 452 // B[i,j] = A[i,j] + A[i, j+1] 453 // } 454 // } 455 // 456 // compressBuffer(A, ...) will compress buffer A from 457 // [100, 200] to [1, 200] and modify the code as follows: 458 // 459 // for (int i = 0; i < 100; ++i) { 460 // for (int j = 0; j < 200; ++j) { 461 // A[0,j] = sin(i*j) 462 // } 463 // for (int j = 0; j < 199; ++j) { 464 // B[i,j] = A[0,j] + A[0, j+1] 465 // } 466 // } 467 static void compressBuffer(const BufPtr& buf, const StmtPtr& stmt); 468 469 // Compresses all buffers in the given statement. 470 // 471 // NOTE: This API assumes that there are no accesses to buffers outside 472 // the given statement. So, this should be called with the entire 473 // kernel statement to avoid incorrect buffer compressions. 474 // 475 // TODO: Add an IR verifier check to detect invalidly compressed buffers. 476 static void compressAllBuffers(const StmtPtr& stmt); 477 478 // Get 'num' loops from the loopnest starting at 'f'. 479 static std::vector<ForPtr> getLoopStmtsInLoopNest( 480 const ForPtr& f, 481 size_t num); 482 483 // LoopOptions are propagated to tail. 484 static void sliceHead( 485 const ForPtr& f, 486 int factor, 487 ForPtr* head, 488 ForPtr* tail); 489 static void sliceHead(const ForPtr& f, int factor); 490 // LoopOptions are propagated to head. 491 static void sliceTail( 492 const ForPtr& f, 493 int factor, 494 ForPtr* head, 495 ForPtr* tail); 496 static void sliceTail(const ForPtr& f, int factor); 497 498 using AccessResult = std::pair<BufPtr, StmtPtr>; 499 // Insert a cache for the consumer's usages of the buffer produced in 500 // consumer, and redirect reads and writes in the consumer to that cache. 501 // Returns a pair of the new cache buffer, and the new rewritten consumer. 502 static AccessResult cacheAccesses( 503 const BufPtr& producer, 504 const std::string& name, 505 const StmtPtr& consumer); 506 507 // Insert a temporary computation of statement S in the scope of loop AT. 508 // S is assumed to be a Store or a Block containing a Store. Along with the 509 // computation itself, this transformation inserts Alloc/Free statements for 510 // the temporary buffer used in the computation. 511 static void computeAt(const StmtPtr& s, const ForPtr& at); 512 513 // Rfactor a reduction axis into a normal axis. 514 // 515 // Requirements: 516 // * S is the reduction store 517 // * S is the only statement in the innermost loop 518 // * There is at least two reduction arguments in S 519 // * OUTER_REDUCTION_FOR loop corresponds to the outermost reduction variable 520 // used in the store and all other reduction variables are index variables of 521 // children loops of OUTER_REDUCTION_FOR 522 // * OUTER_REDUCTION_FOR is a perfect loop nest, i.e. it has only loops 523 // corresponding to the other reduction variables and the store, nested into 524 // each other 525 // 526 // What it does: 527 // * Introduce a new buffer with an extra dimension of a size equal to the 528 // span of the loop OUTER_REDUCTION_FOR (the new buffer is returned via 529 // RFAC_BUF_PTR) 530 // * Insert an initialization store for the new buffer in 531 // OUTER_REDUCTION_FOR before its nested loop 532 // * Replace the reduction store to the original buffer with the reduction 533 // store to the temp buffer, removing the index var of OUTER_REDUCTION_FOR 534 // from reduction arguments 535 // * Insert a final reduction store over the extra dimension of the new 536 // buffer to the original buffer 537 // * Returns TRUE if the transformation succeeded and FALSE otherwise 538 // 539 // Example: 540 // Original IR: 541 // S1: for i # normal axis 542 // S2: X[i] = 0 543 // S3: for j # reduction axis 544 // S4: for k # reduction axis 545 // S5: X[i] = ReduceOp(X[i] + Y[i,j,k], reduce_axis={j,k}) 546 // 547 // After RFACTOR(S5, S3) 548 // S1: for i # normal axis 549 // S2: X[i] = 0 550 // S3: for j # reduction axis for X, normal axis for X_rfac 551 // X_rfac[i,j] = 0 552 // S4: for k # reduction axis 553 // X_rfac[i,j] = ReduceOp(X_rfac[i,j] + Y[i,j,k], reduce_axis={k}) 554 // X[i] = ReduceOp(X[i] + X_rfac[i,j], reduce_axis={j}) 555 static bool rfactor(const StmtPtr& s, const ForPtr& outer_reduction_for); 556 static bool rfactor( 557 const StmtPtr& s, 558 const ForPtr& outer_reduction_for, 559 BufPtr* rfac_buf_ptr); 560 561 // Vectorize the given loop. This method requires that the given loop 562 // does not perform a reduction. 563 // It returns true if vectorization is successful and false otherwise. 564 static bool vectorize(const ForPtr&); 565 566 // Find the inner-most loops and vectorize them. Currently, this only works 567 // for the LLVM backend, when no reductions are involved. 568 void vectorizeInnerLoops(); 569 570 void eliminateDeadStores(); 571 572 void prepareForCodegen(); 573 574 const std::unordered_set<BufPtr> getInputBufs() const; getOutputBufs()575 const std::unordered_set<BufPtr> getOutputBufs() const { 576 return output_bufs_; 577 } 578 std::vector<BufPtr> getIntermediateBufs() const; 579 580 // Finds which is the outer For between a and b for loops. If neither of the 2 581 // Fors is an ancestor of the other, it returns nullptr. 582 static ForPtr findOuterFor(ForPtr a, ForPtr b); 583 584 private: 585 void initialize( 586 const std::vector<Tensor>& output_tensors, 587 const std::vector<Tensor>& tensors_to_compute); 588 589 StmtPtr root_stmt_; 590 591 std::unordered_set<BufPtr> output_bufs_; 592 }; 593 594 TORCH_API StmtPtr FlattenIndexes(const StmtPtr& s); 595 596 // TODO: Revisit this once we decide on how dependencies analysis should look 597 // like. Maybe we would choose to use a different API and BufUse would be 598 // removed, or if we decide to keep it we need to properly document its API. 599 struct BufLoadOrStoreUse { 600 StmtPtr s; 601 bool isStore; 602 }; 603 604 /* 605 * Returns a map ( Buf -> uses of this Buf), uses are represented as vectors of 606 * BufUse elements, which are StmtPtr and a bool isStore flag. The order of uses 607 * in the vectors reflects the order in which the uses appear in the given 608 * statement. 609 */ 610 std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> findLoadOrStoreUses( 611 const StmtPtr& s); 612 613 // replaces all invalid characters with underscore 614 TORCH_API std::string sanitizeName(const std::string& input_name); 615 616 } // namespace torch::jit::tensorexpr 617