xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/loop_unrolling.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 namespace torch::jit {
6 
7 // return true if graph is modified
8 TORCH_API bool UnrollLoops(std::shared_ptr<Graph>& graph);
9 
10 // Only unrolls constant loops. Will unroll them regardless of loop block size
11 TORCH_API bool UnrollConstantLoops(std::shared_ptr<Graph>& graph);
12 
13 TORCH_API Node* PeelLoop(Node* n, size_t times);
14 
15 // return true if graph is modified
16 TORCH_API bool PeelProfilingLoops(const std::shared_ptr<Graph>& graph);
17 
18 struct TORCH_API LoopsPeeler {
19   LoopsPeeler(std::function<bool(Node* n)> callback, size_t num_iterations = 1)
callback_LoopsPeeler20       : callback_(std::move(callback)), num_iterations_(num_iterations) {}
21 
22   bool run(const std::shared_ptr<Graph>& graph);
23 
24  private:
25   void collectLoop(Node* n);
26   void collectLoops(Block* block);
27   void peelLoops();
28 
29   std::function<bool(Node* n)> callback_ = nullptr;
30   Node* in_loop_ = nullptr;
31   std::list<Node*> loops_to_peel_;
32   size_t num_iterations_ = 1;
33 };
34 } // namespace torch::jit
35