xref: /aosp_15_r20/external/tensorflow/tensorflow/c/experimental/stream_executor/stream_executor.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
16 #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
17 #include <stddef.h>
18 #include <stdint.h>
19 
20 #include "tensorflow/c/c_api_macros.h"
21 #include "tensorflow/c/tf_status.h"
22 
23 // --------------------------------------------------------------------------
24 // C API for StreamExecutor. The API is under active development and eventually
25 // should allow registering a pluggable device with TensorFlow.
26 //
27 // Conventions:
28 //   * Struct prefix indicates whether struct fields should be filled by the
29 //     plugin or core implementation:
30 //     * SE_ : set/filled by core unless explicitly marked otherwise.
31 //     * SP_ : set/filled by plugin unless explicitly marked otherwise.
32 //   * We use `struct_size` for version checking. It is exempt from the `SE/SP`
33 //     rule above and should be set both by core and the plugin.
34 //     * For example, `create_device` function receives `SP_Device*` as input
35 //       with `struct_size` populated by core. The plugin is responsible for
36 //       setting `struct_size` as well, along with all other fields.
37 //     * Refer to "TensorFlow Versioning Strategy" section at
38 //       https://github.com/tensorflow/community/pull/257/files.
39 //     * Note that the API is still under active development and doesn't have
40 //       versioning guarantees yet.
41 //   * `void* ext` is a free-form field that can be populated by
42 //     a plugin in `SP_*` structs or potential future extension points in `SE_`
43 //     structs.
44 //
45 // Example usage:
46 //
47 //   /* Sample TensorFlow code below, exact implementation might differ. */
48 //   // Version checking uses `struct_size`. It is exempt from the `SE/SP` rule
49 //   // above and should be set both by core and the plugin."
50 //   SP_Device device { SP_DEVICE_STRUCT_SIZE };
51 //   SE_CreateDeviceParams params { SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE } ;
52 //   params.device = &device;
53 //
54 //   /* Plugin code below */
55 //   constexpr char DEVICE_NAME[] = "MY_DEVICE";
56 //   constexpr char DEVICE_TYPE[] = "GPU";
57 //
58 //   void create_device(const SP_Platform* platform,
59 //                      SE_CreateDeviceParams* params, TF_Status* status) {
60 //     // Custom actions based on TensorFlow's view of SP_Device.
61 //     OnTFDeviceView(params->device->struct_size);
62 //     params->device = { SP_DEVICE_STRUCT_SIZE };
63 //     params->device->device_handle = get_my_device_handle(device->ordinal);
64 //     params->device->ordinal = params->ordinal;
65 //     ...
66 //   }
67 //
68 //   void destroy_device(const SP_Platform* platform, SP_Device* device) {
69 //     delete_my_device_handle(device->device_handle);
70 //   }
71 //
72 //   void SE_InitPlugin(
73 //       SE_PlatformRegistrationParams* params,
74 //       TF_Status* status) {
75 //     params->platform = { SP_PLATFORM_STRUCT_SIZE };
76 //     // Values such as `name` and `type` must outlive SE_InitPlugin call.
77 //     params->platform->name = DEVICE_NAME;
78 //     params->platform->type = DEVICE_TYPE;
79 //     params->platform_fns->get_device_count = get_device_count;
80 //     params->platform_fns->create_device = create_device;
81 //     params->platform_fns->destroy_device = destroy_device;
82 //     ...
83 //   }
84 
85 #define SE_MAJOR 0
86 #define SE_MINOR 0
87 #define SE_PATCH 1
88 
89 #ifdef __cplusplus
90 extern "C" {
91 #endif
92 
93 typedef struct SP_Stream_st* SP_Stream;
94 typedef struct SP_Event_st* SP_Event;
95 typedef struct SP_Timer_st* SP_Timer;
96 // Takes `callback_arg` passed to `host_callback` as the first argument.
97 typedef void (*SE_StatusCallbackFn)(void* const, TF_Status* const);
98 
99 typedef struct SP_TimerFns {
100   size_t struct_size;
101   void* ext;  // reserved for future use
102   uint64_t (*nanoseconds)(SP_Timer timer);
103 } SP_TimerFns;
104 
105 #define SP_TIMER_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_TimerFns, nanoseconds)
106 
107 typedef struct SP_AllocatorStats {
108   size_t struct_size;
109   int64_t num_allocs;
110   int64_t bytes_in_use;
111   int64_t peak_bytes_in_use;
112   int64_t largest_alloc_size;
113 
114   int8_t has_bytes_limit;
115   int64_t bytes_limit;
116 
117   int64_t bytes_reserved;
118   int64_t peak_bytes_reserved;
119 
120   int8_t has_bytes_reservable_limit;
121   int64_t bytes_reservable_limit;
122 
123   int64_t largest_free_block_bytes;
124 } SP_AllocatorStats;
125 
126 #define SP_ALLOCATORSTATS_STRUCT_SIZE \
127   TF_OFFSET_OF_END(SP_AllocatorStats, largest_free_block_bytes)
128 
129 // Potential states for an SP_Event. If `poll_for_status` returns anything aside
130 // from kPending or kComplete, an error has occurred; kUnknown is a bad state.
131 typedef enum SE_EventStatus {
132   SE_EVENT_UNKNOWN,
133   SE_EVENT_ERROR,
134   SE_EVENT_PENDING,
135   SE_EVENT_COMPLETE,
136 } SE_EventStatus;
137 
138 // Memory allocation information.
139 // This matches DeviceMemoryBase defined here:
140 // https://cs.opensource.google/tensorflow/tensorflow/+/refs/tags/v2.3.0:tensorflow/stream_executor/device_memory.h;l=57
141 typedef struct SP_DeviceMemoryBase {
142   size_t struct_size;
143   void* ext;  // Reserved for future use
144   // Platform-dependent value representing allocated memory.
145   // Note that the pointer does not have to be to the virtual address itself.
146   void* opaque;
147   uint64_t size;     // Size in bytes of this allocation.
148   uint64_t payload;  // Value for plugin's use
149 } SP_DeviceMemoryBase;
150 
151 #define SP_DEVICE_MEMORY_BASE_STRUCT_SIZE \
152   TF_OFFSET_OF_END(SP_DeviceMemoryBase, payload)
153 
154 typedef struct SP_Device {
155   size_t struct_size;
156   void* ext;        // free-form data set by plugin
157   int32_t ordinal;  // device index
158 
159   // Device vendor can store handle to their device representation
160   // here.
161   void* device_handle;
162 
163   // [Optional]
164   // Device hardware name. Used for printing.
165   // Must be null-terminated.
166   const char* hardware_name;
167 
168   // [Optional]
169   // Device vendor name. Used for printing.
170   // Must be null-terminated.
171   const char* device_vendor;
172 
173   // [Optional]
174   // Returns the PCI bus identifier for this device, of the form
175   // [domain]:[bus]:[device].[function]
176   // where domain number is usually 0000.
177   // Example: 0000:00:02.1
178   // For more information see:
179   // https://en.wikipedia.org/wiki/PCI_configuration_space
180   // https://www.oreilly.com/library/view/linux-device-drivers/0596005903/ch12.html
181   // Used for printing. Must be null-terminated.
182   const char* pci_bus_id;
183 } SP_Device;
184 
185 #define SP_DEVICE_STRUCT_SIZE TF_OFFSET_OF_END(SP_Device, pci_bus_id)
186 
187 typedef struct SE_CreateDeviceParams {
188   size_t struct_size;
189   void* ext;        // reserved for future use
190   int32_t ordinal;  // device index
191 
192   SP_Device* device;  // Input/output, struct_size set by TF for plugin to read.
193                       // Subsequently plugin fills the entire struct.
194 } SE_CreateDeviceParams;
195 
196 #define SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE \
197   TF_OFFSET_OF_END(SE_CreateDeviceParams, device)
198 
199 typedef struct SP_DeviceFns {
200   size_t struct_size;
201   void* ext;  // reserved for future use
202 
203   // [Optional]
204   // Returns the NUMA node associated with this device, for use in
205   // determining socket locality. If the NUMA node could not be determined, -1
206   // is returned.
207   // Negative values are treated as "unset".
208   int32_t (*get_numa_node)(const SP_Device* device);
209 
210   // [Optional]
211   // Device's memory bandwidth in bytes/sec.  (This is for reads/writes to/from
212   // the device's own memory, not for transfers between the host and device.)
213   // Negative values are treated as "unset".
214   int64_t (*get_memory_bandwidth)(const SP_Device* device);
215 
216   // [Optional]
217   // Estimate of average number of floating point operations per second for
218   // this device * 10e-9.
219   // Negative values are treated as "unset".
220   double (*get_gflops)(const SP_Device* device);
221 } SP_DeviceFns;
222 
223 #define SP_DEVICE_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_DeviceFns, get_gflops)
224 
225 typedef struct SE_CreateDeviceFnsParams {
226   size_t struct_size;
227   void* ext;  // reserved for future use
228 
229   SP_DeviceFns* device_fns;  // output, to be filled by plugin
230 } SE_CreateDeviceFnsParams;
231 
232 #define SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE \
233   TF_OFFSET_OF_END(SE_CreateDeviceFnsParams, device_fns)
234 
235 typedef struct SP_StreamExecutor {
236   size_t struct_size;
237   void* ext;  // reserved for future use
238 
239   /*** ALLOCATION CALLBACKS ***/
240   // Synchronously allocates `size` bytes on the underlying platform and returns
241   // `SP_DeviceMemoryBase` representing that allocation. In the case of failure,
242   // nullptr is returned.
243   // `memory_space` is reserved for a potential future usage and should be set
244   // to 0.
245   void (*allocate)(const SP_Device* device, uint64_t size, int64_t memory_space,
246                    SP_DeviceMemoryBase* mem);
247 
248   // Deallocate the device memory previously allocated via this interface.
249   // Deallocation of a nullptr-representative value is permitted.
250   void (*deallocate)(const SP_Device* device, SP_DeviceMemoryBase* memory);
251 
252   // Allocates a region of host memory and registers it with the platform API.
253   // Memory allocated in this manner is required for use in asynchronous memcpy
254   // operations, such as `memcpy_dtoh`.
255   void* (*host_memory_allocate)(const SP_Device* device, uint64_t size);
256 
257   // Deallocates a region of host memory allocated by `host_memory_allocate`.
258   void (*host_memory_deallocate)(const SP_Device* device, void* mem);
259 
260   // Allocates unified memory space of the given size, if supported. Unified
261   // memory support should be added by setting `supports_unified_memory` field
262   // in `SP_Platform`.
263   void* (*unified_memory_allocate)(const SP_Device* device, uint64_t bytes);
264 
265   // Deallocates unified memory space previously allocated with
266   // `unified_memory_allocate`. Unified
267   // memory support should be added by setting `supports_unified_memory` field
268   // in `SP_Platform`.
269   void (*unified_memory_deallocate)(const SP_Device* device, void* location);
270 
271   // Fills SP_AllocatorStats with allocator statistics, if it is available.
272   // If it is not available, return false.
273   TF_Bool (*get_allocator_stats)(const SP_Device* device,
274                                  SP_AllocatorStats* stats);
275   // Fills the underlying device memory usage information, if it is
276   // available. If it is not available (false is returned), free/total need not
277   // be initialized.
278   TF_Bool (*device_memory_usage)(const SP_Device* device, int64_t* free,
279                                  int64_t* total);
280 
281   /*** STREAM CALLBACKS ***/
282   // Creates SP_Stream. This call should also allocate stream
283   // resources on the underlying platform and initializes its
284   // internals.
285   void (*create_stream)(const SP_Device* device, SP_Stream* stream,
286                         TF_Status* status);
287 
288   // Destroys SP_Stream and deallocates any underlying resources.
289   void (*destroy_stream)(const SP_Device* device, SP_Stream stream);
290 
291   // Causes `dependent` to not begin execution until `other` has finished its
292   // last-enqueued work.
293   void (*create_stream_dependency)(const SP_Device* device, SP_Stream dependent,
294                                    SP_Stream other, TF_Status* status);
295 
296   // Without blocking the device, retrieve the current stream status.
297   void (*get_stream_status)(const SP_Device* device, SP_Stream stream,
298                             TF_Status* status);
299 
300   /*** EVENT CALLBACKS ***/
301   // Create SP_Event. Performs platform-specific allocation and initialization
302   // of an event.
303   void (*create_event)(const SP_Device* device, SP_Event* event,
304                        TF_Status* status);
305 
306   // Destroy SE_Event and perform any platform-specific deallocation and
307   // cleanup of an event.
308   void (*destroy_event)(const SP_Device* device, SP_Event event);
309 
310   // Requests the current status of the event from the underlying platform.
311   SE_EventStatus (*get_event_status)(const SP_Device* device, SP_Event event);
312   // Inserts the specified event at the end of the specified stream.
313   void (*record_event)(const SP_Device* device, SP_Stream stream,
314                        SP_Event event, TF_Status* status);
315 
316   // Wait for the specified event at the end of the specified stream.
317   void (*wait_for_event)(const SP_Device* const device, SP_Stream stream,
318                          SP_Event event, TF_Status* const status);
319 
320   /*** TIMER CALLBACKS ***/
321   // Creates SP_Timer. Allocates timer resources on the underlying platform
322   // and initializes its internals, setting `timer` output variable. Sets
323   // values in `timer_fns` struct.
324   void (*create_timer)(const SP_Device* device, SP_Timer* timer,
325                        TF_Status* status);
326 
327   // Destroy timer and deallocates timer resources on the underlying platform.
328   void (*destroy_timer)(const SP_Device* device, SP_Timer timer);
329 
330   // Records a start event for an interval timer.
331   void (*start_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer,
332                       TF_Status* status);
333 
334   // Records a stop event for an interval timer.
335   void (*stop_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer,
336                      TF_Status* status);
337 
338   /*** MEMCPY CALLBACKS ***/
339   // Enqueues a memcpy operation onto stream, with a host destination location
340   // `host_dst` and a device memory source, with target size `size`.
341   void (*memcpy_dtoh)(const SP_Device* device, SP_Stream stream, void* host_dst,
342                       const SP_DeviceMemoryBase* device_src, uint64_t size,
343                       TF_Status* status);
344 
345   // Enqueues a memcpy operation onto stream, with a device destination
346   // location and a host memory source, with target size `size`.
347   void (*memcpy_htod)(const SP_Device* device, SP_Stream stream,
348                       SP_DeviceMemoryBase* device_dst, const void* host_src,
349                       uint64_t size, TF_Status* status);
350 
351   // Enqueues a memcpy operation onto stream, with a device destination
352   // location and a device memory source, with target size `size`.
353   void (*memcpy_dtod)(const SP_Device* device, SP_Stream stream,
354                       SP_DeviceMemoryBase* device_dst,
355                       const SP_DeviceMemoryBase* device_src, uint64_t size,
356                       TF_Status* status);
357 
358   // Blocks the caller while a data segment of the given size is
359   // copied from the device source to the host destination.
360   void (*sync_memcpy_dtoh)(const SP_Device* device, void* host_dst,
361                            const SP_DeviceMemoryBase* device_src, uint64_t size,
362                            TF_Status* status);
363 
364   // Blocks the caller while a data segment of the given size is
365   // copied from the host source to the device destination.
366   void (*sync_memcpy_htod)(const SP_Device* device,
367                            SP_DeviceMemoryBase* device_dst,
368                            const void* host_src, uint64_t size,
369                            TF_Status* status);
370 
371   // Blocks the caller while a data segment of the given size is copied from the
372   // device source to the device destination.
373   void (*sync_memcpy_dtod)(const SP_Device* device,
374                            SP_DeviceMemoryBase* device_dst,
375                            const SP_DeviceMemoryBase* device_src, uint64_t size,
376                            TF_Status* status);
377 
378   // Causes the host code to synchronously wait for the event to complete.
379   void (*block_host_for_event)(const SP_Device* device, SP_Event event,
380                                TF_Status* status);
381 
382   // [Optional]
383   // Causes the host code to synchronously wait for operations entrained onto
384   // stream to complete. Effectively a join on the asynchronous device
385   // operations enqueued on the stream before this program point.
386   // If not set, then corresponding functionality will be implemented
387   // by registering an event on the `stream` and waiting for it using
388   // `block_host_for_event`.
389   void (*block_host_until_done)(const SP_Device* device, SP_Stream stream,
390                                 TF_Status* status);
391 
392   // Synchronizes all activity occurring in the StreamExecutor's context (most
393   // likely a whole device).
394   void (*synchronize_all_activity)(const SP_Device* device, TF_Status* status);
395 
396   // Zero out `size` bytes starting at the location.
397   void (*mem_zero)(const SP_Device* device, SP_Stream stream,
398                    SP_DeviceMemoryBase* location, uint64_t size,
399                    TF_Status* status);
400 
401   // Set the 8-bit patterns starting at the location with `size` bytes.
402   void (*memset)(const SP_Device* device, SP_Stream stream,
403                  SP_DeviceMemoryBase* location, uint8_t pattern, uint64_t size,
404                  TF_Status* status);
405 
406   // Set the 32-bit patterns starting at the location with `size` bytes.
407   void (*memset32)(const SP_Device* device, SP_Stream stream,
408                    SP_DeviceMemoryBase* location, uint32_t pattern,
409                    uint64_t size, TF_Status* status);
410 
411   // Enqueues on a stream a user-specified function to be run on the host.
412   // `callback_arg` should be passed as the first argument to `callback_fn`.
413   TF_Bool (*host_callback)(const SP_Device* device, SP_Stream stream,
414                            SE_StatusCallbackFn callback_fn, void* callback_arg);
415 } SP_StreamExecutor;
416 
417 #define SP_STREAMEXECUTOR_STRUCT_SIZE \
418   TF_OFFSET_OF_END(SP_StreamExecutor, host_callback)
419 
420 typedef struct SE_CreateStreamExecutorParams {
421   size_t struct_size;
422   void* ext;  // reserved for future use
423 
424   SP_StreamExecutor* stream_executor;  // output, to be filled by plugin
425 } SE_CreateStreamExecutorParams;
426 
427 #define SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE \
428   TF_OFFSET_OF_END(SE_CreateStreamExecutorParams, stream_executor)
429 
430 typedef struct SP_Platform {
431   size_t struct_size;
432 
433   void* ext;  // free-form data set by plugin
434 
435   // Platform name (also referred to as subtype), for example MY_DEVICE.
436   // The name must start with a capital letter and consist of
437   // capital letters and underscores.
438   // Must be null-terminated.
439   const char* name;
440 
441   // Device type name, for example GPU. Must be null-terminated.
442   // The name must start with a capital letter and consist of
443   // capital letters and underscores.
444   const char* type;
445 
446   // Whether this platform supports unified memory.
447   // Unified memory is a single memory address space accessible from any device.
448   TF_Bool supports_unified_memory;
449 
450   // Whether to wrap allocator for this device with an allocator that uses BFC
451   // (best-fit with coalescing) strategy.
452   TF_Bool use_bfc_allocator;
453 
454   // Whether to force the memory allocations to grow over time instead of
455   // allocating it all at once. When this is set to true, the value of
456   // allow_growth is ignored.
457   TF_Bool force_memory_growth;
458 } SP_Platform;
459 
460 #define SP_PLATFORM_STRUCT_SIZE \
461   TF_OFFSET_OF_END(SP_Platform, force_memory_growth)
462 
463 typedef struct SP_PlatformFns {
464   size_t struct_size;
465 
466   void* ext;  // reserved for future use
467 
468   // Callbacks for getting device count
469   void (*get_device_count)(const SP_Platform* platform, int* device_count,
470                            TF_Status* status);
471   // Callbacks for creating/destroying SP_Device.
472   void (*create_device)(const SP_Platform* platform,
473                         SE_CreateDeviceParams* params, TF_Status* status);
474 
475   // Clean up fields inside SP_Device that were allocated
476   // by the plugin. `device` itself should not be deleted here.
477   void (*destroy_device)(const SP_Platform* platform, SP_Device* device);
478 
479   // Callbacks for creating/destroying SP_DeviceFns.
480   void (*create_device_fns)(const SP_Platform* platform,
481                             SE_CreateDeviceFnsParams* params,
482                             TF_Status* status);
483 
484   // Clean up fields inside SP_DeviceFns that were allocated
485   // by the plugin. `device_fns` itself should not be deleted here.
486   void (*destroy_device_fns)(const SP_Platform* platform,
487                              SP_DeviceFns* device_fns);
488 
489   // Callbacks for creating/destroying SP_StreamExecutor.
490   void (*create_stream_executor)(const SP_Platform* platform,
491                                  SE_CreateStreamExecutorParams* params,
492                                  TF_Status* status);
493   // Clean up fields inside SP_StreamExecutor that were allocated
494   // by the plugin. `stream_executor` itself should not be deleted here.
495   void (*destroy_stream_executor)(const SP_Platform* platform,
496                                   SP_StreamExecutor* stream_executor);
497 
498   // Callbacks for creating/destroying SP_TimerFns.
499   void (*create_timer_fns)(const SP_Platform* platform, SP_TimerFns* timer,
500                            TF_Status* status);
501 
502   void (*destroy_timer_fns)(const SP_Platform* platform,
503                             SP_TimerFns* timer_fns);
504 } SP_PlatformFns;
505 
506 #define SP_PLATFORM_FNS_STRUCT_SIZE \
507   TF_OFFSET_OF_END(SP_PlatformFns, destroy_timer_fns)
508 
509 typedef struct SE_PlatformRegistrationParams {
510   size_t struct_size;
511   void* ext;  // reserved for future use
512 
513   // StreamExecutor C API version.
514   int32_t major_version;
515   int32_t minor_version;
516   int32_t patch_version;
517 
518   SP_Platform* platform;         // output, set by plugin
519   SP_PlatformFns* platform_fns;  // output, set by plugin
520   // Clean up fields inside SP_Platform that were allocated
521   // by the plugin. `platform` itself should not be deleted here.
522   void (*destroy_platform)(SP_Platform* platform);  // out, set by plugin
523   void (*destroy_platform_fns)(
524       SP_PlatformFns* platform_fns);  // out, set by plugin
525 } SE_PlatformRegistrationParams;
526 
527 #define SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE \
528   TF_OFFSET_OF_END(SE_PlatformRegistrationParams, destroy_platform_fns)
529 
530 void SE_InitPlugin(SE_PlatformRegistrationParams* params, TF_Status* status);
531 
532 #ifdef __cplusplus
533 }  // extern "C"
534 #endif
535 
536 #endif  // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
537