xref: /aosp_15_r20/external/pytorch/aten/src/ATen/vulkan/Context.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <atomic>
2 
3 #include <ATen/Tensor.h>
4 #include <ATen/vulkan/Context.h>
5 
6 #ifdef USE_VULKAN_API
7 #include <ATen/native/vulkan/api/Context.h>
8 #endif /* USE_VULKAN_API */
9 
10 namespace at {
11 namespace vulkan {
12 
13 std::atomic<const VulkanImplInterface*> g_vulkan_impl_registry;
14 
VulkanImplRegistrar(VulkanImplInterface * impl)15 VulkanImplRegistrar::VulkanImplRegistrar(VulkanImplInterface* impl) {
16   g_vulkan_impl_registry.store(impl);
17 }
18 
vulkan_copy_(at::Tensor & self,const at::Tensor & src)19 at::Tensor& vulkan_copy_(at::Tensor& self, const at::Tensor& src) {
20   auto p = at::vulkan::g_vulkan_impl_registry.load();
21   if (p) {
22     return p->vulkan_copy_(self, src);
23   }
24   AT_ERROR("Vulkan backend was not linked to the build");
25 }
26 } // namespace vulkan
27 
28 namespace native {
is_vulkan_available()29 bool is_vulkan_available() {
30 #ifdef USE_VULKAN_API
31   return native::vulkan::api::available();
32 #else
33   auto p = at::vulkan::g_vulkan_impl_registry.load();
34   return p ? p->is_vulkan_available() : false;
35 #endif
36 }
37 } // namespace native
38 
39 } // namespace at
40