1#import <ATen/native/metal/MetalTensorImpl.h> 2#import <ATen/native/metal/MetalTensorImplStorage.h> 3#import <ATen/native/metal/MetalTensorUtils.h> 4#import <ATen/native/metal/mpscnn/MPSImageWrapper.h> 5 6#include <ATen/Utils.h> 7#include <c10/util/accumulate.h> 8 9namespace at { 10namespace native { 11namespace metal { 12 13class API_AVAILABLE(ios(11.0), macos(10.13)) MetalTensorImplStorage::Impl { 14 public: 15 Impl(const std::vector<int64_t>& sizes, const std::vector<int64_t>& strides) 16 : _sizes(sizes), 17 _strides(strides), 18 _numel(c10::multiply_integers(std::begin(_sizes), std::end(_sizes))), 19 _textureImpl(std::make_unique<MPSImageWrapper>(sizes)) {} 20 21 IntArrayRef sizes() const { 22 return _sizes; 23 } 24 IntArrayRef strides() const { 25 return _strides; 26 } 27 int64_t dim() const { 28 return _sizes.size(); 29 } 30 int64_t numel() const { 31 return _numel; 32 } 33 void set_data_from_host(const float* inputData) { 34 _textureImpl->copyDataFromHost(inputData); 35 } 36 void copy_data_to_host(float* host) { 37 _textureImpl->copyDataToHost(host); 38 } 39 MPSImageWrapper* texture() const { 40 return _textureImpl.get(); 41 } 42 43 private: 44 std::vector<int64_t> _sizes; 45 std::vector<int64_t> _strides; 46 int64_t _numel; 47 std::unique_ptr<MPSImageWrapper> _textureImpl; 48}; 49 50MetalTensorImplStorage::MetalTensorImplStorage( 51 const std::vector<int64_t>& sizes) 52 : MetalTensorImplStorage(sizes, computeStrides(sizes)) {} 53 54MetalTensorImplStorage::MetalTensorImplStorage( 55 const std::vector<int64_t>& sizes, 56 const std::vector<int64_t>& strides) 57 : _impl(std::make_shared<Impl>(std::move(sizes), std::move(strides))) {} 58 59bool MetalTensorImplStorage::defined() const { 60 return static_cast<bool>(_impl); 61} 62 63std::shared_ptr<MetalTensorImplStorage::Impl> MetalTensorImplStorage::impl() { 64 return _impl; 65} 66 67std::shared_ptr<const MetalTensorImplStorage::Impl> MetalTensorImplStorage:: 68 impl() const { 69 return _impl; 70} 71 72IntArrayRef MetalTensorImplStorage::sizes() const { 73 return impl()->sizes(); 74} 75 76IntArrayRef MetalTensorImplStorage::strides() const { 77 return impl()->strides(); 78} 79 80int64_t MetalTensorImplStorage::dim() const { 81 return impl()->dim(); 82} 83 84int64_t MetalTensorImplStorage::numel() const { 85 return impl()->numel(); 86} 87 88void MetalTensorImplStorage::set_data_from_host(const float* inputData) { 89 impl()->set_data_from_host(inputData); 90} 91 92void MetalTensorImplStorage::copy_data_to_host(float* hostData) { 93 impl()->copy_data_to_host(hostData); 94} 95 96API_AVAILABLE(ios(11.0)) 97MPSImageWrapper* MetalTensorImplStorage::texture() const { 98 return impl()->texture(); 99} 100 101std::ostream& operator<<( 102 std::ostream& output, 103 const MetalTensorImplStorage& mt) { 104 auto&& sizes = mt.sizes(); 105 auto&& strides = mt.strides(); 106 output << "[MetalTensorImplStorage] | Size:{"; 107 std::ostringstream oss; 108 std::copy( 109 sizes.begin(), sizes.end() - 1, std::ostream_iterator<int>(oss, ",")); 110 oss << sizes.back(); 111 output << oss.str() << "}, Stride:{"; 112 std::string sizesStr = oss.str(); 113 oss.str(""); 114 oss.clear(); 115 std::copy( 116 strides.begin(), strides.end() - 1, std::ostream_iterator<int>(oss, ",")); 117 oss << sizes.back(); 118 output << oss.str() << "}"; 119 return output; 120} 121 122} 123} 124} 125