1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/RangeFactories.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/TensorIterator.h>
7 #include <c10/util/irange.h>
8 #include <cmath>
9 #include <limits>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/linspace.h>
16 #include <ATen/ops/logspace.h>
17 #include <ATen/ops/arange_native.h>
18 #include <ATen/ops/linspace_native.h>
19 #include <ATen/ops/logspace_native.h>
20 #include <ATen/ops/range_native.h>
21 #endif
22
23 namespace at::native {
24
linspace_out(const Tensor & start,const Tensor & end,int64_t steps,Tensor & result)25 Tensor& linspace_out(const Tensor& start, const Tensor& end, int64_t steps, Tensor& result) {
26 TORCH_CHECK(start.dim() == 0 && end.dim() == 0, "linspace only supports 0-dimensional start and end tensors, "
27 "but got start with ", start.dim(), " dimension(s) and end with ", end.dim()," dimension(s).");
28 return at::linspace_out(result, start.item(), end.item(), steps);
29 }
30
linspace_out(const Tensor & start,const Scalar & end,int64_t steps,Tensor & result)31 Tensor& linspace_out(const Tensor& start, const Scalar& end, int64_t steps, Tensor& result) {
32 TORCH_CHECK(start.dim() == 0, "linspace only supports 0-dimensional start and end tensors, "
33 "but got start with ", start.dim(), " dimension(s).");
34 return at::linspace_out(result, start.item(), end, steps);
35 }
36
linspace_out(const Scalar & start,const Tensor & end,int64_t steps,Tensor & result)37 Tensor& linspace_out(const Scalar& start, const Tensor& end, int64_t steps, Tensor& result) {
38 TORCH_CHECK(end.dim() == 0, "linspace only supports 0-dimensional start and end tensors, "
39 "but got end with ", end.dim()," dimension(s).");
40 return at::linspace_out(result, start, end.item(), steps);
41 }
42
linspace_out(const Scalar & start,const Scalar & end,int64_t steps,Tensor & result)43 Tensor& linspace_out(const Scalar& start, const Scalar& end, int64_t steps, Tensor& result) {
44 TORCH_CHECK(steps >= 0, "number of steps must be non-negative");
45 if (result.numel() != steps) {
46 result.resize_({steps});
47 }
48
49 if (result.device() == kMeta) {
50 return result;
51 }
52
53 if (steps == 0) {
54 // skip
55 } else if (steps == 1) {
56 result.fill_(start);
57 } else {
58 Tensor r = result.is_contiguous() ? result : result.contiguous();
59 auto iter = TensorIterator::borrowing_nullary_op(r);
60 linspace_stub(iter.device_type(), iter, start, end, steps);
61 if (!result.is_contiguous()) {
62 result.copy_(r);
63 }
64 }
65
66 return result;
67 }
68
logspace_out(const Tensor & start,const Tensor & end,int64_t steps,double base,Tensor & result)69 Tensor& logspace_out(const Tensor& start, const Tensor& end, int64_t steps, double base, Tensor& result) {
70 TORCH_CHECK(start.dim() == 0 && end.dim() == 0, "logspace only supports 0-dimensional start and end tensors, "
71 "but got start with ", start.dim(), " dimension(s) and end with ", end.dim()," dimension(s).");
72 return at::logspace_out(result, start.item(), end.item(), steps, base);
73 }
74
logspace_out(const Tensor & start,const Scalar & end,int64_t steps,double base,Tensor & result)75 Tensor& logspace_out(const Tensor& start, const Scalar& end, int64_t steps, double base, Tensor& result) {
76 TORCH_CHECK(start.dim() == 0, "logspace only supports 0-dimensional start and end tensors, "
77 "but got start with ", start.dim(), " dimension(s).");
78 return at::logspace_out(result, start.item(), end, steps, base);
79 }
80
logspace_out(const Scalar & start,const Tensor & end,int64_t steps,double base,Tensor & result)81 Tensor& logspace_out(const Scalar& start, const Tensor& end, int64_t steps, double base, Tensor& result) {
82 TORCH_CHECK(end.dim() == 0, "logspace only supports 0-dimensional start and end tensors, "
83 "but got end with ", end.dim()," dimension(s).");
84 return at::logspace_out(result, start, end.item(), steps, base);
85 }
86
logspace_out(const Scalar & start,const Scalar & end,int64_t steps,double base,Tensor & result)87 Tensor& logspace_out(const Scalar& start, const Scalar& end, int64_t steps, double base, Tensor& result) {
88 TORCH_CHECK(steps >= 0, "number of steps must be non-negative");
89
90 if (result.numel() != steps) {
91 result.resize_({steps});
92 }
93
94 if (result.device() == kMeta) {
95 return result;
96 }
97
98 Tensor r = result.is_contiguous() ? result : result.contiguous();
99
100 if (steps == 0) {
101 // skip
102 } else if (steps == 1) {
103 if (isComplexType(r.scalar_type())){
104 r.fill_(std::pow(base, start.to<c10::complex<double>>()));
105 } else {
106 r.fill_(std::pow(base, start.to<double>()));
107 }
108 } else if (isComplexType(r.scalar_type())) {
109 AT_DISPATCH_COMPLEX_TYPES(r.scalar_type(), "logspace_cpu", [&]() {
110 scalar_t scalar_base = static_cast<scalar_t>(base);
111 scalar_t scalar_start = start.to<scalar_t>();
112 scalar_t scalar_end = end.to<scalar_t>();
113 scalar_t *data_ptr = r.data_ptr<scalar_t>();
114 scalar_t step = (scalar_end - scalar_start) / static_cast<scalar_t>(steps - 1);
115 const int64_t halfway = steps / 2;
116 at::parallel_for(0, steps, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
117 scalar_t is = static_cast<scalar_t>(p_begin);
118 for (int64_t i = p_begin; i < p_end; ++i, is+=1) { //std::complex does not support ++operator
119 if (i < halfway) {
120 data_ptr[i] = std::pow(scalar_base, scalar_start + step*is);
121 } else {
122 data_ptr[i] = std::pow(scalar_base, scalar_end - (step * static_cast<scalar_t>(steps - i - 1)));
123 }
124 }
125 });
126 });
127 } else {
128 AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, r.scalar_type(), "logspace_cpu", [&]() {
129 double scalar_base = static_cast<double>(base); // will be autopromoted anyway
130 scalar_t scalar_start = start.to<scalar_t>();
131 scalar_t scalar_end = end.to<scalar_t>();
132 scalar_t *data_ptr = r.data_ptr<scalar_t>();
133 double step = static_cast<double>(scalar_end - scalar_start) / (steps - 1);
134 const int64_t halfway = steps / 2;
135 at::parallel_for(0, steps, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
136 for (const auto i : c10::irange(p_begin, p_end)) {
137 if (i < halfway) {
138 data_ptr[i] = std::pow(scalar_base, scalar_start + step*i);
139 } else {
140 data_ptr[i] = std::pow(scalar_base, scalar_end - step * (steps - i - 1));
141 }
142 }
143 });
144 });
145 }
146
147 if (!result.is_contiguous()) {
148 result.copy_(r);
149 }
150 return result;
151 }
152
range_out(const Scalar & start,const Scalar & end,const Scalar & step,Tensor & result)153 Tensor& range_out(const Scalar& start, const Scalar& end, const Scalar& step, Tensor& result) {
154 AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, result.scalar_type(), "range_cpu", [&]() {
155 using accscalar_t = at::acc_type<scalar_t, false>;
156 auto xstart = start.to<accscalar_t>();
157 auto xend = end.to<accscalar_t>();
158 auto xstep = step.to<accscalar_t>();
159
160 TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
161 TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
162 std::isfinite(static_cast<double>(xend)),
163 "unsupported range: ", xstart, " -> ", xend);
164 TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
165 "upper bound and lower bound inconsistent with step sign");
166 int64_t size = static_cast<int64_t>(((xend - xstart) / xstep) + 1);
167 if (result.numel() != size) {
168 result.resize_({size});
169 }
170
171 if (result.device() == kMeta) {
172 return;
173 }
174
175 Tensor r = result.is_contiguous() ? result : result.contiguous();
176 scalar_t *data_ptr = r.data_ptr<scalar_t>();
177
178 at::parallel_for(0, size, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
179 accscalar_t is = p_begin;
180 for (int64_t i = p_begin; i < p_end; ++i, ++is) {
181 data_ptr[i] = xstart + is * xstep;
182 }
183 });
184 if (!result.is_contiguous()) {
185 result.copy_(r);
186 }
187 });
188
189 return result;
190 }
191
range_out_no_step(const Scalar & start,const Scalar & end,Tensor & result)192 Tensor& range_out_no_step(const Scalar& start, const Scalar& end, Tensor& result) {
193 return range_out(start, end, /*step = */ 1, result);
194 }
195
arange_out(const Scalar & start,const Scalar & end,const Scalar & step,Tensor & result)196 Tensor& arange_out(const Scalar& start, const Scalar& end, const Scalar& step, Tensor& result) {
197 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, result.scalar_type(), "arange_cpu", [&]() {
198 using accscalar_t = at::acc_type<scalar_t, false>;
199 auto xstart = start.to<accscalar_t>();
200 auto xend = end.to<accscalar_t>();
201 auto xstep = step.to<accscalar_t>();
202
203 TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
204 TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
205 std::isfinite(static_cast<double>(xend)),
206 "unsupported range: ", xstart, " -> ", xend);
207 TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
208 "upper bound and larger bound inconsistent with step sign");
209
210 // we use double precision for (start - end) / step
211 // to compute size_d for consistency across devices.
212 // The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t,
213 // but double on cpu for the same,
214 // and the effective output size starts differing on CPU vs GPU because of precision issues, which
215 // we dont want.
216 // the corner-case we do want to take into account is int64_t, which has higher precision than double
217 double size_d;
218 if constexpr (std::is_same_v<scalar_t, int64_t>) {
219 int64_t sgn = (xstep > 0) - (xstep < 0);
220 size_d = std::ceil((xend - xstart + xstep - sgn) / xstep);
221 } else {
222 size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>())
223 / step.to<double>());
224 }
225
226 TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
227 "invalid size, possible overflow?");
228
229 int64_t size = static_cast<int64_t>(size_d);
230 int64_t numel = result.numel();
231
232 if (numel != size) {
233 if(numel > 0){
234 TORCH_WARN("The number of elements in the out tensor of shape ", result.sizes(),
235 " is ", numel, " which does not match the computed number of elements ", size,
236 ". Note that this may occur as a result of rounding error. "
237 "The out tensor will be resized to a tensor of shape (", size, ",).");
238 }
239 result.resize_({size});
240 }
241
242 if (result.device() == kMeta) {
243 return;
244 }
245
246 Tensor r = result.is_contiguous() ? result : result.contiguous();
247 auto iter = TensorIterator::borrowing_nullary_op(r);
248 arange_stub(iter.device_type(), iter, start, size, step);
249 if (!result.is_contiguous()) {
250 result.copy_(r);
251 }
252 });
253
254 return result;
255 }
256
257 DEFINE_DISPATCH(arange_stub);
258 DEFINE_DISPATCH(linspace_stub);
259
260 } // namespace at::native
261