xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/module.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/jit_type.h>
3 #include <torch/csrc/jit/mobile/debug_info.h>
4 #include <torch/csrc/jit/mobile/function.h>
5 #include <torch/csrc/jit/mobile/method.h>
6 #include <torch/csrc/jit/mobile/quantization.h>
7 
8 #include <utility>
9 
10 namespace torch::jit::mobile {
11 using Stack = std::vector<c10::IValue>;
12 
13 // A CompilationUnit object is the one that gets executed by the lite
14 // interpreter.
15 //
16 // A CompilationUnit object contains a list of Method Objects. These are methods
17 // that appear in the original PyTorch Model. These method correspond to Python
18 // member functions of the Model class.
19 //
20 // Methods in turn contain a Function, and a back-pointer to the Module that
21 // owns this Method instance.
22 //
23 // A Function contains a Code Object (code_) which is defined in interpreter.h
24 //
25 // A Code object contains the following:
26 //
27 // std::vector<Instruction> instructions_;
28 // std::vector<c10::OperatorName> op_names_;
29 // std::vector<std::function<void(Stack&)>> operators_;
30 // std::vector<c10::IValue> constants_;
31 // std::vector<c10::TypePtr> types_;
32 // size_t register_size_; // Aggregated output size.
33 //
34 class CompilationUnit {
35  public:
36   void register_function(std::unique_ptr<Function> fn);
methods()37   std::vector<std::unique_ptr<Function>>& methods() {
38     return methods_;
39   }
methods()40   const std::vector<std::unique_ptr<Function>>& methods() const {
41     return methods_;
42   }
43   Function* find_function(const c10::QualifiedName& qn);
44   const Function* find_function(const c10::QualifiedName& qn) const;
45 
unsafeRemoveFunction(const int64_t index)46   void unsafeRemoveFunction(const int64_t index) {
47     methods_.erase(methods_.begin() + index);
48   }
49 
50  private:
51   std::vector<std::unique_ptr<Function>> methods_;
52 };
53 
54 // A Torch Mobile Module is a representation of the model (trained in case
55 // of inference). A Mobile Module contains
56 //
57 // 1. data (object_)
58 // 2. metadata (optional) about the model (metadata_ from the metadata.pkl
59 //    file added after training)
60 // 3. Compilation Unit (cu_)
61 //
62 class TORCH_API Module {
63  public:
Module(c10::intrusive_ptr<c10::ivalue::Object> object,std::shared_ptr<CompilationUnit> cu)64   Module(
65       c10::intrusive_ptr<c10::ivalue::Object> object,
66       std::shared_ptr<CompilationUnit> cu)
67       : object_(std::move(object)), cu_(std::move(cu)) {}
68   Module() = default;
69   Method get_method(const std::string& method_name) const;
70   template <typename... Types>
run_method(const std::string & method_name,Types &&...args)71   c10::IValue run_method(const std::string& method_name, Types&&... args) {
72     return get_method(method_name)({IValue(std::forward<Types>(args))...});
73   }
forward(std::vector<c10::IValue> inputs)74   c10::IValue forward(std::vector<c10::IValue> inputs) {
75     return get_method("forward")(std::move(inputs));
76   }
77   std::optional<Method> find_method(const std::string& basename) const;
78 
name()79   const std::string name() const {
80     return object_->name();
81   }
slots()82   const std::vector<at::IValue>& slots() const {
83     return object_->slots();
84   }
_ivalue()85   const c10::intrusive_ptr<c10::ivalue::Object> _ivalue() const {
86     return object_;
87   }
88   const std::vector<at::Tensor> parameters() const;
89   const std::map<std::string, at::Tensor> named_parameters() const;
90   std::string get_forward_method_debug_info(int64_t debug_handle) const;
91   std::string getModuleHierarchy(const int64_t debug_handle) const;
92   std::string getCallStack(const int64_t debug_handle) const;
93   /// Enables "training" mode.
94   void train(bool on = true);
95   /// Calls train(false) to enable "eval" mode.
eval()96   void eval() {
97     train(/*on=*/false);
98   }
99   /// True if the module is in training mode.
100   bool is_training() const;
getMetadata()101   const std::unordered_map<std::string, std::string> getMetadata() const {
102     return metadata_;
103   }
setMetadata(const std::unordered_map<std::string,std::string> & metadata)104   void setMetadata(
105       const std::unordered_map<std::string, std::string>& metadata) {
106     metadata_ = metadata;
107   }
108   const std::vector<Method> get_methods() const;
109 
attr(const std::string & name,c10::IValue or_else)110   c10::IValue attr(const std::string& name, c10::IValue or_else) const {
111     if (auto r = object_->type()->findAttributeSlot(name)) {
112       return object_->getSlot(*r);
113     }
114     if (auto r = object_->type()->findConstantSlot(name)) {
115       return object_->type()->getConstant(*r);
116     }
117     return or_else;
118   }
119 
setDebugTable(MobileDebugTable && debug_table)120   void setDebugTable(MobileDebugTable&& debug_table) {
121     debug_table_ = std::move(debug_table);
122   }
getDebugTable()123   const MobileDebugTable& getDebugTable() const {
124     return debug_table_;
125   }
126 
setHasDebugHandles(bool has_debug_handles)127   void setHasDebugHandles(bool has_debug_handles) {
128     has_debug_handles_ = has_debug_handles;
129   }
130 
hasDebugHandles()131   bool hasDebugHandles() const {
132     return has_debug_handles_;
133   }
134 
compilation_unit()135   const CompilationUnit& compilation_unit() const {
136     return *cu_;
137   }
138 
set_delete_memory(std::shared_ptr<char> delete_mem)139   void set_delete_memory(std::shared_ptr<char> delete_mem) {
140     mem_to_delete_ = std::move(delete_mem);
141   }
142 
set_min_operator_version(int64_t version)143   void set_min_operator_version(int64_t version) {
144     min_operator_version_ = version;
145   }
146 
min_operator_version()147   int64_t min_operator_version() const {
148     return min_operator_version_;
149   }
150 
set_bytecode_version(int64_t version)151   void set_bytecode_version(int64_t version) {
152     bytecode_version_ = version;
153   }
154 
bytecode_version()155   int64_t bytecode_version() const {
156     return bytecode_version_;
157   }
158 
159  private:
160   friend class quantization::PTQQuanizationHelper;
161 
162   bool compareMethodSchemas(
163       const std::string& name_1,
164       const std::string& name_2);
165 
166   void unsafeRemoveMethod(const std::string& basename);
167 
168   void unsafeCopyMethod(
169       const std::string& new_method_name,
170       const Function& to_be_copied);
171 
172   c10::intrusive_ptr<c10::ivalue::Object> object_;
173   std::unordered_map<std::string, std::string> metadata_;
174   std::shared_ptr<CompilationUnit> cu_;
175   MobileDebugTable debug_table_;
176   bool has_debug_handles_ = false;
177   int64_t min_operator_version_ = 4;
178   int64_t bytecode_version_ = 4;
179 
180   // Extra handle for the module to delete when itself is deleted
181   std::shared_ptr<char> mem_to_delete_;
182 };
183 
184 struct TORCH_API ModuleInfo {
185   uint64_t bytecode_version;
186   uint64_t operator_version;
187   std::unordered_map<std::string, int> opname_to_num_args;
188   std::unordered_set<std::string> function_names;
189   std::unordered_set<std::string> type_names;
190 };
191 TORCH_API ModuleInfo get_module_info(const mobile::Module& module);
192 
193 } // namespace torch::jit::mobile
194