xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/orchestration/vulkan.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/profiler/stubs/base.h>
4 #include <torch/csrc/profiler/util.h>
5 #include <cstdint>
6 
7 namespace torch {
8 namespace profiler {
9 namespace impl {
10 namespace vulkan {
11 
12 // Using function pointer i.e. [std::tuple<std::string, uint64_t> (*)(int64_t)]
13 // doesn't work because we need to capture the QueryPool in the lambda context
14 // https://stackoverflow.com/a/28746827
15 using GetShaderNameAndDurationNsFn =
16     std::function<std::tuple<std::string, uint64_t>(int64_t)>;
17 TORCH_API void registerGetShaderNameAndDurationNs(
18     GetShaderNameAndDurationNsFn get_shader_name_and_duration_ns);
19 
20 TORCH_API void deregisterGetShaderNameAndDurationNs();
21 
22 std::tuple<std::string, uint64_t> getShaderNameAndDurationNs(
23     const vulkan_id_t& vulkan_id);
24 
25 } // namespace vulkan
26 } // namespace impl
27 } // namespace profiler
28 } // namespace torch
29