1 #pragma once 2 3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName 4 5 #include <functional> 6 #include <memory> 7 #ifdef USE_VULKAN_API 8 9 #include <ATen/native/vulkan/api/vk_api.h> 10 11 #include <ATen/native/vulkan/api/Adapter.h> 12 13 namespace at { 14 namespace native { 15 namespace vulkan { 16 namespace api { 17 18 // 19 // A Vulkan Runtime initializes a Vulkan instance and decouples the concept of 20 // Vulkan instance initialization from initialization of, and subsequent 21 // interactions with, Vulkan [physical and logical] devices as a precursor to 22 // multi-GPU support. The Vulkan Runtime can be queried for available Adapters 23 // (i.e. physical devices) in the system which in turn can be used for creation 24 // of a Vulkan Context (i.e. logical devices). All Vulkan tensors in PyTorch 25 // are associated with a Context to make tensor <-> device affinity explicit. 26 // 27 28 enum AdapterSelector { 29 First, 30 }; 31 32 struct RuntimeConfiguration final { 33 bool enableValidationMessages; 34 bool initDefaultDevice; 35 AdapterSelector defaultSelector; 36 uint32_t numRequestedQueues; 37 }; 38 39 class Runtime final { 40 public: 41 explicit Runtime(const RuntimeConfiguration); 42 43 // Do not allow copying. There should be only one global instance of this 44 // class. 45 Runtime(const Runtime&) = delete; 46 Runtime& operator=(const Runtime&) = delete; 47 48 Runtime(Runtime&&) noexcept; 49 Runtime& operator=(Runtime&&) = delete; 50 51 ~Runtime(); 52 53 using DeviceMapping = std::pair<PhysicalDevice, int32_t>; 54 using AdapterPtr = std::unique_ptr<Adapter>; 55 56 private: 57 RuntimeConfiguration config_; 58 59 VkInstance instance_; 60 61 std::vector<DeviceMapping> device_mappings_; 62 std::vector<AdapterPtr> adapters_; 63 uint32_t default_adapter_i_; 64 65 VkDebugReportCallbackEXT debug_report_callback_; 66 67 public: instance()68 inline VkInstance instance() const { 69 return instance_; 70 } 71 get_adapter_p()72 inline Adapter* get_adapter_p() { 73 VK_CHECK_COND( 74 default_adapter_i_ >= 0 && default_adapter_i_ < adapters_.size(), 75 "Pytorch Vulkan Runtime: Default device adapter is not set correctly!"); 76 return adapters_[default_adapter_i_].get(); 77 } 78 get_adapter_p(uint32_t i)79 inline Adapter* get_adapter_p(uint32_t i) { 80 VK_CHECK_COND( 81 i >= 0 && i < adapters_.size(), 82 "Pytorch Vulkan Runtime: Adapter at index ", 83 i, 84 " is not available!"); 85 return adapters_[i].get(); 86 } 87 default_adapter_i()88 inline uint32_t default_adapter_i() const { 89 return default_adapter_i_; 90 } 91 92 using Selector = 93 std::function<uint32_t(const std::vector<Runtime::DeviceMapping>&)>; 94 uint32_t create_adapter(const Selector&); 95 }; 96 97 // The global runtime is retrieved using this function, where it is declared as 98 // a static local variable. 99 Runtime* runtime(); 100 101 } // namespace api 102 } // namespace vulkan 103 } // namespace native 104 } // namespace at 105 106 #endif /* USE_VULKAN_API */ 107