xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/api/Runtime.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
4 
5 #include <functional>
6 #include <memory>
7 #ifdef USE_VULKAN_API
8 
9 #include <ATen/native/vulkan/api/vk_api.h>
10 
11 #include <ATen/native/vulkan/api/Adapter.h>
12 
13 namespace at {
14 namespace native {
15 namespace vulkan {
16 namespace api {
17 
18 //
19 // A Vulkan Runtime initializes a Vulkan instance and decouples the concept of
20 // Vulkan instance initialization from initialization of, and subsequent
21 // interactions with,  Vulkan [physical and logical] devices as a precursor to
22 // multi-GPU support.  The Vulkan Runtime can be queried for available Adapters
23 // (i.e. physical devices) in the system which in turn can be used for creation
24 // of a Vulkan Context (i.e. logical devices).  All Vulkan tensors in PyTorch
25 // are associated with a Context to make tensor <-> device affinity explicit.
26 //
27 
28 enum AdapterSelector {
29   First,
30 };
31 
32 struct RuntimeConfiguration final {
33   bool enableValidationMessages;
34   bool initDefaultDevice;
35   AdapterSelector defaultSelector;
36   uint32_t numRequestedQueues;
37 };
38 
39 class Runtime final {
40  public:
41   explicit Runtime(const RuntimeConfiguration);
42 
43   // Do not allow copying. There should be only one global instance of this
44   // class.
45   Runtime(const Runtime&) = delete;
46   Runtime& operator=(const Runtime&) = delete;
47 
48   Runtime(Runtime&&) noexcept;
49   Runtime& operator=(Runtime&&) = delete;
50 
51   ~Runtime();
52 
53   using DeviceMapping = std::pair<PhysicalDevice, int32_t>;
54   using AdapterPtr = std::unique_ptr<Adapter>;
55 
56  private:
57   RuntimeConfiguration config_;
58 
59   VkInstance instance_;
60 
61   std::vector<DeviceMapping> device_mappings_;
62   std::vector<AdapterPtr> adapters_;
63   uint32_t default_adapter_i_;
64 
65   VkDebugReportCallbackEXT debug_report_callback_;
66 
67  public:
instance()68   inline VkInstance instance() const {
69     return instance_;
70   }
71 
get_adapter_p()72   inline Adapter* get_adapter_p() {
73     VK_CHECK_COND(
74         default_adapter_i_ >= 0 && default_adapter_i_ < adapters_.size(),
75         "Pytorch Vulkan Runtime: Default device adapter is not set correctly!");
76     return adapters_[default_adapter_i_].get();
77   }
78 
get_adapter_p(uint32_t i)79   inline Adapter* get_adapter_p(uint32_t i) {
80     VK_CHECK_COND(
81         i >= 0 && i < adapters_.size(),
82         "Pytorch Vulkan Runtime: Adapter at index ",
83         i,
84         " is not available!");
85     return adapters_[i].get();
86   }
87 
default_adapter_i()88   inline uint32_t default_adapter_i() const {
89     return default_adapter_i_;
90   }
91 
92   using Selector =
93       std::function<uint32_t(const std::vector<Runtime::DeviceMapping>&)>;
94   uint32_t create_adapter(const Selector&);
95 };
96 
97 // The global runtime is retrieved using this function, where it is declared as
98 // a static local variable.
99 Runtime* runtime();
100 
101 } // namespace api
102 } // namespace vulkan
103 } // namespace native
104 } // namespace at
105 
106 #endif /* USE_VULKAN_API */
107