xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/MetalTensorImplStorage.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Tensor.h>
2 #include <c10/util/ArrayRef.h>
3 
4 namespace at::native::metal {
5 
6 class MPSImageWrapper;
7 class MetalTensorImplStorage final {
8   class Impl;
9 
10  public:
MetalTensorImplStorage()11   MetalTensorImplStorage(){};
12   MetalTensorImplStorage(const std::vector<int64_t>& sizes);
13   MetalTensorImplStorage(
14       const std::vector<int64_t>& sizes,
15       const std::vector<int64_t>& strides);
16   ~MetalTensorImplStorage() = default;
17 
18   MetalTensorImplStorage(MetalTensorImplStorage&&) = default;
19   MetalTensorImplStorage& operator=(MetalTensorImplStorage&&) = default;
20 
21   MetalTensorImplStorage(const MetalTensorImplStorage&) = default;
22   MetalTensorImplStorage& operator=(const MetalTensorImplStorage&) = default;
23 
24   friend std::ostream& operator<<(
25       std::ostream& output,
26       const MetalTensorImplStorage& mt);
27 
28   bool defined() const;
29   IntArrayRef sizes() const;
30   IntArrayRef strides() const;
31   int64_t dim() const;
32   int64_t numel() const;
33   void set_data_from_host(const float* inputData);
34   void copy_data_to_host(float* host);
35   MPSImageWrapper* texture() const;
36 
37  private:
38   std::shared_ptr<Impl> impl();
39   std::shared_ptr<const Impl> impl() const;
40   std::shared_ptr<Impl> _impl;
41 };
42 
43 } // namespace at::native::metal
44