xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/pass_manager.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 /* `getCustomPrePasses()` returns a vector of passes that will be executed
6  * after differentiation but before any fusion. This is the de-facto location
7  * for compiler backends to insert passes.
8  *
9  * `getCustomPostPasses()` returns a vector of passes that will be
10  * executed after differentiation and after fusion (if any). This is the
11  * location for fusion cleanup passes if they are needed.
12  *
13  * Static registration of a pass can be done by creating a global
14  * `Register{Pre,Post}Pass r(Pass)` variable in a compilation unit.
15  *
16  * pass_manager.h uses a Meyer's singleton to store a vector of `Pass`es, which
17  * modify the IR graph in place.
18  */
19 
20 namespace torch::jit {
21 
22 // A pass modifies a Graph in place.
23 using GraphPass = std::function<void(std::shared_ptr<Graph>&)>;
24 
25 // Since Passes are std::functions, we associate a UUID to each pass, this way
26 // if we want to deregister a pass, we have something to reference it by.
27 using GraphPassNameType = unsigned int;
28 
29 // Graph pass entries have a name associated with them
30 using GraphPassEntry = std::pair<GraphPass, GraphPassNameType>;
31 
32 // Return currently registered passes. Passes are stored in a static vector
33 TORCH_API std::vector<std::pair<GraphPass, GraphPassNameType>>&
34 getCustomPostPasses();
35 TORCH_API std::vector<std::pair<GraphPass, GraphPassNameType>>&
36 getCustomPrePasses();
37 
38 TORCH_API GraphPassNameType registerPostPass(GraphPass p);
39 TORCH_API GraphPassNameType registerPrePass(GraphPass p);
40 
41 // Look up pass by name passed in, remove it from registered passes
42 TORCH_API void clearPostPass(GraphPassNameType p);
43 TORCH_API void clearPrePass(GraphPassNameType p);
44 
45 // Remove all passes
46 TORCH_API void clearAllPostPasses();
47 TORCH_API void clearAllPrePasses();
48 
49 // LEGACY CALL
50 struct TORCH_API RegisterPostPass {
51   RegisterPostPass(GraphPass p);
52 };
53 
54 using RegisterPass = RegisterPostPass;
55 
56 /*
57  * PassManager is a wrapper on the register/clear PostPass functions above. It
58  * will register the pass provided in "registerPass" and will hold on to its
59  * associated name that way clearPass can be later called and will delete the
60  * pass used to register when called.
61  *
62  * PassManager is templated because we want static variables based on a
63  * particular GraphPass. When deriving from PassManager, you should send as the
64  * template parameter your derived class as you would for the curiously
65  * recurring template pattern. This template parameter isn't actually used and
66  * is simply done to prevent static members from being shared across derived
67  * types.
68  */
69 template <typename DerivedType>
70 struct C10_EXPORT PassManager {
71  private:
72   // We want this class to be abstract because it's
73   virtual void abstract() = 0;
74 
75  protected:
76   /*
77    * isRegistered() will return if a pass has been registered
78    * isRegistered(true) will change the value of the internal static bool
79    *
80    * There's an internal static bool to this function to keep track of the
81    * state, this is so when functions are derived from this class, they don't
82    * have to worry about initializing the static members.
83    */
84   static bool isRegistered(bool flip_bit = false) {
85     static bool val = false;
86     if (flip_bit)
87       val = !val;
88     return val;
89   }
90 
91   /*
92    * name() will return the name of the registered pass
93    * name(pass_name, true) will set the name of the pass
94    * Similarly to isRegistered we use an internal static variable to hold the
95    * name.
96    */
97   static GraphPassNameType passID(
98       GraphPassNameType PassID = 0,
99       bool set = false) {
100     static GraphPassNameType pass_id = 0;
101     if (set)
102       pass_id = PassID;
103     return pass_id;
104   }
105 
106  public:
107   // registerPass(pass) will register the pass provided and set the
108   // name/isRegistered functions appropriately, it returns a bool value
109   // indicating whether the given pass is already registered previously.
registerPassPassManager110   static bool registerPass(GraphPass p) {
111     if (!isRegistered()) {
112       // If we don't already have a registered pass, register pass
113       // hold on to its name, change isRegistered to true
114       passID(registerPostPass(std::move(p)), true);
115       isRegistered(true);
116       return false;
117     }
118     return true;
119   }
120 
121   // Calls ClearPostPass(passID())
clearPassPassManager122   static void clearPass() {
123     // If the pass is registered, clear it and change isRegistered to false.
124     if (isRegistered()) {
125       clearPostPass(passID());
126       isRegistered(true);
127     }
128   }
129 
130   // clang-tidy requires virtual destructor;
131   virtual ~PassManager() = default;
132 };
133 
134 } // namespace torch::jit
135