xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/pass_manager.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/pass_manager.h>
2 
3 namespace torch::jit {
4 
5 // Start UUID at 1
6 static GraphPassNameType graphPassID = 1;
7 
getCustomPostPasses()8 std::vector<GraphPassEntry>& getCustomPostPasses() {
9   static std::vector<GraphPassEntry> passes;
10   return passes;
11 }
12 
getCustomPrePasses()13 std::vector<GraphPassEntry>& getCustomPrePasses() {
14   static std::vector<GraphPassEntry> passes;
15   return passes;
16 }
17 
registerPostPass(GraphPass p)18 GraphPassNameType registerPostPass(GraphPass p) {
19   getCustomPostPasses().emplace_back(std::move(p), graphPassID);
20   return graphPassID++;
21 }
22 
registerPass(GraphPass p)23 static GraphPassNameType registerPass(GraphPass p) {
24   return registerPostPass(std::move(p));
25 }
26 
registerPrePass(GraphPass p)27 GraphPassNameType registerPrePass(GraphPass p) {
28   getCustomPrePasses().emplace_back(std::move(p), graphPassID);
29   return graphPassID++;
30 }
31 
clearPostPass(GraphPassNameType pid)32 void clearPostPass(GraphPassNameType pid) {
33   auto& passes = getCustomPostPasses();
34   auto it = passes.begin();
35   for (; it != passes.end(); it++) {
36     if (pid == (*it).second)
37       break;
38   }
39   if (it != passes.end())
40     passes.erase(it);
41 }
42 
clearPrePass(GraphPassNameType pid)43 void clearPrePass(GraphPassNameType pid) {
44   auto& passes = getCustomPrePasses();
45   auto it = passes.begin();
46   for (; it != passes.end(); it++) {
47     if (pid == (*it).second)
48       break;
49   }
50   if (it != passes.end())
51     passes.erase(it);
52 }
53 
clearAllPostPasses()54 void clearAllPostPasses() {
55   auto& passes = getCustomPostPasses();
56   passes.erase(passes.begin(), passes.end());
57 }
58 
clearAllPrePasses()59 void clearAllPrePasses() {
60   auto& passes = getCustomPrePasses();
61   passes.erase(passes.begin(), passes.end());
62 }
63 
64 // LEGACY CALL
RegisterPostPass(GraphPass p)65 RegisterPostPass::RegisterPostPass(GraphPass p) {
66   registerPass(std::move(p));
67 }
68 
69 } // namespace torch::jit
70