xref: /aosp_15_r20/external/pytorch/aten/src/ATen/vulkan/Context.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <atomic>
4 
5 #include <ATen/Tensor.h>
6 
7 namespace at {
8 namespace vulkan {
9 
10 struct VulkanImplInterface {
11   virtual ~VulkanImplInterface() = default;
12   virtual bool is_vulkan_available() const = 0;
13   virtual at::Tensor& vulkan_copy_(at::Tensor& self, const at::Tensor& src)
14       const = 0;
15 };
16 
17 extern std::atomic<const VulkanImplInterface*> g_vulkan_impl_registry;
18 
19 class VulkanImplRegistrar {
20  public:
21   explicit VulkanImplRegistrar(VulkanImplInterface*);
22 };
23 
24 at::Tensor& vulkan_copy_(at::Tensor& self, const at::Tensor& src);
25 } // namespace vulkan
26 
27 namespace native {
28   bool is_vulkan_available();
29 }// namespace native
30 
31 } // namespace at
32