xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/record_function.h>
4 #include <c10/util/Synchronized.h>
5 #include <map>
6 #include <set>
7 #include <string>
8 
9 namespace torch::jit::mobile {
10 
11 /* The CustomClassTracer class handles the attachment and removal of a recording
12  * callback that traces the invocation of code that handles loading custom
13  * classes on mobile.
14  *
15  * You can get the set of used custom classes using
16  * getLoadedClasses().
17  *
18  * Note: This class is not thread safe or re-entrant, and should not be used
19  * across multiple threads of execution.
20  *
21  */
22 struct CustomClassTracer final {
23   at::CallbackHandle handle_;
24   /* These are the custom class names (constant
25    * character string) which shows up in code.
26    */
27   typedef std::set<std::string> custom_classes_type;
28 
29   CustomClassTracer();
30   static c10::Synchronized<custom_classes_type>& getLoadedClasses();
31 
~CustomClassTracerfinal32   ~CustomClassTracer() {
33     at::removeCallback(handle_);
34   }
35 };
36 
37 } // namespace torch::jit::mobile
38