1 #pragma once 2 3 #include <torch/csrc/profiler/stubs/base.h> 4 #include <torch/csrc/profiler/util.h> 5 #include <cstdint> 6 7 namespace torch { 8 namespace profiler { 9 namespace impl { 10 namespace vulkan { 11 12 // Using function pointer i.e. [std::tuple<std::string, uint64_t> (*)(int64_t)] 13 // doesn't work because we need to capture the QueryPool in the lambda context 14 // https://stackoverflow.com/a/28746827 15 using GetShaderNameAndDurationNsFn = 16 std::function<std::tuple<std::string, uint64_t>(int64_t)>; 17 TORCH_API void registerGetShaderNameAndDurationNs( 18 GetShaderNameAndDurationNsFn get_shader_name_and_duration_ns); 19 20 TORCH_API void deregisterGetShaderNameAndDurationNs(); 21 22 std::tuple<std::string, uint64_t> getShaderNameAndDurationNs( 23 const vulkan_id_t& vulkan_id); 24 25 } // namespace vulkan 26 } // namespace impl 27 } // namespace profiler 28 } // namespace torch 29