1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <cstdint> 5 #include <string> 6 #include <unordered_map> 7 8 namespace torch::jit::tensorexpr { 9 10 // The external functions that could be called from NNC must have the same 11 // signature defined by `NNCExternalFunction`. 12 // 13 // Why this signature? 14 // It was picked for two reasons: 1) it should be generic enough to represent 15 // most of the ops we might want to call, 2) it should be possible to generate a 16 // code for this call in LLVM codegen. 17 // The first 5 parameters allow to pass any number of contiguous CPU tensors in 18 // case we need to run aten ops (TODO: support different devices). The first 19 // buffer in the array is assumed to be the output buffer. We couldn't use 20 // `at::Tensor` (or `c10::IValue`) type there directly as it would mean that 21 // we'd need to declare it in LLVM codegen in LLVM IR form, which would be very 22 // cumbersome and hard to maintain. Note that the dimensions of all tensors are 23 // concatenated into a single array buf_dims. We do not need to pass its length, 24 // since it can be deduced from total number of buffers and their ranks. 25 // 26 // The last 2 arguments allow to pass any non-tensor arguments encoded as an 27 // array of int64_t values. The way they are encoded is not specified and could 28 // be arbitrary - whatever the most convenient for the specific bridge function 29 // is. 30 // 31 // The bridge functions must not throw exceptions - properly propagating them 32 // from the generated code is too cumbersome, and thus all calls to functions 33 // that could throw must be wrapped with try-catch blocks. 34 using NNCExternalFunction = void (*)( 35 int64_t bufs_num, 36 void** buf_data, 37 int64_t* buf_ranks, 38 int64_t* buf_dims, 39 int64_t* buf_strides, 40 int8_t* buf_dtypes, 41 int64_t args_num, 42 int64_t* extra_args); 43 44 // Return a global map "function-name" -> "function-pointer" for all registered 45 // in NNC external functions 46 TORCH_API std::unordered_map<std::string, NNCExternalFunction>& 47 getNNCFunctionRegistry(); 48 49 // To register a new external function in NNC one needs to create an instance of 50 // this struct 51 struct RegisterNNCExternalFunction { RegisterNNCExternalFunctionRegisterNNCExternalFunction52 RegisterNNCExternalFunction(const std::string& name, NNCExternalFunction fn) { 53 getNNCFunctionRegistry()[name] = fn; 54 } 55 }; 56 57 } // namespace torch::jit::tensorexpr 58