xref: /aosp_15_r20/external/tensorflow/tensorflow/c/kernels_experimental.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_
17 #define TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_
18 
19 #include "tensorflow/c/kernels.h"
20 
21 // --------------------------------------------------------------------------
22 // Experimental kernel C API for TensorFlow.
23 //
24 // The API here is subject to changes in the future.
25 // --------------------------------------------------------------------------
26 
27 // Macro to control visibility of exported symbols in the shared library (.so,
28 // .dylib, .dll).
29 // This duplicates the TF_EXPORT macro definition in
30 // tensorflow/core/platform/macros.h in order to keep this .h file independent
31 // of any other includes.
32 #ifdef SWIG
33 #define TF_CAPI_EXPORT
34 #else
35 #if defined(_WIN32)
36 #ifdef TF_COMPILE_LIBRARY
37 #define TF_CAPI_EXPORT __declspec(dllexport)
38 #else
39 #define TF_CAPI_EXPORT __declspec(dllimport)
40 #endif  // TF_COMPILE_LIBRARY
41 #else
42 #define TF_CAPI_EXPORT __attribute__((visibility("default")))
43 #endif  // _WIN32
44 #endif  // SWIG
45 
46 #ifdef __cplusplus
47 extern "C" {
48 #endif
49 
50 typedef struct TF_VariableInputLockHolder TF_VariableInputLockHolder;
51 
52 // Expose higher level Assignment operation for Pluggable vendors to implement
53 // in the plugin for Training. The API takes in the context with indices for
54 // the input and value tensors. It also accepts the copy callback provided by
55 // pluggable vendor to do the copying of the tensors. The caller takes ownership
56 // of the `source` and `dest` tensors and is responsible for freeing them with
57 // TF_DeleteTensor. This function will return an error when the following
58 // conditions are met:
59 //   1. `validate_shape` is set to `true`
60 //   2. The variable is initialized
61 //   3. The shape of the value tensor doesn't match the shape of the variable
62 //      tensor.
63 TF_CAPI_EXPORT extern void TF_AssignVariable(
64     TF_OpKernelContext* ctx, int input_index, int value_index,
65     bool validate_shape,
66     void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
67                      TF_Tensor* dest),
68     TF_Status* status);
69 
70 // Expose higher level Assignment operation for Pluggable vendors to implement
71 // in the plugin for Training on ref variables. The API takes in the context
72 // with indices for the input and value tensors. It also accepts the copy
73 // callback provided by pluggable vendor to do the copying of the tensors. The
74 // caller takes ownership of the `source` and `dest` tensors and is responsible
75 // for freeing them with TF_DeleteTensor.
76 TF_CAPI_EXPORT extern void TF_AssignRefVariable(
77     TF_OpKernelContext* ctx, int input_ref_index, int output_ref_index,
78     int value_index, bool use_locking, bool validate_shape,
79     void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
80                      TF_Tensor* dest),
81     TF_Status* status);
82 
83 // Expose higher level AssignUpdate operation for Pluggable vendors to implement
84 // in the plugin for Training. The API takes in the context with indices for the
85 // input and value tensors. It also accepts the copy callback provided by
86 // pluggable vendor to do the copying of the tensors and the update callback to
87 // apply the arithmetic operation. The caller takes ownership of the `source`,
88 // `dest`, `tensor` and `value` tensors and is responsible for freeing them with
89 // TF_DeleteTensor.
90 TF_CAPI_EXPORT extern void TF_AssignUpdateVariable(
91     TF_OpKernelContext* ctx, int input_index, int value_index, int Op,
92     int isVariantType,
93     void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
94                      TF_Tensor* dest),
95     void (*updateFunc)(TF_OpKernelContext* ctx, TF_Tensor* tensor,
96                        TF_Tensor* value, int Op),
97     TF_Status* status);
98 
99 // This is a helper function which acquires mutexes in-order to provide
100 // thread-safe way of performing weights update during the optimizer op. It
101 // returns an opaque LockHolder handle back to plugin. This handle is passed to
102 // the Release API for releasing the locks when the weight update is done. The
103 // caller takes ownership of the `source` and `dest` tensors and is responsible
104 // for freeing them with TF_DeleteTensor.
105 TF_CAPI_EXPORT extern void TF_MaybeLockVariableInputMutexesInOrder(
106     TF_OpKernelContext* ctx, bool do_lock, bool sparse, const int* const inputs,
107     size_t len,
108     void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
109                      TF_Tensor* dest),
110     TF_VariableInputLockHolder** lockHolder, TF_Status* status);
111 
112 // This interface returns `out` tensor which is updated corresponding to the
113 // variable passed with input index. The caller takes ownership of the `source`
114 // and `dest` tensors and is responsible for freeing them with TF_DeleteTensor.
115 TF_CAPI_EXPORT extern void TF_GetInputTensorFromVariable(
116     TF_OpKernelContext* ctx, int input, bool lock_held, bool isVariantType,
117     bool sparse,
118     void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
119                      TF_Tensor* dest),
120     TF_Tensor** out, TF_Status* status);
121 
122 // This interface forwards the reference from input to the output tensors
123 // corresponding to the indices provided with `input_index` and `output_index`
124 TF_CAPI_EXPORT extern void TF_OpKernelContext_ForwardRefInputToRefOutput(
125     TF_OpKernelContext* ctx, int32_t input_index, int32_t output_index);
126 
127 // The API releases the opaque lock handle returned with
128 // `TF_MaybeLockVariableInputMutexesInOrder` API
129 TF_CAPI_EXPORT extern void TF_ReleaseVariableInputLockHolder(
130     TF_VariableInputLockHolder* lockHolder);
131 
132 // Allows plugin to get TF_Tensor when passed its input_name
133 TF_CAPI_EXPORT extern void TF_GetInputByName(TF_OpKernelContext* ctx,
134                                              const char* inputName,
135                                              TF_Tensor** tensor,
136                                              TF_Status* status);
137 
138 // Interprets the named kernel construction attribute as a shape attribute and
139 // fills in `vals` with the size of each dimension. `vals` must point to an
140 // array of length at least `max_values` (ideally set to total_size from
141 // TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size,
142 // &total_size)).
143 TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTensorShape(
144     TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* dims,
145     size_t num_dims, TF_Status* status);
146 
147 TF_CAPI_EXPORT extern bool TF_IsRefInput(TF_OpKernelContext* ctx, int i,
148                                          TF_Status* status);
149 
150 #ifndef IS_MOBILE_PLATFORM
151 // Expose higher level AddN operation for Pluggable vendors to implement
152 // in the plugin for Variant data types. The API takes in the context and a
153 // callback provided by pluggable vendor to do a Binary Add operation on the
154 // tensors unwrapped from the Variant tensors. The caller takes ownership of the
155 // `a`, `b` and `out` tensors and is responsible for freeing them with
156 // TF_DeleteTensor.
157 TF_CAPI_EXPORT extern void TF_AddNVariant(
158     TF_OpKernelContext* ctx,
159     void (*binary_add_func)(TF_OpKernelContext* ctx, TF_Tensor* a, TF_Tensor* b,
160                             TF_Tensor* out),
161     TF_Status* status);
162 
163 // Expose higher level ZerosLike operation for Pluggable vendors to implement
164 // in the plugin for Variant data types. The API takes in the context and a
165 // callback provided by pluggable vendor to do a ZerosLike operation on the
166 // tensors unwrapped from the Variant tensors. The caller takes ownership of the
167 // `input` and `out` tensors and is responsible for freeing them with
168 // TF_DeleteTensor.
169 TF_CAPI_EXPORT extern void TF_ZerosLikeVariant(
170     TF_OpKernelContext* ctx,
171     void (*zeros_like_func)(TF_OpKernelContext* ctx, TF_Tensor* input,
172                             TF_Tensor* out),
173     TF_Status* status);
174 #endif  // IS_MOBILE_PLATFORM
175 
176 #ifdef __cplusplus
177 } /* end extern "C" */
178 #endif
179 
180 #endif  // TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_
181