xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifndef MPSImageWrapper_h
2 #define MPSImageWrapper_h
3 
4 #import <ATen/native/metal/MetalCommandBuffer.h>
5 #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
6 #include <c10/util/ArrayRef.h>
7 
8 namespace at {
9 namespace native {
10 namespace metal {
11 
12 class API_AVAILABLE(ios(11.0), macos(10.13)) MPSImageWrapper {
13  public:
14   MPSImageWrapper(IntArrayRef sizes);
15   ~MPSImageWrapper();
16   void copyDataFromHost(const float* inputData);
17   void copyDataToHost(float* hostData);
18   void allocateStorage(IntArrayRef sizes);
19   void allocateTemporaryStorage(
20       IntArrayRef sizes,
21       MetalCommandBuffer* commandBuffer);
22   void setCommandBuffer(MetalCommandBuffer* buffer);
23   MetalCommandBuffer* commandBuffer() const;
24   void setImage(MPSImage* image);
25   MPSImage* image() const;
26   id<MTLBuffer> buffer() const;
27   void synchronize();
28   void prepare();
29   void release();
30 
31  private:
32   std::vector<int64_t> _imageSizes;
33   MPSImage* _image = nil;
34   id<MTLBuffer> _buffer = nil;
35   MetalCommandBuffer* _commandBuffer = nil;
36   id<PTMetalCommandBuffer> _delegate = nil;
37 };
38 
39 } // namespace metal
40 } // namespace native
41 } // namespace at
42 
43 #endif /* MPSImageWrapper_h */
44