1 // Copyright © 2022 Apple Inc. 2 3 #pragma once 4 #include <c10/core/impl/DeviceGuardImplInterface.h> 5 #include <c10/macros/Macros.h> 6 #include <c10/util/Exception.h> 7 #include <ATen/Context.h> 8 #include <ATen/mps/MPSStream.h> 9 #include <ATen/mps/MPSEvent.h> 10 11 #ifdef __OBJC__ 12 #include <Foundation/Foundation.h> 13 #include <Metal/Metal.h> 14 #include <MetalPerformanceShaders/MetalPerformanceShaders.h> 15 #endif 16 17 #include <ATen/Tensor.h> 18 #include <c10/core/MemoryFormat.h> 19 #include <c10/core/Storage.h> 20 #include <c10/core/TensorImpl.h> 21 #include <sys/_types/_size_t.h> 22 #include <memory> 23 #include <c10/core/UndefinedTensorImpl.h> 24 #include <c10/util/intrusive_ptr.h> 25 26 27 namespace at::mps { 28 29 typedef MPSEvent* mpsEvent_t; 30 31 // TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl 32 // https://github.com/pytorch/pytorch/issues/77170 33 struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface { 34 static constexpr c10::DeviceType static_type = c10::DeviceType::MPS; 35 36 // constructor MPSGuardImplfinal37 MPSGuardImpl() {} MPSGuardImplfinal38 explicit MPSGuardImpl(c10::DeviceType t) { 39 TORCH_INTERNAL_ASSERT(t == c10::DeviceType::MPS); 40 } 41 42 // returns the type typefinal43 c10::DeviceType type() const override { 44 return c10::DeviceType::MPS; 45 } 46 exchangeDevicefinal47 Device exchangeDevice(Device d) const override { 48 return Device(c10::DeviceType::MPS, 0); 49 } 50 getDevicefinal51 Device getDevice() const override { 52 return Device(c10::DeviceType::MPS, 0); 53 } 54 uncheckedGetDevicefinal55 std::optional<Device> uncheckedGetDevice() const noexcept { 56 return Device(c10::DeviceType::MPS, 0); 57 } 58 setDevicefinal59 void setDevice(Device d) const override { 60 TORCH_INTERNAL_ASSERT(d.is_mps()); 61 } 62 uncheckedSetDevicefinal63 void uncheckedSetDevice(Device d) const noexcept override { 64 // TODO: Currently setting only device 0 65 } 66 getStreamfinal67 Stream getStream(Device d) const noexcept override { 68 return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); 69 } 70 71 Stream getNewStream(Device, int priority = 0) const override { 72 (void)priority; 73 return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); 74 } 75 getDefaultStreamfinal76 Stream getDefaultStream(Device d) const override { 77 return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); 78 } 79 80 // NB: These do NOT set the current device exchangeStreamfinal81 Stream exchangeStream(Stream s) const noexcept override { 82 return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); 83 } deviceCountfinal84 DeviceIndex deviceCount() const noexcept override { 85 if (at::hasMPS()) { 86 //TODO: extend it for multi-device case 87 return 1; 88 } else { 89 return 0; 90 } 91 } 92 93 // Event-related functions 94 void createEvent( 95 mpsEvent_t* event, 96 const EventFlag flag) const; 97 98 void destroyEvent( 99 void* event, 100 const DeviceIndex device_index) const noexcept override; 101 102 void record( 103 void** event, 104 const Stream& stream, 105 const DeviceIndex device_index, 106 const EventFlag flag) const override; 107 108 void block( 109 void* event, 110 const Stream& stream) const override; 111 112 bool queryEvent(void* event) const override; 113 114 }; 115 116 /// A variant of OptionalDeviceGuard that is specialized for MPS. 117 struct OptionalMPSGuard { OptionalMPSGuardOptionalMPSGuard118 explicit OptionalMPSGuard() : guard_() {} 119 OptionalMPSGuardOptionalMPSGuard120 explicit OptionalMPSGuard(std::optional<Device> device_opt) 121 : guard_(device_opt) {} 122 123 /// Set the current MPS device to the passed device index, if it is not 124 /// nullopt OptionalMPSGuardOptionalMPSGuard125 explicit OptionalMPSGuard(std::optional<DeviceIndex> device_index_opt) 126 : guard_(device_index_opt) {} 127 128 // Copy is not allowed 129 OptionalMPSGuard(const OptionalMPSGuard&) = delete; 130 OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete; 131 OptionalMPSGuard(OptionalMPSGuard&& other) = delete; 132 OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete; 133 134 /// Sets the MPS device to the given device, initializing the guard if it 135 /// is not already initialized. Errors if the given device is not a MPS 136 /// device. set_deviceOptionalMPSGuard137 void set_device(Device device) { 138 guard_.set_device(device); 139 } 140 141 /// Sets the MPS device to the given device, initializing the guard if it is 142 /// not already initialized. Errors if the given device is not a MPS device. reset_deviceOptionalMPSGuard143 void reset_device(Device device) { 144 guard_.reset_device(device); 145 } 146 147 /// Sets the MPS device to the given device index, initializing the guard if 148 /// it is not already initialized. set_indexOptionalMPSGuard149 void set_index(DeviceIndex device_index) { 150 guard_.set_index(device_index); 151 } 152 153 /// Returns the device that was set immediately prior to initialization of the 154 /// guard, or nullopt if the guard is uninitialized. original_deviceOptionalMPSGuard155 std::optional<Device> original_device() const { 156 return guard_.original_device(); 157 } 158 159 /// Returns the most recent device that was set using this device guard, 160 /// either from construction, or via set_device, if the guard is initialized, 161 /// or nullopt if the guard is uninitialized. current_deviceOptionalMPSGuard162 std::optional<Device> current_device() const { 163 return guard_.current_device(); 164 } 165 166 /// Restore the original MPS device, resetting this guard to uninitialized 167 /// state. resetOptionalMPSGuard168 void reset() { 169 guard_.reset(); 170 } 171 172 private: 173 c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_; 174 }; 175 176 177 C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl); 178 179 } // namespace at::mps 180