1 #include <torch/csrc/jit/passes/pass_manager.h> 2 3 namespace torch::jit { 4 5 // Start UUID at 1 6 static GraphPassNameType graphPassID = 1; 7 getCustomPostPasses()8std::vector<GraphPassEntry>& getCustomPostPasses() { 9 static std::vector<GraphPassEntry> passes; 10 return passes; 11 } 12 getCustomPrePasses()13std::vector<GraphPassEntry>& getCustomPrePasses() { 14 static std::vector<GraphPassEntry> passes; 15 return passes; 16 } 17 registerPostPass(GraphPass p)18GraphPassNameType registerPostPass(GraphPass p) { 19 getCustomPostPasses().emplace_back(std::move(p), graphPassID); 20 return graphPassID++; 21 } 22 registerPass(GraphPass p)23static GraphPassNameType registerPass(GraphPass p) { 24 return registerPostPass(std::move(p)); 25 } 26 registerPrePass(GraphPass p)27GraphPassNameType registerPrePass(GraphPass p) { 28 getCustomPrePasses().emplace_back(std::move(p), graphPassID); 29 return graphPassID++; 30 } 31 clearPostPass(GraphPassNameType pid)32void clearPostPass(GraphPassNameType pid) { 33 auto& passes = getCustomPostPasses(); 34 auto it = passes.begin(); 35 for (; it != passes.end(); it++) { 36 if (pid == (*it).second) 37 break; 38 } 39 if (it != passes.end()) 40 passes.erase(it); 41 } 42 clearPrePass(GraphPassNameType pid)43void clearPrePass(GraphPassNameType pid) { 44 auto& passes = getCustomPrePasses(); 45 auto it = passes.begin(); 46 for (; it != passes.end(); it++) { 47 if (pid == (*it).second) 48 break; 49 } 50 if (it != passes.end()) 51 passes.erase(it); 52 } 53 clearAllPostPasses()54void clearAllPostPasses() { 55 auto& passes = getCustomPostPasses(); 56 passes.erase(passes.begin(), passes.end()); 57 } 58 clearAllPrePasses()59void clearAllPrePasses() { 60 auto& passes = getCustomPrePasses(); 61 passes.erase(passes.begin(), passes.end()); 62 } 63 64 // LEGACY CALL RegisterPostPass(GraphPass p)65RegisterPostPass::RegisterPostPass(GraphPass p) { 66 registerPass(std::move(p)); 67 } 68 69 } // namespace torch::jit 70