1 #include <ATen/Tensor.h> 2 #include <c10/util/ArrayRef.h> 3 4 namespace at::native::metal { 5 6 class MPSImageWrapper; 7 class MetalTensorImplStorage final { 8 class Impl; 9 10 public: MetalTensorImplStorage()11 MetalTensorImplStorage(){}; 12 MetalTensorImplStorage(const std::vector<int64_t>& sizes); 13 MetalTensorImplStorage( 14 const std::vector<int64_t>& sizes, 15 const std::vector<int64_t>& strides); 16 ~MetalTensorImplStorage() = default; 17 18 MetalTensorImplStorage(MetalTensorImplStorage&&) = default; 19 MetalTensorImplStorage& operator=(MetalTensorImplStorage&&) = default; 20 21 MetalTensorImplStorage(const MetalTensorImplStorage&) = default; 22 MetalTensorImplStorage& operator=(const MetalTensorImplStorage&) = default; 23 24 friend std::ostream& operator<<( 25 std::ostream& output, 26 const MetalTensorImplStorage& mt); 27 28 bool defined() const; 29 IntArrayRef sizes() const; 30 IntArrayRef strides() const; 31 int64_t dim() const; 32 int64_t numel() const; 33 void set_data_from_host(const float* inputData); 34 void copy_data_to_host(float* host); 35 MPSImageWrapper* texture() const; 36 37 private: 38 std::shared_ptr<Impl> impl(); 39 std::shared_ptr<const Impl> impl() const; 40 std::shared_ptr<Impl> _impl; 41 }; 42 43 } // namespace at::native::metal 44