xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TensorFactories.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/EmptyTensor.h>
5 #include <ATen/TensorIterator.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/Dispatch_v2.h>
8 #include <ATen/native/DispatchStub.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #else
13 #include <ATen/ops/scalar_tensor.h>
14 #endif
15 
16 namespace at::native {
17 // Different combinations of row, col, and offset can lead to two cases:
18 //
19 // Case 1 - Trapezoid (Triangle as a special case): row + offset <= col
20 //    Example A: offset > 0
21 //      1 1 0 0 0
22 //      1 1 1 0 0
23 //      1 1 1 1 0
24 //    Example B: offset <= 0
25 //      0 0 0
26 //      1 0 0
27 //      1 1 0
28 //    In this case, we calculate the number of elements in the first row and
29 //    last row of the tril respectively, and then compute the tril size.
30 //
31 // Case 2 - Trapezoid + Rectangle: row + offset > col
32 //    Example:
33 //      1 1 0
34 //      1 1 1
35 //      1 1 1
36 //    In this case, we first calculate the size of top trapezoid, and then
37 //    calculate the size of the bottom rectangle.
get_tril_size(int64_t row,int64_t col,int64_t offset)38 inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
39   // If either dimension is 0 then the there is no tril
40   if (row == 0 || col == 0) {
41     return 0;
42   }
43   // number of elements in the first row of the tril
44   auto m_first_row = offset > 0 ?
45     std::min<int64_t>(col, 1 + offset) : // upper bounded by col
46     row + offset > 0; // either 0 or 1
47   // number of elements in the last row of the tril, bounded by [0, col]
48   auto m_last_row = std::max<int64_t>(0, std::min<int64_t>(col, row + offset));
49   // number of rows, bounded by [0, row]
50   auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(row, row + offset));
51   auto n_row_trapezoid = (m_last_row - m_first_row + 1);
52 
53   // calculate # of elements in the top trapezoid
54   auto tril_size = (m_first_row + m_last_row) * n_row_trapezoid >> 1;
55 
56   // calculate # of elements in the bottom rectangle if there is any
57   auto diff_row = n_row_all - n_row_trapezoid;
58   if (diff_row > 0) {
59     tril_size += diff_row * col;
60   }
61 
62   return tril_size;
63 }
64 
check_args(int64_t row,int64_t col,std::optional<Layout> layout_opt)65 inline void check_args(
66     int64_t row, int64_t col, std::optional<Layout> layout_opt) {
67   TORCH_CHECK(row >= 0, "row must be non-negative, got", row);
68   TORCH_CHECK(col >= 0, "col must be non-negative, got", col);
69   if (layout_opt.has_value()) {
70     TORCH_CHECK(
71       *layout_opt == at::kStrided,
72       "only support layout=torch.strided, got",
73       *layout_opt)
74   }
75 }
76 
77 using at::check_size_nonnegative;
78 
79 // assumes maximum value in created tensor is n-1 (e.g., torch.randperm(n))
check_supported_max_int_with_precision(int64_t n,const Tensor & tensor)80 inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) {
81   // match defined() to behavior of checks below
82   TORCH_CHECK(at::scalar_tensor(n>0?n-1:n, tensor.options()).defined(),
83               "n is too large for result tensor type: '", tensor.toString(), "'");
84 
85   // Ensure sufficient precision for floating point representation.
86   switch (tensor.scalar_type()) {
87     case at::ScalarType::Half:
88       TORCH_CHECK(n <= (int64_t(1) << 11) + 1, "n cannot be greater than 2049 for Half type.");
89       break;
90     case at::ScalarType::Float:
91       TORCH_CHECK(n <= (int64_t(1) << 24) + 1, "n cannot be greater than 2^24+1 for Float type.");
92       break;
93     case at::ScalarType::Double:  // Unlikely to happen, but doesn't hurt to check
94       TORCH_CHECK(n <= (int64_t(1) << 53) + 1, "n cannot be greater than 2^53+1 for Double type.");
95       break;
96     default:
97       break;
98   }
99 }
100 
101 // Called by `empty*` functions when deterministic algorithms are enabled to
102 // fill the tensor with NaN if it is floating point or complex type, or fill
103 // with max value if it is integer type
fill_empty_deterministic_(Tensor & tensor)104 inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
105   if (tensor.is_floating_point() || tensor.is_complex()) {
106     AT_DISPATCH_V2(
107       tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
108         tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
109     }), AT_EXPAND(AT_FLOATING_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), AT_EXPAND(AT_FLOAT8_TYPES), kBFloat16, kHalf);
110   } else {
111     AT_DISPATCH_V2(
112       tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
113         tensor.fill_(std::numeric_limits<scalar_t>::max());
114     }), kBool, AT_EXPAND(AT_INTEGRAL_TYPES_V2));
115   }
116   return tensor;
117 }
118 
119 // The ZeroTensor allocator ignores whatever allocation is requested and always
120 // gives you nullptr
121 struct ZeroTensorAllocator final : public at::Allocator {
ZeroTensorAllocatorfinal122   ZeroTensorAllocator(at::Device device) : device_(device) {};
123   ~ZeroTensorAllocator() override = default;
deleterfinal124   static void deleter(void* const pointer) {
125     TORCH_INTERNAL_ASSERT(!pointer);
126   }
allocatefinal127   DataPtr allocate(const size_t /*nbytes*/) override {
128     return {nullptr, nullptr, &deleter, device_};
129   }
raw_deleterfinal130   DeleterFnPtr raw_deleter() const override {
131     return deleter;
132   }
copy_datafinal133   void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const final {}
134   at::Device device_;
135 };
136 
137 using binary_fn = void (*)(TensorIterator&);
138 
139 DECLARE_DISPATCH(binary_fn, complex_stub);
140 DECLARE_DISPATCH(binary_fn, polar_stub);
141 
142 } // namespace at::native
143