1 #pragma once 2 3 #include <c10/util/ThreadLocalDebugInfo.h> 4 #include <string> 5 #include <unordered_map> 6 #include <vector> 7 8 namespace torch { 9 10 class MobileDebugInfo : public c10::DebugInfoBase { 11 public: getModelName()12 const std::string& getModelName() { 13 return model_name_; 14 } 15 setModelName(const std::string & model_name)16 void setModelName(const std::string& model_name) { 17 model_name_ = model_name; 18 } 19 getMethodName()20 const std::string& getMethodName() { 21 return method_name_; 22 } 23 setMethodName(const std::string & method_name)24 void setMethodName(const std::string& method_name) { 25 method_name_ = method_name; 26 } 27 getOpIdx()28 size_t getOpIdx() { 29 return op_idx_; 30 } 31 setOpIdx(size_t op_idx)32 void setOpIdx(size_t op_idx) { 33 op_idx_ = op_idx; 34 } 35 36 private: 37 std::string model_name_; 38 std::string method_name_; 39 // TODO: Kimish 40 // If we launch a thread such as for at::launch, interepter continuation 41 // and if the caching allocator is enabled in the base thread 42 // then, in order to propagate this information, that is caching allocator 43 // is enabled, across thread boundaries we can use the mechanism provided 44 // by ThreadLocalDebugInfo 45 // Once the thread local MobileDebugInfo is accessible in the launched 46 // thread, it can be accessed in that thread and that thread can set 47 // its own thread local CachingAllocatorInfo. 48 // However, we cannot expect every launched thread to extract and set 49 // its own thread local copy of CachingAllocatorInfo. 50 // But this can be done in lite interpreter, where in the run method 51 // it can do info = 52 // c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::MOBILE_RUNTIME_INFO)) 53 // .get_caching_allocator_info(); 54 // GetThreadLocalCachingAllocatorInfo() = info; 55 // Other option is to have MobileDebugInfo itself be the place where thread 56 // local copy of CachingAllocatorInfo is stored. Then 57 // DefaultMobileCPUAllocator inspects this to decide if to use 58 // CachingAllocator. However, current lite interpreter does not support FORK, 59 // thus from the run method of lite interpreter we are not really gonna launch 60 // another instance of lite interpreter in a different thread. So for now not 61 // getting bothered about passing CachingAllocatorInfo across thread 62 // boundaries. c10::CachingAllocatorInfo caching_allocator_info; 63 size_t op_idx_ = 0; 64 }; 65 66 class MobileModuleObserver { 67 public: 68 virtual ~MobileModuleObserver() = default; 69 onEnterRunMethod(const int32_t)70 virtual void onEnterRunMethod(const int32_t) {} onExitRunMethod(const std::unordered_map<std::string,std::string> &,const std::string &,const int32_t)71 virtual void onExitRunMethod( 72 const std::unordered_map<std::string, std::string>&, 73 const std::string&, 74 const int32_t) {} onFailRunMethod(const std::unordered_map<std::string,std::string> &,const std::string &,const int32_t,const char *)75 virtual void onFailRunMethod( 76 const std::unordered_map<std::string, std::string>&, 77 const std::string&, 78 const int32_t, 79 const char*) {} onEnterLoadModel(const int32_t)80 virtual void onEnterLoadModel(const int32_t) {} onExitLoadModel(const int32_t,const std::unordered_map<std::string,std::string> &)81 virtual void onExitLoadModel( 82 const int32_t, 83 const std::unordered_map<std::string, std::string>&) { 84 } // key: filename, value: file content onFailLoadModel(const int32_t,const char *)85 virtual void onFailLoadModel(const int32_t, const char*) {} onFailLoadModel(const int32_t,const char *,const std::unordered_map<std::string,std::string> &)86 virtual void onFailLoadModel( 87 const int32_t, 88 const char*, 89 const std::unordered_map<std::string, std::string>&) {} 90 virtual std::vector<std::string> getDefaultExtraFiles() = 0; 91 virtual std::unordered_map<std::string, std::string> processMetadataFromExtra( 92 const std::unordered_map<std::string, std::string>&) = 0; 93 }; 94 95 class MobileObserverConfig { 96 public: setModuleObserver(std::unique_ptr<MobileModuleObserver> reporter)97 void setModuleObserver(std::unique_ptr<MobileModuleObserver> reporter) { 98 module_observer_ = std::move(reporter); 99 } getModuleObserver()100 MobileModuleObserver* getModuleObserver() { 101 return module_observer_.get(); 102 } 103 104 private: 105 std::unique_ptr<MobileModuleObserver> module_observer_; 106 }; 107 108 MobileObserverConfig& observerConfig(); 109 110 } // namespace torch 111