1 #include <ATen/CPUGeneratorImpl.h>
2 // TODO(antoniojkim): Add CUDA support for make_generator_for_device
3 // #ifdef USE_CUDA
4 // #include <ATen/cuda/CUDAGeneratorImpl.h>
5 // #endif
6 #ifdef USE_MPS
7 #include <ATen/mps/MPSGeneratorImpl.h>
8 #endif
9
10 #include <torch/csrc/jit/runtime/register_ops_utils.h>
11 #include <torch/csrc/jit/runtime/slice_indices_adjust.h>
12 #include <limits>
13
14 #include <c10/util/irange.h>
15
16 namespace torch::jit {
17
18 template <>
make_result_list(const TypePtr & elemType)19 c10::impl::GenericList make_result_list<IValue>(const TypePtr& elemType) {
20 return c10::impl::GenericList(elemType);
21 }
22
23 template <>
listIndex(Stack & stack)24 void listIndex<at::Tensor>(Stack& stack) {
25 at::Tensor elem = pop(stack).to<at::Tensor>();
26 c10::List<at::Tensor> list = pop(stack).to<c10::List<at::Tensor>>();
27
28 auto pos =
29 std::find_if(list.begin(), list.end(), [elem](const at::Tensor& b) {
30 const auto cmp_result = elem.eq(b);
31 return at::native::is_nonzero(cmp_result);
32 });
33
34 if (pos != list.end()) {
35 push(stack, static_cast<int64_t>(std::distance(list.begin(), pos)));
36 } else {
37 AT_ERROR("'", elem, "' is not in list");
38 }
39 }
40
41 template <>
listCount(Stack & stack)42 void listCount<at::Tensor>(Stack& stack) {
43 at::Tensor elem = pop(stack).to<at::Tensor>();
44 c10::List<at::Tensor> list = pop(stack).to<c10::List<at::Tensor>>();
45
46 const int64_t count =
47 std::count_if(list.begin(), list.end(), [&](const at::Tensor& b) {
48 const auto cmp_result = elem.eq(b);
49 return at::native::is_nonzero(cmp_result);
50 });
51 push(stack, count);
52 }
53
54 template <>
listEq(Stack & stack)55 void listEq<at::Tensor>(Stack& stack) {
56 c10::List<at::Tensor> b = pop(stack).to<c10::List<at::Tensor>>();
57 c10::List<at::Tensor> a = pop(stack).to<c10::List<at::Tensor>>();
58 push(stack, tensor_list_equal(a, b));
59 }
60
61 template <>
listNe(Stack & stack)62 void listNe<at::Tensor>(Stack& stack) {
63 c10::List<at::Tensor> b = pop(stack).to<c10::List<at::Tensor>>();
64 c10::List<at::Tensor> a = pop(stack).to<c10::List<at::Tensor>>();
65 push(stack, !tensor_list_equal(a, b));
66 }
67
68 template <>
listSort(Stack & stack)69 void listSort<at::Tensor>(Stack& stack) {
70 bool reverse = pop(stack).toBool();
71 c10::List<at::Tensor> list = pop(stack).toTensorList();
72 std::sort(
73 list.begin(),
74 list.end(),
75 [reverse](const at::Tensor& a, const at::Tensor& b) -> bool {
76 // "strict weak ordering" issue - see other sort
77 if (a.getIntrusivePtr() == b.getIntrusivePtr()) {
78 return false;
79 }
80 return (at::native::is_nonzero(a.lt(b))) ^ reverse;
81 });
82 }
83
84 template <>
listCopyAndSort(Stack & stack)85 void listCopyAndSort<at::Tensor>(Stack& stack) {
86 c10::List<at::Tensor> list = pop(stack).toTensorList();
87 auto list_copied = list.copy();
88 std::sort(
89 list_copied.begin(),
90 list_copied.end(),
91 [](const at::Tensor& a, const at::Tensor& b) {
92 return at::native::is_nonzero(a.lt(b));
93 });
94 push(stack, list_copied);
95 }
96
97 template <>
listRemove(Stack & stack)98 void listRemove<at::Tensor>(Stack& stack) {
99 at::Tensor elem = pop(stack).to<at::Tensor>();
100 c10::List<at::Tensor> list = pop(stack).to<c10::List<at::Tensor>>();
101
102 auto pos = std::find_if(list.begin(), list.end(), [&](const at::Tensor& b) {
103 const auto cmp_result = elem.eq(b);
104 return at::native::is_nonzero(cmp_result);
105 });
106
107 if (pos != list.end()) {
108 list.erase(pos);
109 } else {
110 AT_ERROR("list.remove(x): x not in list");
111 }
112 }
113
checkImplicitTensorToNum(const at::Tensor & t,bool toInt)114 void checkImplicitTensorToNum(const at::Tensor& t, bool toInt) {
115 if (t.requires_grad()) {
116 throw std::runtime_error(
117 "Cannot input a tensor that requires grad as a scalar argument");
118 }
119 if (!t.sizes().empty()) {
120 throw std::runtime_error(
121 "Cannot input a tensor of dimension other than 0 as a scalar argument");
122 }
123 if (toInt && !isIntegralType(t.scalar_type(), /*includeBool=*/false)) {
124 std::stringstream ss;
125 ss << "Cannot input a tensor of type " << t.scalar_type()
126 << " as an integral argument";
127 throw std::runtime_error(ss.str());
128 }
129 }
130
checkDoubleInRange(double a)131 void checkDoubleInRange(double a) {
132 if (std::isnan(a) || std::isinf(a) ||
133 a > double(std::numeric_limits<int64_t>::max()) ||
134 a < double(std::numeric_limits<int64_t>::min())) {
135 throw c10::Error(
136 "Cannot convert float " + std::to_string(a) + " to integer");
137 return;
138 }
139 }
140
partProduct(int n,int m)141 int64_t partProduct(int n, int m) {
142 if (m <= (n + 1))
143 return (int64_t)n;
144 if (m == (n + 2))
145 return (int64_t)n * m;
146 auto k = n + (m - n) / 2; // Overflow-safe midpoint
147 if ((k & 1) != 1)
148 k = k - 1;
149 return partProduct(n, k) * partProduct(k + 2, m);
150 }
151
loop(int n,int64_t & p,int64_t & r)152 void loop(int n, int64_t& p, int64_t& r) {
153 if (n <= 2)
154 return;
155 loop(n / 2, p, r);
156 p = p * partProduct(n / 2 + 1 + ((n / 2) & 1), n - 1 + (n & 1));
157 r = r * p;
158 }
159
nminussumofbits(int v)160 int nminussumofbits(int v) {
161 long w = (long)v;
162 w -= (0xaaaaaaaa & w) >> 1; // NOLINT
163 w = (w & 0x33333333) + ((w >> 2) & 0x33333333); // NOLINT
164 w = (w + (w >> 4)) & 0x0f0f0f0f; // NOLINT
165 w += w >> 8; // NOLINT
166 w += w >> 16; // NOLINT
167 return v - (int)(w & 0xff); // NOLINT
168 }
169
factorial(int n)170 int64_t factorial(int n) {
171 if (n < 0) {
172 throw std::runtime_error("factorial() not defined for negative values");
173 }
174 int64_t p = 1, r = 1;
175 loop(n, p, r);
176 return r << nminussumofbits(n);
177 }
178
degrees(double x)179 double degrees(double x) {
180 return x * radToDeg;
181 }
radians(double x)182 double radians(double x) {
183 return x * degToRad;
184 }
185
listAppend(Stack & stack)186 void listAppend(Stack& stack) {
187 IValue el = pop(stack).to<IValue>();
188 c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
189
190 list.push_back(std::move(el));
191 push(stack, std::move(list));
192 }
193
listReverse(Stack & stack)194 void listReverse(Stack& stack) {
195 c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
196
197 std::reverse(list.begin(), list.end());
198 }
199
listPopImpl(Stack & stack,const char * empty_message)200 void listPopImpl(Stack& stack, const char* empty_message) {
201 int64_t idx = pop(stack).to<int64_t>();
202 c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
203
204 const int64_t list_size = list.size();
205 const int64_t normalized_idx = normalizeIndex(idx, list_size);
206
207 if (list_size == 0) {
208 AT_ERROR(empty_message);
209 }
210
211 push(stack, getItem(list, idx));
212 list.erase(list.begin() + normalized_idx);
213 }
214
listPop(Stack & stack)215 void listPop(Stack& stack) {
216 return listPopImpl(stack, "pop from empty list");
217 }
218
listClear(Stack & stack)219 void listClear(Stack& stack) {
220 c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
221
222 list.clear();
223 }
224
listDelete(Stack & stack)225 void listDelete(Stack& stack) {
226 listPopImpl(stack, "pop index out of range");
227 pop(stack);
228 }
229
listInsert(Stack & stack)230 void listInsert(Stack& stack) {
231 IValue elem = pop(stack).to<IValue>();
232 int64_t idx = pop(stack).to<int64_t>();
233 c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
234
235 const int64_t list_size = list.size();
236 const int64_t normalized_idx = normalizeIndex(idx, list_size);
237
238 if (normalized_idx < 0 || normalized_idx >= list_size) {
239 if (normalized_idx < 0) {
240 list.insert(list.begin(), elem);
241 } else {
242 list.push_back(elem);
243 }
244 } else {
245 list.insert(list.begin() + normalized_idx, elem);
246 }
247 }
248
listExtend(Stack & stack)249 void listExtend(Stack& stack) {
250 c10::List<IValue> b = pop(stack).to<c10::List<IValue>>();
251 c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
252
253 a.reserve(a.size() + b.size());
254 for (const auto i : c10::irange(b.size())) {
255 a.push_back(b.get(i));
256 }
257 }
258
listCopy(Stack & stack)259 void listCopy(Stack& stack) {
260 c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
261 push(stack, list.copy());
262 }
263
listSelect(Stack & stack)264 void listSelect(Stack& stack) {
265 int64_t idx = pop(stack).to<int64_t>();
266 c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
267
268 push(stack, getItem(list, idx));
269 }
270
listLen(Stack & stack)271 void listLen(Stack& stack) {
272 c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
273
274 const int64_t size = a.size();
275 push(stack, size);
276 }
277
listList(Stack & stack)278 void listList(Stack& stack) {
279 c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
280 push(stack, a.copy());
281 }
282
listAdd(Stack & stack)283 void listAdd(Stack& stack) {
284 c10::List<IValue> b = pop(stack).to<c10::List<IValue>>();
285 c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
286
287 c10::List<IValue> ret = make_result_list<IValue>(a.elementType());
288
289 if (a.use_count() == 1) {
290 ret = a;
291 } else {
292 ret = a.copy();
293 }
294
295 ret.append(b);
296
297 push(stack, std::move(ret));
298 }
299
listInplaceAdd(Stack & stack)300 void listInplaceAdd(Stack& stack) {
301 c10::List<IValue> b = pop(stack).to<c10::List<IValue>>();
302 c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
303 a.append(b);
304 push(stack, std::move(a));
305 }
306
listMulIntLeftInPlace(Stack & stack)307 void listMulIntLeftInPlace(Stack& stack) {
308 int64_t n = pop(stack).to<int64_t>();
309 c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
310 if (n <= 0) {
311 list.clear();
312 } else if (n > 1) {
313 size_t list_size = list.size();
314 for (const auto i : c10::irange(1, n)) {
315 (void)i; // Suppress unused variable warning
316 for (const auto j : c10::irange(list_size)) {
317 list.push_back(list.get(j));
318 }
319 }
320 }
321
322 push(stack, std::move(list));
323 }
324
listMulIntLeft(Stack & stack)325 void listMulIntLeft(Stack& stack) {
326 int64_t n = pop(stack).to<int64_t>();
327 c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
328
329 c10::List<IValue> ret = make_result_list<IValue>(list.elementType());
330 const auto size = list.size() * n;
331 ret.reserve(size);
332
333 for (const auto i : c10::irange(n)) {
334 (void)i; // Suppress unused variable warning
335 for (IValue e : list) {
336 ret.push_back(std::move(e));
337 }
338 }
339
340 push(stack, std::move(ret));
341 }
342
listMulIntRight(Stack & stack)343 void listMulIntRight(Stack& stack) {
344 c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
345 int64_t n = pop(stack).to<int64_t>();
346
347 c10::List<IValue> ret = make_result_list<IValue>(list.elementType());
348 const auto size = list.size() * n;
349 ret.reserve(size);
350
351 for (const auto i : c10::irange(n)) {
352 (void)i; // Suppress unused variable warning
353 for (IValue e : list) {
354 ret.push_back(std::move(e));
355 }
356 }
357
358 push(stack, std::move(ret));
359 }
360
listSlice(Stack & stack)361 void listSlice(Stack& stack) {
362 auto step_val = pop(stack);
363 auto end_val = pop(stack);
364 auto start_val = pop(stack);
365
366 // By default, both start and end will be None.
367 // By python convention, they will be translated into
368 // INT64_MAX. If the step size is not given, it will be 1.
369 int64_t step = step_val.isInt() ? step_val.to<int64_t>() : 1;
370 int64_t end = end_val.isInt() ? end_val.to<int64_t>()
371 : std::numeric_limits<int64_t>::max();
372 int64_t start = start_val.isInt() ? start_val.to<int64_t>()
373 : std::numeric_limits<int64_t>::max();
374
375 c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
376
377 const int64_t list_size = list.size();
378
379 c10::List<IValue> sliced_list = make_result_list<IValue>(list.elementType());
380 const int64_t num_values =
381 slice_indices_adjust(list_size, &start, &end, step);
382 sliced_list.reserve(num_values);
383
384 int i = start;
385 for (const auto j : c10::irange(num_values)) {
386 (void)j; // Suppress unused variable warning
387 sliced_list.push_back(list.get(i));
388 i += step;
389 }
390
391 push(stack, std::move(sliced_list));
392 }
393
listSetItem(Stack & stack)394 void listSetItem(Stack& stack) {
395 IValue value = pop(stack).to<IValue>();
396 int64_t idx = pop(stack).to<int64_t>();
397 c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
398
399 setItem(list, idx, std::move(value));
400
401 push(stack, std::move(list));
402 }
403
make_generator_for_device(c10::Device device,std::optional<int64_t> seed)404 at::Generator make_generator_for_device(
405 c10::Device device,
406 std::optional<int64_t> seed) {
407 if (device.is_cpu()) {
408 if (seed.has_value()) {
409 return at::detail::createCPUGenerator(seed.value());
410 } else {
411 return at::detail::createCPUGenerator();
412 }
413 // TODO(antoniojkim): Enable support for CUDA device
414 // Implementation below causes issues during rocm build
415 // #ifdef USE_CUDA
416 // } else if (device.is_cuda()) {
417 // auto generator = at::cuda::detail::createCUDAGenerator(device.index());
418 // if (seed.has_value()) {
419 // generator.set_current_seed(seed.value());
420 // }
421 // return generator;
422 // #endif
423 #ifdef USE_MPS
424 } else if (device.is_mps()) {
425 if (seed.has_value()) {
426 return at::mps::detail::createMPSGenerator(seed.value());
427 } else {
428 return at::mps::detail::createMPSGenerator();
429 }
430 #endif
431 } else {
432 AT_ERROR(
433 "Unsupported device for at::make_generator_for_device found: ",
434 device.str());
435 }
436 }
437
438 } // namespace torch::jit
439