1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/cuda/CUDAApplyUtils.cuh>
5 #include <ATen/cuda/CUDAContext.h>
6 #include <ATen/cuda/EmptyTensor.h>
7 #include <ATen/InitialTensorOptions.h>
8 #include <ATen/native/cuda/Resize.h>
9 #include <ATen/native/TensorFactories.h>
10 #include <c10/util/accumulate.h>
11 #include <c10/util/Exception.h>
12 #include <ATen/native/cuda/Loops.cuh>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/_efficientzerotensor_native.h>
19 #include <ATen/ops/empty_native.h>
20 #include <ATen/ops/empty_strided_native.h>
21 #include <ATen/ops/eye_native.h>
22 #include <ATen/ops/tril_indices_native.h>
23 #include <ATen/ops/tril_native.h>
24 #include <ATen/ops/triu_indices_native.h>
25 #include <ATen/ops/triu_native.h>
26 #endif
27
28 #include <algorithm>
29 #include <cmath>
30 #include <cstddef>
31
32 namespace at::native {
33
eye_out_cuda(int64_t n,Tensor & result)34 Tensor& eye_out_cuda(int64_t n, Tensor& result) {
35 // the default value of `m` equals to `n`
36 return at::native::eye_out_cuda(n, n, result);
37 }
38
eye_out_cuda(int64_t n,int64_t m,Tensor & result)39 Tensor& eye_out_cuda(int64_t n, int64_t m, Tensor& result) {
40 TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n);
41 TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m);
42
43 result.resize_({n, m});
44 result.zero_();
45
46 int64_t sz = std::min<int64_t>(n, m);
47 int64_t stride = result.stride(0) + result.stride(1);
48
49 Tensor diag = result.as_strided({sz}, {stride});
50 diag.fill_(1);
51 return result;
52 }
53
empty_cuda(IntArrayRef size,std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt,std::optional<c10::MemoryFormat> memory_format_opt)54 Tensor empty_cuda(IntArrayRef size, std::optional<ScalarType> dtype_opt, std::optional<Layout> layout_opt, std::optional<Device> device_opt, std::optional<bool> pin_memory_opt, std::optional<c10::MemoryFormat> memory_format_opt) {
55 Tensor result = at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
56 // See Note [Enabling Deterministic Operations]
57 if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
58 fill_empty_deterministic_(result);
59 }
60 return result;
61 }
62
_efficientzerotensor_cuda(IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)63 Tensor _efficientzerotensor_cuda(IntArrayRef size,
64 std::optional<ScalarType> dtype,
65 std::optional<Layout> layout,
66 std::optional<Device> device,
67 std::optional<bool> pin_memory) {
68 auto device_ = device_or_default(device);
69 if (!device_.has_index()) {
70 device_.set_index(at::cuda::current_device());
71 }
72 auto allocator = at::native::ZeroTensorAllocator(device_);
73 auto dtype_ = dtype_or_default(dtype);
74 auto zero_ks = at::DispatchKeySet(c10::DispatchKey::CUDA) | at::DispatchKeySet(c10::DispatchKey::ZeroTensor);
75 auto out = at::detail::empty_generic(size, &allocator, zero_ks, dtype_, std::nullopt);
76 return out;
77 }
78
79
empty_strided_cuda(IntArrayRef size,IntArrayRef stride,std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt)80 Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, std::optional<ScalarType> dtype_opt, std::optional<Layout> layout_opt, std::optional<Device> device_opt, std::optional<bool> pin_memory_opt) {
81 Tensor result = at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
82 // See Note [Enabling Deterministic Operations]
83 if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
84 fill_empty_deterministic_(result);
85 }
86 return result;
87 }
88
89 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
90
91 namespace {
92 // To find the max integer that does not exceed the root of an int64_t variable,
93 // we could use a loop to test one bit at a time, which takes up to 31
94 // iterations. This would give the accurate result, but is relatively slow and
95 // is an overkill for most cases where double's precision suffice.
96 //
97 // If we directly use sqrt to calculate the root, the conversion from int64_t
98 // to double would lose 11 bits precision.
99 //
100 // The following solution uses sqrt directly for most cases, and would only
101 // special handle it if there is indeed precision loss.
102 __device__
resolve_root_int(int64_t b,int64_t cX4,int64_t x,int32_t sign)103 inline int64_t resolve_root_int(
104 int64_t b, int64_t cX4, int64_t x, int32_t sign) {
105 int64_t bXb_cX4 = b*b - cX4;
106 // potential precision loss could occur here when casting int64_t (63 bits
107 // precision) to double (52 bits precision)
108 double sr = ::sqrt((double)bXb_cX4);
109 int64_t res = ::__double2ll_rd((-b + sign * sr)/2);
110
111 // have to cast double to int64_t, otherwise it would only compare up to the
112 // precision of a double variable, ignoring the precision loss
113 if (bXb_cX4 != (int64_t) (sr * sr)) {
114 // handle precision loss by using binary search
115 int64_t llsr = ::__double2ll_rd(sr);
116 // Use the following math to reduce search space.
117 // Suppose z is the accurate result of sqrt(bXb_cX4) without precision loss
118 // let d = abs(bXb_cX4 - llsr * llsr), then we have:
119 // z = sqrt(bXb_cX4) <= sqrt(llsr * llsr + d) <= llsr + sqrt(d)
120 // z = sqrt(bXb_cX4) >= sqrt(llsr * llsr - d) >= llsr - sqrt(d)
121 // Hence, it is sufficient to search range [llsr - sqrt(d), llsr + sqrt(d)).
122 // And the true value of row would also be with in range,
123 // [res - sqrt(d), res + sqrt(d) + 1)
124 // as the denominator would only reduce the precision penalty.
125 int64_t diff =
126 ::__double2ll_ru(::sqrt(::fabs((double)(bXb_cX4 - llsr * llsr))));
127 // l never exceeds (could equal to) the target row index
128 auto l = res > diff ? res - diff : 0;
129 // r is always larger than the target row index
130 auto r = res + diff + 1;
131
132 // binary search for the correct answer
133 x <<= 1; // the loop always compares with 2x, so do it once here
134 while (l + 1 < r) {
135 auto m = (l + r) >> 1;
136 // for tril:
137 // b = 2f - 1, sign = 1, hence (2f + m - 1) * m / 2
138 // for triu:
139 // b = -2f - 1, sign = -1, hence (2f - m + 1) * m / 2
140 if (sign * (b + m) * m > x) {
141 r = m;
142 } else {
143 l = m;
144 }
145 }
146 res = l;
147 }
148
149 return res;
150 }
151
152 // f: the number of elements in the first row of the trapezoid.
153 // x: the index of the target coordinates ordered by row and then column.
154 //
155 // View the tril as a top trapezoid stacked on a bottom rectangle. Assume x
156 // corresponds to the coordinate (row, col) in the trapezoid, where the row and
157 // the col both start from 0, then we have:
158 //
159 // (f + f + row - 1) * row / 2 <= x [1]
160 // (f + f + row) * (row + 1) / 2 > x [2]
161 //
162 // Therefore, row is the maximum integer satisfying the following inequality:
163 //
164 // (row + 2f - 1)row <= 2x
165 // row^2 + (2f-1)row - 2x <= 0. [3]
166 //
167 // Based on inequality [3], we have the following coefficients for formula of
168 // root:
169 // a = 1
170 // b = 2f - 1
171 // c = -2x
172 // There are two roots, and we should use the largest integer that does not
173 // exceed the root on the right. Intuitively, it is because:
174 // i) the valid solution range of row is between two roots, as it is <= 0;
175 // ii) as we count in more rows, the total # of elements should always
176 // increase, hence so does the left-hand side row^2 + (2f-1)row - 2x.
177 // Therefore, the valid range of row lies in between the nadir point and
178 // the larger root on the right.
179 // Full proof can be derived from inequality [2]. So, we calculate the result
180 // coordinate as:
181 //
182 // row = floor((-b + sqrt(b^2 - 4c)) / 2)
183 // col = x - (f + f + row - 1) * row / 2
184 __device__
get_coordinate_in_tril_trapezoid(int64_t f,int64_t x,int64_t & row,int64_t & col)185 inline void get_coordinate_in_tril_trapezoid(
186 int64_t f, int64_t x, int64_t & row, int64_t & col) {
187 f <<= 1; // all statements use 2f, so only calculate it once here.
188 auto b = f - 1;
189 auto cX4 = - (x << 3); // 4 * c = 4 * (-2x) = -8x;
190 row = resolve_root_int(b, cX4, x, 1);
191 col = x - ((f + row - 1) * row >> 1);
192 }
193
194 // f: the number of elements in the first row of the bottom trapezoid.
195 // x: the index of the target coordinates ordered by row and then column.
196 //
197 // View the triu as a top rectangle stacked on a bottom trapezoid, where the
198 // trapezoid is upside down. Assume x corresponds to the coordinate (row, col)
199 // in the bottom trapezoid, where the row and the col start from 0, then we
200 // have:
201 //
202 // (f + f - row + 1) * row / 2 <= x [1]
203 // (f + f - row) * (row + 1) / 2 > x [2]
204 //
205 // Therefore, row is the maximum integer satisfying the following inequality:
206 //
207 // (-row + 2f + 1)row <= 2x
208 // row^2 - (2f+1)row + 2x >= 0. [3]
209 //
210 // Based on inequality [3], we have the following coefficients for formula of
211 // root:
212 // a = 1
213 // b = -1 - 2f
214 // c = 2x
215 // There are two roots, and we should use the largest integer that does not
216 // exceed the root on the left. Intuitively, it is because:
217 // i) the valid solution range of row is outside of the two roots, as it is <
218 // > 0;
219 // ii) as we count in more rows, the total # of elements should always
220 // increase, hence so does the left-hand side row^2 - (2f+1)row + 2x.
221 // Therefore, the valid range of row lies to the left of the smaller root
222 // on the left.
223 // Full proof can be derived from inequality [2]. So, we calculate the result
224 // coordinate as:
225 //
226 // row = floor((-b - sqrt(b^2 - 4c)) / 2)
227 // col = x - (f + f - row + 1) * row / 2
228 __device__
get_coordinate_in_triu_trapezoid(int64_t f,int64_t x,int64_t & row,int64_t & col)229 inline void get_coordinate_in_triu_trapezoid(
230 int64_t f, int64_t x, int64_t & row, int64_t & col) {
231 f <<= 1; // all statements use 2f, so only calculate it once here.
232 auto b = -1 - f;
233 auto cX4 = x << 3; // 4 * c = 4 * (2x) = 8x;
234 row = resolve_root_int(b, cX4, x, -1);
235 col = x - ((f - row + 1) * row >> 1) + row;
236 }
237
238 } // namespace
239
240 template <typename scalar_t>
241 __global__
242 #if defined(USE_ROCM)
243 C10_LAUNCH_BOUNDS_1(512)
244 #endif
tril_indices_kernel(scalar_t * tensor,int64_t row_offset,int64_t m_first_row,int64_t col,int64_t trapezoid_size,int64_t tril_size)245 void tril_indices_kernel(scalar_t * tensor,
246 int64_t row_offset,
247 int64_t m_first_row,
248 int64_t col,
249 int64_t trapezoid_size,
250 int64_t tril_size) {
251 int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x;
252
253 if (linear_index < tril_size) {
254 int64_t r, c;
255 if (linear_index < trapezoid_size) {
256 // the coordinate is within the top trapezoid
257 get_coordinate_in_tril_trapezoid(m_first_row, linear_index, r, c);
258 } else {
259 // the coordinate falls in the bottom rectangle
260 auto surplus = linear_index - trapezoid_size;
261 // add the height of trapezoid: m_last_row (col) - m_first_row + 1
262 r = surplus / col + col - m_first_row + 1;
263 c = surplus % col;
264 }
265 r += row_offset;
266
267 tensor[linear_index] = r;
268 tensor[linear_index + tril_size] = c;
269 }
270 }
271
272 // Some Large test cases for the fallback binary search path is disabled by
273 // default to speed up CI tests and to avoid OOM error. When modifying the
274 // implementation, please enable them in test/test_cuda.py and make sure they
275 // pass on your local server.
tril_indices_cuda(int64_t row,int64_t col,int64_t offset,std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt)276 Tensor tril_indices_cuda(
277 int64_t row, int64_t col, int64_t offset, std::optional<ScalarType> dtype_opt,
278 std::optional<Layout> layout_opt, std::optional<Device> device_opt, std::optional<bool> pin_memory_opt) {
279 check_args(row, col, layout_opt);
280
281 auto tril_size = get_tril_size(row, col, offset);
282 auto tensor = empty_cuda({2, tril_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
283
284 if (tril_size > 0) {
285 auto m_first_row = offset > 0 ?
286 std::min<int64_t>(col, 1 + offset) : // upper bounded by col
287 row + offset > 0; // either 0 or 1
288 auto trapezoid_row_offset = std::max<int64_t>(0, -offset);
289 auto rectangle_row_offset = trapezoid_row_offset + col - m_first_row + 1;
290 int64_t rectangle_size = 0;
291 if (rectangle_row_offset < row) {
292 rectangle_size = (row - rectangle_row_offset) * col;
293 }
294
295 dim3 dim_block = cuda::getApplyBlock();
296 dim3 dim_grid;
297 // using tril_size instead of tensor.numel(), as each thread takes care of
298 // two elements in the tensor.
299 TORCH_CHECK(
300 cuda::getApplyGrid(tril_size, dim_grid, tensor.get_device()),
301 "unable to get dim grid");
302
303 AT_DISPATCH_INDEX_TYPES(tensor.scalar_type(), "tril_indices_cuda", [&] {
304 tril_indices_kernel<<<
305 dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
306 tensor.mutable_data_ptr<index_t>(),
307 trapezoid_row_offset,
308 m_first_row,
309 col,
310 tril_size - rectangle_size,
311 tril_size);
312 C10_CUDA_KERNEL_LAUNCH_CHECK();
313 });
314 }
315
316 return tensor;
317 }
318
319 template <typename scalar_t>
320 __global__
triu_indices_kernel(scalar_t * tensor,int64_t col_offset,int64_t m_first_row,int64_t col,int64_t rectangle_size,int64_t triu_size)321 void triu_indices_kernel(scalar_t * tensor,
322 int64_t col_offset,
323 int64_t m_first_row,
324 int64_t col,
325 int64_t rectangle_size,
326 int64_t triu_size) {
327 int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x;
328
329 if (linear_index < triu_size) {
330 int64_t r, c;
331 if (linear_index < rectangle_size) {
332 // the coordinate is within the top rectangle
333 r = linear_index / col;
334 c = linear_index % col;
335 } else {
336 // the coordinate falls in the bottom trapezoid
337 get_coordinate_in_triu_trapezoid(
338 m_first_row, linear_index - rectangle_size, r, c);
339 r += rectangle_size / col;
340 }
341
342 c += col_offset;
343 tensor[linear_index] = r;
344 tensor[linear_index + triu_size] = c;
345 }
346 }
347
348 // Some Large test cases for the fallback binary search path is disabled by
349 // default to speed up CI tests and to avoid OOM error. When modifying the
350 // implementation, please enable them in test/test_cuda.py and make sure they
351 // pass on your local server.
triu_indices_cuda(int64_t row,int64_t col,int64_t offset,std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt)352 Tensor triu_indices_cuda(
353 int64_t row, int64_t col, int64_t offset, std::optional<ScalarType> dtype_opt,
354 std::optional<Layout> layout_opt, std::optional<Device> device_opt, std::optional<bool> pin_memory_opt) {
355 check_args(row, col, layout_opt);
356
357 auto triu_size = row * col - get_tril_size(row, col, offset - 1);
358 auto tensor = empty_cuda({2, triu_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
359
360 if (triu_size > 0) {
361 // # of triu elements in the first row
362 auto m_first_row = offset > 0 ?
363 std::max<int64_t>(col - offset, 0) : // upper bounded by col
364 col;
365
366 // size of the top rectangle
367 int64_t rectangle_size = 0;
368 if (offset < 0) {
369 rectangle_size = std::min<int64_t>(row, -offset) * col;
370 }
371
372 dim3 dim_block = cuda::getApplyBlock();
373 dim3 dim_grid;
374
375 // using triu_size instead of tensor.numel(), as each thread takes care of
376 // two elements in the tensor.
377 TORCH_CHECK(
378 cuda::getApplyGrid(triu_size, dim_grid, tensor.get_device()),
379 "unable to get dim grid");
380
381 AT_DISPATCH_INDEX_TYPES(tensor.scalar_type(), "triu_indices_cuda", [&] {
382 triu_indices_kernel<<<
383 dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
384 tensor.mutable_data_ptr<index_t>(),
385 std::max<int64_t>(0, offset),
386 m_first_row,
387 col,
388 rectangle_size,
389 triu_size);
390 C10_CUDA_KERNEL_LAUNCH_CHECK();
391 });
392 }
393
394 return tensor;
395 }
396
397 } // namespace at::native
398