xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/vk_api/Runtime.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
12 
13 #include <executorch/backends/vulkan/runtime/vk_api/vk_api.h>
14 
15 #include <executorch/backends/vulkan/runtime/vk_api/Adapter.h>
16 
17 #include <functional>
18 #include <memory>
19 
20 namespace vkcompute {
21 namespace vkapi {
22 
23 //
24 // A Vulkan Runtime initializes a Vulkan instance and decouples the concept of
25 // Vulkan instance initialization from initialization of, and subsequent
26 // interactions with,  Vulkan [physical and logical] devices as a precursor to
27 // multi-GPU support.  The Vulkan Runtime can be queried for available Adapters
28 // (i.e. physical devices) in the system which in turn can be used for creation
29 // of a Vulkan Context (i.e. logical devices).  All Vulkan tensors in PyTorch
30 // are associated with a Context to make tensor <-> device affinity explicit.
31 //
32 
33 enum AdapterSelector {
34   First,
35 };
36 
37 struct RuntimeConfig final {
38   bool enable_validation_messages;
39   bool init_default_device;
40   AdapterSelector default_selector;
41   uint32_t num_requested_queues;
42   std::string cache_data_path;
43 };
44 
45 class Runtime final {
46  public:
47   explicit Runtime(const RuntimeConfig);
48 
49   // Do not allow copying. There should be only one global instance of this
50   // class.
51   Runtime(const Runtime&) = delete;
52   Runtime& operator=(const Runtime&) = delete;
53 
54   Runtime(Runtime&&) = delete;
55   Runtime& operator=(Runtime&&) = delete;
56 
57   ~Runtime();
58 
59   using DeviceMapping = std::pair<PhysicalDevice, int32_t>;
60   using AdapterPtr = std::unique_ptr<Adapter>;
61 
62  private:
63   RuntimeConfig config_;
64 
65   VkInstance instance_;
66 
67   std::vector<DeviceMapping> device_mappings_;
68   std::vector<AdapterPtr> adapters_;
69   uint32_t default_adapter_i_;
70 
71   VkDebugReportCallbackEXT debug_report_callback_;
72 
73  public:
instance()74   inline VkInstance instance() const {
75     return instance_;
76   }
77 
get_adapter_p()78   inline Adapter* get_adapter_p() {
79     VK_CHECK_COND(
80         default_adapter_i_ >= 0 && default_adapter_i_ < adapters_.size(),
81         "Pytorch Vulkan Runtime: Default device adapter is not set correctly!");
82     return adapters_[default_adapter_i_].get();
83   }
84 
get_adapter_p(uint32_t i)85   inline Adapter* get_adapter_p(uint32_t i) {
86     VK_CHECK_COND(
87         i >= 0 && i < adapters_.size(),
88         "Pytorch Vulkan Runtime: Adapter at index ",
89         i,
90         " is not available!");
91     return adapters_[i].get();
92   }
93 
default_adapter_i()94   inline uint32_t default_adapter_i() const {
95     return default_adapter_i_;
96   }
97 
98   using Selector =
99       std::function<uint32_t(const std::vector<Runtime::DeviceMapping>&)>;
100   uint32_t create_adapter(const Selector&);
101 };
102 
103 // The global runtime is retrieved using this function, where it is declared as
104 // a static local variable.
105 Runtime* runtime();
106 
107 } // namespace vkapi
108 } // namespace vkcompute
109