xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h>
2 
3 namespace torch::jit::mobile {
OperatorCallTracer()4 OperatorCallTracer::OperatorCallTracer() {
5   getCalledOperators().withLock([](std::set<std::string>& called_operators) {
6     called_operators.clear();
7   });
8 
9   auto recorder_cb =
10       [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
11     std::optional<c10::OperatorName> op_name = fn.operator_name();
12     if (op_name.has_value()) {
13       getCalledOperators().withLock(
14           [op_name](std::set<std::string>& called_operators) {
15             called_operators.insert(c10::toString(*op_name));
16           });
17     }
18     return nullptr;
19   };
20 
21   handle_ = at::addGlobalCallback(at::RecordFunctionCallback(recorder_cb)
22                                       .scopes({at::RecordScope::FUNCTION}));
23 }
24 
25 } // namespace torch::jit::mobile
26