xref: /aosp_15_r20/external/pytorch/aten/src/ATen/Utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/EmptyTensor.h>
4 #include <ATen/Formatting.h>
5 #include <ATen/core/ATenGeneral.h>
6 #include <ATen/core/Generator.h>
7 #include <c10/core/ScalarType.h>
8 #include <c10/core/StorageImpl.h>
9 #include <c10/core/UndefinedTensorImpl.h>
10 #include <c10/util/ArrayRef.h>
11 #include <c10/util/Exception.h>
12 #include <c10/util/accumulate.h>
13 #include <c10/util/irange.h>
14 
15 #include <algorithm>
16 
17 #define AT_DISALLOW_COPY_AND_ASSIGN(TypeName) \
18   TypeName(const TypeName&) = delete;         \
19   void operator=(const TypeName&) = delete
20 
21 namespace at {
22 
23 TORCH_API int _crash_if_asan(int);
24 
25 // Converts a TensorList (i.e. ArrayRef<Tensor> to vector of TensorImpl*)
26 // NB: This is ONLY used by legacy TH bindings, and ONLY used by cat.
27 // Once cat is ported entirely to ATen this can be deleted!
checked_dense_tensor_list_unwrap(ArrayRef<Tensor> tensors,const char * name,int pos,c10::DeviceType device_type,ScalarType scalar_type)28 inline std::vector<TensorImpl*> checked_dense_tensor_list_unwrap(
29     ArrayRef<Tensor> tensors,
30     const char* name,
31     int pos,
32     c10::DeviceType device_type,
33     ScalarType scalar_type) {
34   std::vector<TensorImpl*> unwrapped;
35   unwrapped.reserve(tensors.size());
36   for (const auto i : c10::irange(tensors.size())) {
37     const auto& expr = tensors[i];
38     if (expr.layout() != Layout::Strided) {
39       AT_ERROR(
40           "Expected dense tensor but got ",
41           expr.layout(),
42           " for sequence element ",
43           i,
44           " in sequence argument at position #",
45           pos,
46           " '",
47           name,
48           "'");
49     }
50     if (expr.device().type() != device_type) {
51       AT_ERROR(
52           "Expected object of device type ",
53           device_type,
54           " but got device type ",
55           expr.device().type(),
56           " for sequence element ",
57           i,
58           " in sequence argument at position #",
59           pos,
60           " '",
61           name,
62           "'");
63     }
64     if (expr.scalar_type() != scalar_type) {
65       AT_ERROR(
66           "Expected object of scalar type ",
67           scalar_type,
68           " but got scalar type ",
69           expr.scalar_type(),
70           " for sequence element ",
71           i,
72           " in sequence argument at position #",
73           pos,
74           " '",
75           name,
76           "'");
77     }
78     unwrapped.emplace_back(expr.unsafeGetTensorImpl());
79   }
80   return unwrapped;
81 }
82 
83 template <size_t N>
check_intlist(ArrayRef<int64_t> list,const char * name,int pos)84 std::array<int64_t, N> check_intlist(
85     ArrayRef<int64_t> list,
86     const char* name,
87     int pos) {
88   if (list.empty()) {
89     // TODO: is this necessary?  We used to treat nullptr-vs-not in IntList
90     // differently with strides as a way of faking optional.
91     list = {};
92   }
93   auto res = std::array<int64_t, N>();
94   if (list.size() == 1 && N > 1) {
95     res.fill(list[0]);
96     return res;
97   }
98   if (list.size() != N) {
99     AT_ERROR(
100         "Expected a list of ",
101         N,
102         " ints but got ",
103         list.size(),
104         " for argument #",
105         pos,
106         " '",
107         name,
108         "'");
109   }
110   std::copy_n(list.begin(), N, res.begin());
111   return res;
112 }
113 
114 using at::detail::check_size_nonnegative;
115 
116 namespace detail {
117 
118 template <typename T>
119 TORCH_API Tensor tensor_cpu(ArrayRef<T> values, const TensorOptions& options);
120 
121 template <typename T>
122 TORCH_API Tensor
123 tensor_backend(ArrayRef<T> values, const TensorOptions& options);
124 
125 template <typename T>
126 TORCH_API Tensor
127 tensor_complex_cpu(ArrayRef<T> values, const TensorOptions& options);
128 
129 template <typename T>
130 TORCH_API Tensor
131 tensor_complex_backend(ArrayRef<T> values, const TensorOptions& options);
132 } // namespace detail
133 
134 } // namespace at
135