xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/model_tracer/CustomClassTracer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h>
2 #include <mutex>
3 
4 namespace torch::jit::mobile {
CustomClassTracer()5 CustomClassTracer::CustomClassTracer() {
6   auto recorder_cb =
7       [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
8     std::string name = fn.name();
9     getLoadedClasses().withLock(
10         [&name](CustomClassTracer::custom_classes_type& custom_classes) {
11           custom_classes.insert(name);
12         });
13     return nullptr;
14   };
15 
16   handle_ = at::addGlobalCallback(at::RecordFunctionCallback(recorder_cb)
17                                       .scopes({at::RecordScope::CUSTOM_CLASS}));
18 }
19 
20 c10::Synchronized<CustomClassTracer::custom_classes_type>& CustomClassTracer::
getLoadedClasses()21     getLoadedClasses() {
22   static c10::Synchronized<custom_classes_type> loaded_classes;
23   return loaded_classes;
24 }
25 
26 } // namespace torch::jit::mobile
27