xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ops/from_blob.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 
4 namespace at {
5 
6 namespace detail {
7 
noopDelete(void *)8 TORCH_API inline void noopDelete(void*) {}
9 
10 } // namespace detail
11 
12 /// Provides a fluent API to construct tensors from external data.
13 ///
14 /// The fluent API can be used instead of `from_blob` functions in case the
15 /// required set of parameters does not align with the existing overloads.
16 ///
17 ///     at::Tensor tensor = at::for_blob(data, sizes)
18 ///             .strides(strides)
19 ///             .context(context, [](void *ctx) { delete static_cast<Ctx*>(ctx);
20 ///             }) .options(...) .make_tensor();
21 ///
22 class TORCH_API TensorMaker {
23   friend TensorMaker for_blob(void* data, IntArrayRef sizes) noexcept;
24 
25  public:
26   using ContextDeleter = DeleterFnPtr;
27 
strides(OptionalIntArrayRef value)28   TensorMaker& strides(OptionalIntArrayRef value) noexcept {
29     strides_ = value;
30 
31     return *this;
32   }
33 
storage_offset(std::optional<int64_t> value)34   TensorMaker& storage_offset(std::optional<int64_t> value) noexcept {
35     storage_offset_ = value;
36 
37     return *this;
38   }
39 
deleter(std::function<void (void *)> value)40   TensorMaker& deleter(std::function<void(void*)> value) noexcept {
41     deleter_ = std::move(value);
42 
43     return *this;
44   }
45 
46   TensorMaker& context(void* value, ContextDeleter deleter = nullptr) noexcept {
47     ctx_ = std::unique_ptr<void, ContextDeleter>{
48         value, deleter != nullptr ? deleter : detail::noopDelete};
49 
50     return *this;
51   }
52 
target_device(std::optional<Device> value)53   TensorMaker& target_device(std::optional<Device> value) noexcept {
54     device_ = value;
55 
56     return *this;
57   }
58 
options(TensorOptions value)59   TensorMaker& options(TensorOptions value) noexcept {
60     opts_ = value;
61 
62     return *this;
63   }
64 
resizeable_storage()65   TensorMaker& resizeable_storage() noexcept {
66     resizeable_ = true;
67 
68     return *this;
69   }
70 
allocator(c10::Allocator * allocator)71   TensorMaker& allocator(c10::Allocator* allocator) noexcept {
72     allocator_ = allocator;
73 
74     return *this;
75   }
76 
77   Tensor make_tensor();
78 
79  private:
TensorMaker(void * data,IntArrayRef sizes)80   explicit TensorMaker(void* data, IntArrayRef sizes) noexcept
81       : data_{data}, sizes_{sizes} {}
82 
83   std::size_t computeStorageSize() const noexcept;
84 
85   DataPtr makeDataPtrFromDeleter() noexcept;
86 
87   DataPtr makeDataPtrFromContext() noexcept;
88 
89   IntArrayRef makeTempSizes() const noexcept;
90 
91   void* data_;
92   IntArrayRef sizes_;
93   OptionalIntArrayRef strides_{};
94   std::optional<int64_t> storage_offset_{};
95   std::function<void(void*)> deleter_{};
96   std::unique_ptr<void, ContextDeleter> ctx_{nullptr, detail::noopDelete};
97   std::optional<Device> device_{};
98   TensorOptions opts_{};
99   bool resizeable_{};
100   c10::Allocator* allocator_{};
101 };
102 
for_blob(void * data,IntArrayRef sizes)103 inline TensorMaker for_blob(void* data, IntArrayRef sizes) noexcept {
104   return TensorMaker{data, sizes};
105 }
106 
107 inline Tensor from_blob(
108     void* data,
109     IntArrayRef sizes,
110     IntArrayRef strides,
111     const std::function<void(void*)>& deleter,
112     const TensorOptions& options = {},
113     const std::optional<Device> target_device = std::nullopt) {
114   return for_blob(data, sizes)
115       .strides(strides)
116       .deleter(deleter)
117       .options(options)
118       .target_device(target_device)
119       .make_tensor();
120 }
121 
122 inline Tensor from_blob(
123     void* data,
124     IntArrayRef sizes,
125     IntArrayRef strides,
126     int64_t storage_offset,
127     const std::function<void(void*)>& deleter,
128     const TensorOptions& options = {},
129     const std::optional<Device> target_device = std::nullopt) {
130   return for_blob(data, sizes)
131       .strides(strides)
132       .storage_offset(storage_offset)
133       .deleter(deleter)
134       .options(options)
135       .target_device(target_device)
136       .make_tensor();
137 }
138 
139 inline Tensor from_blob(
140     void* data,
141     IntArrayRef sizes,
142     std::function<void(void*)> deleter,
143     const TensorOptions& options = {},
144     const std::optional<Device> target_device = std::nullopt) {
145   return for_blob(data, sizes)
146       .deleter(std::move(deleter))
147       .options(options)
148       .target_device(target_device)
149       .make_tensor();
150 }
151 
152 inline Tensor from_blob(
153     void* data,
154     IntArrayRef sizes,
155     IntArrayRef strides,
156     const TensorOptions& options = {}) {
157   return for_blob(data, sizes).strides(strides).options(options).make_tensor();
158 }
159 
160 inline Tensor from_blob(
161     void* data,
162     IntArrayRef sizes,
163     const TensorOptions& options = {}) {
164   return for_blob(data, sizes).options(options).make_tensor();
165 }
166 
167 } // namespace at
168