xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/observer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/ThreadLocalDebugInfo.h>
4 #include <string>
5 #include <unordered_map>
6 #include <vector>
7 
8 namespace torch {
9 
10 class MobileDebugInfo : public c10::DebugInfoBase {
11  public:
getModelName()12   const std::string& getModelName() {
13     return model_name_;
14   }
15 
setModelName(const std::string & model_name)16   void setModelName(const std::string& model_name) {
17     model_name_ = model_name;
18   }
19 
getMethodName()20   const std::string& getMethodName() {
21     return method_name_;
22   }
23 
setMethodName(const std::string & method_name)24   void setMethodName(const std::string& method_name) {
25     method_name_ = method_name;
26   }
27 
getOpIdx()28   size_t getOpIdx() {
29     return op_idx_;
30   }
31 
setOpIdx(size_t op_idx)32   void setOpIdx(size_t op_idx) {
33     op_idx_ = op_idx;
34   }
35 
36  private:
37   std::string model_name_;
38   std::string method_name_;
39   // TODO: Kimish
40   // If we launch a thread such as for at::launch, interepter continuation
41   // and if the caching allocator is enabled in the base thread
42   // then, in order to propagate this information, that is caching allocator
43   // is enabled, across thread boundaries we can use the mechanism provided
44   // by ThreadLocalDebugInfo
45   // Once the thread local MobileDebugInfo is accessible in the launched
46   // thread, it can be accessed in that thread and that thread can set
47   // its own thread local CachingAllocatorInfo.
48   // However, we cannot expect every launched thread to extract and set
49   // its own thread local copy of CachingAllocatorInfo.
50   // But this can be done in lite interpreter, where in the run method
51   // it can do info =
52   // c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::MOBILE_RUNTIME_INFO))
53   // .get_caching_allocator_info();
54   // GetThreadLocalCachingAllocatorInfo() = info;
55   // Other option is to have MobileDebugInfo itself be the place where thread
56   // local copy of CachingAllocatorInfo is stored. Then
57   // DefaultMobileCPUAllocator inspects this to decide if to use
58   // CachingAllocator. However, current lite interpreter does not support FORK,
59   // thus from the run method of lite interpreter we are not really gonna launch
60   // another instance of lite interpreter in a different thread. So for now not
61   // getting bothered about passing CachingAllocatorInfo across thread
62   // boundaries. c10::CachingAllocatorInfo caching_allocator_info;
63   size_t op_idx_ = 0;
64 };
65 
66 class MobileModuleObserver {
67  public:
68   virtual ~MobileModuleObserver() = default;
69 
onEnterRunMethod(const int32_t)70   virtual void onEnterRunMethod(const int32_t) {}
onExitRunMethod(const std::unordered_map<std::string,std::string> &,const std::string &,const int32_t)71   virtual void onExitRunMethod(
72       const std::unordered_map<std::string, std::string>&,
73       const std::string&,
74       const int32_t) {}
onFailRunMethod(const std::unordered_map<std::string,std::string> &,const std::string &,const int32_t,const char *)75   virtual void onFailRunMethod(
76       const std::unordered_map<std::string, std::string>&,
77       const std::string&,
78       const int32_t,
79       const char*) {}
onEnterLoadModel(const int32_t)80   virtual void onEnterLoadModel(const int32_t) {}
onExitLoadModel(const int32_t,const std::unordered_map<std::string,std::string> &)81   virtual void onExitLoadModel(
82       const int32_t,
83       const std::unordered_map<std::string, std::string>&) {
84   } // key: filename, value: file content
onFailLoadModel(const int32_t,const char *)85   virtual void onFailLoadModel(const int32_t, const char*) {}
onFailLoadModel(const int32_t,const char *,const std::unordered_map<std::string,std::string> &)86   virtual void onFailLoadModel(
87       const int32_t,
88       const char*,
89       const std::unordered_map<std::string, std::string>&) {}
90   virtual std::vector<std::string> getDefaultExtraFiles() = 0;
91   virtual std::unordered_map<std::string, std::string> processMetadataFromExtra(
92       const std::unordered_map<std::string, std::string>&) = 0;
93 };
94 
95 class MobileObserverConfig {
96  public:
setModuleObserver(std::unique_ptr<MobileModuleObserver> reporter)97   void setModuleObserver(std::unique_ptr<MobileModuleObserver> reporter) {
98     module_observer_ = std::move(reporter);
99   }
getModuleObserver()100   MobileModuleObserver* getModuleObserver() {
101     return module_observer_.get();
102   }
103 
104  private:
105   std::unique_ptr<MobileModuleObserver> module_observer_;
106 };
107 
108 MobileObserverConfig& observerConfig();
109 
110 } // namespace torch
111