xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/MetalTensorImplStorage.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalTensorImpl.h>
2#import <ATen/native/metal/MetalTensorImplStorage.h>
3#import <ATen/native/metal/MetalTensorUtils.h>
4#import <ATen/native/metal/mpscnn/MPSImageWrapper.h>
5
6#include <ATen/Utils.h>
7#include <c10/util/accumulate.h>
8
9namespace at {
10namespace native {
11namespace metal {
12
13class API_AVAILABLE(ios(11.0), macos(10.13)) MetalTensorImplStorage::Impl {
14 public:
15  Impl(const std::vector<int64_t>& sizes, const std::vector<int64_t>& strides)
16      : _sizes(sizes),
17        _strides(strides),
18        _numel(c10::multiply_integers(std::begin(_sizes), std::end(_sizes))),
19        _textureImpl(std::make_unique<MPSImageWrapper>(sizes)) {}
20
21  IntArrayRef sizes() const {
22    return _sizes;
23  }
24  IntArrayRef strides() const {
25    return _strides;
26  }
27  int64_t dim() const {
28    return _sizes.size();
29  }
30  int64_t numel() const {
31    return _numel;
32  }
33  void set_data_from_host(const float* inputData) {
34    _textureImpl->copyDataFromHost(inputData);
35  }
36  void copy_data_to_host(float* host) {
37    _textureImpl->copyDataToHost(host);
38  }
39  MPSImageWrapper* texture() const {
40    return _textureImpl.get();
41  }
42
43 private:
44  std::vector<int64_t> _sizes;
45  std::vector<int64_t> _strides;
46  int64_t _numel;
47  std::unique_ptr<MPSImageWrapper> _textureImpl;
48};
49
50MetalTensorImplStorage::MetalTensorImplStorage(
51    const std::vector<int64_t>& sizes)
52    : MetalTensorImplStorage(sizes, computeStrides(sizes)) {}
53
54MetalTensorImplStorage::MetalTensorImplStorage(
55    const std::vector<int64_t>& sizes,
56    const std::vector<int64_t>& strides)
57    : _impl(std::make_shared<Impl>(std::move(sizes), std::move(strides))) {}
58
59bool MetalTensorImplStorage::defined() const {
60  return static_cast<bool>(_impl);
61}
62
63std::shared_ptr<MetalTensorImplStorage::Impl> MetalTensorImplStorage::impl() {
64  return _impl;
65}
66
67std::shared_ptr<const MetalTensorImplStorage::Impl> MetalTensorImplStorage::
68    impl() const {
69  return _impl;
70}
71
72IntArrayRef MetalTensorImplStorage::sizes() const {
73  return impl()->sizes();
74}
75
76IntArrayRef MetalTensorImplStorage::strides() const {
77  return impl()->strides();
78}
79
80int64_t MetalTensorImplStorage::dim() const {
81  return impl()->dim();
82}
83
84int64_t MetalTensorImplStorage::numel() const {
85  return impl()->numel();
86}
87
88void MetalTensorImplStorage::set_data_from_host(const float* inputData) {
89  impl()->set_data_from_host(inputData);
90}
91
92void MetalTensorImplStorage::copy_data_to_host(float* hostData) {
93  impl()->copy_data_to_host(hostData);
94}
95
96API_AVAILABLE(ios(11.0))
97MPSImageWrapper* MetalTensorImplStorage::texture() const {
98  return impl()->texture();
99}
100
101std::ostream& operator<<(
102    std::ostream& output,
103    const MetalTensorImplStorage& mt) {
104  auto&& sizes = mt.sizes();
105  auto&& strides = mt.strides();
106  output << "[MetalTensorImplStorage] | Size:{";
107  std::ostringstream oss;
108  std::copy(
109      sizes.begin(), sizes.end() - 1, std::ostream_iterator<int>(oss, ","));
110  oss << sizes.back();
111  output << oss.str() << "}, Stride:{";
112  std::string sizesStr = oss.str();
113  oss.str("");
114  oss.clear();
115  std::copy(
116      strides.begin(), strides.end() - 1, std::ostream_iterator<int>(oss, ","));
117  oss << sizes.back();
118  output << oss.str() << "}";
119  return output;
120}
121
122}
123}
124}
125