xref: /aosp_15_r20/external/tensorflow/tensorflow/c/kernels_experimental.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/kernels_experimental.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <utility>
21 
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/c/tf_status_internal.h"
24 #include "tensorflow/c/tf_tensor_internal.h"
25 #include "tensorflow/core/framework/ref_var.h"
26 #include "tensorflow/core/framework/resource_mgr.h"
27 #include "tensorflow/core/framework/resource_var.h"
28 #include "tensorflow/core/framework/variant.h"
29 
30 #ifndef IS_MOBILE_PLATFORM
31 #include "tensorflow/core/kernels/data/optional_ops_util.h"
32 #include "tensorflow/core/kernels/tensor_list.h"
33 #include "tensorflow/core/kernels/tensor_list_util.h"
34 #include "tensorflow/core/kernels/variant_ops_util.h"
35 #include "tensorflow/core/platform/abi.h"
36 #endif  // IS_MOBILE_PLATFORM
37 
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/mutex.h"
40 #include "tensorflow/core/platform/refcount.h"
41 
42 using tensorflow::AllocatorAttributes;
43 using tensorflow::mutex_lock;
44 using tensorflow::Status;
45 using tensorflow::Tensor;
46 using tensorflow::TF_TensorFromTensor;
47 using tensorflow::Var;
48 using tensorflow::Variant;
49 using tensorflow::errors::InvalidArgument;
50 
51 struct TF_VariableInputLockHolder {
TF_VariableInputLockHolderTF_VariableInputLockHolder52   TF_VariableInputLockHolder(
53       std::vector<tensorflow::Var*> vars,
54       std::unique_ptr<std::vector<tensorflow::mutex_lock>> locks,
55       std::unique_ptr<std::vector<tensorflow::tf_shared_lock>> shared_locks)
56       : vars(std::move(vars)),
57         locks(std::move(locks)),
58         shared_locks(std::move(shared_locks)) {}
59 
60   std::vector<tensorflow::Var*> vars;
61   std::unique_ptr<std::vector<tensorflow::mutex_lock>> locks;
62   std::unique_ptr<std::vector<tensorflow::tf_shared_lock>> shared_locks;
63 };
64 
EnsureSparseVariableAccess(TF_OpKernelContext * ctx,bool variantType,void (* copyFunc)(TF_OpKernelContext * ctx,TF_Tensor * source,TF_Tensor * dest),tensorflow::Var * var)65 tensorflow::Status EnsureSparseVariableAccess(
66     TF_OpKernelContext* ctx, bool variantType,
67     void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
68                      TF_Tensor* dest),
69     tensorflow::Var* var) {
70   auto* context = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
71   if (var->copy_on_read_mode.load()) {
72     return ::tensorflow::OkStatus();
73   }
74   mutex_lock ml(*var->mu());
75   // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can
76   // also happen if there are no concurrent reads of the variable and
77   // copy-on-read mode is false.
78   if (var->tensor()->RefCountIsOne()) {
79     var->copy_on_read_mode.store(true);
80     return ::tensorflow::OkStatus();
81   }
82   Tensor tmp;
83   if (variantType) {
84     AllocatorAttributes attr;
85     attr.set_on_host(true);
86     TF_RETURN_IF_ERROR(context->allocate_temp(
87         var->tensor()->dtype(), var->tensor()->shape(), &tmp, attr));
88 
89     const auto elements_in = var->tensor()->flat<Variant>();
90     auto elements_out = tmp.flat<Variant>();
91     for (int64_t i = 0; i < elements_in.size(); ++i) {
92       elements_out(i) = elements_in(i);
93     }
94   } else {
95     AllocatorAttributes attr;
96     attr.set_gpu_compatible(true);
97     attr.set_nic_compatible(true);
98     TF_RETURN_IF_ERROR(context->allocate_temp(
99         var->tensor()->dtype(), var->tensor()->shape(), &tmp, attr));
100     tensorflow::Status s;
101     TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s);
102     TF_Tensor* tf_tensor = TF_TensorFromTensor(*var->tensor(), &s);
103     copyFunc(ctx, tf_tensor, tf_tmp);
104   }
105   *var->tensor() = tmp;
106   var->copy_on_read_mode.store(true);
107   return ::tensorflow::OkStatus();
108 }
109 
PrepareToUpdateVariable(TF_OpKernelContext * ctx,tensorflow::Tensor * tensor,bool copy_on_read_mode,bool variantType,void (* copyFunc)(TF_OpKernelContext * ctx,TF_Tensor * source,TF_Tensor * dest))110 tensorflow::Status PrepareToUpdateVariable(
111     TF_OpKernelContext* ctx, tensorflow::Tensor* tensor, bool copy_on_read_mode,
112     bool variantType,
113     void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
114                      TF_Tensor* dest)) {
115   auto* context = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
116   if (copy_on_read_mode || !tensor->RefCountIsOne()) {
117     // Tensor's buffer is in use by some read, so we need to copy before
118     // updating.
119     Tensor tmp;
120     if (variantType) {
121       AllocatorAttributes attr;
122       attr.set_on_host(true);
123       TF_RETURN_IF_ERROR(
124           context->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr));
125 
126       const auto elements_in = tensor->flat<Variant>();
127       auto elements_out = tmp.flat<Variant>();
128       for (int64_t i = 0; i < elements_in.size(); ++i) {
129         elements_out(i) = elements_in(i);
130       }
131     } else {
132       AllocatorAttributes attr;
133       attr.set_gpu_compatible(true);
134       attr.set_nic_compatible(true);
135       TF_RETURN_IF_ERROR(
136           context->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr));
137       tensorflow::Status s;
138       TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s);
139       TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s);
140       copyFunc(ctx, tf_tensor, tf_tmp);
141     }
142     *tensor = tmp;
143   }
144   return ::tensorflow::OkStatus();
145 }
146 
GetTrainingVariableMutex(TF_OpKernelContext * ctx,int32_t input,bool sparse,void (* copyFunc)(TF_OpKernelContext * ctx,TF_Tensor * source,TF_Tensor * dest),tensorflow::Var ** maybe_resource)147 tensorflow::mutex* GetTrainingVariableMutex(
148     TF_OpKernelContext* ctx, int32_t input, bool sparse,
149     void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
150                      TF_Tensor* dest),
151     tensorflow::Var** maybe_resource) {
152   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
153   *maybe_resource = nullptr;
154   if (cc_ctx->input_dtype(input) == tensorflow::DT_RESOURCE) {
155     if (LookupResource(cc_ctx, HandleFromInput(cc_ctx, input), maybe_resource)
156             .ok()) {
157       if (sparse) {
158         TF_CHECK_OK(
159             EnsureSparseVariableAccess(ctx, false, copyFunc, *maybe_resource));
160       }
161       return (*maybe_resource)->mu();
162     } else {
163       cc_ctx->CtxFailureWithWarning(
164           tensorflow::errors::Internal("Invalid variable reference."));
165       return nullptr;
166     }
167   }
168   return cc_ctx->input_ref_mutex(input);
169 }
170 
TF_AssignVariable(TF_OpKernelContext * ctx,int input_index,int value_index,bool validate_shape,void (* copyFunc)(TF_OpKernelContext * ctx,TF_Tensor * source,TF_Tensor * dest),TF_Status * status)171 void TF_AssignVariable(TF_OpKernelContext* ctx, int input_index,
172                        int value_index, bool validate_shape,
173                        void (*copyFunc)(TF_OpKernelContext* ctx,
174                                         TF_Tensor* source, TF_Tensor* dest),
175                        TF_Status* status) {
176   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
177   tensorflow::core::RefCountPtr<tensorflow::Var> variable;
178   const tensorflow::Tensor& value = cc_ctx->input(value_index);
179   OP_REQUIRES_OK(cc_ctx, tensorflow::LookupOrCreateResource<tensorflow::Var>(
180                              cc_ctx, HandleFromInput(cc_ctx, input_index),
181                              &variable, [&value](tensorflow::Var** ptr) {
182                                *ptr = new tensorflow::Var(value.dtype());
183                                *(*ptr)->tensor() = value;
184                                (*ptr)->is_initialized = true;
185                                return ::tensorflow::OkStatus();
186                              }));
187   tensorflow::mutex_lock ml(*variable->mu());
188 
189   if (validate_shape) {
190     OP_REQUIRES(cc_ctx,
191                 (!variable->is_initialized ||
192                  variable->tensor()->shape().IsSameSize(value.shape())),
193                 InvalidArgument(
194                     "Trying to assign to variable with tensor with wrong shape."
195                     " Expected ",
196                     variable->tensor()->shape().DebugString(), " got ",
197                     value.shape().DebugString()));
198   }
199 
200   if (variable->copy_on_read_mode.load()) {
201     tensorflow::Tensor tmp;
202     tensorflow::AllocatorAttributes attr;
203     attr.set_gpu_compatible(true);
204     attr.set_nic_compatible(true);
205     OP_REQUIRES_OK(cc_ctx, cc_ctx->allocate_temp(value.dtype(), value.shape(),
206                                                  &tmp, attr));
207     tensorflow::Status s;
208     TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s);
209     TF_Tensor* tf_value = TF_TensorFromTensor(value, &s);
210     copyFunc(ctx, tf_value, tf_tmp);
211     *variable->tensor() = tmp;
212   } else {
213     *variable->tensor() = value;
214   }
215   variable->is_initialized = true;
216   TF_SetStatus(status, TF_OK, "");
217 }
218 
TF_AssignRefVariable(TF_OpKernelContext * ctx,int input_ref_index,int output_ref_index,int value_index,bool use_locking,bool validate_shape,void (* copyFunc)(TF_OpKernelContext * ctx,TF_Tensor * source,TF_Tensor * dest),TF_Status * status)219 void TF_AssignRefVariable(TF_OpKernelContext* ctx, int input_ref_index,
220                           int output_ref_index, int value_index,
221                           bool use_locking, bool validate_shape,
222                           void (*copyFunc)(TF_OpKernelContext* ctx,
223                                            TF_Tensor* source, TF_Tensor* dest),
224                           TF_Status* status) {
225   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
226 
227   auto copy = [copyFunc, ctx](::tensorflow::OpKernelContext* cc_ctx,
228                               ::tensorflow::Tensor* lhs,
229                               const ::tensorflow::Tensor& rhs) {
230     ::tensorflow::Status s;
231     TF_Tensor* tf_lhs = TF_TensorFromTensor(*lhs, &s);
232     OP_REQUIRES_OK(cc_ctx, s);
233 
234     TF_Tensor* tf_rhs = TF_TensorFromTensor(rhs, &s);
235 
236     if (!s.ok()) {
237       TF_DeleteTensor(tf_lhs);
238       OP_REQUIRES_OK(cc_ctx, s);
239     }
240 
241     copyFunc(ctx, tf_rhs, tf_lhs);
242   };
243 
244   ::tensorflow::AssignRefVariable(cc_ctx, input_ref_index, output_ref_index,
245                                   value_index, use_locking, validate_shape,
246                                   false, copy);
247   TF_SetStatus(status, TF_OK, "");
248 }
249 
TF_AssignUpdateVariable(TF_OpKernelContext * ctx,int input_index,int value_index,int Op,int isVariantType,void (* copyFunc)(TF_OpKernelContext * ctx,TF_Tensor * source,TF_Tensor * dest),void (* updateFunc)(TF_OpKernelContext * ctx,TF_Tensor * tensor,TF_Tensor * value,int Op),TF_Status * tf_status)250 void TF_AssignUpdateVariable(TF_OpKernelContext* ctx, int input_index,
251                              int value_index, int Op, int isVariantType,
252                              void (*copyFunc)(TF_OpKernelContext* ctx,
253                                               TF_Tensor* source,
254                                               TF_Tensor* dest),
255                              void (*updateFunc)(TF_OpKernelContext* ctx,
256                                                 TF_Tensor* tensor,
257                                                 TF_Tensor* value, int Op),
258                              TF_Status* tf_status) {
259   auto* context = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
260   tensorflow::core::RefCountPtr<Var> variable;
261   Status status =
262       LookupResource(context, HandleFromInput(context, input_index), &variable);
263   if (!status.ok()) {
264     printf("Failed with error: %s\n", status.error_message().c_str());
265     abort();
266   }
267   const Tensor& value = context->input(value_index);
268   mutex_lock ml(*variable->mu());
269   Tensor* var_tensor = variable->tensor();
270   OP_REQUIRES(
271       context, var_tensor->shape().IsSameSize(value.shape()),
272       InvalidArgument("Cannot update variable with shape ",
273                       var_tensor->shape().DebugString(),
274                       " using a Tensor with shape ",
275                       value.shape().DebugString(), ", shapes must be equal."));
276   OP_REQUIRES_OK(context,
277                  PrepareToUpdateVariable(ctx, var_tensor,
278                                          variable->copy_on_read_mode.load(),
279                                          isVariantType, copyFunc));
280   tensorflow::Status s;
281   TF_Tensor* tf_var_tensor = TF_TensorFromTensor(*var_tensor, &s);
282   TF_Tensor* tf_value = TF_TensorFromTensor(value, &s);
283   updateFunc(ctx, tf_var_tensor, tf_value, Op);
284   TF_SetStatus(tf_status, TF_OK, "");
285 }
286 
TF_MaybeLockVariableInputMutexesInOrder(TF_OpKernelContext * ctx,bool do_lock,bool sparse,const int * const inputs,size_t len,void (* copyFunc)(TF_OpKernelContext * ctx,TF_Tensor * source,TF_Tensor * dest),TF_VariableInputLockHolder ** lockHolder,TF_Status * status)287 void TF_MaybeLockVariableInputMutexesInOrder(
288     TF_OpKernelContext* ctx, bool do_lock, bool sparse, const int* const inputs,
289     size_t len,
290     void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
291                      TF_Tensor* dest),
292     TF_VariableInputLockHolder** lockHolder, TF_Status* status) {
293   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
294   bool any_resource = false;
295   std::vector<int> input_ids(inputs, inputs + len);
296   for (auto i : input_ids) {
297     if (cc_ctx->input_dtype(i) == tensorflow::DT_RESOURCE) {
298       any_resource = true;
299       break;
300     }
301   }
302   if (!do_lock && !any_resource) {
303     *lockHolder = new TF_VariableInputLockHolder({}, {}, {});
304     TF_SetStatus(status, TF_OK, "");
305     return;
306   }
307   std::vector<tensorflow::Var*> vars;
308   std::vector<tensorflow::mutex*> mutexes;
309   std::vector<int32_t> acquire_order;
310   for (auto input : input_ids) {
311     tensorflow::Var* var;
312     tensorflow::mutex* mutex =
313         GetTrainingVariableMutex(ctx, input, sparse, copyFunc, &var);
314     if (var) vars.push_back(var);
315     // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
316     if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
317       acquire_order.push_back(mutexes.size());
318       mutexes.push_back(mutex);
319     }
320   }
321   std::sort(acquire_order.begin(), acquire_order.end(),
322             [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
323 
324   auto locks = absl::make_unique<std::vector<tensorflow::mutex_lock>>();
325   auto shared_locks =
326       absl::make_unique<std::vector<tensorflow::tf_shared_lock>>();
327   locks->reserve(acquire_order.size());
328 
329   for (auto input : acquire_order) {
330     tensorflow::Var* var;
331     tensorflow::mutex* mu =
332         GetTrainingVariableMutex(ctx, input, sparse, copyFunc, &var);
333     tensorflow::core::ScopedUnref scoped_unref(var);
334     if (mu != nullptr) {
335       if (do_lock) {
336         locks->emplace_back(*mu);
337       } else {
338         shared_locks->emplace_back(*mu);
339       }
340     }
341   }
342   *lockHolder = new TF_VariableInputLockHolder(
343       std::move(vars), std::move(locks), std::move(shared_locks));
344   TF_SetStatus(status, TF_OK, "");
345 }
346 
TF_GetInputTensorFromVariable(TF_OpKernelContext * ctx,int input,bool lock_held,bool isVariantType,bool sparse,void (* copyFunc)(TF_OpKernelContext * ctx,TF_Tensor * source,TF_Tensor * dest),TF_Tensor ** out,TF_Status * status)347 void TF_GetInputTensorFromVariable(TF_OpKernelContext* ctx, int input,
348                                    bool lock_held, bool isVariantType,
349                                    bool sparse,
350                                    void (*copyFunc)(TF_OpKernelContext* ctx,
351                                                     TF_Tensor* source,
352                                                     TF_Tensor* dest),
353                                    TF_Tensor** out, TF_Status* status) {
354   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
355   tensorflow::Status s;
356   if (cc_ctx->input_dtype(input) == tensorflow::DT_RESOURCE) {
357     tensorflow::core::RefCountPtr<tensorflow::Var> var;
358     OP_REQUIRES_OK(
359         cc_ctx, LookupResource(cc_ctx, HandleFromInput(cc_ctx, input), &var));
360     if (sparse) {
361       OP_REQUIRES_OK(cc_ctx, EnsureSparseVariableAccess(ctx, isVariantType,
362                                                         copyFunc, var.get()));
363       *out = ::tensorflow::TF_TensorFromTensor(*var->tensor(), &s);
364       ::tensorflow::Set_TF_Status_from_Status(status, s);
365       return;
366     }
367     OP_REQUIRES_OK(cc_ctx, PrepareToUpdateVariable(
368                                ctx, var->tensor(),
369                                var->copy_on_read_mode.load(), false, copyFunc));
370     *out = ::tensorflow::TF_TensorFromTensor(*var->tensor(), &s);
371     ::tensorflow::Set_TF_Status_from_Status(status, s);
372     return;
373   }
374   *out = ::tensorflow::TF_TensorFromTensor(
375       cc_ctx->mutable_input(input, lock_held), &s);
376   ::tensorflow::Set_TF_Status_from_Status(status, s);
377 }
378 
TF_OpKernelContext_ForwardRefInputToRefOutput(TF_OpKernelContext * ctx,int32_t input_index,int32_t output_index)379 void TF_OpKernelContext_ForwardRefInputToRefOutput(TF_OpKernelContext* ctx,
380                                                    int32_t input_index,
381                                                    int32_t output_index) {
382   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
383   if (cc_ctx->input_dtype(input_index) != tensorflow::DT_RESOURCE) {
384     cc_ctx->forward_ref_input_to_ref_output(input_index, output_index);
385   }
386 }
387 
TF_ReleaseVariableInputLockHolder(TF_VariableInputLockHolder * lockHolder)388 void TF_ReleaseVariableInputLockHolder(TF_VariableInputLockHolder* lockHolder) {
389   if (lockHolder != nullptr) {
390     lockHolder->locks.reset();
391     for (tensorflow::Var* var : lockHolder->vars) {
392       var->Unref();
393     }
394     delete lockHolder;
395   }
396 }
397 
TF_GetInputByName(TF_OpKernelContext * ctx,const char * inputName,TF_Tensor ** tensor,TF_Status * status)398 void TF_GetInputByName(TF_OpKernelContext* ctx, const char* inputName,
399                        TF_Tensor** tensor, TF_Status* status) {
400   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
401   const ::tensorflow::Tensor* cc_tensor = nullptr;
402   tensorflow::Status s = cc_ctx->input(inputName, &cc_tensor);
403 
404   if (!s.ok()) {
405     ::tensorflow::Set_TF_Status_from_Status(status, s);
406     return;
407   }
408   TF_Tensor* result =
409       ::tensorflow::TF_TensorFromTensor(*cc_tensor, &status->status);
410   if (TF_GetCode(status) == TF_OK) {
411     *tensor = result;
412   }
413 }
414 
TF_OpKernelConstruction_GetAttrTensorShape(TF_OpKernelConstruction * ctx,const char * attr_name,int64_t * dims,size_t num_dims,TF_Status * status)415 void TF_OpKernelConstruction_GetAttrTensorShape(TF_OpKernelConstruction* ctx,
416                                                 const char* attr_name,
417                                                 int64_t* dims, size_t num_dims,
418                                                 TF_Status* status) {
419   ::tensorflow::TensorShape shape;
420   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
421   ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &shape);
422   ::tensorflow::Set_TF_Status_from_Status(status, s);
423   size_t rank = static_cast<size_t>(shape.dims());
424 
425   if (!status->status.ok()) return;
426 
427   if (num_dims != rank) {
428     status->status = InvalidArgument("Expected rank is ", num_dims,
429                                      " but actual rank is ", rank);
430     return;
431   }
432 
433   for (int i = 0; i < rank; ++i) {
434     dims[i] = static_cast<int64_t>(shape.dim_size(i));
435   }
436 }
437 
TF_IsRefInput(TF_OpKernelContext * ctx,int i,TF_Status * status)438 bool TF_IsRefInput(TF_OpKernelContext* ctx, int i, TF_Status* status) {
439   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
440   if (i < 0 || i >= cc_ctx->num_inputs()) {
441     TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
442     return false;
443   }
444   TF_SetStatus(status, TF_OK, "");
445   return cc_ctx->input_is_ref(i);
446 }
447 
448 #ifndef IS_MOBILE_PLATFORM
449 template <typename T>
ValidateVariantType(const Variant & variant)450 static Status ValidateVariantType(const Variant& variant) {
451   if (variant.get<T>() == nullptr) {
452     const std::string type_index_name =
453         ::tensorflow::port::MaybeAbiDemangle(variant.TypeId().name());
454 
455     return ::tensorflow::errors::Internal(
456         "VariantBinaryOpFn: Could not access object 'a', type_index: ",
457         type_index_name);
458   }
459 
460   return ::tensorflow::OkStatus();
461 }
462 
TF_AddNVariant(TF_OpKernelContext * ctx,void (* binary_add_func)(TF_OpKernelContext * ctx,TF_Tensor * a,TF_Tensor * b,TF_Tensor * out),TF_Status * status)463 void TF_AddNVariant(TF_OpKernelContext* ctx,
464                     void (*binary_add_func)(TF_OpKernelContext* ctx,
465                                             TF_Tensor* a, TF_Tensor* b,
466                                             TF_Tensor* out),
467                     TF_Status* status) {
468   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
469 
470   auto cc_binary_add_func = [binary_add_func](
471                                 ::tensorflow::OpKernelContext* cc_ctx,
472                                 const Tensor& cc_a, const Tensor& cc_b,
473                                 Tensor* cc_out) {
474     if (cc_a.dtype() == ::tensorflow::DT_INVALID) {
475       *cc_out = cc_b;
476       return ::tensorflow::OkStatus();
477     }
478     if (cc_b.dtype() == ::tensorflow::DT_INVALID) {
479       *cc_out = cc_a;
480       return ::tensorflow::OkStatus();
481     }
482 
483     Status status;
484     TF_Tensor* a = TF_TensorFromTensor(cc_a, &status);
485     TF_RETURN_IF_ERROR(status);
486 
487     TF_Tensor* b = TF_TensorFromTensor(cc_b, &status);
488     if (!status.ok()) {
489       TF_DeleteTensor(a);
490       return status;
491     }
492 
493     ::tensorflow::AllocatorAttributes attr;
494     if (cc_a.dtype() == ::tensorflow::DT_VARIANT) {
495       attr.set_on_host(true);
496     }
497 
498     status = cc_ctx->allocate_temp(cc_a.dtype(), cc_a.shape(), cc_out, attr);
499     if (!status.ok()) {
500       TF_DeleteTensor(a);
501       TF_DeleteTensor(b);
502       return status;
503     }
504 
505     TF_Tensor* out = TF_TensorFromTensor(*cc_out, &status);
506     if (!status.ok()) {
507       TF_DeleteTensor(a);
508       TF_DeleteTensor(b);
509       return status;
510     }
511 
512     auto* ctx = reinterpret_cast<TF_OpKernelContext*>(cc_ctx);
513     binary_add_func(ctx, a, b, out);
514     return cc_ctx->status();
515   };
516 
517   auto binary_add_variant = [cc_binary_add_func](
518                                 ::tensorflow::OpKernelContext* cc_ctx,
519                                 const Variant& a, const Variant& b,
520                                 Variant* out) {
521     if (out == nullptr) {
522       return ::tensorflow::errors::Internal(
523           "The output variant hasn't been initialized");
524     }
525 
526     if (a.TypeId() != b.TypeId()) {
527       return ::tensorflow::errors::Internal(
528           "BinaryOpVariants: Variants a and b have different "
529           "type ids.  Type names: '",
530           a.TypeName(), "' vs. '", b.TypeName(), "'");
531     }
532 
533     if (a.TypeId() == tensorflow::TypeIndex::Make<::tensorflow::TensorList>()) {
534       TF_RETURN_IF_ERROR(ValidateVariantType<::tensorflow::TensorList>(a));
535       *out = ::tensorflow::TensorList();
536 
537       return ::tensorflow::TensorListBinaryAdd(
538           cc_ctx, *a.get<::tensorflow::TensorList>(),
539           *b.get<::tensorflow::TensorList>(),
540           out->get<::tensorflow::TensorList>(), cc_binary_add_func);
541     } else if (a.TypeId() == tensorflow::TypeIndex::Make<
542                                  ::tensorflow::data::OptionalVariant>()) {
543       TF_RETURN_IF_ERROR(
544           ValidateVariantType<::tensorflow::data::OptionalVariant>(a));
545       *out = ::tensorflow::data::OptionalVariant();
546 
547       return ::tensorflow::data::OptionalBinaryAdd(
548           cc_ctx, *a.get<::tensorflow::data::OptionalVariant>(),
549           *b.get<::tensorflow::data::OptionalVariant>(),
550           out->get<::tensorflow::data::OptionalVariant>(), cc_binary_add_func);
551     }
552 
553     const std::string type_index_name =
554         ::tensorflow::port::MaybeAbiDemangle(a.TypeId().name());
555 
556     return ::tensorflow::errors::Internal(
557         "No unary variant binary_op function found for op ADD Variant "
558         "type_name: ",
559         type_index_name, " for device type: ", cc_ctx->device()->name());
560   };
561   ::tensorflow::AddNVariant(cc_ctx, binary_add_variant);
562   ::tensorflow::Set_TF_Status_from_Status(status, cc_ctx->status());
563 }
564 
ZerosLikeVariant(::tensorflow::OpKernelContext * cc_ctx,const Variant & input,Variant * out,void (* zeros_like_func)(TF_OpKernelContext * ctx,TF_Tensor * input,TF_Tensor * out))565 static Status ZerosLikeVariant(::tensorflow::OpKernelContext* cc_ctx,
566                                const Variant& input, Variant* out,
567                                void (*zeros_like_func)(TF_OpKernelContext* ctx,
568                                                        TF_Tensor* input,
569                                                        TF_Tensor* out)) {
570   auto cc_zeros_like_func = [zeros_like_func](
571                                 ::tensorflow::OpKernelContext* cc_ctx,
572                                 const Tensor& cc_input, Tensor* cc_out) {
573     AllocatorAttributes attr;
574     if (cc_input.dtype() == ::tensorflow::DT_VARIANT) {
575       attr.set_on_host(true);
576     }
577     TF_RETURN_IF_ERROR(cc_ctx->allocate_temp(cc_input.dtype(), cc_input.shape(),
578                                              cc_out, attr));
579 
580     switch (cc_input.dtype()) {
581       case ::tensorflow::DT_INVALID: {
582         *cc_out = Tensor(::tensorflow::DT_INVALID);
583         break;
584       }
585       case ::tensorflow::DT_VARIANT: {
586         // If the wrapped tensor is also a variant, recursively call
587         // ZerosLikeVariant to unwrap it the same way
588         Variant* out_variant = cc_out->scalar<Variant>().data();
589         TF_RETURN_IF_ERROR(ZerosLikeVariant(cc_ctx,
590                                             cc_input.scalar<Variant>()(),
591                                             out_variant, zeros_like_func));
592         break;
593       }
594       default: {
595         Status status;
596         TF_Tensor* input = TF_TensorFromTensor(cc_input, &status);
597         TF_RETURN_IF_ERROR(status);
598 
599         TF_Tensor* out = TF_TensorFromTensor(*cc_out, &status);
600         if (!status.ok()) {
601           TF_DeleteTensor(input);
602           return status;
603         }
604 
605         auto* ctx = reinterpret_cast<TF_OpKernelContext*>(cc_ctx);
606         zeros_like_func(ctx, input, out);
607       }
608     }
609     return cc_ctx->status();
610   };
611 
612   if (out == nullptr) {
613     return ::tensorflow::errors::Internal(
614         "The output variant hasn't been initialized");
615   }
616 
617   if (input.TypeId() ==
618       tensorflow::TypeIndex::Make<::tensorflow::TensorList>()) {
619     TF_RETURN_IF_ERROR(ValidateVariantType<::tensorflow::TensorList>(input));
620     *out = ::tensorflow::TensorList();
621 
622     return ::tensorflow::TensorListZerosLike(
623         cc_ctx, *input.get<::tensorflow::TensorList>(),
624         out->get<::tensorflow::TensorList>(), cc_zeros_like_func);
625   } else if (input.TypeId() == tensorflow::TypeIndex::Make<
626                                    ::tensorflow::data::OptionalVariant>()) {
627     TF_RETURN_IF_ERROR(
628         ValidateVariantType<::tensorflow::data::OptionalVariant>(input));
629     *out = ::tensorflow::data::OptionalVariant();
630 
631     return ::tensorflow::data::OptionalZerosLike(
632         cc_ctx, *input.get<::tensorflow::data::OptionalVariant>(),
633         out->get<::tensorflow::data::OptionalVariant>(), cc_zeros_like_func);
634   }
635 
636   const std::string type_index_name =
637       ::tensorflow::port::MaybeAbiDemangle(input.TypeId().name());
638 
639   return ::tensorflow::errors::Internal(
640       "No unary variant unary_op function found for op ZEROS_LIKE Variant "
641       "type_name: ",
642       type_index_name, " for device type: ", cc_ctx->device()->name());
643 }
644 
TF_ZerosLikeVariant(TF_OpKernelContext * ctx,void (* zeros_like_func)(TF_OpKernelContext * ctx,TF_Tensor * input,TF_Tensor * out),TF_Status * status)645 void TF_ZerosLikeVariant(TF_OpKernelContext* ctx,
646                          void (*zeros_like_func)(TF_OpKernelContext* ctx,
647                                                  TF_Tensor* input,
648                                                  TF_Tensor* out),
649                          TF_Status* status) {
650   auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
651 
652   const Tensor& input = cc_ctx->input(0);
653   OP_REQUIRES(cc_ctx, input.dims() == 0,
654               InvalidArgument(
655                   "ZerosLike non-scalar Tensor with dtype=DT_VARIANT is not "
656                   "supported."));
657   const Variant& v = input.scalar<Variant>()();
658   // DT_VARIANT tensors must be allocated on CPU since they wrap C++
659   // objects which can not be efficiently represented in GPU memory.
660   int numa_node = cc_ctx->device()->NumaNode();
661   Tensor out(::tensorflow::cpu_allocator(numa_node), ::tensorflow::DT_VARIANT,
662              ::tensorflow::TensorShape({}));
663   Variant* out_v = &(out.scalar<Variant>()());
664   Status cc_status = ZerosLikeVariant(cc_ctx, v, out_v, zeros_like_func);
665   ::tensorflow::Set_TF_Status_from_Status(status, cc_status);
666   OP_REQUIRES_OK(cc_ctx, cc_status);
667   cc_ctx->set_output(0, out);
668 }
669 #endif  // IS_MOBILE_PLATFORM
670