1 #pragma once 2 3 #include <ATen/record_function.h> 4 #include <c10/util/Synchronized.h> 5 6 namespace torch::jit::mobile { 7 /* The OperatorCallTracer class handles the attachment and removal of a 8 * recording callback that traces invocation of ATen (and other) PyTorch 9 * operators that get called via the Dispatcher. 10 * 11 * You can get the set of operators that were called (op_name.overload_name) 12 * using getCalledOperators(). 13 * 14 * Note: This class is not thread safe or re-entrant, and should not be used 15 * across multiple threads of execution. 16 * 17 */ 18 struct OperatorCallTracer final { 19 at::CallbackHandle handle_; 20 21 OperatorCallTracer(); 22 getCalledOperatorsfinal23 static c10::Synchronized<std::set<std::string>>& getCalledOperators() { 24 static c10::Synchronized<std::set<std::string>> called_operators_; 25 return called_operators_; 26 } 27 ~OperatorCallTracerfinal28 ~OperatorCallTracer() { 29 at::removeCallback(handle_); 30 } 31 }; 32 } // namespace torch::jit::mobile 33