xref: /aosp_15_r20/external/pytorch/aten/src/ATen/detail/HIPHooksInterface.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Allocator.h>
4 #include <c10/core/GeneratorImpl.h>
5 #include <c10/util/Exception.h>
6 
7 #include <c10/util/Registry.h>
8 
9 #include <ATen/detail/AcceleratorHooksInterface.h>
10 
11 #include <memory>
12 
13 namespace at {
14 class Context;
15 }
16 
17 // NB: Class must live in `at` due to limitations of Registry.h.
18 namespace at {
19 
20 // The HIPHooksInterface is an omnibus interface for any HIP functionality
21 // which we may want to call into from CPU code (and thus must be dynamically
22 // dispatched, to allow for separate compilation of HIP code).  See
23 // CUDAHooksInterface for more detailed motivation.
24 struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface {
25   // This should never actually be implemented, but it is used to
26   // squelch -Werror=non-virtual-dtor
27   ~HIPHooksInterface() override = default;
28 
29   // Initialize the HIP library state
initHIPHIPHooksInterface30   virtual void initHIP() const {
31     AT_ERROR("Cannot initialize HIP without ATen_hip library.");
32   }
33 
initHIPGeneratorHIPHooksInterface34   virtual std::unique_ptr<c10::GeneratorImpl> initHIPGenerator(Context*) const {
35     AT_ERROR("Cannot initialize HIP generator without ATen_hip library.");
36   }
37 
hasHIPHIPHooksInterface38   virtual bool hasHIP() const {
39     return false;
40   }
41 
current_deviceHIPHooksInterface42   virtual c10::DeviceIndex current_device() const {
43     return -1;
44   }
45 
isPinnedPtrHIPHooksInterface46   bool isPinnedPtr(const void* data) const override {
47     return false;
48   }
49 
getPinnedMemoryAllocatorHIPHooksInterface50   Allocator* getPinnedMemoryAllocator() const override {
51     AT_ERROR("Pinned memory requires HIP.");
52   }
53 
registerHIPTypesHIPHooksInterface54   virtual void registerHIPTypes(Context*) const {
55     AT_ERROR("Cannot registerHIPTypes() without ATen_hip library.");
56   }
57 
getNumGPUsHIPHooksInterface58   virtual int getNumGPUs() const {
59     return 0;
60   }
61 
hasPrimaryContextHIPHooksInterface62   bool hasPrimaryContext(DeviceIndex device_index) const override {
63     AT_ERROR("Cannot check primary context without ATen_hip library.");
64   }
65 };
66 
67 // NB: dummy argument to suppress "ISO C++11 requires at least one argument
68 // for the "..." in a variadic macro"
69 struct TORCH_API HIPHooksArgs {};
70 
71 TORCH_DECLARE_REGISTRY(HIPHooksRegistry, HIPHooksInterface, HIPHooksArgs);
72 #define REGISTER_HIP_HOOKS(clsname) \
73   C10_REGISTER_CLASS(HIPHooksRegistry, clsname, clsname)
74 
75 namespace detail {
76 TORCH_API const HIPHooksInterface& getHIPHooks();
77 
78 } // namespace detail
79 } // namespace at
80