1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_DECL_H_ 17 #define TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_DECL_H_ 18 19 #include <stddef.h> 20 #include <stdint.h> 21 22 #include "tensorflow/c/tf_attrtype.h" 23 #include "tensorflow/core/tpu/libtftpu.h" 24 25 extern "C" { 26 27 struct TF_Status; 28 typedef struct TF_Status TF_Status; 29 30 // Maximum number of array elements to inline into structs for performance. 31 #define TPU_C_API_MAX_INLINED 6 32 33 enum TpuCoreTypeEnum { 34 kTensorCore, 35 kEmbeddingV1, 36 kEmbeddingV2, 37 }; 38 39 enum TpuVersionEnum { 40 kUnknownTpuVersion, 41 kTpuV2, 42 kTpuV3, 43 kTpuV4, 44 }; 45 46 typedef struct TpuRuntimeVersion { 47 // The three version numbers are: major, minor, patch 48 int version[3]; 49 const char* metadata; 50 size_t metadata_size; 51 } TpuRuntimeVersion; 52 53 typedef struct SE_Platform SE_Platform; 54 typedef struct SE_StreamExecutor SE_StreamExecutor; 55 typedef struct SE_Stream SE_Stream; 56 typedef struct SE_Event SE_Event; 57 typedef struct SE_Timer SE_Timer; 58 59 typedef struct TpuSerializedProto { 60 const char* bytes; 61 size_t size; 62 } TpuSerializedProto; 63 64 typedef struct SE_PlatformId { 65 void* id; // aka stream_executor::Platform::Id 66 } SE_PlatformId; 67 typedef struct SE_StreamExecutorConfig SE_StreamExecutorConfig; 68 typedef struct SE_DeviceOptions SE_DeviceOptions; 69 typedef TF_Status* (*SE_StatusCallbackFn)(void*); 70 71 typedef struct SE_DeviceMemoryBase { 72 void* opaque; 73 uint64_t size; 74 uint64_t payload; 75 } SE_DeviceMemoryBase; 76 77 typedef struct SE_ScopedDeviceMemory { 78 SE_DeviceMemoryBase wrapped; 79 int device_ordinal; 80 } SE_ScopedDeviceMemory; 81 82 typedef struct SE_AllocatorStats { 83 int64_t num_allocs; 84 int64_t bytes_in_use; 85 int64_t peak_bytes_in_use; 86 int64_t largest_alloc_size; 87 88 bool has_bytes_limit; 89 int64_t bytes_limit; 90 91 int64_t bytes_reserved; 92 int64_t peak_bytes_reserved; 93 94 bool has_bytes_reservable_limit; 95 int64_t bytes_reservable_limit; 96 97 int64_t largest_free_block_bytes; 98 } SE_AllocatorStats; 99 100 // Note, due to the... odd way in which DeviceMemoryAllocator is used in TF, we 101 // cannot simply wrap an underlying pointer. Instead, we reverse the call 102 // direction and request memory via a callback. 103 typedef void (*SE_AllocateFn)(void* ctx, int device_ordinal, uint64_t size, 104 bool retry_on_failure, int64_t memory_space, 105 SE_ScopedDeviceMemory* result, TF_Status* status); 106 107 typedef void (*SE_DeallocateFn)(void* ctx, SE_DeviceMemoryBase* base, 108 int device_ordinal, TF_Status* status); 109 110 typedef struct SE_DeviceMemoryAllocator { 111 SE_Platform* platform; 112 void* ctx; 113 SE_AllocateFn allocate; 114 SE_DeallocateFn deallocate; 115 } SE_DeviceMemoryAllocator; 116 117 typedef struct SE_DeviceDescription { 118 char* device_vendor; 119 char* platform_version; 120 char* driver_version; 121 char* runtime_version; 122 char* pci_bus_id; 123 char* name; 124 125 int64_t thread_dim_limit_x; 126 int64_t thread_dim_limit_y; 127 int64_t thread_dim_limit_z; 128 int64_t block_dim_limit_x; 129 int64_t block_dim_limit_y; 130 int64_t block_dim_limit_z; 131 132 int64_t threads_per_core_limit; 133 int64_t threads_per_block_limit; 134 int64_t threads_per_warp; 135 136 int64_t registers_per_core_limit; 137 int64_t registers_per_block_limit; 138 139 int64_t device_address_bits; 140 int64_t device_memory_size; 141 int64_t memory_bandwidth; 142 143 int64_t shared_memory_per_core; 144 int64_t shared_memory_per_block; 145 146 float clock_rate_ghz; 147 148 int cuda_compute_capability_major; 149 int cuda_compute_capability_minor; 150 151 int numa_node; 152 int core_count; 153 bool ecc_enabled; 154 } SE_DeviceDescription; 155 156 typedef struct Tpu_Compiler Tpu_Compiler; 157 typedef struct SE_Executable SE_Executable; 158 159 typedef struct SE_ExecutableRunOptions { 160 SE_DeviceMemoryAllocator allocator; 161 int device_ordinal; 162 SE_Stream* stream; 163 SE_Stream* host_to_device_stream; 164 TpuSerializedProto device_assignment; 165 int rng_seed; 166 int64_t run_id; 167 int launch_id; 168 } SE_ExecutableRunOptions; 169 170 typedef struct SE_ExecutableSerializationHandle 171 SE_ExecutableSerializationHandle; 172 173 typedef struct SE_MaybeOwningDeviceMemory { 174 SE_DeviceMemoryBase memory; 175 bool owned; 176 177 // Set if owned 178 int device_ordinal; 179 SE_DeviceMemoryAllocator allocator; 180 } SE_MaybeOwningDeviceMemory; 181 182 struct IntList { 183 union { 184 int* heap; // owned 185 int inlined[TPU_C_API_MAX_INLINED]; 186 }; 187 int64_t size; 188 }; 189 190 struct Int64List { 191 union { 192 int64_t* heap; // owned 193 int64_t inlined[TPU_C_API_MAX_INLINED]; 194 }; 195 int64_t size; 196 }; 197 198 struct FloatList { 199 union { 200 float* heap; // owned 201 float inlined[TPU_C_API_MAX_INLINED]; 202 }; 203 int64_t size; 204 }; 205 206 struct BoolList { 207 union { 208 bool* heap; // owned 209 bool inlined[TPU_C_API_MAX_INLINED]; 210 }; 211 int64_t size; 212 }; 213 214 struct FloatListRef { 215 float* ptr; // not owned 216 int64_t size; 217 }; 218 219 typedef struct TpuEmbeddingEngineParameters { 220 FloatListRef** parameters[8]; 221 size_t num_tables; 222 } TpuEmbeddingEngineParameters; 223 224 typedef struct XLA_Tile { 225 Int64List dimensions; 226 } XLA_Tile; 227 228 struct TileList { 229 union { 230 XLA_Tile* heap; // owned 231 XLA_Tile inlined[TPU_C_API_MAX_INLINED]; 232 }; 233 int64_t size; 234 }; 235 236 typedef struct XLA_Layout { 237 Int64List minor_to_major; 238 IntList dim_level_types; 239 TileList tiles; 240 int64_t element_size_in_bits; 241 int64_t memory_space; 242 } XLA_Layout; 243 244 // Represents an XLA shape tree. 245 typedef struct XLA_Shape { 246 int element_type; 247 Int64List dimensions; 248 BoolList dynamic_dimensions; 249 XLA_Shape* tuple_shapes; // owned 250 int ntuple_shapes; 251 bool has_layout; 252 XLA_Layout layout; 253 } XLA_Shape; 254 255 // Represents a leaf node for a XLA shaped buffer. 256 typedef struct XLA_ShapedBuffer { 257 XLA_Shape on_device_shape; 258 int device_ordinal; 259 260 SE_DeviceMemoryBase* bases; 261 size_t count; 262 } XLA_ShapedBuffer; 263 264 // Represents a leaf XLA literal. 265 typedef struct XLA_Literal { 266 char** buffers; 267 size_t* sizes; 268 size_t count; 269 XLA_Shape shape; 270 } XLA_Literal; 271 272 typedef struct XLA_MaybeOwningDeviceMemoryShapeTree { 273 XLA_Shape shape; 274 SE_MaybeOwningDeviceMemory* buffers; 275 } XLA_MaybeOwningDeviceMemoryShapeTree; 276 277 typedef struct XLA_ShapeIndex { 278 int64_t indices[8]; 279 int64_t count; 280 } XLA_ShapeIndex; 281 282 typedef struct SE_ExecutionInput { 283 XLA_MaybeOwningDeviceMemoryShapeTree shape_tree; 284 XLA_ShapeIndex* unowned_indices; 285 int unowned_indices_size; 286 XLA_Shape dynamic_shape; 287 } SE_ExecutionInput; 288 289 typedef struct SE_ExecutionOutput { 290 XLA_ShapedBuffer result; 291 SE_MaybeOwningDeviceMemory* to_be_released; 292 int to_be_released_size; 293 XLA_ShapeIndex* aliased_indices; 294 int aliased_indices_size; 295 } SE_ExecutionOutput; 296 297 typedef struct XLA_ComputationLayout { 298 int parameter_count; 299 XLA_Shape* parameter_layouts; 300 XLA_Shape result_layout; 301 } XLA_ComputationLayout; 302 303 typedef struct XLA_HloModuleConfig { 304 uint64_t seed; 305 int32_t launch_id; 306 int64_t replica_count; 307 int64_t num_partitions; 308 bool use_spmd_partitioning; 309 bool use_auto_spmd_partitioning; 310 Int64List auto_spmd_partitioning_mesh_shape; 311 Int64List auto_spmd_partitioning_mesh_ids; 312 TpuSerializedProto debug_options; 313 bool has_static_device_assignment; 314 TpuSerializedProto static_device_assignment; 315 bool has_entry_computation_layout; 316 XLA_ComputationLayout entry_computation_layout; 317 } XLA_HloModuleConfig; 318 319 typedef struct SE_HloExecutionProfile SE_HloExecutionProfile; 320 321 struct SE_StreamExecutorList { 322 SE_StreamExecutor** exec; 323 int count; 324 }; 325 326 typedef struct XLA_HloModuleGroup { 327 TpuSerializedProto proto; 328 XLA_HloModuleConfig* module_config; 329 } XLA_HloModuleGroup; 330 331 typedef struct XLA_HloModule { 332 TpuSerializedProto proto; 333 XLA_HloModuleConfig module_config; 334 } XLA_HloModule; 335 336 typedef struct XLA_TransferManager XLA_TransferManager; 337 338 typedef struct XLA_ComputationPlacer XLA_ComputationPlacer; 339 340 typedef void (*XLA_CallbackFn)(void*); 341 typedef void (*XLA_StatusCallbackFn)(void*, TF_Status*); 342 343 typedef struct SE_TpuTopology SE_TpuTopology; 344 typedef struct SE_TpuTopology_Core SE_TpuTopology_Core; 345 typedef struct SE_TpuTopology_Core SE_TpuTopology_Host; 346 } 347 348 #endif // TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_DECL_H_ 349