xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSGuardImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 //  Copyright © 2022 Apple Inc.
2 
3 #pragma once
4 #include <c10/core/impl/DeviceGuardImplInterface.h>
5 #include <c10/macros/Macros.h>
6 #include <c10/util/Exception.h>
7 #include <ATen/Context.h>
8 #include <ATen/mps/MPSStream.h>
9 #include <ATen/mps/MPSEvent.h>
10 
11 #ifdef __OBJC__
12 #include <Foundation/Foundation.h>
13 #include <Metal/Metal.h>
14 #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
15 #endif
16 
17 #include <ATen/Tensor.h>
18 #include <c10/core/MemoryFormat.h>
19 #include <c10/core/Storage.h>
20 #include <c10/core/TensorImpl.h>
21 #include <sys/_types/_size_t.h>
22 #include <memory>
23 #include <c10/core/UndefinedTensorImpl.h>
24 #include <c10/util/intrusive_ptr.h>
25 
26 
27 namespace at::mps {
28 
29 typedef MPSEvent* mpsEvent_t;
30 
31 // TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
32 // https://github.com/pytorch/pytorch/issues/77170
33 struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
34   static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
35 
36   // constructor
MPSGuardImplfinal37   MPSGuardImpl() {}
MPSGuardImplfinal38   explicit MPSGuardImpl(c10::DeviceType t) {
39     TORCH_INTERNAL_ASSERT(t == c10::DeviceType::MPS);
40   }
41 
42   // returns the type
typefinal43   c10::DeviceType type() const override {
44     return c10::DeviceType::MPS;
45   }
46 
exchangeDevicefinal47   Device exchangeDevice(Device d) const override {
48     return Device(c10::DeviceType::MPS, 0);
49   }
50 
getDevicefinal51   Device getDevice() const override {
52     return Device(c10::DeviceType::MPS, 0);
53   }
54 
uncheckedGetDevicefinal55   std::optional<Device> uncheckedGetDevice() const noexcept {
56     return Device(c10::DeviceType::MPS, 0);
57   }
58 
setDevicefinal59   void setDevice(Device d) const override {
60     TORCH_INTERNAL_ASSERT(d.is_mps());
61   }
62 
uncheckedSetDevicefinal63   void uncheckedSetDevice(Device d) const noexcept override {
64     // TODO: Currently setting only device 0
65   }
66 
getStreamfinal67   Stream getStream(Device d) const noexcept override {
68     return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
69   }
70 
71   Stream getNewStream(Device, int priority = 0) const override {
72     (void)priority;
73     return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
74   }
75 
getDefaultStreamfinal76   Stream getDefaultStream(Device d) const override {
77     return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
78   }
79 
80   // NB: These do NOT set the current device
exchangeStreamfinal81   Stream exchangeStream(Stream s) const noexcept override {
82     return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
83   }
deviceCountfinal84   DeviceIndex deviceCount() const noexcept override {
85     if (at::hasMPS()) {
86       //TODO: extend it for multi-device case
87       return 1;
88     } else {
89       return 0;
90     }
91   }
92 
93   // Event-related functions
94   void createEvent(
95     mpsEvent_t* event,
96     const EventFlag flag) const;
97 
98   void destroyEvent(
99     void* event,
100     const DeviceIndex device_index) const noexcept override;
101 
102   void record(
103     void** event,
104     const Stream& stream,
105     const DeviceIndex device_index,
106     const EventFlag flag) const override;
107 
108   void block(
109     void* event,
110     const Stream& stream) const override;
111 
112   bool queryEvent(void* event) const override;
113 
114 };
115 
116 /// A variant of OptionalDeviceGuard that is specialized for MPS.
117 struct OptionalMPSGuard {
OptionalMPSGuardOptionalMPSGuard118   explicit OptionalMPSGuard() : guard_() {}
119 
OptionalMPSGuardOptionalMPSGuard120   explicit OptionalMPSGuard(std::optional<Device> device_opt)
121       : guard_(device_opt) {}
122 
123   /// Set the current MPS device to the passed device index, if it is not
124   /// nullopt
OptionalMPSGuardOptionalMPSGuard125   explicit OptionalMPSGuard(std::optional<DeviceIndex> device_index_opt)
126       : guard_(device_index_opt) {}
127 
128   // Copy is not allowed
129   OptionalMPSGuard(const OptionalMPSGuard&) = delete;
130   OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete;
131   OptionalMPSGuard(OptionalMPSGuard&& other) = delete;
132   OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete;
133 
134   /// Sets the MPS device to the given device, initializing the guard if it
135   /// is not already initialized.  Errors if the given device is not a MPS
136   /// device.
set_deviceOptionalMPSGuard137   void set_device(Device device) {
138     guard_.set_device(device);
139   }
140 
141   /// Sets the MPS device to the given device, initializing the guard if it is
142   /// not already initialized.  Errors if the given device is not a MPS device.
reset_deviceOptionalMPSGuard143   void reset_device(Device device) {
144     guard_.reset_device(device);
145   }
146 
147   /// Sets the MPS device to the given device index, initializing the guard if
148   /// it is not already initialized.
set_indexOptionalMPSGuard149   void set_index(DeviceIndex device_index) {
150     guard_.set_index(device_index);
151   }
152 
153   /// Returns the device that was set immediately prior to initialization of the
154   /// guard, or nullopt if the guard is uninitialized.
original_deviceOptionalMPSGuard155   std::optional<Device> original_device() const {
156     return guard_.original_device();
157   }
158 
159   /// Returns the most recent device that was set using this device guard,
160   /// either from construction, or via set_device, if the guard is initialized,
161   /// or nullopt if the guard is uninitialized.
current_deviceOptionalMPSGuard162   std::optional<Device> current_device() const {
163     return guard_.current_device();
164   }
165 
166   /// Restore the original MPS device, resetting this guard to uninitialized
167   /// state.
resetOptionalMPSGuard168   void reset() {
169     guard_.reset();
170   }
171 
172  private:
173   c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_;
174 };
175 
176 
177 C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl);
178 
179 } // namespace at::mps
180