1 #include <torch/csrc/jit/tensorexpr/external_functions_core.h>
2
3 namespace torch::jit::tensorexpr {
4
5 #ifdef C10_MOBILE
6 extern "C" {
7 #endif
8
9 using ParallelCallee = void (*)(int64_t, int8_t*);
DispatchParallel(int8_t * func,int64_t start,int64_t stop,int8_t * packed_data)10 void DispatchParallel(
11 int8_t* func,
12 int64_t start,
13 int64_t stop,
14 int8_t* packed_data) noexcept {
15 // TODO: preserve the func type.
16 try {
17 ParallelCallee callee = reinterpret_cast<ParallelCallee>(func);
18 at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) {
19 for (int64_t index = f_begin; index < f_end; index++) {
20 callee(index, packed_data);
21 }
22 });
23 } catch (...) {
24 }
25 }
26
nnc_aten_free(size_t bufs_num,void ** ptrs)27 void nnc_aten_free(size_t bufs_num, void** ptrs) noexcept {
28 for (const auto i : c10::irange(bufs_num)) {
29 c10::raw::intrusive_ptr::decref((c10::TensorImpl*)ptrs[i]);
30 }
31 }
32
33 #ifdef C10_MOBILE
34 } // extern "C"
35 #endif
36
37 } // namespace torch::jit::tensorexpr
38