xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/module.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/module.h>
2 
3 #include <torch/csrc/jit/backends/backend_exception.h>
4 #include <torch/csrc/jit/mobile/interpreter.h>
5 #include <torch/csrc/jit/mobile/observer.h>
6 #include <torch/csrc/jit/mobile/type_parser.h>
7 #include <torch/csrc/jit/runtime/jit_exception.h>
8 
9 #include <ATen/record_function.h>
10 #include <c10/util/ScopeExit.h>
11 #include <c10/util/irange.h>
12 
13 namespace torch::jit {
14 std::ostream& operator<<(std::ostream& out, Instruction inst);
15 namespace mobile {
16 
register_function(std::unique_ptr<Function> fn)17 void CompilationUnit::register_function(std::unique_ptr<Function> fn) {
18   methods_.emplace_back(std::move(fn));
19 }
20 
find_function(const c10::QualifiedName & qn) const21 const Function* CompilationUnit::find_function(
22     const c10::QualifiedName& qn) const {
23   for (auto& fn : methods_) {
24     if (fn->qualname() == qn) {
25       return fn.get();
26     }
27   }
28   return nullptr;
29 }
30 
find_function(const c10::QualifiedName & qn)31 Function* CompilationUnit::find_function(const c10::QualifiedName& qn) {
32   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
33   return const_cast<Function*>(
34       static_cast<const CompilationUnit*>(this)->find_function(qn));
35 }
36 
get_method(const std::string & name) const37 Method Module::get_method(const std::string& name) const {
38   if (auto method = find_method(name)) {
39     return *method;
40   }
41   AT_ERROR("Method '", name, "' is not defined.");
42 }
43 
compareMethodSchemas(const std::string & name_1,const std::string & name_2)44 bool Module::compareMethodSchemas(
45     const std::string& name_1,
46     const std::string& name_2) {
47   std::optional<c10::FunctionSchema> schema_1, schema_2;
48   for (const auto& fn : cu_->methods()) {
49     if (fn->name() == name_1) {
50       schema_1 = fn->getSchema();
51     }
52     if (fn->name() == name_2) {
53       schema_2 = fn->getSchema();
54     }
55   }
56   if (schema_1.has_value() && schema_2.has_value()) {
57     return (schema_1 == schema_2);
58   }
59   return false;
60 }
61 
unsafeRemoveMethod(const std::string & basename)62 void Module::unsafeRemoveMethod(const std::string& basename) {
63   int64_t i = 0;
64   for (; i < static_cast<int64_t>(cu_->methods().size()); ++i) {
65     if ((cu_->methods()[i])->name() == basename) {
66       break;
67     }
68   }
69   object_->type()->unsafeRemoveMethod(basename);
70   cu_->unsafeRemoveFunction(i);
71 }
72 
unsafeCopyMethod(const std::string & new_method_name,const Function & to_be_copied)73 void Module::unsafeCopyMethod(
74     const std::string& new_method_name,
75     const Function& to_be_copied) {
76   TORCH_CHECK(
77       !find_method(new_method_name).has_value(),
78       "Trying to replace existing method.");
79   const c10::QualifiedName& tobe_copied_name = to_be_copied.qualname();
80   c10::QualifiedName qualified_method_name(
81       tobe_copied_name.prefix(), new_method_name);
82   std::unique_ptr<Function> new_fn = std::make_unique<Function>(
83       qualified_method_name, to_be_copied.get_code(), to_be_copied.getSchema());
84   object_->type()->addMethod(new_fn.get());
85   cu_->register_function(std::move(new_fn));
86 }
87 
find_method(const std::string & basename) const88 std::optional<Method> Module::find_method(const std::string& basename) const {
89   for (const auto& fn : cu_->methods()) {
90     if (fn->name() == basename) {
91       return std::make_optional<Method>(Method(this, fn.get()));
92     }
93   }
94   return std::nullopt;
95 }
96 
97 namespace {
98 // For JIT, there is a private function to get all modules by iteration in
99 // struct slot_iterator_impl (jit/api/module.h). The following function use
100 // recursion to mimic the logic without allocating extra memory to get module
101 // list and set training attribute directly.
set_train_recurse(const c10::intrusive_ptr<c10::ivalue::Object> & obj,bool on)102 void set_train_recurse(
103     const c10::intrusive_ptr<c10::ivalue::Object>& obj,
104     bool on) {
105   if (auto slot = obj->type()->findAttributeSlot("training")) {
106     obj->setSlot(*slot, on);
107   } else {
108     TORCH_INTERNAL_ASSERT(
109         false,
110         "'training' attribute not found. Did you accidentally "
111         "call .eval() before saving your model?");
112   }
113   for (const auto& slot : obj->slots()) {
114     // slots is a list of IValue. Continue setting training attribute only
115     // if the slot is an object and a module.
116     if (slot.isObject() && slot.toObjectRef().type()->is_module()) {
117       set_train_recurse(slot.toObject(), on);
118     }
119   }
120 }
121 
slot_params_recurse(const c10::intrusive_ptr<c10::ivalue::Object> & obj,std::vector<at::Tensor> * params)122 void slot_params_recurse(
123     const c10::intrusive_ptr<c10::ivalue::Object>& obj,
124     std::vector<at::Tensor>* params) {
125   for (const auto& slot : obj->slots()) {
126     if (slot.isTensor()) {
127       params->emplace_back(slot.toTensor());
128     } else if (slot.isObject()) {
129       slot_params_recurse(slot.toObject(), params);
130     }
131   }
132 }
133 
slot_named_params_recurse(const c10::intrusive_ptr<c10::ivalue::Object> & obj,std::map<std::string,at::Tensor> * params,const std::string & parent_name)134 void slot_named_params_recurse(
135     const c10::intrusive_ptr<c10::ivalue::Object>& obj,
136     std::map<std::string, at::Tensor>* params,
137     const std::string& parent_name) {
138   auto slots = obj->slots();
139   size_t nslots = slots.size();
140   for (const auto i : c10::irange(nslots)) {
141     auto slot = slots[i];
142     std::string name = parent_name.empty() ? parent_name : parent_name + ".";
143     name += obj->type()->getAttributeName(i);
144     // TODO: Fix this filter. Requires_grad is not the appropriate
145     // filter of a parameter, but is a temporary hack to help probable
146     // users of this api. The correct behavior is to filter by the
147     // obj->type->is_parameter() but this currently always returns
148     // false on mobile.
149     if (slot.isTensor() && slot.toTensor().requires_grad()) {
150       (*params)[name] = slot.toTensor();
151     } else if (slot.isObject()) {
152       slot_named_params_recurse(slot.toObject(), params, name);
153     }
154   }
155 }
156 
157 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
getTopModuleTypeName(const Module & m)158 std::string getTopModuleTypeName(const Module& m) {
159   std::string name;
160   if (m._ivalue()->type() && m._ivalue()->type()->name()) {
161     name = m._ivalue()->type()->name().value().name();
162   }
163   return name;
164 }
165 #endif
166 
167 } // namespace
168 
parameters() const169 const std::vector<at::Tensor> Module::parameters() const {
170   std::vector<at::Tensor> params;
171   slot_params_recurse(object_, &params);
172   return params;
173 }
174 
175 // Returns a mapping for all attributes that requires_grad=True in a module.
176 // This behavior differs from full torch script modules. This is a bug,
177 // but currently there is no way to correctly label parameters in the
178 // loading of a mobile module. TODO
named_parameters() const179 const std::map<std::string, at::Tensor> Module::named_parameters() const {
180   std::map<std::string, at::Tensor> params;
181   const std::string name = "";
182   slot_named_params_recurse(object_, &params, name);
183   return params;
184 }
185 
getModuleHierarchy(const int64_t debug_handle) const186 std::string Module::getModuleHierarchy(const int64_t debug_handle) const {
187 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
188   return getDebugTable().getModuleHierarchyInfo(
189       debug_handle, getTopModuleTypeName(*this));
190 #else
191   return "";
192 #endif
193 }
194 
getCallStack(const int64_t debug_handle) const195 std::string Module::getCallStack(const int64_t debug_handle) const {
196 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
197   return getDebugTable().getSourceDebugString(
198       debug_handle, getTopModuleTypeName(*this));
199 #else
200   return "";
201 #endif
202 }
203 
204 // We will continue to support this API for now as this is being relied upon
205 // for profiling.
206 // We really need to change this part, so in the next step for profiling support
207 // for delegates, the first thing will be to rewrite how profiling is done
208 // for lite interpreter.
get_forward_method_debug_info(int64_t debug_handle) const209 std::string Module::get_forward_method_debug_info(int64_t debug_handle) const {
210 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
211   return getDebugTable().getModuleHierarchyInfo(
212       debug_handle, getTopModuleTypeName(*this));
213 #else
214   return "";
215 #endif
216 }
217 
train(bool on)218 void Module::train(bool on) {
219   set_train_recurse(object_, on);
220 }
221 
is_training() const222 bool Module::is_training() const {
223   if (auto slot = object_->type()->findAttributeSlot("training")) {
224     return object_->getSlot(*slot).toBool();
225   }
226   return true;
227 }
228 
get_methods() const229 const std::vector<Method> Module::get_methods() const {
230   std::vector<Method> methods;
231   for (std::unique_ptr<Function>& fn : cu_->methods()) {
232     methods.emplace_back(this, fn.get());
233   }
234   return methods;
235 }
236 
Method(const Module * owner,Function * function)237 Method::Method(const Module* owner, Function* function)
238     : owner_(owner), function_(function) {}
239 
run(Stack & stack) const240 void Method::run(Stack& stack) const {
241   auto observer = torch::observerConfig().getModuleObserver();
242   // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
243   auto instance_key = std::rand();
244   /* if the metadata dict doesn't contain "model_name", copy the metadata and
245   set the value of "model_name" as name() */
246   std::unordered_map<std::string, std::string> copied_metadata =
247       owner_->getMetadata();
248 
249   if (observer) {
250     observer->onEnterRunMethod(instance_key);
251   }
252 
253   auto debug_info = std::make_shared<MobileDebugInfo>();
254   std::string name = copied_metadata["model_name"];
255   debug_info->setModelName(name);
256   debug_info->setMethodName(function_->name());
257   at::DebugInfoGuard guard(at::DebugInfoKind::MOBILE_RUNTIME_INFO, debug_info);
258 
259   std::string error_message;
260   auto failure_guard = c10::make_scope_exit([&]() {
261     if (!observer) {
262       return;
263     }
264 
265 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
266     if (error_message.empty()) {
267       error_message = owner_->getDebugTable().getSourceDebugString(
268           function_->getExceptionDebugHandles(), getTopModuleTypeName(*owner_));
269     }
270 #endif
271 
272     observer->onFailRunMethod(
273         copied_metadata,
274         function_->name(),
275         instance_key,
276         error_message.empty() ? "Unknown exception" : error_message.c_str());
277   });
278 
279   try {
280     stack.insert(stack.begin(), owner_->_ivalue()); // self
281     function_->run(stack);
282     if (observer) {
283       observer->onExitRunMethod(
284           copied_metadata, function_->name(), instance_key);
285     }
286     failure_guard.release();
287     // This exception must be caught first as it derived from c10::Error
288   } catch (c10::BackendRuntimeException& e) {
289 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
290     for (auto handle : function_->getExceptionDebugHandles()) {
291       e.pushDebugHandle(handle);
292     }
293     // symbolicate all handles
294     auto debug_string = owner_->getDebugTable().getSourceDebugString(
295         e.getDebugHandles(), getTopModuleTypeName(*owner_));
296     e.add_context(debug_string);
297 #endif
298     error_message = e.what();
299     TORCH_RETHROW(e);
300   } catch (c10::Error& error) {
301 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
302     auto debug_string = owner_->getDebugTable().getSourceDebugString(
303         function_->getExceptionDebugHandles(), getTopModuleTypeName(*owner_));
304     error.add_context(debug_string);
305 #endif
306     error_message = error.what();
307     TORCH_RETHROW(error);
308   }
309 }
310 
operator ()(std::vector<c10::IValue> stack) const311 c10::IValue Method::operator()(std::vector<c10::IValue> stack) const {
312   run(stack);
313   TORCH_INTERNAL_ASSERT(!stack.empty());
314   return stack.front();
315 }
316 
print_type(const c10::Type & t)317 static std::optional<std::string> print_type(const c10::Type& t) {
318   auto namedType = t.cast<c10::NamedType>();
319   if (namedType && namedType->name()) {
320     return namedType->name().value().qualifiedName();
321   }
322   if (auto dyn = t.castRaw<c10::DynamicType>()) {
323     return dyn->fallback()->annotation_str();
324   }
325   return std::nullopt;
326 }
327 
get_module_info(const mobile::Module & module)328 TORCH_API ModuleInfo get_module_info(const mobile::Module& module) {
329   ModuleInfo minfo;
330   minfo.operator_version = module.min_operator_version();
331   minfo.bytecode_version = module.bytecode_version();
332   std::vector<std::string> type_name_list;
333   for (const auto& func_ptr : module.compilation_unit().methods()) {
334     const auto& function = *func_ptr;
335     for (const auto i : c10::irange(function.get_code().op_names_.size())) {
336       const auto& op = function.get_code().op_names_[i];
337       minfo.opname_to_num_args[mobile::operator_str(op)] =
338           function.get_code().operator_input_sizes_[i];
339     }
340     for (const c10::TypePtr& tp : function.get_code().types_) {
341       type_name_list.push_back(tp->annotation_str(print_type));
342     }
343     minfo.function_names.insert(function.qualname().qualifiedName());
344   }
345   c10::TypeParser parser(type_name_list);
346   parser.parseList();
347   minfo.type_names = parser.getContainedTypes();
348   return minfo;
349 }
350 
351 } // namespace mobile
352 } // namespace torch::jit
353