xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/RangeFactoriesKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/RangeFactories.h>
3 #include <cmath>
4 #include <ATen/Config.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/native/DispatchStub.h>
7 
8 #include <ATen/AccumulateType.h>
9 #include <ATen/cpu/vec/vec.h>
10 #include <ATen/native/TensorIterator.h>
11 #include <ATen/Parallel.h>
12 #include <ATen/native/cpu/Loops.h>
13 
14 #include <c10/core/Scalar.h>
15 
16 namespace at::native {
17 namespace {
18 
19 using namespace vec;
20 
arange_kernel(TensorIterator & iter,const Scalar & scalar_start,const Scalar & scalar_steps,const Scalar & scalar_step)21 static void arange_kernel(TensorIterator& iter, const Scalar& scalar_start, const Scalar& scalar_steps, const Scalar& scalar_step) {
22   AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "arange_cpu", [&]() {
23     using accscalar_t = at::acc_type<scalar_t, false>;
24     auto start = scalar_start.to<accscalar_t>();
25     auto steps = scalar_steps.to<accscalar_t>();
26     auto step = scalar_step.to<accscalar_t>();
27     at::parallel_for(0, steps, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
28       int64_t idx(p_begin);
29       TensorIterator it(iter);
30       cpu_serial_kernel_vec(
31           it,
32           [start, step, &idx]() -> scalar_t {
33             return start + step * (idx++);
34           },
35           [start, step, &idx]() -> Vectorized<scalar_t> {
36             Vectorized<scalar_t> res;
37             res = Vectorized<scalar_t>::arange(start + step * idx, step);
38             idx += Vectorized<scalar_t>::size();
39             return res;
40           }, {p_begin, p_end});
41     });
42   });
43 }
44 
linspace_kernel(TensorIterator & iter,const Scalar & scalar_start,const Scalar & scalar_end,int64_t steps)45 static void linspace_kernel(TensorIterator& iter, const Scalar& scalar_start, const Scalar& scalar_end, int64_t steps) {
46   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.dtype(), "linspace_cpu", [&]() {
47     // step should be of double type for all integral types
48     using step_t = std::conditional_t<std::is_integral<scalar_t>::value, double, scalar_t>;
49     const scalar_t start = scalar_start.to<scalar_t>();
50     const scalar_t end = scalar_end.to<scalar_t>();
51     // Cast `end` and `start` to `step_t`, since range can be larger than scalar_t for integral types
52     const step_t step = (static_cast<step_t>(end) - static_cast<step_t>(start)) / (steps - 1);
53     int64_t halfway = steps / 2;
54     at::parallel_for(0, steps, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
55       int64_t idx(p_begin);
56       TensorIterator it(iter);
57       // Remove vectorization implementation, due to the precision issue between integer and double.
58       // Will not harm the performance.
59       cpu_serial_kernel(
60           it,
61           [start, end, step, halfway, steps, &idx]() -> scalar_t {
62             if (idx < halfway) {
63               return start + step * (idx++);
64             } else {
65               return end - step * (steps - (idx++) - 1);
66             }
67           }, {p_begin, p_end});
68     });
69   });
70 }
71 
72 } // anonymous namespace
73 
74 REGISTER_DISPATCH(arange_stub, &arange_kernel);
75 REGISTER_DISPATCH(linspace_stub, &linspace_kernel);
76 
77 } // namespace at::native
78