xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/detail/TensorDataContainer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Dispatch.h>
4 #include <ATen/ScalarOps.h>
5 #include <ATen/core/Tensor.h>
6 #include <ATen/core/grad_mode.h>
7 
8 #include <c10/util/irange.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #else
13 #include <ATen/ops/empty.h>
14 #include <ATen/ops/tensor.h>
15 #endif
16 
17 #include <initializer_list>
18 
19 namespace torch {
20 
21 namespace detail {
22 
23 enum class TensorDataContainerType { Scalar, InitList, Tensor };
24 
25 struct TensorDataContainer;
26 
27 inline std::ostream& operator<<(
28     std::ostream& stream,
29     const TensorDataContainer& tensor_data_container);
30 
compute_desired_dtype(c10::ScalarType scalar_type)31 inline c10::ScalarType compute_desired_dtype(c10::ScalarType scalar_type) {
32   if (scalar_type == at::kInt || scalar_type == at::kLong) {
33     // C++ `torch::tensor` with an integer type or an `at::ArrayRef` /
34     // `std::vector` / (nested) braced-init-list of integer types always
35     // produces a tensor of dtype `at::kLong` (aka. int64_t), matching Python
36     // `torch.tensor` behavior.
37     return at::kLong;
38   } else if (scalar_type == at::kFloat || scalar_type == at::kDouble) {
39     // C++ `torch::tensor` with a floating-point type or an `at::ArrayRef` /
40     // `std::vector` / (nested) braced-init-list of floating-point types always
41     // produces a tensor of dtype `torch::get_default_dtype()`, matching Python
42     // `torch.tensor` behavior.
43     return at::typeMetaToScalarType(at::get_default_dtype());
44   } else {
45     return scalar_type;
46   }
47 }
48 
49 // We use `TensorDataContainer` to support converting the following data
50 // container types into the equivalent Tensor:
51 //
52 // 1. Arbitrarily nested braced-init-list (e.g. `{{1, 2}, {3, 4}}`).
53 // 2. `at::ArrayRef` of supported tensor data types.
54 // 3. `std::vector` of supported tensor data types.
55 //
56 // At any time, a `TensorDataContainer` object represents one of the following:
57 //
58 // 1. A scalar with value `scalar()` and type `scalar_type()`.
59 // 2. A Tensor represented in `std::initializer_list<TensorDataContainer>` form,
60 //    with value `init_list()`, Tensor scalar type `scalar_type()`, and Tensor
61 //    sizes `sizes()`.
62 // 3. A Tensor represented in `at::Tensor` form, with value `tensor()`, scalar
63 // type `scalar_type()`,
64 //    and Tensor sizes `sizes()`.
65 //
66 // All the infrastructure here is mostly to support converting an arbitrarily
67 // nested braced-init-list to the equivalent Tensor successfully. Consider the
68 // following example:
69 //
70 // `torch::tensor({{1}, {2}})`
71 //
72 // this will call into the `torch::tensor` function:
73 //
74 // `at::Tensor tensor(detail::TensorDataContainer tensor_data_container, const
75 // at::TensorOptions& options = {})`
76 //
77 // the compiler will first try to convert `{{1}, {2}}` to `TensorDataContainer`
78 // type:
79 //
80 // `TensorDataContainer({{1}, {2}})`
81 //
82 // which matches to the
83 // `TensorDataContainer(std::initializer_list<TensorDataContainer>)`
84 // constructor, and in an attempt to convert `{1}` and `{2}` to
85 // `TensorDataContainer`, it calls the following:
86 //
87 // `TensorDataContainer({1})`  (same call path happens for `{2}`, and we'll just
88 // focus on `{1}` here)
89 //
90 // At this point, theoretically there are two plausible ways for `{1}` to be
91 // matched to one of the constructors of `TensorDataContainer`:
92 //
93 // 1. It can be a list-initialization of a scalar value, thus matching
94 // `TensorDataContainer(int value)`.
95 // 2. It can be converted to `std::initializer_list<TensorDataContainer>`, thus
96 // matching
97 //    `TensorDataContainer(std::initializer_list<TensorDataContainer>)`.
98 //
99 // How does the compiler decide which one to choose? According to
100 // `https://en.cppreference.com/w/cpp/language/list_initialization`,
101 // braced-init-list always prefers the constructor that takes
102 // `std::initializer_list`. Hence we happily move forward with constructor #2,
103 // and it calls the following:
104 //
105 // `TensorDataContainer(1)`
106 //
107 // Now it matches `TensorDataContainer(int value)`, which stores `1` as a scalar
108 // value. All is good.
109 struct TensorDataContainer {
110   // NOTE: For tensors with zero-size dimensions (e.g. `torch::tensor({{},
111   // {}})`), the innermost empty braced-init-list `{}` matches the default
112   // constructor of the innermost `TensorDataContainer`.
113   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TensorDataContainerTensorDataContainer114   TensorDataContainer()
115       : sizes_({0}),
116         // NOTE: In Python, the dtype of tensors with zero-size dimensions (e.g.
117         // `torch.tensor([[], []])`) depends on the value of
118         // `torch.get_default_dtype()`, and we should do the same for the C++
119         // equivalent.
120         scalar_type_(at::typeMetaToScalarType(at::get_default_dtype())),
121         type_(TensorDataContainerType::InitList) {}
122 #define TENSOR(T, S)                            \
123   TensorDataContainer(T value)                  \
124       : sizes_(),                               \
125         scalar_type_(at::k##S),                 \
126         type_(TensorDataContainerType::Scalar), \
127         scalar_(value) {}
128   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
AT_FORALL_SCALAR_TYPES_AND3TensorDataContainer129   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
130   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
131   AT_FORALL_COMPLEX_TYPES(TENSOR)
132 #undef TENSOR
133   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
134   TensorDataContainer(std::initializer_list<TensorDataContainer> init_list)
135       : sizes_(),
136         scalar_type_(init_list.begin()->scalar_type()),
137         type_(TensorDataContainerType::InitList),
138         init_list_(init_list) {
139     const TensorDataContainer& first_elem = *(init_list.begin());
140     for (const auto& elem : init_list) {
141       TORCH_CHECK(
142           elem.sizes() == first_elem.sizes(),
143           "Expected all sub-lists to have sizes: ",
144           first_elem.sizes(),
145           " (e.g. ",
146           first_elem,
147           "), ",
148           "but got sub-list ",
149           elem,
150           " with sizes: ",
151           elem.sizes());
152       TORCH_CHECK(
153           elem.scalar_type() == first_elem.scalar_type(),
154           "Expected all elements of the tensor to have the same scalar type: ",
155           first_elem.scalar_type(),
156           ", but got element of scalar type: ",
157           elem.scalar_type());
158     }
159     sizes_.reserve(first_elem.sizes().size() + 1);
160     sizes_.push_back(init_list.size());
161     sizes_.insert(
162         sizes_.end(), first_elem.sizes().begin(), first_elem.sizes().end());
163   }
164 
165 #define TENSOR(T, S)                                                          \
166   TensorDataContainer(at::ArrayRef<T> values)                                 \
167       : sizes_({(int64_t)values.size()}),                                     \
168         scalar_type_(at::k##S),                                               \
169         type_(TensorDataContainerType::Tensor) {                              \
170     at::AutoDispatchBelowAutograd mode;                                       \
171     if (scalar_type_ == at::kBool) {                                          \
172       tensor_ = at::tensor(values, at::TensorOptions().device(at::kCPU));     \
173     } else {                                                                  \
174       tensor_ = at::tensor(values, at::dtype(scalar_type_).device(at::kCPU)); \
175     }                                                                         \
176   }
177   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
AT_FORALL_SCALAR_TYPES_AND3TensorDataContainer178   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
179   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
180   AT_FORALL_COMPLEX_TYPES(TENSOR)
181 #undef TENSOR
182 
183   // NOTE: We need to handle `std::vector` explicitly instead of relying on an
184   // implicit conversion to `at::ArrayRef`, otherwise the following error can be
185   // thrown when calling `torch::tensor(std::vector<int>({1, 2}))`:
186   // ```
187   // error: no matching function for call to 'tensor(const std::vector<int>&)'
188   // no known conversion for argument 1 from 'const std::vector<int>' to
189   // 'torch::detail::TensorDataContainer'
190   // ```
191   //
192   // NOTE: `torch::tensor(std::vector<bool>)` is not supported for now, because
193   // ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.
194 #define TENSOR(T, S)                                \
195   TensorDataContainer(const std::vector<T>& values) \
196       : TensorDataContainer(at::ArrayRef<T>(values)) {}
197   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
198   AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TENSOR)
199   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
200   AT_FORALL_COMPLEX_TYPES(TENSOR)
201 #undef TENSOR
202 
203   bool is_scalar() const {
204     return type_ == TensorDataContainerType::Scalar;
205   }
206 
scalarTensorDataContainer207   const c10::Scalar& scalar() const {
208     TORCH_CHECK(
209         is_scalar(),
210         "Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`");
211     return scalar_;
212   }
213 
is_init_listTensorDataContainer214   bool is_init_list() const {
215     return type_ == TensorDataContainerType::InitList;
216   }
217 
init_listTensorDataContainer218   const std::initializer_list<TensorDataContainer>& init_list() const {
219     TORCH_CHECK(
220         is_init_list(),
221         "Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`");
222     return init_list_;
223   }
224 
is_tensorTensorDataContainer225   bool is_tensor() const {
226     return type_ == TensorDataContainerType::Tensor;
227   }
228 
tensorTensorDataContainer229   const at::Tensor& tensor() const {
230     TORCH_CHECK(
231         is_tensor(),
232         "Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`");
233     return tensor_;
234   }
235 
sizesTensorDataContainer236   const std::vector<int64_t>& sizes() const {
237     return sizes_;
238   }
239 
scalar_typeTensorDataContainer240   const c10::ScalarType& scalar_type() const {
241     return scalar_type_;
242   }
243 
convert_to_tensorTensorDataContainer244   at::Tensor convert_to_tensor(at::TensorOptions options) const {
245     if (!options.has_dtype()) {
246       options = options.dtype(compute_desired_dtype(scalar_type_));
247     }
248 
249     if (is_scalar()) {
250       at::AutoDispatchBelowAutograd mode;
251       return at::scalar_tensor(scalar_, options);
252     } else if (is_init_list()) {
253       // NOTE: Here we explicitly choose to initialize the tensor on CPU first,
254       // fill each element of the tensor, and then move the tensor to the
255       // desired device. For CUDA device, this approach only involves 1 CUDA
256       // kernel launch, and is much faster than initializing the tensor on CUDA
257       // first and then filling each element of it (which involves `N` CUDA
258       // kernel launches where `N` is the number of the elements in the tensor).
259       at::Tensor tensor = ([&]() {
260         at::AutoDispatchBelowAutograd mode;
261         return at::empty(sizes_, options.device(at::kCPU));
262       })();
263       fill_tensor(tensor);
264       return tensor.to(options.device());
265     } else if (is_tensor()) {
266       auto output = tensor_.to(options);
267       TORCH_CHECK(
268           !tensor_.is_complex() || output.is_complex(),
269           "can not do torch::tensor(complex, dtype=non-complex) because complex can not be casted to real number without loss of information");
270       return output;
271     } else {
272       TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type");
273     }
274   }
275 
pretty_print_recursiveTensorDataContainer276   void pretty_print_recursive(std::ostream& stream) const {
277     if (is_scalar()) {
278       AT_DISPATCH_ALL_TYPES_AND3(
279           at::kBool,
280           at::kHalf,
281           at::kBFloat16,
282           scalar_type_,
283           "TensorDataContainer_pretty_print_scalar",
284           [&] { stream << scalar_.to<scalar_t>(); });
285     } else if (is_init_list()) {
286       stream << "{";
287       for (const TensorDataContainer* it = init_list_.begin();
288            it != init_list_.end();
289            it++) {
290         stream << *it;
291         if (std::next(it) != init_list_.end())
292           stream << ", ";
293       }
294       stream << "}";
295     } else if (is_tensor()) {
296       stream << "{";
297       for (const auto i : c10::irange(tensor_.sizes()[0])) {
298         AT_DISPATCH_ALL_TYPES_AND3(
299             at::kBool,
300             at::kHalf,
301             at::kBFloat16,
302             scalar_type_,
303             "TensorDataContainer_pretty_print_tensor_item",
304             [&] { stream << tensor_[i].item<scalar_t>(); });
305         if (i != tensor_.sizes()[0] - 1)
306           stream << ", ";
307       }
308       stream << "}";
309     } else {
310       TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type");
311     }
312   }
313 
314  private:
fill_tensorTensorDataContainer315   void fill_tensor(at::Tensor& tensor) const {
316     if (is_scalar()) {
317       TORCH_INTERNAL_ASSERT(
318           tensor.dim() == 0,
319           "Expected a 0-dim Tensor, but got Tensor with dimensions: ",
320           tensor.dim());
321       at::NoGradGuard guard;
322       tensor.fill_(scalar_);
323     } else if (is_init_list()) {
324       TORCH_INTERNAL_ASSERT(
325           tensor.sizes()[0] == (int64_t)init_list_.size(),
326           "Expected a Tensor with size ",
327           init_list_.size(),
328           " in its first dimension, but got Tensor with size ",
329           tensor.sizes()[0],
330           " in its first dimension");
331       size_t index = 0;
332       for (const auto& elem : init_list_) {
333         at::Tensor slice = tensor[index];
334         elem.fill_tensor(slice);
335         index++;
336       }
337     } else if (is_tensor()) {
338       TORCH_INTERNAL_ASSERT(
339           false,
340           "TensorDataContainer is already a Tensor type, `fill_tensor` should not be called");
341     } else {
342       TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type");
343     }
344   }
345 
346   std::vector<int64_t> sizes_;
347   c10::ScalarType scalar_type_;
348   TensorDataContainerType type_;
349   c10::Scalar scalar_;
350   std::initializer_list<TensorDataContainer> init_list_;
351   at::Tensor tensor_;
352 };
353 
354 inline std::ostream& operator<<(
355     std::ostream& stream,
356     const TensorDataContainer& tensor_data_container) {
357   tensor_data_container.pretty_print_recursive(stream);
358   return stream;
359 }
360 
361 } // namespace detail
362 
363 } // namespace torch
364