1 #pragma once
2
3 #include <ATen/Dispatch.h>
4 #include <ATen/ScalarOps.h>
5 #include <ATen/core/Tensor.h>
6 #include <ATen/core/grad_mode.h>
7
8 #include <c10/util/irange.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #else
13 #include <ATen/ops/empty.h>
14 #include <ATen/ops/tensor.h>
15 #endif
16
17 #include <initializer_list>
18
19 namespace torch {
20
21 namespace detail {
22
23 enum class TensorDataContainerType { Scalar, InitList, Tensor };
24
25 struct TensorDataContainer;
26
27 inline std::ostream& operator<<(
28 std::ostream& stream,
29 const TensorDataContainer& tensor_data_container);
30
compute_desired_dtype(c10::ScalarType scalar_type)31 inline c10::ScalarType compute_desired_dtype(c10::ScalarType scalar_type) {
32 if (scalar_type == at::kInt || scalar_type == at::kLong) {
33 // C++ `torch::tensor` with an integer type or an `at::ArrayRef` /
34 // `std::vector` / (nested) braced-init-list of integer types always
35 // produces a tensor of dtype `at::kLong` (aka. int64_t), matching Python
36 // `torch.tensor` behavior.
37 return at::kLong;
38 } else if (scalar_type == at::kFloat || scalar_type == at::kDouble) {
39 // C++ `torch::tensor` with a floating-point type or an `at::ArrayRef` /
40 // `std::vector` / (nested) braced-init-list of floating-point types always
41 // produces a tensor of dtype `torch::get_default_dtype()`, matching Python
42 // `torch.tensor` behavior.
43 return at::typeMetaToScalarType(at::get_default_dtype());
44 } else {
45 return scalar_type;
46 }
47 }
48
49 // We use `TensorDataContainer` to support converting the following data
50 // container types into the equivalent Tensor:
51 //
52 // 1. Arbitrarily nested braced-init-list (e.g. `{{1, 2}, {3, 4}}`).
53 // 2. `at::ArrayRef` of supported tensor data types.
54 // 3. `std::vector` of supported tensor data types.
55 //
56 // At any time, a `TensorDataContainer` object represents one of the following:
57 //
58 // 1. A scalar with value `scalar()` and type `scalar_type()`.
59 // 2. A Tensor represented in `std::initializer_list<TensorDataContainer>` form,
60 // with value `init_list()`, Tensor scalar type `scalar_type()`, and Tensor
61 // sizes `sizes()`.
62 // 3. A Tensor represented in `at::Tensor` form, with value `tensor()`, scalar
63 // type `scalar_type()`,
64 // and Tensor sizes `sizes()`.
65 //
66 // All the infrastructure here is mostly to support converting an arbitrarily
67 // nested braced-init-list to the equivalent Tensor successfully. Consider the
68 // following example:
69 //
70 // `torch::tensor({{1}, {2}})`
71 //
72 // this will call into the `torch::tensor` function:
73 //
74 // `at::Tensor tensor(detail::TensorDataContainer tensor_data_container, const
75 // at::TensorOptions& options = {})`
76 //
77 // the compiler will first try to convert `{{1}, {2}}` to `TensorDataContainer`
78 // type:
79 //
80 // `TensorDataContainer({{1}, {2}})`
81 //
82 // which matches to the
83 // `TensorDataContainer(std::initializer_list<TensorDataContainer>)`
84 // constructor, and in an attempt to convert `{1}` and `{2}` to
85 // `TensorDataContainer`, it calls the following:
86 //
87 // `TensorDataContainer({1})` (same call path happens for `{2}`, and we'll just
88 // focus on `{1}` here)
89 //
90 // At this point, theoretically there are two plausible ways for `{1}` to be
91 // matched to one of the constructors of `TensorDataContainer`:
92 //
93 // 1. It can be a list-initialization of a scalar value, thus matching
94 // `TensorDataContainer(int value)`.
95 // 2. It can be converted to `std::initializer_list<TensorDataContainer>`, thus
96 // matching
97 // `TensorDataContainer(std::initializer_list<TensorDataContainer>)`.
98 //
99 // How does the compiler decide which one to choose? According to
100 // `https://en.cppreference.com/w/cpp/language/list_initialization`,
101 // braced-init-list always prefers the constructor that takes
102 // `std::initializer_list`. Hence we happily move forward with constructor #2,
103 // and it calls the following:
104 //
105 // `TensorDataContainer(1)`
106 //
107 // Now it matches `TensorDataContainer(int value)`, which stores `1` as a scalar
108 // value. All is good.
109 struct TensorDataContainer {
110 // NOTE: For tensors with zero-size dimensions (e.g. `torch::tensor({{},
111 // {}})`), the innermost empty braced-init-list `{}` matches the default
112 // constructor of the innermost `TensorDataContainer`.
113 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TensorDataContainerTensorDataContainer114 TensorDataContainer()
115 : sizes_({0}),
116 // NOTE: In Python, the dtype of tensors with zero-size dimensions (e.g.
117 // `torch.tensor([[], []])`) depends on the value of
118 // `torch.get_default_dtype()`, and we should do the same for the C++
119 // equivalent.
120 scalar_type_(at::typeMetaToScalarType(at::get_default_dtype())),
121 type_(TensorDataContainerType::InitList) {}
122 #define TENSOR(T, S) \
123 TensorDataContainer(T value) \
124 : sizes_(), \
125 scalar_type_(at::k##S), \
126 type_(TensorDataContainerType::Scalar), \
127 scalar_(value) {}
128 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
AT_FORALL_SCALAR_TYPES_AND3TensorDataContainer129 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
130 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
131 AT_FORALL_COMPLEX_TYPES(TENSOR)
132 #undef TENSOR
133 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
134 TensorDataContainer(std::initializer_list<TensorDataContainer> init_list)
135 : sizes_(),
136 scalar_type_(init_list.begin()->scalar_type()),
137 type_(TensorDataContainerType::InitList),
138 init_list_(init_list) {
139 const TensorDataContainer& first_elem = *(init_list.begin());
140 for (const auto& elem : init_list) {
141 TORCH_CHECK(
142 elem.sizes() == first_elem.sizes(),
143 "Expected all sub-lists to have sizes: ",
144 first_elem.sizes(),
145 " (e.g. ",
146 first_elem,
147 "), ",
148 "but got sub-list ",
149 elem,
150 " with sizes: ",
151 elem.sizes());
152 TORCH_CHECK(
153 elem.scalar_type() == first_elem.scalar_type(),
154 "Expected all elements of the tensor to have the same scalar type: ",
155 first_elem.scalar_type(),
156 ", but got element of scalar type: ",
157 elem.scalar_type());
158 }
159 sizes_.reserve(first_elem.sizes().size() + 1);
160 sizes_.push_back(init_list.size());
161 sizes_.insert(
162 sizes_.end(), first_elem.sizes().begin(), first_elem.sizes().end());
163 }
164
165 #define TENSOR(T, S) \
166 TensorDataContainer(at::ArrayRef<T> values) \
167 : sizes_({(int64_t)values.size()}), \
168 scalar_type_(at::k##S), \
169 type_(TensorDataContainerType::Tensor) { \
170 at::AutoDispatchBelowAutograd mode; \
171 if (scalar_type_ == at::kBool) { \
172 tensor_ = at::tensor(values, at::TensorOptions().device(at::kCPU)); \
173 } else { \
174 tensor_ = at::tensor(values, at::dtype(scalar_type_).device(at::kCPU)); \
175 } \
176 }
177 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
AT_FORALL_SCALAR_TYPES_AND3TensorDataContainer178 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
179 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
180 AT_FORALL_COMPLEX_TYPES(TENSOR)
181 #undef TENSOR
182
183 // NOTE: We need to handle `std::vector` explicitly instead of relying on an
184 // implicit conversion to `at::ArrayRef`, otherwise the following error can be
185 // thrown when calling `torch::tensor(std::vector<int>({1, 2}))`:
186 // ```
187 // error: no matching function for call to 'tensor(const std::vector<int>&)'
188 // no known conversion for argument 1 from 'const std::vector<int>' to
189 // 'torch::detail::TensorDataContainer'
190 // ```
191 //
192 // NOTE: `torch::tensor(std::vector<bool>)` is not supported for now, because
193 // ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.
194 #define TENSOR(T, S) \
195 TensorDataContainer(const std::vector<T>& values) \
196 : TensorDataContainer(at::ArrayRef<T>(values)) {}
197 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
198 AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TENSOR)
199 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
200 AT_FORALL_COMPLEX_TYPES(TENSOR)
201 #undef TENSOR
202
203 bool is_scalar() const {
204 return type_ == TensorDataContainerType::Scalar;
205 }
206
scalarTensorDataContainer207 const c10::Scalar& scalar() const {
208 TORCH_CHECK(
209 is_scalar(),
210 "Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`");
211 return scalar_;
212 }
213
is_init_listTensorDataContainer214 bool is_init_list() const {
215 return type_ == TensorDataContainerType::InitList;
216 }
217
init_listTensorDataContainer218 const std::initializer_list<TensorDataContainer>& init_list() const {
219 TORCH_CHECK(
220 is_init_list(),
221 "Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`");
222 return init_list_;
223 }
224
is_tensorTensorDataContainer225 bool is_tensor() const {
226 return type_ == TensorDataContainerType::Tensor;
227 }
228
tensorTensorDataContainer229 const at::Tensor& tensor() const {
230 TORCH_CHECK(
231 is_tensor(),
232 "Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`");
233 return tensor_;
234 }
235
sizesTensorDataContainer236 const std::vector<int64_t>& sizes() const {
237 return sizes_;
238 }
239
scalar_typeTensorDataContainer240 const c10::ScalarType& scalar_type() const {
241 return scalar_type_;
242 }
243
convert_to_tensorTensorDataContainer244 at::Tensor convert_to_tensor(at::TensorOptions options) const {
245 if (!options.has_dtype()) {
246 options = options.dtype(compute_desired_dtype(scalar_type_));
247 }
248
249 if (is_scalar()) {
250 at::AutoDispatchBelowAutograd mode;
251 return at::scalar_tensor(scalar_, options);
252 } else if (is_init_list()) {
253 // NOTE: Here we explicitly choose to initialize the tensor on CPU first,
254 // fill each element of the tensor, and then move the tensor to the
255 // desired device. For CUDA device, this approach only involves 1 CUDA
256 // kernel launch, and is much faster than initializing the tensor on CUDA
257 // first and then filling each element of it (which involves `N` CUDA
258 // kernel launches where `N` is the number of the elements in the tensor).
259 at::Tensor tensor = ([&]() {
260 at::AutoDispatchBelowAutograd mode;
261 return at::empty(sizes_, options.device(at::kCPU));
262 })();
263 fill_tensor(tensor);
264 return tensor.to(options.device());
265 } else if (is_tensor()) {
266 auto output = tensor_.to(options);
267 TORCH_CHECK(
268 !tensor_.is_complex() || output.is_complex(),
269 "can not do torch::tensor(complex, dtype=non-complex) because complex can not be casted to real number without loss of information");
270 return output;
271 } else {
272 TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type");
273 }
274 }
275
pretty_print_recursiveTensorDataContainer276 void pretty_print_recursive(std::ostream& stream) const {
277 if (is_scalar()) {
278 AT_DISPATCH_ALL_TYPES_AND3(
279 at::kBool,
280 at::kHalf,
281 at::kBFloat16,
282 scalar_type_,
283 "TensorDataContainer_pretty_print_scalar",
284 [&] { stream << scalar_.to<scalar_t>(); });
285 } else if (is_init_list()) {
286 stream << "{";
287 for (const TensorDataContainer* it = init_list_.begin();
288 it != init_list_.end();
289 it++) {
290 stream << *it;
291 if (std::next(it) != init_list_.end())
292 stream << ", ";
293 }
294 stream << "}";
295 } else if (is_tensor()) {
296 stream << "{";
297 for (const auto i : c10::irange(tensor_.sizes()[0])) {
298 AT_DISPATCH_ALL_TYPES_AND3(
299 at::kBool,
300 at::kHalf,
301 at::kBFloat16,
302 scalar_type_,
303 "TensorDataContainer_pretty_print_tensor_item",
304 [&] { stream << tensor_[i].item<scalar_t>(); });
305 if (i != tensor_.sizes()[0] - 1)
306 stream << ", ";
307 }
308 stream << "}";
309 } else {
310 TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type");
311 }
312 }
313
314 private:
fill_tensorTensorDataContainer315 void fill_tensor(at::Tensor& tensor) const {
316 if (is_scalar()) {
317 TORCH_INTERNAL_ASSERT(
318 tensor.dim() == 0,
319 "Expected a 0-dim Tensor, but got Tensor with dimensions: ",
320 tensor.dim());
321 at::NoGradGuard guard;
322 tensor.fill_(scalar_);
323 } else if (is_init_list()) {
324 TORCH_INTERNAL_ASSERT(
325 tensor.sizes()[0] == (int64_t)init_list_.size(),
326 "Expected a Tensor with size ",
327 init_list_.size(),
328 " in its first dimension, but got Tensor with size ",
329 tensor.sizes()[0],
330 " in its first dimension");
331 size_t index = 0;
332 for (const auto& elem : init_list_) {
333 at::Tensor slice = tensor[index];
334 elem.fill_tensor(slice);
335 index++;
336 }
337 } else if (is_tensor()) {
338 TORCH_INTERNAL_ASSERT(
339 false,
340 "TensorDataContainer is already a Tensor type, `fill_tensor` should not be called");
341 } else {
342 TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type");
343 }
344 }
345
346 std::vector<int64_t> sizes_;
347 c10::ScalarType scalar_type_;
348 TensorDataContainerType type_;
349 c10::Scalar scalar_;
350 std::initializer_list<TensorDataContainer> init_list_;
351 at::Tensor tensor_;
352 };
353
354 inline std::ostream& operator<<(
355 std::ostream& stream,
356 const TensorDataContainer& tensor_data_container) {
357 tensor_data_container.pretty_print_recursive(stream);
358 return stream;
359 }
360
361 } // namespace detail
362
363 } // namespace torch
364