xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/external_functions_core.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/external_functions_core.h>
2 
3 namespace torch::jit::tensorexpr {
4 
5 #ifdef C10_MOBILE
6 extern "C" {
7 #endif
8 
9 using ParallelCallee = void (*)(int64_t, int8_t*);
DispatchParallel(int8_t * func,int64_t start,int64_t stop,int8_t * packed_data)10 void DispatchParallel(
11     int8_t* func,
12     int64_t start,
13     int64_t stop,
14     int8_t* packed_data) noexcept {
15   // TODO: preserve the func type.
16   try {
17     ParallelCallee callee = reinterpret_cast<ParallelCallee>(func);
18     at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) {
19       for (int64_t index = f_begin; index < f_end; index++) {
20         callee(index, packed_data);
21       }
22     });
23   } catch (...) {
24   }
25 }
26 
nnc_aten_free(size_t bufs_num,void ** ptrs)27 void nnc_aten_free(size_t bufs_num, void** ptrs) noexcept {
28   for (const auto i : c10::irange(bufs_num)) {
29     c10::raw::intrusive_ptr::decref((c10::TensorImpl*)ptrs[i]);
30   }
31 }
32 
33 #ifdef C10_MOBILE
34 } // extern "C"
35 #endif
36 
37 } // namespace torch::jit::tensorexpr
38