1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_ 17 #define TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_ 18 19 #include "tensorflow/c/kernels.h" 20 21 // -------------------------------------------------------------------------- 22 // Experimental kernel C API for TensorFlow. 23 // 24 // The API here is subject to changes in the future. 25 // -------------------------------------------------------------------------- 26 27 // Macro to control visibility of exported symbols in the shared library (.so, 28 // .dylib, .dll). 29 // This duplicates the TF_EXPORT macro definition in 30 // tensorflow/core/platform/macros.h in order to keep this .h file independent 31 // of any other includes. 32 #ifdef SWIG 33 #define TF_CAPI_EXPORT 34 #else 35 #if defined(_WIN32) 36 #ifdef TF_COMPILE_LIBRARY 37 #define TF_CAPI_EXPORT __declspec(dllexport) 38 #else 39 #define TF_CAPI_EXPORT __declspec(dllimport) 40 #endif // TF_COMPILE_LIBRARY 41 #else 42 #define TF_CAPI_EXPORT __attribute__((visibility("default"))) 43 #endif // _WIN32 44 #endif // SWIG 45 46 #ifdef __cplusplus 47 extern "C" { 48 #endif 49 50 typedef struct TF_VariableInputLockHolder TF_VariableInputLockHolder; 51 52 // Expose higher level Assignment operation for Pluggable vendors to implement 53 // in the plugin for Training. The API takes in the context with indices for 54 // the input and value tensors. It also accepts the copy callback provided by 55 // pluggable vendor to do the copying of the tensors. The caller takes ownership 56 // of the `source` and `dest` tensors and is responsible for freeing them with 57 // TF_DeleteTensor. This function will return an error when the following 58 // conditions are met: 59 // 1. `validate_shape` is set to `true` 60 // 2. The variable is initialized 61 // 3. The shape of the value tensor doesn't match the shape of the variable 62 // tensor. 63 TF_CAPI_EXPORT extern void TF_AssignVariable( 64 TF_OpKernelContext* ctx, int input_index, int value_index, 65 bool validate_shape, 66 void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, 67 TF_Tensor* dest), 68 TF_Status* status); 69 70 // Expose higher level Assignment operation for Pluggable vendors to implement 71 // in the plugin for Training on ref variables. The API takes in the context 72 // with indices for the input and value tensors. It also accepts the copy 73 // callback provided by pluggable vendor to do the copying of the tensors. The 74 // caller takes ownership of the `source` and `dest` tensors and is responsible 75 // for freeing them with TF_DeleteTensor. 76 TF_CAPI_EXPORT extern void TF_AssignRefVariable( 77 TF_OpKernelContext* ctx, int input_ref_index, int output_ref_index, 78 int value_index, bool use_locking, bool validate_shape, 79 void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, 80 TF_Tensor* dest), 81 TF_Status* status); 82 83 // Expose higher level AssignUpdate operation for Pluggable vendors to implement 84 // in the plugin for Training. The API takes in the context with indices for the 85 // input and value tensors. It also accepts the copy callback provided by 86 // pluggable vendor to do the copying of the tensors and the update callback to 87 // apply the arithmetic operation. The caller takes ownership of the `source`, 88 // `dest`, `tensor` and `value` tensors and is responsible for freeing them with 89 // TF_DeleteTensor. 90 TF_CAPI_EXPORT extern void TF_AssignUpdateVariable( 91 TF_OpKernelContext* ctx, int input_index, int value_index, int Op, 92 int isVariantType, 93 void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, 94 TF_Tensor* dest), 95 void (*updateFunc)(TF_OpKernelContext* ctx, TF_Tensor* tensor, 96 TF_Tensor* value, int Op), 97 TF_Status* status); 98 99 // This is a helper function which acquires mutexes in-order to provide 100 // thread-safe way of performing weights update during the optimizer op. It 101 // returns an opaque LockHolder handle back to plugin. This handle is passed to 102 // the Release API for releasing the locks when the weight update is done. The 103 // caller takes ownership of the `source` and `dest` tensors and is responsible 104 // for freeing them with TF_DeleteTensor. 105 TF_CAPI_EXPORT extern void TF_MaybeLockVariableInputMutexesInOrder( 106 TF_OpKernelContext* ctx, bool do_lock, bool sparse, const int* const inputs, 107 size_t len, 108 void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, 109 TF_Tensor* dest), 110 TF_VariableInputLockHolder** lockHolder, TF_Status* status); 111 112 // This interface returns `out` tensor which is updated corresponding to the 113 // variable passed with input index. The caller takes ownership of the `source` 114 // and `dest` tensors and is responsible for freeing them with TF_DeleteTensor. 115 TF_CAPI_EXPORT extern void TF_GetInputTensorFromVariable( 116 TF_OpKernelContext* ctx, int input, bool lock_held, bool isVariantType, 117 bool sparse, 118 void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, 119 TF_Tensor* dest), 120 TF_Tensor** out, TF_Status* status); 121 122 // This interface forwards the reference from input to the output tensors 123 // corresponding to the indices provided with `input_index` and `output_index` 124 TF_CAPI_EXPORT extern void TF_OpKernelContext_ForwardRefInputToRefOutput( 125 TF_OpKernelContext* ctx, int32_t input_index, int32_t output_index); 126 127 // The API releases the opaque lock handle returned with 128 // `TF_MaybeLockVariableInputMutexesInOrder` API 129 TF_CAPI_EXPORT extern void TF_ReleaseVariableInputLockHolder( 130 TF_VariableInputLockHolder* lockHolder); 131 132 // Allows plugin to get TF_Tensor when passed its input_name 133 TF_CAPI_EXPORT extern void TF_GetInputByName(TF_OpKernelContext* ctx, 134 const char* inputName, 135 TF_Tensor** tensor, 136 TF_Status* status); 137 138 // Interprets the named kernel construction attribute as a shape attribute and 139 // fills in `vals` with the size of each dimension. `vals` must point to an 140 // array of length at least `max_values` (ideally set to total_size from 141 // TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size, 142 // &total_size)). 143 TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTensorShape( 144 TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* dims, 145 size_t num_dims, TF_Status* status); 146 147 TF_CAPI_EXPORT extern bool TF_IsRefInput(TF_OpKernelContext* ctx, int i, 148 TF_Status* status); 149 150 #ifndef IS_MOBILE_PLATFORM 151 // Expose higher level AddN operation for Pluggable vendors to implement 152 // in the plugin for Variant data types. The API takes in the context and a 153 // callback provided by pluggable vendor to do a Binary Add operation on the 154 // tensors unwrapped from the Variant tensors. The caller takes ownership of the 155 // `a`, `b` and `out` tensors and is responsible for freeing them with 156 // TF_DeleteTensor. 157 TF_CAPI_EXPORT extern void TF_AddNVariant( 158 TF_OpKernelContext* ctx, 159 void (*binary_add_func)(TF_OpKernelContext* ctx, TF_Tensor* a, TF_Tensor* b, 160 TF_Tensor* out), 161 TF_Status* status); 162 163 // Expose higher level ZerosLike operation for Pluggable vendors to implement 164 // in the plugin for Variant data types. The API takes in the context and a 165 // callback provided by pluggable vendor to do a ZerosLike operation on the 166 // tensors unwrapped from the Variant tensors. The caller takes ownership of the 167 // `input` and `out` tensors and is responsible for freeing them with 168 // TF_DeleteTensor. 169 TF_CAPI_EXPORT extern void TF_ZerosLikeVariant( 170 TF_OpKernelContext* ctx, 171 void (*zeros_like_func)(TF_OpKernelContext* ctx, TF_Tensor* input, 172 TF_Tensor* out), 173 TF_Status* status); 174 #endif // IS_MOBILE_PLATFORM 175 176 #ifdef __cplusplus 177 } /* end extern "C" */ 178 #endif 179 180 #endif // TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_ 181