xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/register_ops_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/CPUGeneratorImpl.h>
2 // TODO(antoniojkim): Add CUDA support for make_generator_for_device
3 // #ifdef USE_CUDA
4 // #include <ATen/cuda/CUDAGeneratorImpl.h>
5 // #endif
6 #ifdef USE_MPS
7 #include <ATen/mps/MPSGeneratorImpl.h>
8 #endif
9 
10 #include <torch/csrc/jit/runtime/register_ops_utils.h>
11 #include <torch/csrc/jit/runtime/slice_indices_adjust.h>
12 #include <limits>
13 
14 #include <c10/util/irange.h>
15 
16 namespace torch::jit {
17 
18 template <>
make_result_list(const TypePtr & elemType)19 c10::impl::GenericList make_result_list<IValue>(const TypePtr& elemType) {
20   return c10::impl::GenericList(elemType);
21 }
22 
23 template <>
listIndex(Stack & stack)24 void listIndex<at::Tensor>(Stack& stack) {
25   at::Tensor elem = pop(stack).to<at::Tensor>();
26   c10::List<at::Tensor> list = pop(stack).to<c10::List<at::Tensor>>();
27 
28   auto pos =
29       std::find_if(list.begin(), list.end(), [elem](const at::Tensor& b) {
30         const auto cmp_result = elem.eq(b);
31         return at::native::is_nonzero(cmp_result);
32       });
33 
34   if (pos != list.end()) {
35     push(stack, static_cast<int64_t>(std::distance(list.begin(), pos)));
36   } else {
37     AT_ERROR("'", elem, "' is not in list");
38   }
39 }
40 
41 template <>
listCount(Stack & stack)42 void listCount<at::Tensor>(Stack& stack) {
43   at::Tensor elem = pop(stack).to<at::Tensor>();
44   c10::List<at::Tensor> list = pop(stack).to<c10::List<at::Tensor>>();
45 
46   const int64_t count =
47       std::count_if(list.begin(), list.end(), [&](const at::Tensor& b) {
48         const auto cmp_result = elem.eq(b);
49         return at::native::is_nonzero(cmp_result);
50       });
51   push(stack, count);
52 }
53 
54 template <>
listEq(Stack & stack)55 void listEq<at::Tensor>(Stack& stack) {
56   c10::List<at::Tensor> b = pop(stack).to<c10::List<at::Tensor>>();
57   c10::List<at::Tensor> a = pop(stack).to<c10::List<at::Tensor>>();
58   push(stack, tensor_list_equal(a, b));
59 }
60 
61 template <>
listNe(Stack & stack)62 void listNe<at::Tensor>(Stack& stack) {
63   c10::List<at::Tensor> b = pop(stack).to<c10::List<at::Tensor>>();
64   c10::List<at::Tensor> a = pop(stack).to<c10::List<at::Tensor>>();
65   push(stack, !tensor_list_equal(a, b));
66 }
67 
68 template <>
listSort(Stack & stack)69 void listSort<at::Tensor>(Stack& stack) {
70   bool reverse = pop(stack).toBool();
71   c10::List<at::Tensor> list = pop(stack).toTensorList();
72   std::sort(
73       list.begin(),
74       list.end(),
75       [reverse](const at::Tensor& a, const at::Tensor& b) -> bool {
76         // "strict weak ordering" issue - see other sort
77         if (a.getIntrusivePtr() == b.getIntrusivePtr()) {
78           return false;
79         }
80         return (at::native::is_nonzero(a.lt(b))) ^ reverse;
81       });
82 }
83 
84 template <>
listCopyAndSort(Stack & stack)85 void listCopyAndSort<at::Tensor>(Stack& stack) {
86   c10::List<at::Tensor> list = pop(stack).toTensorList();
87   auto list_copied = list.copy();
88   std::sort(
89       list_copied.begin(),
90       list_copied.end(),
91       [](const at::Tensor& a, const at::Tensor& b) {
92         return at::native::is_nonzero(a.lt(b));
93       });
94   push(stack, list_copied);
95 }
96 
97 template <>
listRemove(Stack & stack)98 void listRemove<at::Tensor>(Stack& stack) {
99   at::Tensor elem = pop(stack).to<at::Tensor>();
100   c10::List<at::Tensor> list = pop(stack).to<c10::List<at::Tensor>>();
101 
102   auto pos = std::find_if(list.begin(), list.end(), [&](const at::Tensor& b) {
103     const auto cmp_result = elem.eq(b);
104     return at::native::is_nonzero(cmp_result);
105   });
106 
107   if (pos != list.end()) {
108     list.erase(pos);
109   } else {
110     AT_ERROR("list.remove(x): x not in list");
111   }
112 }
113 
checkImplicitTensorToNum(const at::Tensor & t,bool toInt)114 void checkImplicitTensorToNum(const at::Tensor& t, bool toInt) {
115   if (t.requires_grad()) {
116     throw std::runtime_error(
117         "Cannot input a tensor that requires grad as a scalar argument");
118   }
119   if (!t.sizes().empty()) {
120     throw std::runtime_error(
121         "Cannot input a tensor of dimension other than 0 as a scalar argument");
122   }
123   if (toInt && !isIntegralType(t.scalar_type(), /*includeBool=*/false)) {
124     std::stringstream ss;
125     ss << "Cannot input a tensor of type " << t.scalar_type()
126        << " as an integral argument";
127     throw std::runtime_error(ss.str());
128   }
129 }
130 
checkDoubleInRange(double a)131 void checkDoubleInRange(double a) {
132   if (std::isnan(a) || std::isinf(a) ||
133       a > double(std::numeric_limits<int64_t>::max()) ||
134       a < double(std::numeric_limits<int64_t>::min())) {
135     throw c10::Error(
136         "Cannot convert float " + std::to_string(a) + " to integer");
137     return;
138   }
139 }
140 
partProduct(int n,int m)141 int64_t partProduct(int n, int m) {
142   if (m <= (n + 1))
143     return (int64_t)n;
144   if (m == (n + 2))
145     return (int64_t)n * m;
146   auto k = n + (m - n) / 2; // Overflow-safe midpoint
147   if ((k & 1) != 1)
148     k = k - 1;
149   return partProduct(n, k) * partProduct(k + 2, m);
150 }
151 
loop(int n,int64_t & p,int64_t & r)152 void loop(int n, int64_t& p, int64_t& r) {
153   if (n <= 2)
154     return;
155   loop(n / 2, p, r);
156   p = p * partProduct(n / 2 + 1 + ((n / 2) & 1), n - 1 + (n & 1));
157   r = r * p;
158 }
159 
nminussumofbits(int v)160 int nminussumofbits(int v) {
161   long w = (long)v;
162   w -= (0xaaaaaaaa & w) >> 1; // NOLINT
163   w = (w & 0x33333333) + ((w >> 2) & 0x33333333); // NOLINT
164   w = (w + (w >> 4)) & 0x0f0f0f0f; // NOLINT
165   w += w >> 8; // NOLINT
166   w += w >> 16; // NOLINT
167   return v - (int)(w & 0xff); // NOLINT
168 }
169 
factorial(int n)170 int64_t factorial(int n) {
171   if (n < 0) {
172     throw std::runtime_error("factorial() not defined for negative values");
173   }
174   int64_t p = 1, r = 1;
175   loop(n, p, r);
176   return r << nminussumofbits(n);
177 }
178 
degrees(double x)179 double degrees(double x) {
180   return x * radToDeg;
181 }
radians(double x)182 double radians(double x) {
183   return x * degToRad;
184 }
185 
listAppend(Stack & stack)186 void listAppend(Stack& stack) {
187   IValue el = pop(stack).to<IValue>();
188   c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
189 
190   list.push_back(std::move(el));
191   push(stack, std::move(list));
192 }
193 
listReverse(Stack & stack)194 void listReverse(Stack& stack) {
195   c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
196 
197   std::reverse(list.begin(), list.end());
198 }
199 
listPopImpl(Stack & stack,const char * empty_message)200 void listPopImpl(Stack& stack, const char* empty_message) {
201   int64_t idx = pop(stack).to<int64_t>();
202   c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
203 
204   const int64_t list_size = list.size();
205   const int64_t normalized_idx = normalizeIndex(idx, list_size);
206 
207   if (list_size == 0) {
208     AT_ERROR(empty_message);
209   }
210 
211   push(stack, getItem(list, idx));
212   list.erase(list.begin() + normalized_idx);
213 }
214 
listPop(Stack & stack)215 void listPop(Stack& stack) {
216   return listPopImpl(stack, "pop from empty list");
217 }
218 
listClear(Stack & stack)219 void listClear(Stack& stack) {
220   c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
221 
222   list.clear();
223 }
224 
listDelete(Stack & stack)225 void listDelete(Stack& stack) {
226   listPopImpl(stack, "pop index out of range");
227   pop(stack);
228 }
229 
listInsert(Stack & stack)230 void listInsert(Stack& stack) {
231   IValue elem = pop(stack).to<IValue>();
232   int64_t idx = pop(stack).to<int64_t>();
233   c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
234 
235   const int64_t list_size = list.size();
236   const int64_t normalized_idx = normalizeIndex(idx, list_size);
237 
238   if (normalized_idx < 0 || normalized_idx >= list_size) {
239     if (normalized_idx < 0) {
240       list.insert(list.begin(), elem);
241     } else {
242       list.push_back(elem);
243     }
244   } else {
245     list.insert(list.begin() + normalized_idx, elem);
246   }
247 }
248 
listExtend(Stack & stack)249 void listExtend(Stack& stack) {
250   c10::List<IValue> b = pop(stack).to<c10::List<IValue>>();
251   c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
252 
253   a.reserve(a.size() + b.size());
254   for (const auto i : c10::irange(b.size())) {
255     a.push_back(b.get(i));
256   }
257 }
258 
listCopy(Stack & stack)259 void listCopy(Stack& stack) {
260   c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
261   push(stack, list.copy());
262 }
263 
listSelect(Stack & stack)264 void listSelect(Stack& stack) {
265   int64_t idx = pop(stack).to<int64_t>();
266   c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
267 
268   push(stack, getItem(list, idx));
269 }
270 
listLen(Stack & stack)271 void listLen(Stack& stack) {
272   c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
273 
274   const int64_t size = a.size();
275   push(stack, size);
276 }
277 
listList(Stack & stack)278 void listList(Stack& stack) {
279   c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
280   push(stack, a.copy());
281 }
282 
listAdd(Stack & stack)283 void listAdd(Stack& stack) {
284   c10::List<IValue> b = pop(stack).to<c10::List<IValue>>();
285   c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
286 
287   c10::List<IValue> ret = make_result_list<IValue>(a.elementType());
288 
289   if (a.use_count() == 1) {
290     ret = a;
291   } else {
292     ret = a.copy();
293   }
294 
295   ret.append(b);
296 
297   push(stack, std::move(ret));
298 }
299 
listInplaceAdd(Stack & stack)300 void listInplaceAdd(Stack& stack) {
301   c10::List<IValue> b = pop(stack).to<c10::List<IValue>>();
302   c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
303   a.append(b);
304   push(stack, std::move(a));
305 }
306 
listMulIntLeftInPlace(Stack & stack)307 void listMulIntLeftInPlace(Stack& stack) {
308   int64_t n = pop(stack).to<int64_t>();
309   c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
310   if (n <= 0) {
311     list.clear();
312   } else if (n > 1) {
313     size_t list_size = list.size();
314     for (const auto i : c10::irange(1, n)) {
315       (void)i; // Suppress unused variable warning
316       for (const auto j : c10::irange(list_size)) {
317         list.push_back(list.get(j));
318       }
319     }
320   }
321 
322   push(stack, std::move(list));
323 }
324 
listMulIntLeft(Stack & stack)325 void listMulIntLeft(Stack& stack) {
326   int64_t n = pop(stack).to<int64_t>();
327   c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
328 
329   c10::List<IValue> ret = make_result_list<IValue>(list.elementType());
330   const auto size = list.size() * n;
331   ret.reserve(size);
332 
333   for (const auto i : c10::irange(n)) {
334     (void)i; // Suppress unused variable warning
335     for (IValue e : list) {
336       ret.push_back(std::move(e));
337     }
338   }
339 
340   push(stack, std::move(ret));
341 }
342 
listMulIntRight(Stack & stack)343 void listMulIntRight(Stack& stack) {
344   c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
345   int64_t n = pop(stack).to<int64_t>();
346 
347   c10::List<IValue> ret = make_result_list<IValue>(list.elementType());
348   const auto size = list.size() * n;
349   ret.reserve(size);
350 
351   for (const auto i : c10::irange(n)) {
352     (void)i; // Suppress unused variable warning
353     for (IValue e : list) {
354       ret.push_back(std::move(e));
355     }
356   }
357 
358   push(stack, std::move(ret));
359 }
360 
listSlice(Stack & stack)361 void listSlice(Stack& stack) {
362   auto step_val = pop(stack);
363   auto end_val = pop(stack);
364   auto start_val = pop(stack);
365 
366   // By default, both start and end will be None.
367   // By python convention, they will be translated into
368   // INT64_MAX. If the step size is not given, it will be 1.
369   int64_t step = step_val.isInt() ? step_val.to<int64_t>() : 1;
370   int64_t end = end_val.isInt() ? end_val.to<int64_t>()
371                                 : std::numeric_limits<int64_t>::max();
372   int64_t start = start_val.isInt() ? start_val.to<int64_t>()
373                                     : std::numeric_limits<int64_t>::max();
374 
375   c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
376 
377   const int64_t list_size = list.size();
378 
379   c10::List<IValue> sliced_list = make_result_list<IValue>(list.elementType());
380   const int64_t num_values =
381       slice_indices_adjust(list_size, &start, &end, step);
382   sliced_list.reserve(num_values);
383 
384   int i = start;
385   for (const auto j : c10::irange(num_values)) {
386     (void)j; // Suppress unused variable warning
387     sliced_list.push_back(list.get(i));
388     i += step;
389   }
390 
391   push(stack, std::move(sliced_list));
392 }
393 
listSetItem(Stack & stack)394 void listSetItem(Stack& stack) {
395   IValue value = pop(stack).to<IValue>();
396   int64_t idx = pop(stack).to<int64_t>();
397   c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
398 
399   setItem(list, idx, std::move(value));
400 
401   push(stack, std::move(list));
402 }
403 
make_generator_for_device(c10::Device device,std::optional<int64_t> seed)404 at::Generator make_generator_for_device(
405     c10::Device device,
406     std::optional<int64_t> seed) {
407   if (device.is_cpu()) {
408     if (seed.has_value()) {
409       return at::detail::createCPUGenerator(seed.value());
410     } else {
411       return at::detail::createCPUGenerator();
412     }
413 // TODO(antoniojkim): Enable support for CUDA device
414 //                    Implementation below causes issues during rocm build
415 // #ifdef USE_CUDA
416 //   } else if (device.is_cuda()) {
417 //     auto generator = at::cuda::detail::createCUDAGenerator(device.index());
418 //     if (seed.has_value()) {
419 //       generator.set_current_seed(seed.value());
420 //     }
421 //     return generator;
422 // #endif
423 #ifdef USE_MPS
424   } else if (device.is_mps()) {
425     if (seed.has_value()) {
426       return at::mps::detail::createMPSGenerator(seed.value());
427     } else {
428       return at::mps::detail::createMPSGenerator();
429     }
430 #endif
431   } else {
432     AT_ERROR(
433         "Unsupported device for at::make_generator_for_device found: ",
434         device.str());
435   }
436 }
437 
438 } // namespace torch::jit
439