1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 5 namespace torch::jit { 6 7 // return true if graph is modified 8 TORCH_API bool UnrollLoops(std::shared_ptr<Graph>& graph); 9 10 // Only unrolls constant loops. Will unroll them regardless of loop block size 11 TORCH_API bool UnrollConstantLoops(std::shared_ptr<Graph>& graph); 12 13 TORCH_API Node* PeelLoop(Node* n, size_t times); 14 15 // return true if graph is modified 16 TORCH_API bool PeelProfilingLoops(const std::shared_ptr<Graph>& graph); 17 18 struct TORCH_API LoopsPeeler { 19 LoopsPeeler(std::function<bool(Node* n)> callback, size_t num_iterations = 1) callback_LoopsPeeler20 : callback_(std::move(callback)), num_iterations_(num_iterations) {} 21 22 bool run(const std::shared_ptr<Graph>& graph); 23 24 private: 25 void collectLoop(Node* n); 26 void collectLoops(Block* block); 27 void peelLoops(); 28 29 std::function<bool(Node* n)> callback_ = nullptr; 30 Node* in_loop_ = nullptr; 31 std::list<Node*> loops_to_peel_; 32 size_t num_iterations_ = 1; 33 }; 34 } // namespace torch::jit 35