1 #include <torch/csrc/jit/mobile/module.h>
2
3 #include <torch/csrc/jit/backends/backend_exception.h>
4 #include <torch/csrc/jit/mobile/interpreter.h>
5 #include <torch/csrc/jit/mobile/observer.h>
6 #include <torch/csrc/jit/mobile/type_parser.h>
7 #include <torch/csrc/jit/runtime/jit_exception.h>
8
9 #include <ATen/record_function.h>
10 #include <c10/util/ScopeExit.h>
11 #include <c10/util/irange.h>
12
13 namespace torch::jit {
14 std::ostream& operator<<(std::ostream& out, Instruction inst);
15 namespace mobile {
16
register_function(std::unique_ptr<Function> fn)17 void CompilationUnit::register_function(std::unique_ptr<Function> fn) {
18 methods_.emplace_back(std::move(fn));
19 }
20
find_function(const c10::QualifiedName & qn) const21 const Function* CompilationUnit::find_function(
22 const c10::QualifiedName& qn) const {
23 for (auto& fn : methods_) {
24 if (fn->qualname() == qn) {
25 return fn.get();
26 }
27 }
28 return nullptr;
29 }
30
find_function(const c10::QualifiedName & qn)31 Function* CompilationUnit::find_function(const c10::QualifiedName& qn) {
32 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
33 return const_cast<Function*>(
34 static_cast<const CompilationUnit*>(this)->find_function(qn));
35 }
36
get_method(const std::string & name) const37 Method Module::get_method(const std::string& name) const {
38 if (auto method = find_method(name)) {
39 return *method;
40 }
41 AT_ERROR("Method '", name, "' is not defined.");
42 }
43
compareMethodSchemas(const std::string & name_1,const std::string & name_2)44 bool Module::compareMethodSchemas(
45 const std::string& name_1,
46 const std::string& name_2) {
47 std::optional<c10::FunctionSchema> schema_1, schema_2;
48 for (const auto& fn : cu_->methods()) {
49 if (fn->name() == name_1) {
50 schema_1 = fn->getSchema();
51 }
52 if (fn->name() == name_2) {
53 schema_2 = fn->getSchema();
54 }
55 }
56 if (schema_1.has_value() && schema_2.has_value()) {
57 return (schema_1 == schema_2);
58 }
59 return false;
60 }
61
unsafeRemoveMethod(const std::string & basename)62 void Module::unsafeRemoveMethod(const std::string& basename) {
63 int64_t i = 0;
64 for (; i < static_cast<int64_t>(cu_->methods().size()); ++i) {
65 if ((cu_->methods()[i])->name() == basename) {
66 break;
67 }
68 }
69 object_->type()->unsafeRemoveMethod(basename);
70 cu_->unsafeRemoveFunction(i);
71 }
72
unsafeCopyMethod(const std::string & new_method_name,const Function & to_be_copied)73 void Module::unsafeCopyMethod(
74 const std::string& new_method_name,
75 const Function& to_be_copied) {
76 TORCH_CHECK(
77 !find_method(new_method_name).has_value(),
78 "Trying to replace existing method.");
79 const c10::QualifiedName& tobe_copied_name = to_be_copied.qualname();
80 c10::QualifiedName qualified_method_name(
81 tobe_copied_name.prefix(), new_method_name);
82 std::unique_ptr<Function> new_fn = std::make_unique<Function>(
83 qualified_method_name, to_be_copied.get_code(), to_be_copied.getSchema());
84 object_->type()->addMethod(new_fn.get());
85 cu_->register_function(std::move(new_fn));
86 }
87
find_method(const std::string & basename) const88 std::optional<Method> Module::find_method(const std::string& basename) const {
89 for (const auto& fn : cu_->methods()) {
90 if (fn->name() == basename) {
91 return std::make_optional<Method>(Method(this, fn.get()));
92 }
93 }
94 return std::nullopt;
95 }
96
97 namespace {
98 // For JIT, there is a private function to get all modules by iteration in
99 // struct slot_iterator_impl (jit/api/module.h). The following function use
100 // recursion to mimic the logic without allocating extra memory to get module
101 // list and set training attribute directly.
set_train_recurse(const c10::intrusive_ptr<c10::ivalue::Object> & obj,bool on)102 void set_train_recurse(
103 const c10::intrusive_ptr<c10::ivalue::Object>& obj,
104 bool on) {
105 if (auto slot = obj->type()->findAttributeSlot("training")) {
106 obj->setSlot(*slot, on);
107 } else {
108 TORCH_INTERNAL_ASSERT(
109 false,
110 "'training' attribute not found. Did you accidentally "
111 "call .eval() before saving your model?");
112 }
113 for (const auto& slot : obj->slots()) {
114 // slots is a list of IValue. Continue setting training attribute only
115 // if the slot is an object and a module.
116 if (slot.isObject() && slot.toObjectRef().type()->is_module()) {
117 set_train_recurse(slot.toObject(), on);
118 }
119 }
120 }
121
slot_params_recurse(const c10::intrusive_ptr<c10::ivalue::Object> & obj,std::vector<at::Tensor> * params)122 void slot_params_recurse(
123 const c10::intrusive_ptr<c10::ivalue::Object>& obj,
124 std::vector<at::Tensor>* params) {
125 for (const auto& slot : obj->slots()) {
126 if (slot.isTensor()) {
127 params->emplace_back(slot.toTensor());
128 } else if (slot.isObject()) {
129 slot_params_recurse(slot.toObject(), params);
130 }
131 }
132 }
133
slot_named_params_recurse(const c10::intrusive_ptr<c10::ivalue::Object> & obj,std::map<std::string,at::Tensor> * params,const std::string & parent_name)134 void slot_named_params_recurse(
135 const c10::intrusive_ptr<c10::ivalue::Object>& obj,
136 std::map<std::string, at::Tensor>* params,
137 const std::string& parent_name) {
138 auto slots = obj->slots();
139 size_t nslots = slots.size();
140 for (const auto i : c10::irange(nslots)) {
141 auto slot = slots[i];
142 std::string name = parent_name.empty() ? parent_name : parent_name + ".";
143 name += obj->type()->getAttributeName(i);
144 // TODO: Fix this filter. Requires_grad is not the appropriate
145 // filter of a parameter, but is a temporary hack to help probable
146 // users of this api. The correct behavior is to filter by the
147 // obj->type->is_parameter() but this currently always returns
148 // false on mobile.
149 if (slot.isTensor() && slot.toTensor().requires_grad()) {
150 (*params)[name] = slot.toTensor();
151 } else if (slot.isObject()) {
152 slot_named_params_recurse(slot.toObject(), params, name);
153 }
154 }
155 }
156
157 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
getTopModuleTypeName(const Module & m)158 std::string getTopModuleTypeName(const Module& m) {
159 std::string name;
160 if (m._ivalue()->type() && m._ivalue()->type()->name()) {
161 name = m._ivalue()->type()->name().value().name();
162 }
163 return name;
164 }
165 #endif
166
167 } // namespace
168
parameters() const169 const std::vector<at::Tensor> Module::parameters() const {
170 std::vector<at::Tensor> params;
171 slot_params_recurse(object_, ¶ms);
172 return params;
173 }
174
175 // Returns a mapping for all attributes that requires_grad=True in a module.
176 // This behavior differs from full torch script modules. This is a bug,
177 // but currently there is no way to correctly label parameters in the
178 // loading of a mobile module. TODO
named_parameters() const179 const std::map<std::string, at::Tensor> Module::named_parameters() const {
180 std::map<std::string, at::Tensor> params;
181 const std::string name = "";
182 slot_named_params_recurse(object_, ¶ms, name);
183 return params;
184 }
185
getModuleHierarchy(const int64_t debug_handle) const186 std::string Module::getModuleHierarchy(const int64_t debug_handle) const {
187 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
188 return getDebugTable().getModuleHierarchyInfo(
189 debug_handle, getTopModuleTypeName(*this));
190 #else
191 return "";
192 #endif
193 }
194
getCallStack(const int64_t debug_handle) const195 std::string Module::getCallStack(const int64_t debug_handle) const {
196 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
197 return getDebugTable().getSourceDebugString(
198 debug_handle, getTopModuleTypeName(*this));
199 #else
200 return "";
201 #endif
202 }
203
204 // We will continue to support this API for now as this is being relied upon
205 // for profiling.
206 // We really need to change this part, so in the next step for profiling support
207 // for delegates, the first thing will be to rewrite how profiling is done
208 // for lite interpreter.
get_forward_method_debug_info(int64_t debug_handle) const209 std::string Module::get_forward_method_debug_info(int64_t debug_handle) const {
210 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
211 return getDebugTable().getModuleHierarchyInfo(
212 debug_handle, getTopModuleTypeName(*this));
213 #else
214 return "";
215 #endif
216 }
217
train(bool on)218 void Module::train(bool on) {
219 set_train_recurse(object_, on);
220 }
221
is_training() const222 bool Module::is_training() const {
223 if (auto slot = object_->type()->findAttributeSlot("training")) {
224 return object_->getSlot(*slot).toBool();
225 }
226 return true;
227 }
228
get_methods() const229 const std::vector<Method> Module::get_methods() const {
230 std::vector<Method> methods;
231 for (std::unique_ptr<Function>& fn : cu_->methods()) {
232 methods.emplace_back(this, fn.get());
233 }
234 return methods;
235 }
236
Method(const Module * owner,Function * function)237 Method::Method(const Module* owner, Function* function)
238 : owner_(owner), function_(function) {}
239
run(Stack & stack) const240 void Method::run(Stack& stack) const {
241 auto observer = torch::observerConfig().getModuleObserver();
242 // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
243 auto instance_key = std::rand();
244 /* if the metadata dict doesn't contain "model_name", copy the metadata and
245 set the value of "model_name" as name() */
246 std::unordered_map<std::string, std::string> copied_metadata =
247 owner_->getMetadata();
248
249 if (observer) {
250 observer->onEnterRunMethod(instance_key);
251 }
252
253 auto debug_info = std::make_shared<MobileDebugInfo>();
254 std::string name = copied_metadata["model_name"];
255 debug_info->setModelName(name);
256 debug_info->setMethodName(function_->name());
257 at::DebugInfoGuard guard(at::DebugInfoKind::MOBILE_RUNTIME_INFO, debug_info);
258
259 std::string error_message;
260 auto failure_guard = c10::make_scope_exit([&]() {
261 if (!observer) {
262 return;
263 }
264
265 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
266 if (error_message.empty()) {
267 error_message = owner_->getDebugTable().getSourceDebugString(
268 function_->getExceptionDebugHandles(), getTopModuleTypeName(*owner_));
269 }
270 #endif
271
272 observer->onFailRunMethod(
273 copied_metadata,
274 function_->name(),
275 instance_key,
276 error_message.empty() ? "Unknown exception" : error_message.c_str());
277 });
278
279 try {
280 stack.insert(stack.begin(), owner_->_ivalue()); // self
281 function_->run(stack);
282 if (observer) {
283 observer->onExitRunMethod(
284 copied_metadata, function_->name(), instance_key);
285 }
286 failure_guard.release();
287 // This exception must be caught first as it derived from c10::Error
288 } catch (c10::BackendRuntimeException& e) {
289 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
290 for (auto handle : function_->getExceptionDebugHandles()) {
291 e.pushDebugHandle(handle);
292 }
293 // symbolicate all handles
294 auto debug_string = owner_->getDebugTable().getSourceDebugString(
295 e.getDebugHandles(), getTopModuleTypeName(*owner_));
296 e.add_context(debug_string);
297 #endif
298 error_message = e.what();
299 TORCH_RETHROW(e);
300 } catch (c10::Error& error) {
301 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
302 auto debug_string = owner_->getDebugTable().getSourceDebugString(
303 function_->getExceptionDebugHandles(), getTopModuleTypeName(*owner_));
304 error.add_context(debug_string);
305 #endif
306 error_message = error.what();
307 TORCH_RETHROW(error);
308 }
309 }
310
operator ()(std::vector<c10::IValue> stack) const311 c10::IValue Method::operator()(std::vector<c10::IValue> stack) const {
312 run(stack);
313 TORCH_INTERNAL_ASSERT(!stack.empty());
314 return stack.front();
315 }
316
print_type(const c10::Type & t)317 static std::optional<std::string> print_type(const c10::Type& t) {
318 auto namedType = t.cast<c10::NamedType>();
319 if (namedType && namedType->name()) {
320 return namedType->name().value().qualifiedName();
321 }
322 if (auto dyn = t.castRaw<c10::DynamicType>()) {
323 return dyn->fallback()->annotation_str();
324 }
325 return std::nullopt;
326 }
327
get_module_info(const mobile::Module & module)328 TORCH_API ModuleInfo get_module_info(const mobile::Module& module) {
329 ModuleInfo minfo;
330 minfo.operator_version = module.min_operator_version();
331 minfo.bytecode_version = module.bytecode_version();
332 std::vector<std::string> type_name_list;
333 for (const auto& func_ptr : module.compilation_unit().methods()) {
334 const auto& function = *func_ptr;
335 for (const auto i : c10::irange(function.get_code().op_names_.size())) {
336 const auto& op = function.get_code().op_names_[i];
337 minfo.opname_to_num_args[mobile::operator_str(op)] =
338 function.get_code().operator_input_sizes_[i];
339 }
340 for (const c10::TypePtr& tp : function.get_code().types_) {
341 type_name_list.push_back(tp->annotation_str(print_type));
342 }
343 minfo.function_names.insert(function.qualname().qualifiedName());
344 }
345 c10::TypeParser parser(type_name_list);
346 parser.parseList();
347 minfo.type_names = parser.getContainedTypes();
348 return minfo;
349 }
350
351 } // namespace mobile
352 } // namespace torch::jit
353