xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/external_functions_registry.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <cstdint>
5 #include <string>
6 #include <unordered_map>
7 
8 namespace torch::jit::tensorexpr {
9 
10 // The external functions that could be called from NNC must have the same
11 // signature defined by `NNCExternalFunction`.
12 //
13 // Why this signature?
14 // It was picked for two reasons: 1) it should be generic enough to represent
15 // most of the ops we might want to call, 2) it should be possible to generate a
16 // code for this call in LLVM codegen.
17 // The first 5 parameters allow to pass any number of contiguous CPU tensors in
18 // case we need to run aten ops (TODO: support different devices). The first
19 // buffer in the array is assumed to be the output buffer. We couldn't use
20 // `at::Tensor` (or `c10::IValue`) type there directly as it would mean that
21 // we'd need to declare it in LLVM codegen in LLVM IR form, which would be very
22 // cumbersome and hard to maintain. Note that the dimensions of all tensors are
23 // concatenated into a single array buf_dims. We do not need to pass its length,
24 // since it can be deduced from total number of buffers and their ranks.
25 //
26 // The last 2 arguments allow to pass any non-tensor arguments encoded as an
27 // array of int64_t values. The way they are encoded is not specified and could
28 // be arbitrary - whatever the most convenient for the specific bridge function
29 // is.
30 //
31 // The bridge functions must not throw exceptions - properly propagating them
32 // from the generated code is too cumbersome, and thus all calls to functions
33 // that could throw must be wrapped with try-catch blocks.
34 using NNCExternalFunction = void (*)(
35     int64_t bufs_num,
36     void** buf_data,
37     int64_t* buf_ranks,
38     int64_t* buf_dims,
39     int64_t* buf_strides,
40     int8_t* buf_dtypes,
41     int64_t args_num,
42     int64_t* extra_args);
43 
44 // Return a global map "function-name" -> "function-pointer" for all registered
45 // in NNC external functions
46 TORCH_API std::unordered_map<std::string, NNCExternalFunction>&
47 getNNCFunctionRegistry();
48 
49 // To register a new external function in NNC one needs to create an instance of
50 // this struct
51 struct RegisterNNCExternalFunction {
RegisterNNCExternalFunctionRegisterNNCExternalFunction52   RegisterNNCExternalFunction(const std::string& name, NNCExternalFunction fn) {
53     getNNCFunctionRegistry()[name] = fn;
54   }
55 };
56 
57 } // namespace torch::jit::tensorexpr
58