1 #pragma once
2 #include <ATen/core/Tensor.h>
3
4 namespace at {
5
6 namespace detail {
7
noopDelete(void *)8 TORCH_API inline void noopDelete(void*) {}
9
10 } // namespace detail
11
12 /// Provides a fluent API to construct tensors from external data.
13 ///
14 /// The fluent API can be used instead of `from_blob` functions in case the
15 /// required set of parameters does not align with the existing overloads.
16 ///
17 /// at::Tensor tensor = at::for_blob(data, sizes)
18 /// .strides(strides)
19 /// .context(context, [](void *ctx) { delete static_cast<Ctx*>(ctx);
20 /// }) .options(...) .make_tensor();
21 ///
22 class TORCH_API TensorMaker {
23 friend TensorMaker for_blob(void* data, IntArrayRef sizes) noexcept;
24
25 public:
26 using ContextDeleter = DeleterFnPtr;
27
strides(OptionalIntArrayRef value)28 TensorMaker& strides(OptionalIntArrayRef value) noexcept {
29 strides_ = value;
30
31 return *this;
32 }
33
storage_offset(std::optional<int64_t> value)34 TensorMaker& storage_offset(std::optional<int64_t> value) noexcept {
35 storage_offset_ = value;
36
37 return *this;
38 }
39
deleter(std::function<void (void *)> value)40 TensorMaker& deleter(std::function<void(void*)> value) noexcept {
41 deleter_ = std::move(value);
42
43 return *this;
44 }
45
46 TensorMaker& context(void* value, ContextDeleter deleter = nullptr) noexcept {
47 ctx_ = std::unique_ptr<void, ContextDeleter>{
48 value, deleter != nullptr ? deleter : detail::noopDelete};
49
50 return *this;
51 }
52
target_device(std::optional<Device> value)53 TensorMaker& target_device(std::optional<Device> value) noexcept {
54 device_ = value;
55
56 return *this;
57 }
58
options(TensorOptions value)59 TensorMaker& options(TensorOptions value) noexcept {
60 opts_ = value;
61
62 return *this;
63 }
64
resizeable_storage()65 TensorMaker& resizeable_storage() noexcept {
66 resizeable_ = true;
67
68 return *this;
69 }
70
allocator(c10::Allocator * allocator)71 TensorMaker& allocator(c10::Allocator* allocator) noexcept {
72 allocator_ = allocator;
73
74 return *this;
75 }
76
77 Tensor make_tensor();
78
79 private:
TensorMaker(void * data,IntArrayRef sizes)80 explicit TensorMaker(void* data, IntArrayRef sizes) noexcept
81 : data_{data}, sizes_{sizes} {}
82
83 std::size_t computeStorageSize() const noexcept;
84
85 DataPtr makeDataPtrFromDeleter() noexcept;
86
87 DataPtr makeDataPtrFromContext() noexcept;
88
89 IntArrayRef makeTempSizes() const noexcept;
90
91 void* data_;
92 IntArrayRef sizes_;
93 OptionalIntArrayRef strides_{};
94 std::optional<int64_t> storage_offset_{};
95 std::function<void(void*)> deleter_{};
96 std::unique_ptr<void, ContextDeleter> ctx_{nullptr, detail::noopDelete};
97 std::optional<Device> device_{};
98 TensorOptions opts_{};
99 bool resizeable_{};
100 c10::Allocator* allocator_{};
101 };
102
for_blob(void * data,IntArrayRef sizes)103 inline TensorMaker for_blob(void* data, IntArrayRef sizes) noexcept {
104 return TensorMaker{data, sizes};
105 }
106
107 inline Tensor from_blob(
108 void* data,
109 IntArrayRef sizes,
110 IntArrayRef strides,
111 const std::function<void(void*)>& deleter,
112 const TensorOptions& options = {},
113 const std::optional<Device> target_device = std::nullopt) {
114 return for_blob(data, sizes)
115 .strides(strides)
116 .deleter(deleter)
117 .options(options)
118 .target_device(target_device)
119 .make_tensor();
120 }
121
122 inline Tensor from_blob(
123 void* data,
124 IntArrayRef sizes,
125 IntArrayRef strides,
126 int64_t storage_offset,
127 const std::function<void(void*)>& deleter,
128 const TensorOptions& options = {},
129 const std::optional<Device> target_device = std::nullopt) {
130 return for_blob(data, sizes)
131 .strides(strides)
132 .storage_offset(storage_offset)
133 .deleter(deleter)
134 .options(options)
135 .target_device(target_device)
136 .make_tensor();
137 }
138
139 inline Tensor from_blob(
140 void* data,
141 IntArrayRef sizes,
142 std::function<void(void*)> deleter,
143 const TensorOptions& options = {},
144 const std::optional<Device> target_device = std::nullopt) {
145 return for_blob(data, sizes)
146 .deleter(std::move(deleter))
147 .options(options)
148 .target_device(target_device)
149 .make_tensor();
150 }
151
152 inline Tensor from_blob(
153 void* data,
154 IntArrayRef sizes,
155 IntArrayRef strides,
156 const TensorOptions& options = {}) {
157 return for_blob(data, sizes).strides(strides).options(options).make_tensor();
158 }
159
160 inline Tensor from_blob(
161 void* data,
162 IntArrayRef sizes,
163 const TensorOptions& options = {}) {
164 return for_blob(data, sizes).options(options).make_tensor();
165 }
166
167 } // namespace at
168