1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 5 /* `getCustomPrePasses()` returns a vector of passes that will be executed 6 * after differentiation but before any fusion. This is the de-facto location 7 * for compiler backends to insert passes. 8 * 9 * `getCustomPostPasses()` returns a vector of passes that will be 10 * executed after differentiation and after fusion (if any). This is the 11 * location for fusion cleanup passes if they are needed. 12 * 13 * Static registration of a pass can be done by creating a global 14 * `Register{Pre,Post}Pass r(Pass)` variable in a compilation unit. 15 * 16 * pass_manager.h uses a Meyer's singleton to store a vector of `Pass`es, which 17 * modify the IR graph in place. 18 */ 19 20 namespace torch::jit { 21 22 // A pass modifies a Graph in place. 23 using GraphPass = std::function<void(std::shared_ptr<Graph>&)>; 24 25 // Since Passes are std::functions, we associate a UUID to each pass, this way 26 // if we want to deregister a pass, we have something to reference it by. 27 using GraphPassNameType = unsigned int; 28 29 // Graph pass entries have a name associated with them 30 using GraphPassEntry = std::pair<GraphPass, GraphPassNameType>; 31 32 // Return currently registered passes. Passes are stored in a static vector 33 TORCH_API std::vector<std::pair<GraphPass, GraphPassNameType>>& 34 getCustomPostPasses(); 35 TORCH_API std::vector<std::pair<GraphPass, GraphPassNameType>>& 36 getCustomPrePasses(); 37 38 TORCH_API GraphPassNameType registerPostPass(GraphPass p); 39 TORCH_API GraphPassNameType registerPrePass(GraphPass p); 40 41 // Look up pass by name passed in, remove it from registered passes 42 TORCH_API void clearPostPass(GraphPassNameType p); 43 TORCH_API void clearPrePass(GraphPassNameType p); 44 45 // Remove all passes 46 TORCH_API void clearAllPostPasses(); 47 TORCH_API void clearAllPrePasses(); 48 49 // LEGACY CALL 50 struct TORCH_API RegisterPostPass { 51 RegisterPostPass(GraphPass p); 52 }; 53 54 using RegisterPass = RegisterPostPass; 55 56 /* 57 * PassManager is a wrapper on the register/clear PostPass functions above. It 58 * will register the pass provided in "registerPass" and will hold on to its 59 * associated name that way clearPass can be later called and will delete the 60 * pass used to register when called. 61 * 62 * PassManager is templated because we want static variables based on a 63 * particular GraphPass. When deriving from PassManager, you should send as the 64 * template parameter your derived class as you would for the curiously 65 * recurring template pattern. This template parameter isn't actually used and 66 * is simply done to prevent static members from being shared across derived 67 * types. 68 */ 69 template <typename DerivedType> 70 struct C10_EXPORT PassManager { 71 private: 72 // We want this class to be abstract because it's 73 virtual void abstract() = 0; 74 75 protected: 76 /* 77 * isRegistered() will return if a pass has been registered 78 * isRegistered(true) will change the value of the internal static bool 79 * 80 * There's an internal static bool to this function to keep track of the 81 * state, this is so when functions are derived from this class, they don't 82 * have to worry about initializing the static members. 83 */ 84 static bool isRegistered(bool flip_bit = false) { 85 static bool val = false; 86 if (flip_bit) 87 val = !val; 88 return val; 89 } 90 91 /* 92 * name() will return the name of the registered pass 93 * name(pass_name, true) will set the name of the pass 94 * Similarly to isRegistered we use an internal static variable to hold the 95 * name. 96 */ 97 static GraphPassNameType passID( 98 GraphPassNameType PassID = 0, 99 bool set = false) { 100 static GraphPassNameType pass_id = 0; 101 if (set) 102 pass_id = PassID; 103 return pass_id; 104 } 105 106 public: 107 // registerPass(pass) will register the pass provided and set the 108 // name/isRegistered functions appropriately, it returns a bool value 109 // indicating whether the given pass is already registered previously. registerPassPassManager110 static bool registerPass(GraphPass p) { 111 if (!isRegistered()) { 112 // If we don't already have a registered pass, register pass 113 // hold on to its name, change isRegistered to true 114 passID(registerPostPass(std::move(p)), true); 115 isRegistered(true); 116 return false; 117 } 118 return true; 119 } 120 121 // Calls ClearPostPass(passID()) clearPassPassManager122 static void clearPass() { 123 // If the pass is registered, clear it and change isRegistered to false. 124 if (isRegistered()) { 125 clearPostPass(passID()); 126 isRegistered(true); 127 } 128 } 129 130 // clang-tidy requires virtual destructor; 131 virtual ~PassManager() = default; 132 }; 133 134 } // namespace torch::jit 135