1 #include <torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h> 2 #include <mutex> 3 4 namespace torch::jit::mobile { CustomClassTracer()5CustomClassTracer::CustomClassTracer() { 6 auto recorder_cb = 7 [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> { 8 std::string name = fn.name(); 9 getLoadedClasses().withLock( 10 [&name](CustomClassTracer::custom_classes_type& custom_classes) { 11 custom_classes.insert(name); 12 }); 13 return nullptr; 14 }; 15 16 handle_ = at::addGlobalCallback(at::RecordFunctionCallback(recorder_cb) 17 .scopes({at::RecordScope::CUSTOM_CLASS})); 18 } 19 20 c10::Synchronized<CustomClassTracer::custom_classes_type>& CustomClassTracer:: getLoadedClasses()21 getLoadedClasses() { 22 static c10::Synchronized<custom_classes_type> loaded_classes; 23 return loaded_classes; 24 } 25 26 } // namespace torch::jit::mobile 27