1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/RangeFactories.h>
3 #include <cmath>
4 #include <ATen/Config.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/native/DispatchStub.h>
7
8 #include <ATen/AccumulateType.h>
9 #include <ATen/cpu/vec/vec.h>
10 #include <ATen/native/TensorIterator.h>
11 #include <ATen/Parallel.h>
12 #include <ATen/native/cpu/Loops.h>
13
14 #include <c10/core/Scalar.h>
15
16 namespace at::native {
17 namespace {
18
19 using namespace vec;
20
arange_kernel(TensorIterator & iter,const Scalar & scalar_start,const Scalar & scalar_steps,const Scalar & scalar_step)21 static void arange_kernel(TensorIterator& iter, const Scalar& scalar_start, const Scalar& scalar_steps, const Scalar& scalar_step) {
22 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "arange_cpu", [&]() {
23 using accscalar_t = at::acc_type<scalar_t, false>;
24 auto start = scalar_start.to<accscalar_t>();
25 auto steps = scalar_steps.to<accscalar_t>();
26 auto step = scalar_step.to<accscalar_t>();
27 at::parallel_for(0, steps, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
28 int64_t idx(p_begin);
29 TensorIterator it(iter);
30 cpu_serial_kernel_vec(
31 it,
32 [start, step, &idx]() -> scalar_t {
33 return start + step * (idx++);
34 },
35 [start, step, &idx]() -> Vectorized<scalar_t> {
36 Vectorized<scalar_t> res;
37 res = Vectorized<scalar_t>::arange(start + step * idx, step);
38 idx += Vectorized<scalar_t>::size();
39 return res;
40 }, {p_begin, p_end});
41 });
42 });
43 }
44
linspace_kernel(TensorIterator & iter,const Scalar & scalar_start,const Scalar & scalar_end,int64_t steps)45 static void linspace_kernel(TensorIterator& iter, const Scalar& scalar_start, const Scalar& scalar_end, int64_t steps) {
46 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.dtype(), "linspace_cpu", [&]() {
47 // step should be of double type for all integral types
48 using step_t = std::conditional_t<std::is_integral<scalar_t>::value, double, scalar_t>;
49 const scalar_t start = scalar_start.to<scalar_t>();
50 const scalar_t end = scalar_end.to<scalar_t>();
51 // Cast `end` and `start` to `step_t`, since range can be larger than scalar_t for integral types
52 const step_t step = (static_cast<step_t>(end) - static_cast<step_t>(start)) / (steps - 1);
53 int64_t halfway = steps / 2;
54 at::parallel_for(0, steps, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
55 int64_t idx(p_begin);
56 TensorIterator it(iter);
57 // Remove vectorization implementation, due to the precision issue between integer and double.
58 // Will not harm the performance.
59 cpu_serial_kernel(
60 it,
61 [start, end, step, halfway, steps, &idx]() -> scalar_t {
62 if (idx < halfway) {
63 return start + step * (idx++);
64 } else {
65 return end - step * (steps - (idx++) - 1);
66 }
67 }, {p_begin, p_end});
68 });
69 });
70 }
71
72 } // anonymous namespace
73
74 REGISTER_DISPATCH(arange_stub, &arange_kernel);
75 REGISTER_DISPATCH(linspace_stub, &linspace_kernel);
76
77 } // namespace at::native
78