1 #pragma once 2 3 // WARNING: Be careful when adding new includes here. This header will be used 4 // in model.so, and should not refer to any aten/c10 headers except the stable 5 // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule 6 // applies to other files under torch/csrc/inductor/aoti_runtime/. 7 #include <torch/csrc/inductor/aoti_runtime/utils.h> 8 9 extern "C" { 10 struct AOTInductorModelOpaque; 11 using AOTInductorModelHandle = AOTInductorModelOpaque*; 12 13 struct AOTInductorModelContainerOpaque; 14 using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*; 15 16 struct AOTInductorStreamOpaque; 17 using AOTInductorStreamHandle = AOTInductorStreamOpaque*; 18 19 struct AOTInductorConstantMap; 20 using AOTInductorConstantMapHandle = AOTInductorConstantMap*; 21 22 // TODO: Deprecate this API. This was kept for BC compatibility. 23 // Please use AOTInductorModelContainerCreateWithDevice instead. 24 AOTIRuntimeError AOTInductorModelContainerCreate( 25 AOTInductorModelContainerHandle* container_handle, 26 size_t num_models, 27 bool is_cpu, 28 const char* cubin_dir); 29 30 // Creates an AOTInductor model container. The parameter num_models 31 // specifies the number of model instances that may be run concurrently for 32 // the same input model. 33 // `device_str` MUST NOT be nullptr. It must be a valid device string, e.g. 34 // "cpu", "cuda", "cuda:0", etc. If the device index is not specified for CUDA 35 // device, runtime will use the device index returned by 36 // "cudaGetDevice(&device_idx)" 37 AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( 38 AOTInductorModelContainerHandle* container_handle, 39 size_t num_models, 40 const char* device_str, 41 const char* cubin_dir); 42 43 // Deletes the AOTInductor model container. 44 AOTIRuntimeError AOTInductorModelContainerDelete( 45 AOTInductorModelContainerHandle container_handle); 46 47 // Runs the inference. 48 AOTIRuntimeError AOTInductorModelContainerRun( 49 AOTInductorModelContainerHandle container_handle, 50 AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles 51 // are stolen; the array itself is borrowed 52 size_t num_inputs, 53 AtenTensorHandle* 54 output_handles, // array for writing output AtenTensorHandle; handles 55 // will be stolen by the caller; the array itself is 56 // borrowed 57 size_t num_outputs, 58 AOTInductorStreamHandle stream_handle, 59 AOTIProxyExecutorHandle proxy_executor_handle); 60 61 // Retrieves the number of constants for the model. 62 AOTIRuntimeError AOTInductorModelContainerGetNumConstants( 63 AOTInductorModelContainerHandle container_handle, 64 size_t* num_constants); 65 66 // Retrieves a constant's name. 67 // idx is the index of the internal's constants. 68 // Need idx < num_constants from AOTInductorModelContainerGetNumConstants 69 AOTIRuntimeError AOTInductorModelContainerGetConstantName( 70 AOTInductorModelContainerHandle container_handle, 71 size_t idx, 72 const char** name); 73 74 // Retrieves a constant's original FQN. 75 // idx is the index of the internal's constants. 76 // Need idx < num_constants from AOTInductorModelContainerGetNumConstants 77 AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( 78 AOTInductorModelContainerHandle container_handle, 79 size_t idx, 80 const char** original_fqn); 81 82 // Retrieves whether a constant is from folded. 83 // idx is the index of the internal's constants. 84 // Need idx < num_constants from AOTInductorModelContainerGetNumConstants 85 AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( 86 AOTInductorModelContainerHandle container_handle, 87 size_t idx, 88 bool* from_folded); 89 90 // Retrieves a constant's dtype. 91 // idx is the index of the internal's constants. 92 // Need idx < num_constants from AOTInductorModelContainerGetNumConstants 93 AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( 94 AOTInductorModelContainerHandle container_handle, 95 size_t idx, 96 int32_t* dtype); 97 98 // Setup the constant buffer in model container with provided ConstantMap 99 // use_inactive should be set as true if the inactive buffer is to be updated. 100 // validate_full_update checks if all constants are included in the ConstantMap 101 AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( 102 AOTInductorModelContainerHandle container_handle, 103 AOTInductorConstantMapHandle constant_map_handle, 104 bool use_inactive, 105 bool validate_full_update); 106 107 // Setup the inactive constant buffer in model container with provided 108 // ConstantMap 109 AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( 110 AOTInductorModelContainerHandle container_handle, 111 AOTInductorConstantMapHandle constant_map_handle); 112 113 // Run constant folding on constant buffer. 114 AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( 115 AOTInductorModelContainerHandle container_handle, 116 bool use_inactive, 117 AOTInductorStreamHandle stream_handle, 118 AOTIProxyExecutorHandle proxy_executor_handle); 119 120 // Swap the constant buffer being used to the inactive one. 121 AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( 122 AOTInductorModelContainerHandle container_handle); 123 124 // Retrieves the number of inputs for the model. 125 AOTIRuntimeError AOTInductorModelContainerGetNumInputs( 126 AOTInductorModelContainerHandle container_handle, 127 size_t* ret_num_inputs); 128 129 // Retrieves the input name at the given index. 130 AOTIRuntimeError AOTInductorModelContainerGetInputName( 131 AOTInductorModelContainerHandle container_handle, 132 size_t input_idx, 133 const char** ret_input_names); 134 135 // Retrieves the number of outputs for the model. 136 AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( 137 AOTInductorModelContainerHandle container_handle, 138 size_t* ret_num_outputs); 139 140 // Retrieves the output name at the given index. 141 AOTIRuntimeError AOTInductorModelContainerGetOutputName( 142 AOTInductorModelContainerHandle container_handle, 143 size_t output_idx, 144 const char** ret_output_names); 145 146 // Creates an AOTInductorModel instance. This is a thin and light wrapper 147 // around the compiled model; it doesn't handle concurrency, queueing, device 148 // management, etc. Use this if bare-metal performance is needed and you are 149 // willing to handle other "management" aspects yourself. 150 // 151 // constant_map_handle is an opaque type to satisfy the C ABI. It should be a 152 // std::unordered_map<std::string, at::Tensor*>*. 153 AOTIRuntimeError AOTInductorModelCreate( 154 AOTInductorModelHandle* model_handle, 155 AOTInductorConstantMapHandle constant_map_handle); 156 157 // Run an AOTInductorModel (see AOTInductorModelCreate for when one should use 158 // this function versus AOTInductorModelContainerRun). 159 AOTIRuntimeError AOTInductorModelRun( 160 AOTInductorModelHandle model_handle, 161 AtenTensorHandle* input_handles, 162 AtenTensorHandle* output_handles); 163 164 // Replace AOTInductorModel's constant map. Note it doesn't handle concurrency 165 // so be sure to handle ordering if AOTInductorModelRun is ran concurrently. 166 AOTIRuntimeError AOTInductorModelUpdateConstantsMap( 167 AOTInductorModelHandle model_handle, 168 AOTInductorConstantMapHandle constant_map_handle); 169 170 // Delete an AOTInductorModel created by AOTInductorModelCreate. 171 AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle); 172 173 AOTIRuntimeError AOTInductorModelGetNumOutputs( 174 AOTInductorModelHandle model_handle, 175 size_t* ret_num_outputs); 176 177 AOTIRuntimeError AOTInductorModelContainerGetCallSpec( 178 AOTInductorModelContainerHandle container_handle, 179 const char** in_spec, 180 const char** out_spec); 181 182 } // extern "C" 183