xref: /aosp_15_r20/external/tensorflow/tensorflow/c/kernels.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/kernels.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/c_api_macros.h"
22 #include "tensorflow/c/tf_buffer_internal.h"
23 #include "tensorflow/c/tf_status_helper.h"
24 #include "tensorflow/c/tf_tensor_internal.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/kernel_def_builder.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/framework/types.h"
31 // Required for IS_MOBILE_PLATFORM definition
32 #include "tensorflow/core/platform/platform.h"
33 #include "tensorflow/core/platform/types.h"
34 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
35 #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
36 #include "tensorflow/stream_executor/stream.h"
37 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
38 
39 using tensorflow::errors::InvalidArgument;
40 // This file forms the basis of a stable ABI for third-party kernel
41 // implementations. It is crucial that changes to this file are made cautiously
42 // and with a focus on maintaining both source and binary compatibility.
43 
44 struct TF_KernelBuilder {
45   ::tensorflow::KernelDefBuilder* cc_builder;
46 
47   void* (*create_function)(TF_OpKernelConstruction*);
48   void (*compute_function)(void*, TF_OpKernelContext*);
49   void (*delete_function)(void*);
50 };
51 
TF_NewKernelBuilder(const char * op_name,const char * device_name,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))52 TF_KernelBuilder* TF_NewKernelBuilder(
53     const char* op_name, const char* device_name,
54     void* (*create_func)(TF_OpKernelConstruction*),
55     void (*compute_func)(void*, TF_OpKernelContext*),
56     void (*delete_func)(void*)) {
57   TF_KernelBuilder* result = new TF_KernelBuilder;
58   result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
59   result->cc_builder->Device(device_name);
60   result->create_function = create_func;
61   result->compute_function = compute_func;
62   result->delete_function = delete_func;
63   return result;
64 }
65 
TF_DeleteKernelBuilder(TF_KernelBuilder * builder)66 void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
67   if (builder != nullptr) {
68     delete builder->cc_builder;
69     delete builder;
70   }
71 }
72 
73 namespace tensorflow {
74 namespace {
75 
76 #define CASE(type)                                               \
77   case DataTypeToEnum<type>::value: {                            \
78     kernel_builder->cc_builder->TypeConstraint<type>(attr_name); \
79     break;                                                       \
80   }
81 
AddTypeConstraint(TF_KernelBuilder * kernel_builder,const char * attr_name,const DataType dtype,TF_Status * status)82 void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
83                        const DataType dtype, TF_Status* status) {
84   // This needs to be under tensorflow:: namespace so that
85   // TF_CALL_ALL_TYPES macro can find tensorflow::string as string.
86   switch (dtype) {
87     TF_CALL_ALL_TYPES(CASE);
88     TF_CALL_QUANTIZED_TYPES(CASE);
89     TF_CALL_quint16(CASE);
90     TF_CALL_qint16(CASE);
91     default:
92       status->status = errors::Unimplemented("Unexpected type ", dtype);
93       return;
94   }
95   TF_SetStatus(status, TF_OK, "");
96 }
97 #undef CASE
98 
99 }  // namespace
100 }  // namespace tensorflow
101 
102 namespace {
GetAttrValue(TF_OpKernelConstruction * ctx,const char * attr_name,TF_Status * status)103 const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
104                                           const char* attr_name,
105                                           TF_Status* status) {
106   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
107   const tensorflow::AttrValue* attr =
108       ::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
109   if (attr == nullptr) {
110     status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
111                                      "' has no attr named '", attr_name, "'.");
112   }
113   return attr;
114 }
115 }  // namespace
116 
TF_KernelBuilder_TypeConstraint(TF_KernelBuilder * kernel_builder,const char * attr_name,const TF_DataType type,TF_Status * status)117 void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
118                                      const char* attr_name,
119                                      const TF_DataType type,
120                                      TF_Status* status) {
121   tensorflow::DataType dtype = static_cast<tensorflow::DataType>(type);
122   tensorflow::AddTypeConstraint(kernel_builder, attr_name, dtype, status);
123 }
124 
TF_KernelBuilder_HostMemory(TF_KernelBuilder * kernel_builder,const char * arg_name)125 void TF_KernelBuilder_HostMemory(TF_KernelBuilder* kernel_builder,
126                                  const char* arg_name) {
127   kernel_builder->cc_builder->HostMemory(arg_name);
128 }
129 
TF_KernelBuilder_Priority(TF_KernelBuilder * kernel_builder,int32_t priority_number)130 void TF_KernelBuilder_Priority(TF_KernelBuilder* kernel_builder,
131                                int32_t priority_number) {
132   kernel_builder->cc_builder->Priority(priority_number);
133 }
134 
TF_KernelBuilder_Label(TF_KernelBuilder * kernel_builder,const char * label)135 void TF_KernelBuilder_Label(TF_KernelBuilder* kernel_builder,
136                             const char* label) {
137   kernel_builder->cc_builder->Label(label);
138 }
139 
140 namespace tensorflow {
141 namespace {
142 
143 // An OpKernel whose methods delegate to C function pointers.
144 class COpKernel : public OpKernel {
145  public:
COpKernel(OpKernelConstruction * ctx,void * (* create_func)(TF_OpKernelConstruction *),void (* compute_func)(void *,TF_OpKernelContext *),void (* delete_func)(void *))146   explicit COpKernel(OpKernelConstruction* ctx,
147                      void* (*create_func)(TF_OpKernelConstruction*),
148                      void (*compute_func)(void*, TF_OpKernelContext*),
149                      void (*delete_func)(void*))
150       : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
151     if (create_func != nullptr) {
152       c_kernel_ =
153           (*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx));
154     } else {
155       c_kernel_ = nullptr;
156     }
157   }
158 
Compute(OpKernelContext * ctx)159   void Compute(OpKernelContext* ctx) override {
160     (*compute_func_)(c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx));
161   }
162 
~COpKernel()163   ~COpKernel() override {
164     if (delete_func_ != nullptr) {
165       (*delete_func_)(c_kernel_);
166     }
167   }
168 
169  private:
170   void (*compute_func_)(void*, TF_OpKernelContext* context);
171   void (*delete_func_)(void*);
172   void* c_kernel_;
173 };
174 
175 // A KernelFactory that returns COpKernel instances.
176 class KernelBuilderFactory
177     : public ::tensorflow::kernel_factory::OpKernelFactory {
178  public:
KernelBuilderFactory(TF_KernelBuilder * builder)179   explicit KernelBuilderFactory(TF_KernelBuilder* builder)
180       : builder_(builder) {}
Create(::tensorflow::OpKernelConstruction * context)181   ::tensorflow::OpKernel* Create(
182       ::tensorflow::OpKernelConstruction* context) override {
183     return new ::tensorflow::COpKernel(context, builder_->create_function,
184                                        builder_->compute_function,
185                                        builder_->delete_function);
186   }
~KernelBuilderFactory()187   ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
188 
189  private:
190   TF_KernelBuilder* builder_;
191 };
192 }  // namespace
193 }  // namespace tensorflow
194 
TF_RegisterKernelBuilder(const char * name,TF_KernelBuilder * builder,TF_Status * status)195 void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
196                               TF_Status* status) {
197   using tensorflow::register_kernel::Name;
198 
199   TF_RegisterKernelBuilderWithKernelDef(
200       /*serialized_kernel_def=*/nullptr, name, builder, status);
201 }
202 
TF_RegisterKernelBuilderWithKernelDef(const char * serialized_kernel_def,const char * name,TF_KernelBuilder * builder,TF_Status * status)203 void TF_RegisterKernelBuilderWithKernelDef(const char* serialized_kernel_def,
204                                            const char* name,
205                                            TF_KernelBuilder* builder,
206                                            TF_Status* status) {
207   using tensorflow::register_kernel::Name;
208   if (serialized_kernel_def == nullptr) {
209     // If user doesn't provide a serialized KernelDef, use the kernel builder
210     // to build a new one.
211     tensorflow::kernel_factory::OpKernelRegistrar(
212         builder->cc_builder->Build(), name,
213         std::make_unique<tensorflow::KernelBuilderFactory>(builder));
214 
215     TF_SetStatus(status, TF_OK, "");
216     return;
217   }
218 
219   tensorflow::KernelDef* kernel_def = new tensorflow::KernelDef();
220   bool success = kernel_def->ParsePartialFromString(serialized_kernel_def);
221   if (!success) {
222     TF_SetStatus(status, TF_INVALID_ARGUMENT,
223                  "Error parsing serialized KernelDef.");
224     return;
225   }
226 
227   tensorflow::kernel_factory::OpKernelRegistrar(
228       kernel_def, name,
229       std::make_unique<tensorflow::KernelBuilderFactory>(builder));
230 
231   TF_SetStatus(status, TF_OK, "");
232 }
233 
234 // This function is only for pluggable device.
235 // It will return nullptr in all other cases.
236 // This function is experimental and subject to change.
TF_GetStream(TF_OpKernelContext * ctx,TF_Status * status)237 SP_Stream TF_GetStream(TF_OpKernelContext* ctx, TF_Status* status) {
238 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
239   status->status = tensorflow::errors::Unimplemented(
240       "Accessing device stream is not supported on mobile. File a bug at "
241       "https://github.com/tensorflow/tensorflow/issues if this feature is "
242       "important to you");
243   return nullptr;
244 #else
245   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
246   if (cc_ctx->op_device_context() == nullptr) {  // CPU Device
247     status->status = tensorflow::errors::FailedPrecondition(
248         "Accessing device stream is not supported for a CPU device.");
249     return nullptr;
250   } else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
251     status->status = tensorflow::errors::FailedPrecondition(
252         "Accessing device stream is only supported for pluggable devices.");
253     return nullptr;
254   } else {  // Is a PluggableDevice
255     TF_SetStatus(status, TF_OK, "");
256     auto c_stream = static_cast<stream_executor::CStream*>(
257         cc_ctx->op_device_context()->stream()->implementation());
258     return c_stream->Handle();
259   }
260 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
261 }
262 
TF_NumInputs(TF_OpKernelContext * ctx)263 int TF_NumInputs(TF_OpKernelContext* ctx) {
264   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
265   return cc_ctx->num_inputs();
266 }
267 
TF_NumOutputs(TF_OpKernelContext * ctx)268 int TF_NumOutputs(TF_OpKernelContext* ctx) {
269   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
270   return cc_ctx->num_outputs();
271 }
272 
TF_GetInput(TF_OpKernelContext * ctx,int i,TF_Tensor ** tensor,TF_Status * status)273 void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
274                  TF_Status* status) {
275   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
276   if (i < 0 || i >= cc_ctx->num_inputs()) {
277     TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
278     return;
279   }
280   const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
281   TF_Tensor* result =
282       ::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
283   if (TF_GetCode(status) == TF_OK) {
284     *tensor = result;
285   }
286 }
287 
TF_InputRange(TF_OpKernelContext * ctx,const char * name,TF_InputRange_Args * args)288 void TF_InputRange(TF_OpKernelContext* ctx, const char* name,
289                    TF_InputRange_Args* args) {
290   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
291   int start = -1, stop = -1;
292   auto status = cc_ctx->op_kernel().InputRange(name, &start, &stop);
293   args->start = start;
294   args->stop = stop;
295   tensorflow::Set_TF_Status_from_Status(args->status, status);
296 }
297 
TF_SetOutput(TF_OpKernelContext * ctx,int i,const TF_Tensor * tensor,TF_Status * status)298 void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
299                   TF_Status* status) {
300   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
301   if (i < 0 || i >= cc_ctx->num_outputs()) {
302     TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
303     return;
304   }
305   ::tensorflow::Tensor cc_tensor;
306   ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor);
307   TF_SetStatus(status, TF_OK, "");
308   ::tensorflow::Set_TF_Status_from_Status(status, s);
309   if (s.ok()) {
310     cc_ctx->set_output(i, cc_tensor);
311   }
312 }
313 
TF_GetMutableOutput(TF_OpKernelContext * ctx,int i,TF_Status * status)314 TF_Tensor* TF_GetMutableOutput(TF_OpKernelContext* ctx, int i,
315                                TF_Status* status) {
316   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
317   if (i < 0 || i >= cc_ctx->num_outputs()) {
318     TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
319     return nullptr;
320   }
321   const ::tensorflow::Tensor& cc_tensor = *(cc_ctx->mutable_output(i));
322   TF_Tensor* result =
323       ::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
324   if (TF_GetCode(status) == TF_OK) {
325     return result;
326   } else {
327     return nullptr;
328   }
329 }
330 
TF_GetSerializedFunctionDefLibrary(TF_OpKernelContext * ctx,TF_Buffer * serialized_function_def_library,TF_Status * status)331 void TF_GetSerializedFunctionDefLibrary(
332     TF_OpKernelContext* ctx, TF_Buffer* serialized_function_def_library,
333     TF_Status* status) {
334   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
335   auto fdef_lib =
336       cc_ctx->function_library()->GetFunctionLibraryDefinition()->ToProto();
337   auto cc_status =
338       tensorflow::MessageToBuffer(fdef_lib, serialized_function_def_library);
339   tensorflow::Set_TF_Status_from_Status(status, cc_status);
340 }
341 
TF_GetSerializedConfigProto(TF_OpKernelContext * ctx,TF_Buffer * serialized_config_proto,TF_Status * status)342 void TF_GetSerializedConfigProto(TF_OpKernelContext* ctx,
343                                  TF_Buffer* serialized_config_proto,
344                                  TF_Status* status) {
345   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
346   const tensorflow::ConfigProto* config_proto_ptr =
347       cc_ctx->function_library()->config_proto();
348   tensorflow::ConfigProto config_proto;
349   if (config_proto_ptr != nullptr) {
350     config_proto = *config_proto_ptr;
351   }
352   auto cc_status =
353       tensorflow::MessageToBuffer(config_proto, serialized_config_proto);
354   tensorflow::Set_TF_Status_from_Status(status, cc_status);
355 }
356 
TF_OpKernelConstruction_Failure(TF_OpKernelConstruction * ctx,TF_Status * status)357 void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx,
358                                      TF_Status* status) {
359   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
360   ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
361   cc_ctx->CtxFailure(s);
362 }
363 
TF_OpKernelContext_Failure(TF_OpKernelContext * ctx,TF_Status * status)364 void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
365   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
366   ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status));
367   cc_ctx->CtxFailure(s);
368 }
369 
TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction * ctx,const char * attr_name,int32_t * list_size,int32_t * total_size,TF_Status * status)370 void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
371                                          const char* attr_name,
372                                          int32_t* list_size,
373                                          int32_t* total_size,
374                                          TF_Status* status) {
375   const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
376   if (!status->status.ok()) {
377     *list_size = -1;
378     *total_size = -1;
379     return;
380   }
381   switch (attr->value_case()) {
382 #define SINGLE_CASE(kK, attr_type, size_expr) \
383   case tensorflow::AttrValue::kK:             \
384     *list_size = -1;                          \
385     *total_size = size_expr;                  \
386     break;
387 
388     SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
389     SINGLE_CASE(kI, TF_ATTR_INT, -1);
390     SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
391     SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
392     SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
393     SINGLE_CASE(kShape, TF_ATTR_SHAPE,
394                 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
395     SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
396 #undef SINGLE_CASE
397 
398     case tensorflow::AttrValue::kList:
399       *list_size = 0;
400       *total_size = -1;
401 #define LIST_CASE(field, attr_type, ...)      \
402   if (attr->list().field##_size() > 0) {      \
403     *list_size = attr->list().field##_size(); \
404     __VA_ARGS__;                              \
405     break;                                    \
406   }
407 
408       LIST_CASE(
409           s, TF_ATTR_STRING, *total_size = 0;
410           for (int i = 0; i < attr->list().s_size();
411                ++i) { *total_size += attr->list().s(i).size(); });
412       LIST_CASE(i, TF_ATTR_INT);
413       LIST_CASE(f, TF_ATTR_FLOAT);
414       LIST_CASE(b, TF_ATTR_BOOL);
415       LIST_CASE(type, TF_ATTR_TYPE);
416       LIST_CASE(
417           shape, TF_ATTR_SHAPE, *total_size = 0;
418           for (int i = 0; i < attr->list().shape_size(); ++i) {
419             const auto& s = attr->list().shape(i);
420             *total_size += s.unknown_rank() ? 0 : s.dim_size();
421           });
422       LIST_CASE(tensor, TF_ATTR_TENSOR);
423       LIST_CASE(tensor, TF_ATTR_FUNC);
424 #undef LIST_CASE
425       break;
426 
427     case tensorflow::AttrValue::kPlaceholder:
428       *list_size = -1;
429       *total_size = -1;
430       break;
431 
432     case tensorflow::AttrValue::kFunc:
433       *list_size = -1;
434       *total_size = -1;
435       break;
436 
437     case tensorflow::AttrValue::VALUE_NOT_SET:
438       status->status =
439           InvalidArgument("Attribute '", attr_name, "' has no value set");
440       break;
441   }
442 }
443 
444 #define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field)        \
445   void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx,     \
446                                              const char* attr_name,            \
447                                              c_type* val, TF_Status* status) { \
448     TF_SetStatus(status, TF_OK, "");                                           \
449     cc_type v;                                                                 \
450     auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \
451     ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);                   \
452     ::tensorflow::Set_TF_Status_from_Status(status, s);                        \
453     if (s.ok()) {                                                              \
454       *val = static_cast<c_type>(v);                                           \
455     }                                                                          \
456   }                                                                            \
457   void TF_OpKernelConstruction_GetAttr##func##List(                            \
458       TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals,       \
459       int max_vals, TF_Status* status) {                                       \
460     TF_SetStatus(status, TF_OK, "");                                           \
461     const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);  \
462     if (!status->status.ok()) return;                                          \
463     if (attr->value_case() != tensorflow::AttrValue::kList) {                  \
464       status->status =                                                         \
465           InvalidArgument("Value for '", attr_name, "' is not a list.");       \
466       return;                                                                  \
467     }                                                                          \
468     status->status =                                                           \
469         tensorflow::AttrValueHasType(*attr, "list(" attr_type ")");            \
470     if (!status->status.ok()) return;                                          \
471     const auto len = std::min(max_vals, attr->list().list_field##_size());     \
472     for (int i = 0; i < len; ++i) {                                            \
473       vals[i] = static_cast<c_type>(attr->list().list_field(i));               \
474     }                                                                          \
475   }
476 
477 DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
478 DEFINE_TF_GETATTR(Int32, int32_t, int32_t, "int", i)
479 DEFINE_TF_GETATTR(Int64, int64_t, int64_t, "int", i)
480 DEFINE_TF_GETATTR(Float, float, float, "float", f)
481 DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
482 
TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction * ctx,const char * attr_name,char * value,size_t max_length,TF_Status * status)483 void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
484                                            const char* attr_name, char* value,
485                                            size_t max_length,
486                                            TF_Status* status) {
487   std::string v;
488   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
489   ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
490   ::tensorflow::Set_TF_Status_from_Status(status, s);
491 
492   if (!status->status.ok()) return;
493 
494   if (max_length <= 0) {
495     return;
496   }
497   std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
498 }
499 
TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction * ctx,const char * attr_name,char ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)500 void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
501                                                const char* attr_name,
502                                                char** values, size_t* lengths,
503                                                int max_values, void* storage,
504                                                size_t storage_size,
505                                                TF_Status* status) {
506   std::vector<std::string> v;
507   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
508   ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
509   ::tensorflow::Set_TF_Status_from_Status(status, s);
510 
511   if (!status->status.ok()) return;
512 
513   const auto len = std::min(max_values, static_cast<int>(v.size()));
514   char* p = static_cast<char*>(storage);
515   for (int i = 0; i < len; ++i) {
516     const std::string& s = v[i];
517     values[i] = p;
518     lengths[i] = s.size();
519     if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
520       status->status = InvalidArgument(
521           "Not enough storage to hold the requested list of strings");
522       return;
523     }
524     memcpy(values[i], s.data(), s.size());
525     p += s.size();
526   }
527 }
528 
TF_OpKernelConstruction_GetAttrTensor(TF_OpKernelConstruction * ctx,const char * attr_name,TF_Tensor ** val,TF_Status * status)529 void TF_OpKernelConstruction_GetAttrTensor(TF_OpKernelConstruction* ctx,
530                                            const char* attr_name,
531                                            TF_Tensor** val, TF_Status* status) {
532   *val = nullptr;
533   ::tensorflow::Tensor t;
534   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
535   ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &t);
536   ::tensorflow::Set_TF_Status_from_Status(status, s);
537 
538   if (!status->status.ok()) return;
539 
540   *val = TF_TensorFromTensor(t, &status->status);
541 }
542 
TF_OpKernelConstruction_GetAttrTensorList(TF_OpKernelConstruction * ctx,const char * attr_name,TF_Tensor ** vals,int max_values,TF_Status * status)543 void TF_OpKernelConstruction_GetAttrTensorList(TF_OpKernelConstruction* ctx,
544                                                const char* attr_name,
545                                                TF_Tensor** vals, int max_values,
546                                                TF_Status* status) {
547   std::vector<::tensorflow::Tensor> v;
548   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
549   ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
550   ::tensorflow::Set_TF_Status_from_Status(status, s);
551 
552   if (!status->status.ok()) return;
553 
554   const auto len = std::min(max_values, static_cast<int>(v.size()));
555   for (int i = 0; i < len; ++i) {
556     vals[i] = TF_TensorFromTensor(v[i], &status->status);
557     if (!status->status.ok()) return;
558   }
559 }
560 
TF_OpKernelConstruction_GetAttrFunction(TF_OpKernelConstruction * ctx,const char * attr_name,TF_Status * status)561 TF_Buffer* TF_OpKernelConstruction_GetAttrFunction(TF_OpKernelConstruction* ctx,
562                                                    const char* attr_name,
563                                                    TF_Status* status) {
564   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
565   tensorflow::NameAttrList function;
566   auto cc_status = cc_ctx->GetAttr(attr_name, &function);
567   if (!cc_status.ok()) {
568     Set_TF_Status_from_Status(status, cc_status);
569     return nullptr;
570   }
571   TF_Buffer* buffer = TF_NewBuffer();
572   cc_status = tensorflow::MessageToBuffer(function, buffer);
573   Set_TF_Status_from_Status(status, cc_status);
574   if (!cc_status.ok())
575     return nullptr;
576   else
577     return buffer;
578 }
579 
TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction * ctx,const char * attr_name,TF_Status * status)580 bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
581                                      const char* attr_name, TF_Status* status) {
582   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
583   return cc_ctx->HasAttr(attr_name);
584 }
585 
TF_OpKernelConstruction_GetName(TF_OpKernelConstruction * ctx)586 TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
587   auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
588   TF_StringView string_view_of_name;
589   string_view_of_name.data = cc_ctx->def().name().data();
590   string_view_of_name.len = cc_ctx->def().name().length();
591   return string_view_of_name;
592 }
593 
TF_ExpectedOutputDataType(TF_OpKernelContext * ctx,int i)594 TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
595   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
596   CHECK_GE(i, 0);
597   CHECK_LT(i, cc_ctx->num_outputs());
598   return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i));
599 }
600 
TF_IsHostMemoryInput(TF_OpKernelContext * ctx,int i,TF_Status * status)601 bool TF_IsHostMemoryInput(TF_OpKernelContext* ctx, int i, TF_Status* status) {
602   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
603   if (i < 0 || i >= cc_ctx->num_inputs()) {
604     TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
605     return false;
606   }
607   TF_SetStatus(status, TF_OK, "");
608   return cc_ctx->input_memory_type(i) == tensorflow::HOST_MEMORY;
609 }
610 
TF_IsHostMemoryOutput(TF_OpKernelContext * ctx,int i,TF_Status * status)611 bool TF_IsHostMemoryOutput(TF_OpKernelContext* ctx, int i, TF_Status* status) {
612   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
613   if (i < 0 || i >= cc_ctx->num_outputs()) {
614     TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
615     return false;
616   }
617   TF_SetStatus(status, TF_OK, "");
618   return cc_ctx->output_memory_type(i) == tensorflow::HOST_MEMORY;
619 }
620 
TF_StepId(TF_OpKernelContext * ctx)621 int64_t TF_StepId(TF_OpKernelContext* ctx) {
622   return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
623 }
624 
TF_OpKernelConstruction_GetNodeDef(TF_OpKernelConstruction * ctx,TF_Status * status)625 TF_Buffer* TF_OpKernelConstruction_GetNodeDef(TF_OpKernelConstruction* ctx,
626                                               TF_Status* status) {
627   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
628   TF_Buffer* ret = TF_NewBuffer();
629   status->status = MessageToBuffer(cc_ctx->def(), ret);
630   if (!status->status.ok()) {
631     TF_DeleteBuffer(ret);
632     return nullptr;
633   }
634   return ret;
635 }
636 
TF_GetFrameId(TF_OpKernelContext * ctx)637 uint64_t TF_GetFrameId(TF_OpKernelContext* ctx) {
638   return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)
639       ->frame_iter()
640       .frame_id;
641 }
642 
TF_GetGraphDefVersion(TF_OpKernelContext * ctx)643 int TF_GetGraphDefVersion(TF_OpKernelContext* ctx) {
644   return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)
645       ->function_library()
646       ->graph_def_version();
647 }
648 
TF_GetIterId(TF_OpKernelContext * ctx)649 int64_t TF_GetIterId(TF_OpKernelContext* ctx) {
650   return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)
651       ->frame_iter()
652       .iter_id;
653 }
654 
TF_GetOpKernelName(TF_OpKernelContext * ctx)655 TF_StringView TF_GetOpKernelName(TF_OpKernelContext* ctx) {
656   auto cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
657   TF_StringView opkernel_name_sv;
658   opkernel_name_sv.data = cc_ctx->op_kernel().name().data();
659   opkernel_name_sv.len = cc_ctx->op_kernel().name().length();
660   return opkernel_name_sv;
661 }
662 
TF_GetResourceMgrDefaultContainerName(TF_OpKernelContext * ctx)663 TF_StringView TF_GetResourceMgrDefaultContainerName(TF_OpKernelContext* ctx) {
664   auto cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
665   TF_StringView default_container_name_sv;
666   default_container_name_sv.data =
667       cc_ctx->resource_manager()->default_container().data();
668   default_container_name_sv.len =
669       cc_ctx->resource_manager()->default_container().length();
670   return default_container_name_sv;
671 }
672 
TF_GetOpKernelRequestedInput(TF_OpKernelContext * ctx,size_t index)673 TF_StringView TF_GetOpKernelRequestedInput(TF_OpKernelContext* ctx,
674                                            size_t index) {
675   auto cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
676   TF_StringView requested_input_sv;
677   requested_input_sv.data = cc_ctx->op_kernel().requested_input(index).data();
678   requested_input_sv.len = cc_ctx->op_kernel().requested_input(index).length();
679   return requested_input_sv;
680 }
681 
TF_AllocateOutput(TF_OpKernelContext * context,int index,TF_DataType dtype,const int64_t * dims,int num_dims,size_t len,TF_Status * status)682 TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
683                              TF_DataType dtype, const int64_t* dims,
684                              int num_dims, size_t len, TF_Status* status) {
685   TF_SetStatus(status, TF_OK, "");
686   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
687   static_assert(sizeof(int64_t) == sizeof(int64_t),
688                 "64-bit int types should match in size");
689   tensorflow::gtl::ArraySlice<const int64_t> dimarray(
690       reinterpret_cast<const int64_t*>(dims), num_dims);
691   tensorflow::Tensor* tensor;
692   tensorflow::Status s = cc_ctx->allocate_output(
693       index, tensorflow::TensorShape(dimarray), &tensor);
694   if (!s.ok()) {
695     ::tensorflow::Set_TF_Status_from_Status(status, s);
696     return nullptr;
697   }
698   TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s);
699   if (!s.ok()) {
700     ::tensorflow::Set_TF_Status_from_Status(status, s);
701     return nullptr;
702   }
703   return tf_tensor;
704 }
705 
TF_ForwardInputOrAllocateOutput(TF_OpKernelContext * context,const int * candidate_input_indices,int num_candidate_input_indices,int output_index,const int64_t * output_dims,int output_num_dims,int * forwarded_input,TF_Status * status)706 TF_Tensor* TF_ForwardInputOrAllocateOutput(
707     TF_OpKernelContext* context, const int* candidate_input_indices,
708     int num_candidate_input_indices, int output_index,
709     const int64_t* output_dims, int output_num_dims, int* forwarded_input,
710     TF_Status* status) {
711   TF_SetStatus(status, TF_OK, "");
712   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
713 
714   static_assert(sizeof(int64_t) == sizeof(int64_t),
715                 "64-bit int types should match in size");
716   tensorflow::gtl::ArraySlice<int> input_indices_array(
717       candidate_input_indices, num_candidate_input_indices);
718   tensorflow::gtl::ArraySlice<const int64_t> output_dimarray(
719       reinterpret_cast<const int64_t*>(output_dims), output_num_dims);
720   tensorflow::Tensor* output_tensor_pointer;
721   tensorflow::Status s = cc_ctx->forward_input_or_allocate_output(
722       input_indices_array, output_index,
723       tensorflow::TensorShape(output_dimarray), &output_tensor_pointer,
724       forwarded_input);
725   if (!s.ok()) {
726     ::tensorflow::Set_TF_Status_from_Status(status, s);
727     return nullptr;
728   }
729   TF_Tensor* tf_tensor_output = TF_TensorFromTensor(*output_tensor_pointer, &s);
730   if (!s.ok()) {
731     ::tensorflow::Set_TF_Status_from_Status(status, s);
732     return nullptr;
733   }
734   return tf_tensor_output;
735 }
736 
TF_AllocateTemp(TF_OpKernelContext * context,TF_DataType dtype,const int64_t * dims,int num_dims,TF_AllocatorAttributes * attributes,TF_Status * status)737 TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype,
738                            const int64_t* dims, int num_dims,
739                            TF_AllocatorAttributes* attributes,
740                            TF_Status* status) {
741   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
742   TF_SetStatus(status, TF_OK, "");
743   static_assert(sizeof(int64_t) == sizeof(int64_t),
744                 "64-bit int types should match in size");
745   tensorflow::gtl::ArraySlice<const int64_t> dimarray(
746       reinterpret_cast<const int64_t*>(dims), num_dims);
747   if (attributes && !attributes->struct_size) {
748     TF_SetStatus(
749         status, TF_INVALID_ARGUMENT,
750         "TF_AllocatorAttributes struct "
751         "size member must be set to TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE");
752     return nullptr;
753   }
754   tensorflow::AllocatorAttributes allocator_attr;
755   if (attributes && attributes->on_host) {
756     allocator_attr.set_on_host(true);
757   }
758   tensorflow::Status s;
759   tensorflow::Tensor tensor;
760   s = cc_ctx->allocate_temp(static_cast<tensorflow::DataType>(dtype),
761                             tensorflow::TensorShape(dimarray), &tensor,
762                             allocator_attr);
763   if (!s.ok()) {
764     ::tensorflow::Set_TF_Status_from_Status(status, s);
765     return nullptr;
766   }
767   TF_Tensor* tf_tensor;
768   tf_tensor = TF_TensorFromTensor(tensor, &s);
769   if (!s.ok()) {
770     ::tensorflow::Set_TF_Status_from_Status(status, s);
771     return nullptr;
772   }
773   return tf_tensor;
774 }
775