xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
2 #include <map>
3 #include <mutex>
4 #include <set>
5 #include <string>
6 
7 namespace torch::jit::mobile {
KernelDTypeTracer()8 KernelDTypeTracer::KernelDTypeTracer() {
9   auto recorder_cb =
10       [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
11     std::string name = fn.name();
12     size_t dollar_pos = name.find_first_of('$');
13     std::string kernel_tag = name.substr(0, dollar_pos);
14     std::string dtype = name.substr(dollar_pos + 1);
15 
16     getCalledKernelTags().withLock([&](kernel_tags_type& kernel_tags) {
17       kernel_tags[kernel_tag].insert(dtype);
18     });
19     return nullptr;
20   };
21 
22   handle_ = at::addGlobalCallback(
23       at::RecordFunctionCallback(recorder_cb)
24           .scopes({at::RecordScope::KERNEL_FUNCTION_DTYPE}));
25 }
26 
27 c10::Synchronized<KernelDTypeTracer::kernel_tags_type>& KernelDTypeTracer::
getCalledKernelTags()28     getCalledKernelTags() {
29   static c10::Synchronized<kernel_tags_type> called_kernel_tags;
30   return called_kernel_tags;
31 }
32 
33 } // namespace torch::jit::mobile
34