xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_runtime/interface.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // WARNING: Be careful when adding new includes here. This header will be used
4 // in model.so, and should not refer to any aten/c10 headers except the stable
5 // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
6 // applies to other files under torch/csrc/inductor/aoti_runtime/.
7 #include <torch/csrc/inductor/aoti_runtime/utils.h>
8 
9 extern "C" {
10 struct AOTInductorModelOpaque;
11 using AOTInductorModelHandle = AOTInductorModelOpaque*;
12 
13 struct AOTInductorModelContainerOpaque;
14 using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*;
15 
16 struct AOTInductorStreamOpaque;
17 using AOTInductorStreamHandle = AOTInductorStreamOpaque*;
18 
19 struct AOTInductorConstantMap;
20 using AOTInductorConstantMapHandle = AOTInductorConstantMap*;
21 
22 // TODO: Deprecate this API. This was kept for BC compatibility.
23 // Please use AOTInductorModelContainerCreateWithDevice instead.
24 AOTIRuntimeError AOTInductorModelContainerCreate(
25     AOTInductorModelContainerHandle* container_handle,
26     size_t num_models,
27     bool is_cpu,
28     const char* cubin_dir);
29 
30 // Creates an AOTInductor model container. The parameter num_models
31 // specifies the number of model instances that may be run concurrently for
32 // the same input model.
33 // `device_str` MUST NOT be nullptr. It must be a valid device string, e.g.
34 // "cpu", "cuda", "cuda:0", etc. If the device index is not specified for CUDA
35 // device, runtime will use the device index returned by
36 // "cudaGetDevice(&device_idx)"
37 AOTIRuntimeError AOTInductorModelContainerCreateWithDevice(
38     AOTInductorModelContainerHandle* container_handle,
39     size_t num_models,
40     const char* device_str,
41     const char* cubin_dir);
42 
43 // Deletes the AOTInductor model container.
44 AOTIRuntimeError AOTInductorModelContainerDelete(
45     AOTInductorModelContainerHandle container_handle);
46 
47 // Runs the inference.
48 AOTIRuntimeError AOTInductorModelContainerRun(
49     AOTInductorModelContainerHandle container_handle,
50     AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles
51                                      // are stolen; the array itself is borrowed
52     size_t num_inputs,
53     AtenTensorHandle*
54         output_handles, // array for writing output AtenTensorHandle; handles
55                         // will be stolen by the caller; the array itself is
56                         // borrowed
57     size_t num_outputs,
58     AOTInductorStreamHandle stream_handle,
59     AOTIProxyExecutorHandle proxy_executor_handle);
60 
61 // Retrieves the number of constants for the model.
62 AOTIRuntimeError AOTInductorModelContainerGetNumConstants(
63     AOTInductorModelContainerHandle container_handle,
64     size_t* num_constants);
65 
66 // Retrieves a constant's name.
67 // idx is the index of the internal's constants.
68 // Need idx < num_constants from AOTInductorModelContainerGetNumConstants
69 AOTIRuntimeError AOTInductorModelContainerGetConstantName(
70     AOTInductorModelContainerHandle container_handle,
71     size_t idx,
72     const char** name);
73 
74 // Retrieves a constant's original FQN.
75 // idx is the index of the internal's constants.
76 // Need idx < num_constants from AOTInductorModelContainerGetNumConstants
77 AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN(
78     AOTInductorModelContainerHandle container_handle,
79     size_t idx,
80     const char** original_fqn);
81 
82 // Retrieves whether a constant is from folded.
83 // idx is the index of the internal's constants.
84 // Need idx < num_constants from AOTInductorModelContainerGetNumConstants
85 AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded(
86     AOTInductorModelContainerHandle container_handle,
87     size_t idx,
88     bool* from_folded);
89 
90 // Retrieves a constant's dtype.
91 // idx is the index of the internal's constants.
92 // Need idx < num_constants from AOTInductorModelContainerGetNumConstants
93 AOTIRuntimeError AOTInductorModelContainerGetConstantDtype(
94     AOTInductorModelContainerHandle container_handle,
95     size_t idx,
96     int32_t* dtype);
97 
98 // Setup the constant buffer in model container with provided ConstantMap
99 // use_inactive should be set as true if the inactive buffer is to be updated.
100 // validate_full_update checks if all constants are included in the ConstantMap
101 AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer(
102     AOTInductorModelContainerHandle container_handle,
103     AOTInductorConstantMapHandle constant_map_handle,
104     bool use_inactive,
105     bool validate_full_update);
106 
107 // Setup the inactive constant buffer in model container with provided
108 // ConstantMap
109 AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer(
110     AOTInductorModelContainerHandle container_handle,
111     AOTInductorConstantMapHandle constant_map_handle);
112 
113 // Run constant folding on constant buffer.
114 AOTIRuntimeError AOTInductorModelContainerRunConstantFolding(
115     AOTInductorModelContainerHandle container_handle,
116     bool use_inactive,
117     AOTInductorStreamHandle stream_handle,
118     AOTIProxyExecutorHandle proxy_executor_handle);
119 
120 // Swap the constant buffer being used to the inactive one.
121 AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer(
122     AOTInductorModelContainerHandle container_handle);
123 
124 // Retrieves the number of inputs for the model.
125 AOTIRuntimeError AOTInductorModelContainerGetNumInputs(
126     AOTInductorModelContainerHandle container_handle,
127     size_t* ret_num_inputs);
128 
129 // Retrieves the input name at the given index.
130 AOTIRuntimeError AOTInductorModelContainerGetInputName(
131     AOTInductorModelContainerHandle container_handle,
132     size_t input_idx,
133     const char** ret_input_names);
134 
135 // Retrieves the number of outputs for the model.
136 AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(
137     AOTInductorModelContainerHandle container_handle,
138     size_t* ret_num_outputs);
139 
140 // Retrieves the output name at the given index.
141 AOTIRuntimeError AOTInductorModelContainerGetOutputName(
142     AOTInductorModelContainerHandle container_handle,
143     size_t output_idx,
144     const char** ret_output_names);
145 
146 // Creates an AOTInductorModel instance.  This is a thin and light wrapper
147 // around the compiled model; it doesn't handle concurrency, queueing, device
148 // management, etc.  Use this if bare-metal performance is needed and you are
149 // willing to handle other "management" aspects yourself.
150 //
151 // constant_map_handle is an opaque type to satisfy the C ABI.  It should be a
152 // std::unordered_map<std::string, at::Tensor*>*.
153 AOTIRuntimeError AOTInductorModelCreate(
154     AOTInductorModelHandle* model_handle,
155     AOTInductorConstantMapHandle constant_map_handle);
156 
157 // Run an AOTInductorModel (see AOTInductorModelCreate for when one should use
158 // this function versus AOTInductorModelContainerRun).
159 AOTIRuntimeError AOTInductorModelRun(
160     AOTInductorModelHandle model_handle,
161     AtenTensorHandle* input_handles,
162     AtenTensorHandle* output_handles);
163 
164 // Replace AOTInductorModel's constant map. Note it doesn't handle concurrency
165 // so be sure to handle ordering if AOTInductorModelRun is ran concurrently.
166 AOTIRuntimeError AOTInductorModelUpdateConstantsMap(
167     AOTInductorModelHandle model_handle,
168     AOTInductorConstantMapHandle constant_map_handle);
169 
170 // Delete an AOTInductorModel created by AOTInductorModelCreate.
171 AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle);
172 
173 AOTIRuntimeError AOTInductorModelGetNumOutputs(
174     AOTInductorModelHandle model_handle,
175     size_t* ret_num_outputs);
176 
177 AOTIRuntimeError AOTInductorModelContainerGetCallSpec(
178     AOTInductorModelContainerHandle container_handle,
179     const char** in_spec,
180     const char** out_spec);
181 
182 } // extern "C"
183