1 #pragma once 2 3 #include <c10/core/Allocator.h> 4 #include <c10/core/GeneratorImpl.h> 5 #include <c10/util/Exception.h> 6 7 #include <c10/util/Registry.h> 8 9 #include <ATen/detail/AcceleratorHooksInterface.h> 10 11 #include <memory> 12 13 namespace at { 14 class Context; 15 } 16 17 // NB: Class must live in `at` due to limitations of Registry.h. 18 namespace at { 19 20 // The HIPHooksInterface is an omnibus interface for any HIP functionality 21 // which we may want to call into from CPU code (and thus must be dynamically 22 // dispatched, to allow for separate compilation of HIP code). See 23 // CUDAHooksInterface for more detailed motivation. 24 struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface { 25 // This should never actually be implemented, but it is used to 26 // squelch -Werror=non-virtual-dtor 27 ~HIPHooksInterface() override = default; 28 29 // Initialize the HIP library state initHIPHIPHooksInterface30 virtual void initHIP() const { 31 AT_ERROR("Cannot initialize HIP without ATen_hip library."); 32 } 33 initHIPGeneratorHIPHooksInterface34 virtual std::unique_ptr<c10::GeneratorImpl> initHIPGenerator(Context*) const { 35 AT_ERROR("Cannot initialize HIP generator without ATen_hip library."); 36 } 37 hasHIPHIPHooksInterface38 virtual bool hasHIP() const { 39 return false; 40 } 41 current_deviceHIPHooksInterface42 virtual c10::DeviceIndex current_device() const { 43 return -1; 44 } 45 isPinnedPtrHIPHooksInterface46 bool isPinnedPtr(const void* data) const override { 47 return false; 48 } 49 getPinnedMemoryAllocatorHIPHooksInterface50 Allocator* getPinnedMemoryAllocator() const override { 51 AT_ERROR("Pinned memory requires HIP."); 52 } 53 registerHIPTypesHIPHooksInterface54 virtual void registerHIPTypes(Context*) const { 55 AT_ERROR("Cannot registerHIPTypes() without ATen_hip library."); 56 } 57 getNumGPUsHIPHooksInterface58 virtual int getNumGPUs() const { 59 return 0; 60 } 61 hasPrimaryContextHIPHooksInterface62 bool hasPrimaryContext(DeviceIndex device_index) const override { 63 AT_ERROR("Cannot check primary context without ATen_hip library."); 64 } 65 }; 66 67 // NB: dummy argument to suppress "ISO C++11 requires at least one argument 68 // for the "..." in a variadic macro" 69 struct TORCH_API HIPHooksArgs {}; 70 71 TORCH_DECLARE_REGISTRY(HIPHooksRegistry, HIPHooksInterface, HIPHooksArgs); 72 #define REGISTER_HIP_HOOKS(clsname) \ 73 C10_REGISTER_CLASS(HIPHooksRegistry, clsname, clsname) 74 75 namespace detail { 76 TORCH_API const HIPHooksInterface& getHIPHooks(); 77 78 } // namespace detail 79 } // namespace at 80