xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_DECL_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_DECL_H_
18 
19 #include <stddef.h>
20 #include <stdint.h>
21 
22 #include "tensorflow/c/tf_attrtype.h"
23 #include "tensorflow/core/tpu/libtftpu.h"
24 
25 extern "C" {
26 
27 struct TF_Status;
28 typedef struct TF_Status TF_Status;
29 
30 // Maximum number of array elements to inline into structs for performance.
31 #define TPU_C_API_MAX_INLINED 6
32 
33 enum TpuCoreTypeEnum {
34   kTensorCore,
35   kEmbeddingV1,
36   kEmbeddingV2,
37 };
38 
39 enum TpuVersionEnum {
40   kUnknownTpuVersion,
41   kTpuV2,
42   kTpuV3,
43   kTpuV4,
44 };
45 
46 typedef struct TpuRuntimeVersion {
47   // The three version numbers are: major, minor, patch
48   int version[3];
49   const char* metadata;
50   size_t metadata_size;
51 } TpuRuntimeVersion;
52 
53 typedef struct SE_Platform SE_Platform;
54 typedef struct SE_StreamExecutor SE_StreamExecutor;
55 typedef struct SE_Stream SE_Stream;
56 typedef struct SE_Event SE_Event;
57 typedef struct SE_Timer SE_Timer;
58 
59 typedef struct TpuSerializedProto {
60   const char* bytes;
61   size_t size;
62 } TpuSerializedProto;
63 
64 typedef struct SE_PlatformId {
65   void* id;  // aka stream_executor::Platform::Id
66 } SE_PlatformId;
67 typedef struct SE_StreamExecutorConfig SE_StreamExecutorConfig;
68 typedef struct SE_DeviceOptions SE_DeviceOptions;
69 typedef TF_Status* (*SE_StatusCallbackFn)(void*);
70 
71 typedef struct SE_DeviceMemoryBase {
72   void* opaque;
73   uint64_t size;
74   uint64_t payload;
75 } SE_DeviceMemoryBase;
76 
77 typedef struct SE_ScopedDeviceMemory {
78   SE_DeviceMemoryBase wrapped;
79   int device_ordinal;
80 } SE_ScopedDeviceMemory;
81 
82 typedef struct SE_AllocatorStats {
83   int64_t num_allocs;
84   int64_t bytes_in_use;
85   int64_t peak_bytes_in_use;
86   int64_t largest_alloc_size;
87 
88   bool has_bytes_limit;
89   int64_t bytes_limit;
90 
91   int64_t bytes_reserved;
92   int64_t peak_bytes_reserved;
93 
94   bool has_bytes_reservable_limit;
95   int64_t bytes_reservable_limit;
96 
97   int64_t largest_free_block_bytes;
98 } SE_AllocatorStats;
99 
100 // Note, due to the... odd way in which DeviceMemoryAllocator is used in TF, we
101 // cannot simply wrap an underlying pointer. Instead, we reverse the call
102 // direction and request memory via a callback.
103 typedef void (*SE_AllocateFn)(void* ctx, int device_ordinal, uint64_t size,
104                               bool retry_on_failure, int64_t memory_space,
105                               SE_ScopedDeviceMemory* result, TF_Status* status);
106 
107 typedef void (*SE_DeallocateFn)(void* ctx, SE_DeviceMemoryBase* base,
108                                 int device_ordinal, TF_Status* status);
109 
110 typedef struct SE_DeviceMemoryAllocator {
111   SE_Platform* platform;
112   void* ctx;
113   SE_AllocateFn allocate;
114   SE_DeallocateFn deallocate;
115 } SE_DeviceMemoryAllocator;
116 
117 typedef struct SE_DeviceDescription {
118   char* device_vendor;
119   char* platform_version;
120   char* driver_version;
121   char* runtime_version;
122   char* pci_bus_id;
123   char* name;
124 
125   int64_t thread_dim_limit_x;
126   int64_t thread_dim_limit_y;
127   int64_t thread_dim_limit_z;
128   int64_t block_dim_limit_x;
129   int64_t block_dim_limit_y;
130   int64_t block_dim_limit_z;
131 
132   int64_t threads_per_core_limit;
133   int64_t threads_per_block_limit;
134   int64_t threads_per_warp;
135 
136   int64_t registers_per_core_limit;
137   int64_t registers_per_block_limit;
138 
139   int64_t device_address_bits;
140   int64_t device_memory_size;
141   int64_t memory_bandwidth;
142 
143   int64_t shared_memory_per_core;
144   int64_t shared_memory_per_block;
145 
146   float clock_rate_ghz;
147 
148   int cuda_compute_capability_major;
149   int cuda_compute_capability_minor;
150 
151   int numa_node;
152   int core_count;
153   bool ecc_enabled;
154 } SE_DeviceDescription;
155 
156 typedef struct Tpu_Compiler Tpu_Compiler;
157 typedef struct SE_Executable SE_Executable;
158 
159 typedef struct SE_ExecutableRunOptions {
160   SE_DeviceMemoryAllocator allocator;
161   int device_ordinal;
162   SE_Stream* stream;
163   SE_Stream* host_to_device_stream;
164   TpuSerializedProto device_assignment;
165   int rng_seed;
166   int64_t run_id;
167   int launch_id;
168 } SE_ExecutableRunOptions;
169 
170 typedef struct SE_ExecutableSerializationHandle
171     SE_ExecutableSerializationHandle;
172 
173 typedef struct SE_MaybeOwningDeviceMemory {
174   SE_DeviceMemoryBase memory;
175   bool owned;
176 
177   // Set if owned
178   int device_ordinal;
179   SE_DeviceMemoryAllocator allocator;
180 } SE_MaybeOwningDeviceMemory;
181 
182 struct IntList {
183   union {
184     int* heap;  // owned
185     int inlined[TPU_C_API_MAX_INLINED];
186   };
187   int64_t size;
188 };
189 
190 struct Int64List {
191   union {
192     int64_t* heap;  // owned
193     int64_t inlined[TPU_C_API_MAX_INLINED];
194   };
195   int64_t size;
196 };
197 
198 struct FloatList {
199   union {
200     float* heap;  // owned
201     float inlined[TPU_C_API_MAX_INLINED];
202   };
203   int64_t size;
204 };
205 
206 struct BoolList {
207   union {
208     bool* heap;  // owned
209     bool inlined[TPU_C_API_MAX_INLINED];
210   };
211   int64_t size;
212 };
213 
214 struct FloatListRef {
215   float* ptr;  // not owned
216   int64_t size;
217 };
218 
219 typedef struct TpuEmbeddingEngineParameters {
220   FloatListRef** parameters[8];
221   size_t num_tables;
222 } TpuEmbeddingEngineParameters;
223 
224 typedef struct XLA_Tile {
225   Int64List dimensions;
226 } XLA_Tile;
227 
228 struct TileList {
229   union {
230     XLA_Tile* heap;  // owned
231     XLA_Tile inlined[TPU_C_API_MAX_INLINED];
232   };
233   int64_t size;
234 };
235 
236 typedef struct XLA_Layout {
237   Int64List minor_to_major;
238   IntList dim_level_types;
239   TileList tiles;
240   int64_t element_size_in_bits;
241   int64_t memory_space;
242 } XLA_Layout;
243 
244 // Represents an XLA shape tree.
245 typedef struct XLA_Shape {
246   int element_type;
247   Int64List dimensions;
248   BoolList dynamic_dimensions;
249   XLA_Shape* tuple_shapes;  // owned
250   int ntuple_shapes;
251   bool has_layout;
252   XLA_Layout layout;
253 } XLA_Shape;
254 
255 // Represents a leaf node for a XLA shaped buffer.
256 typedef struct XLA_ShapedBuffer {
257   XLA_Shape on_device_shape;
258   int device_ordinal;
259 
260   SE_DeviceMemoryBase* bases;
261   size_t count;
262 } XLA_ShapedBuffer;
263 
264 // Represents a leaf XLA literal.
265 typedef struct XLA_Literal {
266   char** buffers;
267   size_t* sizes;
268   size_t count;
269   XLA_Shape shape;
270 } XLA_Literal;
271 
272 typedef struct XLA_MaybeOwningDeviceMemoryShapeTree {
273   XLA_Shape shape;
274   SE_MaybeOwningDeviceMemory* buffers;
275 } XLA_MaybeOwningDeviceMemoryShapeTree;
276 
277 typedef struct XLA_ShapeIndex {
278   int64_t indices[8];
279   int64_t count;
280 } XLA_ShapeIndex;
281 
282 typedef struct SE_ExecutionInput {
283   XLA_MaybeOwningDeviceMemoryShapeTree shape_tree;
284   XLA_ShapeIndex* unowned_indices;
285   int unowned_indices_size;
286   XLA_Shape dynamic_shape;
287 } SE_ExecutionInput;
288 
289 typedef struct SE_ExecutionOutput {
290   XLA_ShapedBuffer result;
291   SE_MaybeOwningDeviceMemory* to_be_released;
292   int to_be_released_size;
293   XLA_ShapeIndex* aliased_indices;
294   int aliased_indices_size;
295 } SE_ExecutionOutput;
296 
297 typedef struct XLA_ComputationLayout {
298   int parameter_count;
299   XLA_Shape* parameter_layouts;
300   XLA_Shape result_layout;
301 } XLA_ComputationLayout;
302 
303 typedef struct XLA_HloModuleConfig {
304   uint64_t seed;
305   int32_t launch_id;
306   int64_t replica_count;
307   int64_t num_partitions;
308   bool use_spmd_partitioning;
309   bool use_auto_spmd_partitioning;
310   Int64List auto_spmd_partitioning_mesh_shape;
311   Int64List auto_spmd_partitioning_mesh_ids;
312   TpuSerializedProto debug_options;
313   bool has_static_device_assignment;
314   TpuSerializedProto static_device_assignment;
315   bool has_entry_computation_layout;
316   XLA_ComputationLayout entry_computation_layout;
317 } XLA_HloModuleConfig;
318 
319 typedef struct SE_HloExecutionProfile SE_HloExecutionProfile;
320 
321 struct SE_StreamExecutorList {
322   SE_StreamExecutor** exec;
323   int count;
324 };
325 
326 typedef struct XLA_HloModuleGroup {
327   TpuSerializedProto proto;
328   XLA_HloModuleConfig* module_config;
329 } XLA_HloModuleGroup;
330 
331 typedef struct XLA_HloModule {
332   TpuSerializedProto proto;
333   XLA_HloModuleConfig module_config;
334 } XLA_HloModule;
335 
336 typedef struct XLA_TransferManager XLA_TransferManager;
337 
338 typedef struct XLA_ComputationPlacer XLA_ComputationPlacer;
339 
340 typedef void (*XLA_CallbackFn)(void*);
341 typedef void (*XLA_StatusCallbackFn)(void*, TF_Status*);
342 
343 typedef struct SE_TpuTopology SE_TpuTopology;
344 typedef struct SE_TpuTopology_Core SE_TpuTopology_Core;
345 typedef struct SE_TpuTopology_Core SE_TpuTopology_Host;
346 }
347 
348 #endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_DECL_H_
349