1 #include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h> 2 #include <map> 3 #include <mutex> 4 #include <set> 5 #include <string> 6 7 namespace torch::jit::mobile { KernelDTypeTracer()8KernelDTypeTracer::KernelDTypeTracer() { 9 auto recorder_cb = 10 [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> { 11 std::string name = fn.name(); 12 size_t dollar_pos = name.find_first_of('$'); 13 std::string kernel_tag = name.substr(0, dollar_pos); 14 std::string dtype = name.substr(dollar_pos + 1); 15 16 getCalledKernelTags().withLock([&](kernel_tags_type& kernel_tags) { 17 kernel_tags[kernel_tag].insert(dtype); 18 }); 19 return nullptr; 20 }; 21 22 handle_ = at::addGlobalCallback( 23 at::RecordFunctionCallback(recorder_cb) 24 .scopes({at::RecordScope::KERNEL_FUNCTION_DTYPE})); 25 } 26 27 c10::Synchronized<KernelDTypeTracer::kernel_tags_type>& KernelDTypeTracer:: getCalledKernelTags()28 getCalledKernelTags() { 29 static c10::Synchronized<kernel_tags_type> called_kernel_tags; 30 return called_kernel_tags; 31 } 32 33 } // namespace torch::jit::mobile 34