xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/api/QueryPool.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
4 
5 #include <functional>
6 #ifdef USE_VULKAN_API
7 
8 #include <ATen/native/vulkan/api/vk_api.h>
9 
10 #include <ATen/native/vulkan/api/Adapter.h>
11 #include <ATen/native/vulkan/api/Command.h>
12 #include <ATen/native/vulkan/api/Pipeline.h>
13 
14 namespace at {
15 namespace native {
16 namespace vulkan {
17 namespace api {
18 
19 struct QueryPoolConfig final {
20   uint32_t maxQueryCount;
21   uint32_t initialReserveSize;
22 };
23 
24 struct ShaderDuration final {
25   uint32_t idx;
26 
27   // Execution Properties
28   std::string kernel_name;
29   VkExtent3D global_workgroup_size;
30   VkExtent3D local_workgroup_size;
31 
32   // Query indexes
33   uint32_t start_query_idx;
34   uint32_t end_query_idx;
35 
36   // Timings
37   uint64_t start_time_ns;
38   uint64_t end_time_ns;
39   uint64_t execution_duration_ns;
40 };
41 
42 class QueryPool final {
43  public:
44   explicit QueryPool(const QueryPoolConfig&, const Adapter* adapter_p);
45 
46   QueryPool(const QueryPool&) = delete;
47   QueryPool& operator=(const QueryPool&) = delete;
48 
49   QueryPool(QueryPool&&) = delete;
50   QueryPool& operator=(QueryPool&&) = delete;
51 
52   ~QueryPool();
53 
54  private:
55   std::mutex mutex_;
56 
57   VkDevice device_;
58   QueryPoolConfig config_;
59 
60   VkQueryPool querypool_;
61 
62   std::vector<std::vector<ShaderDuration>> shader_logs_;
63   size_t in_use_;
64 
65   /** Total number of entries in shader logs from before most recent reset */
66   size_t previous_shader_count_;
67 
68   /**
69    * Indicates whether there are new log entries in the shader log since the
70    * most recent call to extract_results()
71    */
72   bool results_pending_;
73 
74  private:
75   size_t write_timestamp(const CommandBuffer&);
76 
77   std::string generate_string_report();
78 
79   /** Most recent shader log since the last time the QueryPool was reset */
shader_log()80   inline std::vector<ShaderDuration>& shader_log() {
81     return shader_logs_[shader_logs_.size() - 1];
82   }
83 
84   /** Total number of entries in all shader logs, but without locking mutex */
85   size_t shader_logs_entry_count_thread_unsafe();
86 
87  public:
is_enabled()88   inline bool is_enabled() const {
89     return VK_NULL_HANDLE != querypool_;
90   }
91 
92   void reset(const CommandBuffer&);
93 
94   uint32_t shader_profile_begin(
95       const CommandBuffer&,
96       const std::string&,
97       const VkExtent3D,
98       const VkExtent3D);
99   void shader_profile_end(const CommandBuffer&, const uint32_t);
100 
101   void extract_results();
102   void print_results();
103   uint64_t get_total_op_ns(const std::string& op_name);
104   uint64_t ns_per_tick_;
105   void shader_log_for_each(std::function<void(const ShaderDuration&)> fn);
106   /**
107    * query_index is what number entry across all of the QueryPool's shader logs
108    * is being queried, regardless of resets. This may be different than
109    * ShaderDuration's idx field, which is what number entry it is since the last
110    * reset before it was added to the shader logs.
111    */
112   std::tuple<std::string, uint64_t> get_shader_name_and_execution_duration_ns(
113       size_t query_index);
114   /** Total number of entries in all shader logs */
115   size_t shader_logs_entry_count();
116 };
117 
118 } // namespace api
119 } // namespace vulkan
120 } // namespace native
121 } // namespace at
122 
123 #endif /* USE_VULKAN_API */
124