xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/record_function.h>
4 #include <c10/util/Synchronized.h>
5 
6 namespace torch::jit::mobile {
7 /* The OperatorCallTracer class handles the attachment and removal of a
8  * recording callback that traces invocation of ATen (and other) PyTorch
9  * operators that get called via the Dispatcher.
10  *
11  * You can get the set of operators that were called (op_name.overload_name)
12  * using getCalledOperators().
13  *
14  * Note: This class is not thread safe or re-entrant, and should not be used
15  * across multiple threads of execution.
16  *
17  */
18 struct OperatorCallTracer final {
19   at::CallbackHandle handle_;
20 
21   OperatorCallTracer();
22 
getCalledOperatorsfinal23   static c10::Synchronized<std::set<std::string>>& getCalledOperators() {
24     static c10::Synchronized<std::set<std::string>> called_operators_;
25     return called_operators_;
26   }
27 
~OperatorCallTracerfinal28   ~OperatorCallTracer() {
29     at::removeCallback(handle_);
30   }
31 };
32 } // namespace torch::jit::mobile
33