1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ 16 #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ 17 #include <stddef.h> 18 #include <stdint.h> 19 20 #include "tensorflow/c/c_api_macros.h" 21 #include "tensorflow/c/tf_status.h" 22 23 // -------------------------------------------------------------------------- 24 // C API for StreamExecutor. The API is under active development and eventually 25 // should allow registering a pluggable device with TensorFlow. 26 // 27 // Conventions: 28 // * Struct prefix indicates whether struct fields should be filled by the 29 // plugin or core implementation: 30 // * SE_ : set/filled by core unless explicitly marked otherwise. 31 // * SP_ : set/filled by plugin unless explicitly marked otherwise. 32 // * We use `struct_size` for version checking. It is exempt from the `SE/SP` 33 // rule above and should be set both by core and the plugin. 34 // * For example, `create_device` function receives `SP_Device*` as input 35 // with `struct_size` populated by core. The plugin is responsible for 36 // setting `struct_size` as well, along with all other fields. 37 // * Refer to "TensorFlow Versioning Strategy" section at 38 // https://github.com/tensorflow/community/pull/257/files. 39 // * Note that the API is still under active development and doesn't have 40 // versioning guarantees yet. 41 // * `void* ext` is a free-form field that can be populated by 42 // a plugin in `SP_*` structs or potential future extension points in `SE_` 43 // structs. 44 // 45 // Example usage: 46 // 47 // /* Sample TensorFlow code below, exact implementation might differ. */ 48 // // Version checking uses `struct_size`. It is exempt from the `SE/SP` rule 49 // // above and should be set both by core and the plugin." 50 // SP_Device device { SP_DEVICE_STRUCT_SIZE }; 51 // SE_CreateDeviceParams params { SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE } ; 52 // params.device = &device; 53 // 54 // /* Plugin code below */ 55 // constexpr char DEVICE_NAME[] = "MY_DEVICE"; 56 // constexpr char DEVICE_TYPE[] = "GPU"; 57 // 58 // void create_device(const SP_Platform* platform, 59 // SE_CreateDeviceParams* params, TF_Status* status) { 60 // // Custom actions based on TensorFlow's view of SP_Device. 61 // OnTFDeviceView(params->device->struct_size); 62 // params->device = { SP_DEVICE_STRUCT_SIZE }; 63 // params->device->device_handle = get_my_device_handle(device->ordinal); 64 // params->device->ordinal = params->ordinal; 65 // ... 66 // } 67 // 68 // void destroy_device(const SP_Platform* platform, SP_Device* device) { 69 // delete_my_device_handle(device->device_handle); 70 // } 71 // 72 // void SE_InitPlugin( 73 // SE_PlatformRegistrationParams* params, 74 // TF_Status* status) { 75 // params->platform = { SP_PLATFORM_STRUCT_SIZE }; 76 // // Values such as `name` and `type` must outlive SE_InitPlugin call. 77 // params->platform->name = DEVICE_NAME; 78 // params->platform->type = DEVICE_TYPE; 79 // params->platform_fns->get_device_count = get_device_count; 80 // params->platform_fns->create_device = create_device; 81 // params->platform_fns->destroy_device = destroy_device; 82 // ... 83 // } 84 85 #define SE_MAJOR 0 86 #define SE_MINOR 0 87 #define SE_PATCH 1 88 89 #ifdef __cplusplus 90 extern "C" { 91 #endif 92 93 typedef struct SP_Stream_st* SP_Stream; 94 typedef struct SP_Event_st* SP_Event; 95 typedef struct SP_Timer_st* SP_Timer; 96 // Takes `callback_arg` passed to `host_callback` as the first argument. 97 typedef void (*SE_StatusCallbackFn)(void* const, TF_Status* const); 98 99 typedef struct SP_TimerFns { 100 size_t struct_size; 101 void* ext; // reserved for future use 102 uint64_t (*nanoseconds)(SP_Timer timer); 103 } SP_TimerFns; 104 105 #define SP_TIMER_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_TimerFns, nanoseconds) 106 107 typedef struct SP_AllocatorStats { 108 size_t struct_size; 109 int64_t num_allocs; 110 int64_t bytes_in_use; 111 int64_t peak_bytes_in_use; 112 int64_t largest_alloc_size; 113 114 int8_t has_bytes_limit; 115 int64_t bytes_limit; 116 117 int64_t bytes_reserved; 118 int64_t peak_bytes_reserved; 119 120 int8_t has_bytes_reservable_limit; 121 int64_t bytes_reservable_limit; 122 123 int64_t largest_free_block_bytes; 124 } SP_AllocatorStats; 125 126 #define SP_ALLOCATORSTATS_STRUCT_SIZE \ 127 TF_OFFSET_OF_END(SP_AllocatorStats, largest_free_block_bytes) 128 129 // Potential states for an SP_Event. If `poll_for_status` returns anything aside 130 // from kPending or kComplete, an error has occurred; kUnknown is a bad state. 131 typedef enum SE_EventStatus { 132 SE_EVENT_UNKNOWN, 133 SE_EVENT_ERROR, 134 SE_EVENT_PENDING, 135 SE_EVENT_COMPLETE, 136 } SE_EventStatus; 137 138 // Memory allocation information. 139 // This matches DeviceMemoryBase defined here: 140 // https://cs.opensource.google/tensorflow/tensorflow/+/refs/tags/v2.3.0:tensorflow/stream_executor/device_memory.h;l=57 141 typedef struct SP_DeviceMemoryBase { 142 size_t struct_size; 143 void* ext; // Reserved for future use 144 // Platform-dependent value representing allocated memory. 145 // Note that the pointer does not have to be to the virtual address itself. 146 void* opaque; 147 uint64_t size; // Size in bytes of this allocation. 148 uint64_t payload; // Value for plugin's use 149 } SP_DeviceMemoryBase; 150 151 #define SP_DEVICE_MEMORY_BASE_STRUCT_SIZE \ 152 TF_OFFSET_OF_END(SP_DeviceMemoryBase, payload) 153 154 typedef struct SP_Device { 155 size_t struct_size; 156 void* ext; // free-form data set by plugin 157 int32_t ordinal; // device index 158 159 // Device vendor can store handle to their device representation 160 // here. 161 void* device_handle; 162 163 // [Optional] 164 // Device hardware name. Used for printing. 165 // Must be null-terminated. 166 const char* hardware_name; 167 168 // [Optional] 169 // Device vendor name. Used for printing. 170 // Must be null-terminated. 171 const char* device_vendor; 172 173 // [Optional] 174 // Returns the PCI bus identifier for this device, of the form 175 // [domain]:[bus]:[device].[function] 176 // where domain number is usually 0000. 177 // Example: 0000:00:02.1 178 // For more information see: 179 // https://en.wikipedia.org/wiki/PCI_configuration_space 180 // https://www.oreilly.com/library/view/linux-device-drivers/0596005903/ch12.html 181 // Used for printing. Must be null-terminated. 182 const char* pci_bus_id; 183 } SP_Device; 184 185 #define SP_DEVICE_STRUCT_SIZE TF_OFFSET_OF_END(SP_Device, pci_bus_id) 186 187 typedef struct SE_CreateDeviceParams { 188 size_t struct_size; 189 void* ext; // reserved for future use 190 int32_t ordinal; // device index 191 192 SP_Device* device; // Input/output, struct_size set by TF for plugin to read. 193 // Subsequently plugin fills the entire struct. 194 } SE_CreateDeviceParams; 195 196 #define SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE \ 197 TF_OFFSET_OF_END(SE_CreateDeviceParams, device) 198 199 typedef struct SP_DeviceFns { 200 size_t struct_size; 201 void* ext; // reserved for future use 202 203 // [Optional] 204 // Returns the NUMA node associated with this device, for use in 205 // determining socket locality. If the NUMA node could not be determined, -1 206 // is returned. 207 // Negative values are treated as "unset". 208 int32_t (*get_numa_node)(const SP_Device* device); 209 210 // [Optional] 211 // Device's memory bandwidth in bytes/sec. (This is for reads/writes to/from 212 // the device's own memory, not for transfers between the host and device.) 213 // Negative values are treated as "unset". 214 int64_t (*get_memory_bandwidth)(const SP_Device* device); 215 216 // [Optional] 217 // Estimate of average number of floating point operations per second for 218 // this device * 10e-9. 219 // Negative values are treated as "unset". 220 double (*get_gflops)(const SP_Device* device); 221 } SP_DeviceFns; 222 223 #define SP_DEVICE_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_DeviceFns, get_gflops) 224 225 typedef struct SE_CreateDeviceFnsParams { 226 size_t struct_size; 227 void* ext; // reserved for future use 228 229 SP_DeviceFns* device_fns; // output, to be filled by plugin 230 } SE_CreateDeviceFnsParams; 231 232 #define SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE \ 233 TF_OFFSET_OF_END(SE_CreateDeviceFnsParams, device_fns) 234 235 typedef struct SP_StreamExecutor { 236 size_t struct_size; 237 void* ext; // reserved for future use 238 239 /*** ALLOCATION CALLBACKS ***/ 240 // Synchronously allocates `size` bytes on the underlying platform and returns 241 // `SP_DeviceMemoryBase` representing that allocation. In the case of failure, 242 // nullptr is returned. 243 // `memory_space` is reserved for a potential future usage and should be set 244 // to 0. 245 void (*allocate)(const SP_Device* device, uint64_t size, int64_t memory_space, 246 SP_DeviceMemoryBase* mem); 247 248 // Deallocate the device memory previously allocated via this interface. 249 // Deallocation of a nullptr-representative value is permitted. 250 void (*deallocate)(const SP_Device* device, SP_DeviceMemoryBase* memory); 251 252 // Allocates a region of host memory and registers it with the platform API. 253 // Memory allocated in this manner is required for use in asynchronous memcpy 254 // operations, such as `memcpy_dtoh`. 255 void* (*host_memory_allocate)(const SP_Device* device, uint64_t size); 256 257 // Deallocates a region of host memory allocated by `host_memory_allocate`. 258 void (*host_memory_deallocate)(const SP_Device* device, void* mem); 259 260 // Allocates unified memory space of the given size, if supported. Unified 261 // memory support should be added by setting `supports_unified_memory` field 262 // in `SP_Platform`. 263 void* (*unified_memory_allocate)(const SP_Device* device, uint64_t bytes); 264 265 // Deallocates unified memory space previously allocated with 266 // `unified_memory_allocate`. Unified 267 // memory support should be added by setting `supports_unified_memory` field 268 // in `SP_Platform`. 269 void (*unified_memory_deallocate)(const SP_Device* device, void* location); 270 271 // Fills SP_AllocatorStats with allocator statistics, if it is available. 272 // If it is not available, return false. 273 TF_Bool (*get_allocator_stats)(const SP_Device* device, 274 SP_AllocatorStats* stats); 275 // Fills the underlying device memory usage information, if it is 276 // available. If it is not available (false is returned), free/total need not 277 // be initialized. 278 TF_Bool (*device_memory_usage)(const SP_Device* device, int64_t* free, 279 int64_t* total); 280 281 /*** STREAM CALLBACKS ***/ 282 // Creates SP_Stream. This call should also allocate stream 283 // resources on the underlying platform and initializes its 284 // internals. 285 void (*create_stream)(const SP_Device* device, SP_Stream* stream, 286 TF_Status* status); 287 288 // Destroys SP_Stream and deallocates any underlying resources. 289 void (*destroy_stream)(const SP_Device* device, SP_Stream stream); 290 291 // Causes `dependent` to not begin execution until `other` has finished its 292 // last-enqueued work. 293 void (*create_stream_dependency)(const SP_Device* device, SP_Stream dependent, 294 SP_Stream other, TF_Status* status); 295 296 // Without blocking the device, retrieve the current stream status. 297 void (*get_stream_status)(const SP_Device* device, SP_Stream stream, 298 TF_Status* status); 299 300 /*** EVENT CALLBACKS ***/ 301 // Create SP_Event. Performs platform-specific allocation and initialization 302 // of an event. 303 void (*create_event)(const SP_Device* device, SP_Event* event, 304 TF_Status* status); 305 306 // Destroy SE_Event and perform any platform-specific deallocation and 307 // cleanup of an event. 308 void (*destroy_event)(const SP_Device* device, SP_Event event); 309 310 // Requests the current status of the event from the underlying platform. 311 SE_EventStatus (*get_event_status)(const SP_Device* device, SP_Event event); 312 // Inserts the specified event at the end of the specified stream. 313 void (*record_event)(const SP_Device* device, SP_Stream stream, 314 SP_Event event, TF_Status* status); 315 316 // Wait for the specified event at the end of the specified stream. 317 void (*wait_for_event)(const SP_Device* const device, SP_Stream stream, 318 SP_Event event, TF_Status* const status); 319 320 /*** TIMER CALLBACKS ***/ 321 // Creates SP_Timer. Allocates timer resources on the underlying platform 322 // and initializes its internals, setting `timer` output variable. Sets 323 // values in `timer_fns` struct. 324 void (*create_timer)(const SP_Device* device, SP_Timer* timer, 325 TF_Status* status); 326 327 // Destroy timer and deallocates timer resources on the underlying platform. 328 void (*destroy_timer)(const SP_Device* device, SP_Timer timer); 329 330 // Records a start event for an interval timer. 331 void (*start_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer, 332 TF_Status* status); 333 334 // Records a stop event for an interval timer. 335 void (*stop_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer, 336 TF_Status* status); 337 338 /*** MEMCPY CALLBACKS ***/ 339 // Enqueues a memcpy operation onto stream, with a host destination location 340 // `host_dst` and a device memory source, with target size `size`. 341 void (*memcpy_dtoh)(const SP_Device* device, SP_Stream stream, void* host_dst, 342 const SP_DeviceMemoryBase* device_src, uint64_t size, 343 TF_Status* status); 344 345 // Enqueues a memcpy operation onto stream, with a device destination 346 // location and a host memory source, with target size `size`. 347 void (*memcpy_htod)(const SP_Device* device, SP_Stream stream, 348 SP_DeviceMemoryBase* device_dst, const void* host_src, 349 uint64_t size, TF_Status* status); 350 351 // Enqueues a memcpy operation onto stream, with a device destination 352 // location and a device memory source, with target size `size`. 353 void (*memcpy_dtod)(const SP_Device* device, SP_Stream stream, 354 SP_DeviceMemoryBase* device_dst, 355 const SP_DeviceMemoryBase* device_src, uint64_t size, 356 TF_Status* status); 357 358 // Blocks the caller while a data segment of the given size is 359 // copied from the device source to the host destination. 360 void (*sync_memcpy_dtoh)(const SP_Device* device, void* host_dst, 361 const SP_DeviceMemoryBase* device_src, uint64_t size, 362 TF_Status* status); 363 364 // Blocks the caller while a data segment of the given size is 365 // copied from the host source to the device destination. 366 void (*sync_memcpy_htod)(const SP_Device* device, 367 SP_DeviceMemoryBase* device_dst, 368 const void* host_src, uint64_t size, 369 TF_Status* status); 370 371 // Blocks the caller while a data segment of the given size is copied from the 372 // device source to the device destination. 373 void (*sync_memcpy_dtod)(const SP_Device* device, 374 SP_DeviceMemoryBase* device_dst, 375 const SP_DeviceMemoryBase* device_src, uint64_t size, 376 TF_Status* status); 377 378 // Causes the host code to synchronously wait for the event to complete. 379 void (*block_host_for_event)(const SP_Device* device, SP_Event event, 380 TF_Status* status); 381 382 // [Optional] 383 // Causes the host code to synchronously wait for operations entrained onto 384 // stream to complete. Effectively a join on the asynchronous device 385 // operations enqueued on the stream before this program point. 386 // If not set, then corresponding functionality will be implemented 387 // by registering an event on the `stream` and waiting for it using 388 // `block_host_for_event`. 389 void (*block_host_until_done)(const SP_Device* device, SP_Stream stream, 390 TF_Status* status); 391 392 // Synchronizes all activity occurring in the StreamExecutor's context (most 393 // likely a whole device). 394 void (*synchronize_all_activity)(const SP_Device* device, TF_Status* status); 395 396 // Zero out `size` bytes starting at the location. 397 void (*mem_zero)(const SP_Device* device, SP_Stream stream, 398 SP_DeviceMemoryBase* location, uint64_t size, 399 TF_Status* status); 400 401 // Set the 8-bit patterns starting at the location with `size` bytes. 402 void (*memset)(const SP_Device* device, SP_Stream stream, 403 SP_DeviceMemoryBase* location, uint8_t pattern, uint64_t size, 404 TF_Status* status); 405 406 // Set the 32-bit patterns starting at the location with `size` bytes. 407 void (*memset32)(const SP_Device* device, SP_Stream stream, 408 SP_DeviceMemoryBase* location, uint32_t pattern, 409 uint64_t size, TF_Status* status); 410 411 // Enqueues on a stream a user-specified function to be run on the host. 412 // `callback_arg` should be passed as the first argument to `callback_fn`. 413 TF_Bool (*host_callback)(const SP_Device* device, SP_Stream stream, 414 SE_StatusCallbackFn callback_fn, void* callback_arg); 415 } SP_StreamExecutor; 416 417 #define SP_STREAMEXECUTOR_STRUCT_SIZE \ 418 TF_OFFSET_OF_END(SP_StreamExecutor, host_callback) 419 420 typedef struct SE_CreateStreamExecutorParams { 421 size_t struct_size; 422 void* ext; // reserved for future use 423 424 SP_StreamExecutor* stream_executor; // output, to be filled by plugin 425 } SE_CreateStreamExecutorParams; 426 427 #define SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE \ 428 TF_OFFSET_OF_END(SE_CreateStreamExecutorParams, stream_executor) 429 430 typedef struct SP_Platform { 431 size_t struct_size; 432 433 void* ext; // free-form data set by plugin 434 435 // Platform name (also referred to as subtype), for example MY_DEVICE. 436 // The name must start with a capital letter and consist of 437 // capital letters and underscores. 438 // Must be null-terminated. 439 const char* name; 440 441 // Device type name, for example GPU. Must be null-terminated. 442 // The name must start with a capital letter and consist of 443 // capital letters and underscores. 444 const char* type; 445 446 // Whether this platform supports unified memory. 447 // Unified memory is a single memory address space accessible from any device. 448 TF_Bool supports_unified_memory; 449 450 // Whether to wrap allocator for this device with an allocator that uses BFC 451 // (best-fit with coalescing) strategy. 452 TF_Bool use_bfc_allocator; 453 454 // Whether to force the memory allocations to grow over time instead of 455 // allocating it all at once. When this is set to true, the value of 456 // allow_growth is ignored. 457 TF_Bool force_memory_growth; 458 } SP_Platform; 459 460 #define SP_PLATFORM_STRUCT_SIZE \ 461 TF_OFFSET_OF_END(SP_Platform, force_memory_growth) 462 463 typedef struct SP_PlatformFns { 464 size_t struct_size; 465 466 void* ext; // reserved for future use 467 468 // Callbacks for getting device count 469 void (*get_device_count)(const SP_Platform* platform, int* device_count, 470 TF_Status* status); 471 // Callbacks for creating/destroying SP_Device. 472 void (*create_device)(const SP_Platform* platform, 473 SE_CreateDeviceParams* params, TF_Status* status); 474 475 // Clean up fields inside SP_Device that were allocated 476 // by the plugin. `device` itself should not be deleted here. 477 void (*destroy_device)(const SP_Platform* platform, SP_Device* device); 478 479 // Callbacks for creating/destroying SP_DeviceFns. 480 void (*create_device_fns)(const SP_Platform* platform, 481 SE_CreateDeviceFnsParams* params, 482 TF_Status* status); 483 484 // Clean up fields inside SP_DeviceFns that were allocated 485 // by the plugin. `device_fns` itself should not be deleted here. 486 void (*destroy_device_fns)(const SP_Platform* platform, 487 SP_DeviceFns* device_fns); 488 489 // Callbacks for creating/destroying SP_StreamExecutor. 490 void (*create_stream_executor)(const SP_Platform* platform, 491 SE_CreateStreamExecutorParams* params, 492 TF_Status* status); 493 // Clean up fields inside SP_StreamExecutor that were allocated 494 // by the plugin. `stream_executor` itself should not be deleted here. 495 void (*destroy_stream_executor)(const SP_Platform* platform, 496 SP_StreamExecutor* stream_executor); 497 498 // Callbacks for creating/destroying SP_TimerFns. 499 void (*create_timer_fns)(const SP_Platform* platform, SP_TimerFns* timer, 500 TF_Status* status); 501 502 void (*destroy_timer_fns)(const SP_Platform* platform, 503 SP_TimerFns* timer_fns); 504 } SP_PlatformFns; 505 506 #define SP_PLATFORM_FNS_STRUCT_SIZE \ 507 TF_OFFSET_OF_END(SP_PlatformFns, destroy_timer_fns) 508 509 typedef struct SE_PlatformRegistrationParams { 510 size_t struct_size; 511 void* ext; // reserved for future use 512 513 // StreamExecutor C API version. 514 int32_t major_version; 515 int32_t minor_version; 516 int32_t patch_version; 517 518 SP_Platform* platform; // output, set by plugin 519 SP_PlatformFns* platform_fns; // output, set by plugin 520 // Clean up fields inside SP_Platform that were allocated 521 // by the plugin. `platform` itself should not be deleted here. 522 void (*destroy_platform)(SP_Platform* platform); // out, set by plugin 523 void (*destroy_platform_fns)( 524 SP_PlatformFns* platform_fns); // out, set by plugin 525 } SE_PlatformRegistrationParams; 526 527 #define SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE \ 528 TF_OFFSET_OF_END(SE_PlatformRegistrationParams, destroy_platform_fns) 529 530 void SE_InitPlugin(SE_PlatformRegistrationParams* params, TF_Status* status); 531 532 #ifdef __cplusplus 533 } // extern "C" 534 #endif 535 536 #endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ 537