1 #pragma once 2 #include <ATen/core/jit_type.h> 3 #include <torch/csrc/jit/mobile/debug_info.h> 4 #include <torch/csrc/jit/mobile/function.h> 5 #include <torch/csrc/jit/mobile/method.h> 6 #include <torch/csrc/jit/mobile/quantization.h> 7 8 #include <utility> 9 10 namespace torch::jit::mobile { 11 using Stack = std::vector<c10::IValue>; 12 13 // A CompilationUnit object is the one that gets executed by the lite 14 // interpreter. 15 // 16 // A CompilationUnit object contains a list of Method Objects. These are methods 17 // that appear in the original PyTorch Model. These method correspond to Python 18 // member functions of the Model class. 19 // 20 // Methods in turn contain a Function, and a back-pointer to the Module that 21 // owns this Method instance. 22 // 23 // A Function contains a Code Object (code_) which is defined in interpreter.h 24 // 25 // A Code object contains the following: 26 // 27 // std::vector<Instruction> instructions_; 28 // std::vector<c10::OperatorName> op_names_; 29 // std::vector<std::function<void(Stack&)>> operators_; 30 // std::vector<c10::IValue> constants_; 31 // std::vector<c10::TypePtr> types_; 32 // size_t register_size_; // Aggregated output size. 33 // 34 class CompilationUnit { 35 public: 36 void register_function(std::unique_ptr<Function> fn); methods()37 std::vector<std::unique_ptr<Function>>& methods() { 38 return methods_; 39 } methods()40 const std::vector<std::unique_ptr<Function>>& methods() const { 41 return methods_; 42 } 43 Function* find_function(const c10::QualifiedName& qn); 44 const Function* find_function(const c10::QualifiedName& qn) const; 45 unsafeRemoveFunction(const int64_t index)46 void unsafeRemoveFunction(const int64_t index) { 47 methods_.erase(methods_.begin() + index); 48 } 49 50 private: 51 std::vector<std::unique_ptr<Function>> methods_; 52 }; 53 54 // A Torch Mobile Module is a representation of the model (trained in case 55 // of inference). A Mobile Module contains 56 // 57 // 1. data (object_) 58 // 2. metadata (optional) about the model (metadata_ from the metadata.pkl 59 // file added after training) 60 // 3. Compilation Unit (cu_) 61 // 62 class TORCH_API Module { 63 public: Module(c10::intrusive_ptr<c10::ivalue::Object> object,std::shared_ptr<CompilationUnit> cu)64 Module( 65 c10::intrusive_ptr<c10::ivalue::Object> object, 66 std::shared_ptr<CompilationUnit> cu) 67 : object_(std::move(object)), cu_(std::move(cu)) {} 68 Module() = default; 69 Method get_method(const std::string& method_name) const; 70 template <typename... Types> run_method(const std::string & method_name,Types &&...args)71 c10::IValue run_method(const std::string& method_name, Types&&... args) { 72 return get_method(method_name)({IValue(std::forward<Types>(args))...}); 73 } forward(std::vector<c10::IValue> inputs)74 c10::IValue forward(std::vector<c10::IValue> inputs) { 75 return get_method("forward")(std::move(inputs)); 76 } 77 std::optional<Method> find_method(const std::string& basename) const; 78 name()79 const std::string name() const { 80 return object_->name(); 81 } slots()82 const std::vector<at::IValue>& slots() const { 83 return object_->slots(); 84 } _ivalue()85 const c10::intrusive_ptr<c10::ivalue::Object> _ivalue() const { 86 return object_; 87 } 88 const std::vector<at::Tensor> parameters() const; 89 const std::map<std::string, at::Tensor> named_parameters() const; 90 std::string get_forward_method_debug_info(int64_t debug_handle) const; 91 std::string getModuleHierarchy(const int64_t debug_handle) const; 92 std::string getCallStack(const int64_t debug_handle) const; 93 /// Enables "training" mode. 94 void train(bool on = true); 95 /// Calls train(false) to enable "eval" mode. eval()96 void eval() { 97 train(/*on=*/false); 98 } 99 /// True if the module is in training mode. 100 bool is_training() const; getMetadata()101 const std::unordered_map<std::string, std::string> getMetadata() const { 102 return metadata_; 103 } setMetadata(const std::unordered_map<std::string,std::string> & metadata)104 void setMetadata( 105 const std::unordered_map<std::string, std::string>& metadata) { 106 metadata_ = metadata; 107 } 108 const std::vector<Method> get_methods() const; 109 attr(const std::string & name,c10::IValue or_else)110 c10::IValue attr(const std::string& name, c10::IValue or_else) const { 111 if (auto r = object_->type()->findAttributeSlot(name)) { 112 return object_->getSlot(*r); 113 } 114 if (auto r = object_->type()->findConstantSlot(name)) { 115 return object_->type()->getConstant(*r); 116 } 117 return or_else; 118 } 119 setDebugTable(MobileDebugTable && debug_table)120 void setDebugTable(MobileDebugTable&& debug_table) { 121 debug_table_ = std::move(debug_table); 122 } getDebugTable()123 const MobileDebugTable& getDebugTable() const { 124 return debug_table_; 125 } 126 setHasDebugHandles(bool has_debug_handles)127 void setHasDebugHandles(bool has_debug_handles) { 128 has_debug_handles_ = has_debug_handles; 129 } 130 hasDebugHandles()131 bool hasDebugHandles() const { 132 return has_debug_handles_; 133 } 134 compilation_unit()135 const CompilationUnit& compilation_unit() const { 136 return *cu_; 137 } 138 set_delete_memory(std::shared_ptr<char> delete_mem)139 void set_delete_memory(std::shared_ptr<char> delete_mem) { 140 mem_to_delete_ = std::move(delete_mem); 141 } 142 set_min_operator_version(int64_t version)143 void set_min_operator_version(int64_t version) { 144 min_operator_version_ = version; 145 } 146 min_operator_version()147 int64_t min_operator_version() const { 148 return min_operator_version_; 149 } 150 set_bytecode_version(int64_t version)151 void set_bytecode_version(int64_t version) { 152 bytecode_version_ = version; 153 } 154 bytecode_version()155 int64_t bytecode_version() const { 156 return bytecode_version_; 157 } 158 159 private: 160 friend class quantization::PTQQuanizationHelper; 161 162 bool compareMethodSchemas( 163 const std::string& name_1, 164 const std::string& name_2); 165 166 void unsafeRemoveMethod(const std::string& basename); 167 168 void unsafeCopyMethod( 169 const std::string& new_method_name, 170 const Function& to_be_copied); 171 172 c10::intrusive_ptr<c10::ivalue::Object> object_; 173 std::unordered_map<std::string, std::string> metadata_; 174 std::shared_ptr<CompilationUnit> cu_; 175 MobileDebugTable debug_table_; 176 bool has_debug_handles_ = false; 177 int64_t min_operator_version_ = 4; 178 int64_t bytecode_version_ = 4; 179 180 // Extra handle for the module to delete when itself is deleted 181 std::shared_ptr<char> mem_to_delete_; 182 }; 183 184 struct TORCH_API ModuleInfo { 185 uint64_t bytecode_version; 186 uint64_t operator_version; 187 std::unordered_map<std::string, int> opname_to_num_args; 188 std::unordered_set<std::string> function_names; 189 std::unordered_set<std::string> type_names; 190 }; 191 TORCH_API ModuleInfo get_module_info(const mobile::Module& module); 192 193 } // namespace torch::jit::mobile 194