xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/functional_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/types.h"
16 #define EIGEN_USE_THREADS
17 
18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/framework/device_base.h"
21 #include "tensorflow/core/framework/function.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/threadpool.h"
27 #include "tensorflow/core/platform/casts.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/profiler/lib/traceme.h"
31 
32 namespace tensorflow {
33 typedef Eigen::GpuDevice GPUDevice;
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35 typedef FunctionLibraryRuntime::Handle FHandle;
36 typedef std::vector<Tensor> TensorVec;
37 
38 namespace {
39 
40 // Helper to instantiate function "func" in the library "lib".
Instantiate(FunctionLibraryRuntime * lib,const NameAttrList & func,FunctionLibraryRuntime::Handle * handle)41 Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
42                    FunctionLibraryRuntime::Handle* handle) {
43   return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
44 }
45 
Instantiate(OpKernelContext * ctx,const NameAttrList & func,FunctionLibraryRuntime::Handle * handle)46 Status Instantiate(OpKernelContext* ctx, const NameAttrList& func,
47                    FunctionLibraryRuntime::Handle* handle) {
48   FunctionLibraryRuntime::InstantiateOptions opts;
49   opts.executor_type = ctx->executor_type();
50   return ctx->function_library()->Instantiate(
51       func.name(), AttrSlice(&func.attr()), opts, handle);
52 }
53 
54 // If "t" is a scalar of a supported type, returns t != 0 in "*v".
ToBool(gtl::ArraySlice<Tensor> t,bool * v)55 Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) {
56   if (t.size() != 1) {
57     return errors::InvalidArgument(
58         "Expected a single scalar which can be converted to a boolean, got ",
59         t.size(), " tensors.");
60   }
61   if (TensorShapeUtils::IsScalar(t[0].shape())) {
62     switch (t[0].dtype()) {
63 #define CASE(T)                   \
64   case DataTypeToEnum<T>::value:  \
65     *v = t[0].scalar<T>()() != 0; \
66     break;
67 
68       CASE(float);
69       CASE(double);
70       CASE(int32);
71       CASE(uint8);
72       CASE(int16);
73       CASE(int8);
74       CASE(int64_t);
75 #undef CASE
76       case DT_BOOL:
77         *v = t[0].scalar<bool>()();
78         break;
79       case DT_STRING:
80         *v = !t[0].scalar<tstring>()().empty();
81         break;
82       default:
83         return errors::InvalidArgument(DataTypeString(t[0].dtype()),
84                                        " cannot be converted to a boolean");
85     }
86   } else {
87     *v = t[0].NumElements() > 0;
88   }
89   return OkStatus();
90 }
91 
92 // Sets "rets" to be the output of "ctx". Validates rets' types based
93 // on "kernel".
SetOutputs(const OpKernel * kernel,OpKernelContext * ctx,gtl::ArraySlice<Tensor> rets)94 Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx,
95                   gtl::ArraySlice<Tensor> rets) {
96   if (rets.size() != ctx->num_outputs()) {
97     return errors::Internal("Expect to produce ", ctx->num_outputs(),
98                             " tensors, but only get ", rets.size());
99   }
100   for (int i = 0; i < rets.size(); ++i) {
101     if (rets[i].dtype() != kernel->output_type(i)) {
102       return errors::Internal("Expect ", i, "-th output is of type ",
103                               DataTypeString(kernel->output_type(i)),
104                               " but get ", DataTypeString(rets[i].dtype()));
105     }
106     ctx->set_output(i, rets[i]);
107   }
108   return OkStatus();
109 }
110 
SetRunOptions(OpKernelContext * ctx,FunctionLibraryRuntime::Options * opts,bool always_collect_stats)111 void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
112                    bool always_collect_stats) {
113   opts->rendezvous = ctx->rendezvous();
114   opts->cancellation_manager = ctx->cancellation_manager();
115   opts->collective_executor = ctx->collective_executor();
116   if (always_collect_stats) {
117     opts->stats_collector = ctx->stats_collector();
118   }
119   opts->runner = ctx->runner();
120   opts->run_all_kernels_inline = ctx->run_all_kernels_inline();
121   opts->step_container = ctx->step_container();
122 }
123 
124 class IfOp : public AsyncOpKernel {
125  public:
IfOp(OpKernelConstruction * ctx)126   explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
127     auto lib = ctx->function_library();
128     OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
129     OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &then_func_));
130     OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &else_func_));
131   }
132 
~IfOp()133   ~IfOp() override {}
134 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)135   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
136     FHandle then_handle;
137     FHandle else_handle;
138     OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &then_handle, &else_handle),
139                          done);
140     bool cond;
141     OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond));
142     (new State(this, ctx, cond, then_handle, else_handle, done))->Start();
143   }
144 
145  private:
146   NameAttrList then_func_;
147   NameAttrList else_func_;
148 
149   mutex mu_;
150   std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
151       handles_ ABSL_GUARDED_BY(mu_);
152 
153   class State {
154    public:
State(IfOp * kernel,OpKernelContext * ctx,bool cond,FHandle then_handle,FHandle else_handle,DoneCallback done)155     State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle,
156           FHandle else_handle, DoneCallback done)
157         : kernel_(kernel),
158           ctx_(ctx),
159           cond_(cond),
160           then_handle_(then_handle),
161           else_handle_(else_handle),
162           done_(std::move(done)),
163           lib_(CHECK_NOTNULL(ctx_->function_library())) {
164       SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
165       for (int i = 1; i < ctx_->num_inputs(); ++i) {
166         args_.push_back(ctx_->input(i));
167       }
168     }
169 
~State()170     ~State() {}
171 
Start()172     void Start() {
173       FHandle handle = cond_ ? then_handle_ : else_handle_;
174       rets_.clear();
175       profiler::TraceMe trace_me("IfOp");
176       lib_->Run(
177           // Evaluate one of the branch.
178           opts_, handle, args_, &rets_,
179           // Done callback
180           [this](Status s) {
181             if (s.ok()) {
182               s = SetOutputs(kernel_, ctx_, rets_);
183             }
184             ctx_->SetStatus(s);
185             DoneCallback captured_done(std::move(done_));
186             delete this;
187             captured_done();
188           });
189     }
190 
191    private:
192     IfOp* const kernel_;
193     OpKernelContext* const ctx_;
194     const bool cond_;
195     FHandle then_handle_;
196     FHandle else_handle_;
197     DoneCallback done_;
198     FunctionLibraryRuntime* const lib_;
199     FunctionLibraryRuntime::Options opts_;
200     TensorVec args_;
201     TensorVec rets_;
202   };
203 
GetHandles(OpKernelContext * ctx,FHandle * then_handle,FHandle * else_handle)204   Status GetHandles(OpKernelContext* ctx, FHandle* then_handle,
205                     FHandle* else_handle) {
206     // TODO(b/37549631): Because this op has `SetIsStateful()` in its
207     // op registration, this kernel may be shared by multiple
208     // subgraphs, which have different associated
209     // `FunctionLibraryRuntime` objects and hence different `FHandle`
210     // namespaces. We currently work around this by caching the map
211     // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
212     // functions this op uses.
213     auto lib = ctx->function_library();
214     if (lib == nullptr) return errors::Internal("No function library");
215     *then_handle = kInvalidHandle;
216     *else_handle = kInvalidHandle;
217     {
218       tf_shared_lock l(mu_);
219       const auto iter = handles_.find(lib);
220       if (TF_PREDICT_TRUE(iter != handles_.end())) {
221         *then_handle = iter->second.first;
222         *else_handle = iter->second.second;
223       }
224     }
225     if (TF_PREDICT_FALSE(*then_handle == kInvalidHandle)) {
226       mutex_lock l(mu_);
227       const auto iter = handles_.find(lib);
228       if (TF_PREDICT_TRUE(iter != handles_.end())) {
229         *then_handle = iter->second.first;
230         *else_handle = iter->second.second;
231       } else {
232         TF_RETURN_IF_ERROR(Instantiate(ctx, then_func_, then_handle));
233         TF_RETURN_IF_ERROR(Instantiate(ctx, else_func_, else_handle));
234         handles_[lib] = {*then_handle, *else_handle};
235       }
236     }
237     return OkStatus();
238   }
239 };
240 
241 class CaseOp : public AsyncOpKernel {
242  public:
CaseOp(OpKernelConstruction * ctx)243   explicit CaseOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
244     auto lib = ctx->function_library();
245     OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
246     OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &branch_funcs_));
247   }
248 
~CaseOp()249   ~CaseOp() override {}
250 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)251   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
252     auto lib = ctx->function_library();
253     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
254                       errors::Internal("No function library"), done);
255 
256     // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
257     // registration, this kernel may be shared by multiple subgraphs, which have
258     // different associated `FunctionLibraryRuntime` objects and hence different
259     // `FHandle` namespaces. So we must call Instantiate() to make sure we get
260     // the correct function handles with respect to `lib`. Note the underlying
261     // `lib->Instantiate()` caches the created function handles, so calling
262     // `Instantiate()` repeatedly on the same `lib` and function is cheap.
263     std::vector<FHandle> branch_handles(branch_funcs_.size());
264     for (int i = 0; i < branch_funcs_.size(); i++) {
265       OP_REQUIRES_OK_ASYNC(
266           ctx, Instantiate(lib, branch_funcs_[i], &branch_handles[i]), done);
267     }
268 
269     const Tensor& branch_index = ctx->input(0);
270     OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(branch_index.shape()),
271                       errors::InvalidArgument("branch_index must be scalar"),
272                       done);
273     int32_t branch = branch_index.scalar<int32>()();
274     (new State(this, ctx, branch, branch_handles, done))->Start();
275   }
276 
277  private:
278   std::vector<NameAttrList> branch_funcs_;
279 
280   class State {
281    public:
State(CaseOp * kernel,OpKernelContext * ctx,int branch,std::vector<FHandle> branch_handles,DoneCallback done)282     State(CaseOp* kernel, OpKernelContext* ctx, int branch,
283           std::vector<FHandle> branch_handles, DoneCallback done)
284         : kernel_(kernel),
285           ctx_(ctx),
286           branch_(branch),
287           branch_handles_(branch_handles),
288           done_(std::move(done)),
289           lib_(CHECK_NOTNULL(ctx_->function_library())) {
290       SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
291       for (int i = 1; i < ctx_->num_inputs(); ++i) {
292         args_.push_back(ctx_->input(i));
293       }
294     }
295 
~State()296     ~State() {}
297 
Start()298     void Start() {
299       int branch = branch_;
300       // The last branch is the default branch.
301       if (branch < 0 || branch >= branch_handles_.size()) {
302         branch = branch_handles_.size() - 1;
303       }
304       rets_.clear();
305       profiler::TraceMe trace_me("CaseOp");
306       lib_->Run(
307           // Evaluate one of the branch.
308           opts_, branch_handles_[branch], args_, &rets_,
309           // Done callback
310           [this](Status s) {
311             if (s.ok()) {
312               s = SetOutputs(kernel_, ctx_, rets_);
313             }
314             ctx_->SetStatus(s);
315             DoneCallback captured_done(std::move(done_));
316             delete this;
317             captured_done();
318           });
319     }
320 
321    private:
322     CaseOp* const kernel_;
323     OpKernelContext* const ctx_;
324     const int branch_;
325     std::vector<FHandle> branch_handles_;
326     DoneCallback done_;
327     FunctionLibraryRuntime* const lib_;
328     FunctionLibraryRuntime::Options opts_;
329     TensorVec args_;
330     TensorVec rets_;
331   };
332 };
333 
334 // TODO(drpng): remove this.
335 REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), IfOp);
336 REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_DEFAULT).HostMemory("cond"),
337                         IfOp);
338 
339 REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp);
340 REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_DEFAULT).HostMemory("cond"),
341                         IfOp);
342 
343 REGISTER_KERNEL_BUILDER(Name("Case").Device(DEVICE_CPU), CaseOp);
344 REGISTER_KERNEL_BUILDER(
345     Name("Case").Device(DEVICE_DEFAULT).HostMemory("branch_index"), CaseOp);
346 REGISTER_KERNEL_BUILDER(Name("StatelessCase").Device(DEVICE_CPU), CaseOp);
347 REGISTER_KERNEL_BUILDER(
348     Name("StatelessCase").Device(DEVICE_DEFAULT).HostMemory("branch_index"),
349     CaseOp);
350 
351 REGISTER_KERNEL_BUILDER(Name("StatelessIf").Device(DEVICE_CPU), IfOp);
352 REGISTER_KERNEL_BUILDER(
353     Name("StatelessIf").Device(DEVICE_DEFAULT).HostMemory("cond"), IfOp);
354 
355 class WhileOp : public AsyncOpKernel {
356  public:
WhileOp(OpKernelConstruction * ctx)357   explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
358     OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &cond_func_));
359     OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &body_func_));
360   }
361 
~WhileOp()362   ~WhileOp() override {}
363 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)364   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
365     if (ctx->run_all_kernels_inline()) {
366       // Use the non-callback-based implementation when kernels (and function
367       // callbacks) execute inline to avoid stack overflow.
368       OP_REQUIRES_OK_ASYNC(ctx, DoComputeSync(ctx), done);
369       done();
370     } else {
371       FHandle cond_handle;
372       FHandle body_handle;
373       OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle),
374                            done);
375       (new State(this, ctx, cond_handle, body_handle, done))->Start();
376     }
377   }
378 
Compute(OpKernelContext * ctx)379   void Compute(OpKernelContext* ctx) override {
380     // Use the non-callback-based implementation when the synchronous Compute()
381     // method is invoked, because the caller is explicitly donating a thread.
382     Status s = DoComputeSync(ctx);
383     // NOTE: Unfortunately, we cannot use OP_REQUIRES_OK here, because this is
384     // still an AsyncOpKernel, and there is a run-time check to avoid calling
385     // OP_REQUIRES_OK in AsyncOpKernel::ComputeAsync() (which would deadlock in
386     // the event of an error).
387     if (TF_PREDICT_FALSE(!s.ok())) {
388       ctx->SetStatus(s);
389     }
390   }
391 
392  private:
393   NameAttrList cond_func_;
394   NameAttrList body_func_;
395 
396   mutex mu_;
397   std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
398       handles_ ABSL_GUARDED_BY(mu_);
399 
CondResultToBool(OpKernelContext * ctx,const FunctionLibraryRuntime::Options & opts,const Tensor & cond_t,bool * out_result)400   static Status CondResultToBool(OpKernelContext* ctx,
401                                  const FunctionLibraryRuntime::Options& opts,
402                                  const Tensor& cond_t, bool* out_result) {
403     bool is_pluggable = ctx->op_device_context() &&
404                         ctx->op_device_context()->IsPluggableDevice();
405     const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info =
406         ctx->device()->tensorflow_accelerator_device_info();
407     const bool is_hostmem_dtype =
408         cond_t.dtype() == DT_INT32 || cond_t.dtype() == DT_INT64;
409     if (!is_hostmem_dtype && (is_pluggable || accelerator_device_info) &&
410         (opts.rets_alloc_attrs.empty() ||
411          !opts.rets_alloc_attrs[0].on_host())) {
412       // Copy the ret value to host if it's allocated on device.
413       Device* device = down_cast<Device*>(ctx->device());
414       DeviceContext* device_ctx = ctx->op_device_context();
415       Tensor host_cond_t = Tensor(cond_t.dtype(), cond_t.shape());
416       TF_RETURN_IF_ERROR(device_ctx->CopyDeviceTensorToCPUSync(
417           &cond_t, /*tensor_name=*/"", device, &host_cond_t));
418       return ToBool({host_cond_t}, out_result);
419     }
420     return ToBool({cond_t}, out_result);
421   }
422 
423   // The initial loop variable args are the inputs to the kernel.
424   //
425   // We attempt to forward the input so that it can be consumed inside the
426   // body function (and participate in buffer forwarding, etc.).
GetArgsFromContext(OpKernelContext * ctx,std::vector<Tensor> * out_args,DataTypeVector * out_var_types)427   static void GetArgsFromContext(OpKernelContext* ctx,
428                                  std::vector<Tensor>* out_args,
429                                  DataTypeVector* out_var_types) {
430     const int num_loop_vars = ctx->num_inputs();
431     out_args->reserve(num_loop_vars);
432     out_var_types->resize(num_loop_vars);
433     for (int i = 0; i < num_loop_vars; ++i) {
434       const Tensor& input = ctx->input(i);
435       (*out_var_types)[i] = input.dtype();
436       std::unique_ptr<Tensor> maybe_forwarded_input = ctx->forward_input(
437           i, /* output_index= */ OpKernelContext::Params::kNoReservation,
438           input.dtype(), input.shape(), ctx->input_memory_type(i),
439           ctx->input_alloc_attr(i));
440       if (maybe_forwarded_input) {
441         out_args->push_back(std::move(*maybe_forwarded_input));
442       } else {
443         out_args->push_back(input);
444       }
445     }
446   }
447 
448   class BodyFuncCallFrame : public CallFrameInterface {
449    public:
BodyFuncCallFrame(std::vector<Tensor> * args,std::vector<Tensor> * retvals,DataTypeSlice ret_types)450     BodyFuncCallFrame(std::vector<Tensor>* args, std::vector<Tensor>* retvals,
451                       DataTypeSlice ret_types)
452         : args_(args), retvals_(retvals), ret_types_(ret_types) {}
453 
num_args() const454     size_t num_args() const override { return args_->size(); }
num_retvals() const455     size_t num_retvals() const override { return retvals_->size(); }
456 
GetArg(int index,const Tensor ** val)457     Status GetArg(int index, const Tensor** val) override {
458       if (index < args_->size()) {
459         *val = &(*args_)[index];
460         return OkStatus();
461       } else {
462         return errors::InvalidArgument("Argument ", index, " is out of range.");
463       }
464     }
465 
ConsumeArg(int index,Tensor * val)466     void ConsumeArg(int index, Tensor* val) override {
467       DCHECK_GE(index, 0);
468       DCHECK_LT(index, args_->size());
469       *val = std::move((*args_)[index]);
470     }
CanConsumeArg(int index) const471     bool CanConsumeArg(int index) const override {
472       return index >= 0 && index < args_->size();
473     }
474 
SetRetval(int index,const Tensor & val)475     Status SetRetval(int index, const Tensor& val) override {
476       if (TF_PREDICT_FALSE(index < 0)) {
477         return errors::InvalidArgument(
478             "Expected non-negative return value index, but got: ", index, ".");
479       } else if (TF_PREDICT_FALSE(index >= retvals_->size())) {
480         return errors::InvalidArgument("While loop body returned ", index + 1,
481                                        " arguments. Expected: ", num_retvals(),
482                                        ".");
483       } else if (TF_PREDICT_FALSE(val.dtype() != ret_types_[index])) {
484         return errors::InvalidArgument("Expected type ",
485                                        DataTypeString(ret_types_[index]),
486                                        " for return value ", index, " but got ",
487                                        DataTypeString(val.dtype()), ".");
488       }
489       (*retvals_)[index] = val;
490       return OkStatus();
491     }
492 
493    private:
494     std::vector<Tensor>* const args_;     // Not owned.
495     std::vector<Tensor>* const retvals_;  // Not owned.
496     DataTypeSlice ret_types_;
497 
498     TF_DISALLOW_COPY_AND_ASSIGN(BodyFuncCallFrame);
499   };
500 
501   class State {
502    public:
State(WhileOp * kernel,OpKernelContext * ctx,FHandle cond_handle,FHandle body_handle,DoneCallback done)503     State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
504           FHandle body_handle, DoneCallback done)
505         : kernel_(kernel),
506           ctx_(ctx),
507           cond_handle_(cond_handle),
508           body_handle_(body_handle),
509           done_(std::move(done)),
510           lib_(CHECK_NOTNULL(ctx_->function_library())) {
511       SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
512       GetArgsFromContext(ctx, &args_, &loop_var_types_);
513       body_frame_ =
514           std::make_unique<BodyFuncCallFrame>(&args_, &rets_, loop_var_types_);
515     }
516 
~State()517     ~State() {}
518 
Start()519     void Start() { EvalCond(); }
520 
521    private:
522     WhileOp* const kernel_;
523     OpKernelContext* const ctx_;
524     const FHandle cond_handle_;
525     const FHandle body_handle_;
526     const DoneCallback done_;
527     FunctionLibraryRuntime* const lib_;
528     FunctionLibraryRuntime::Options opts_;
529     TensorVec args_;
530     TensorVec rets_;
531     DataTypeVector loop_var_types_;
532     std::unique_ptr<BodyFuncCallFrame> body_frame_;
533 
EvalCond()534     void EvalCond() {
535       profiler::TraceMe trace_me("WhileOp-EvalCond");
536       lib_->Run(
537           // Evaluate the condition.
538           opts_, cond_handle_, args_, &rets_,
539           // Done cb.
540           [this](const Status& s) {
541             if (!s.ok()) {
542               return Finish(s);
543             }
544             StartBody();
545           });
546     }
547 
StartBody()548     void StartBody() {
549       Status s;
550       if (rets_.size() != 1) {
551         s = errors::InvalidArgument(
552             "Expected a single scalar return value from WhileOp cond, got ",
553             rets_.size(), " tensors.");
554         return Finish(s);
555       }
556 
557       if (!s.ok()) {
558         return Finish(s);
559       }
560       bool cond;
561       s = CondResultToBool(ctx_, opts_, rets_[0], &cond);
562       if (!s.ok()) {
563         return Finish(s);
564       }
565 
566       if (!cond) {
567         return Finish(OkStatus());
568       }
569       rets_.clear();
570       rets_.resize(args_.size());
571       profiler::TraceMe trace_me("WhileOp-StartBody");
572       lib_->Run(
573           // Evaluate the body.
574           opts_, body_handle_, body_frame_.get(),
575           // Done callback
576           [this](const Status& s) {
577             if (!s.ok()) {
578               return Finish(s);
579             }
580             if (args_.size() != rets_.size()) {
581               return Finish(errors::InvalidArgument(
582                   "While loop body returned ", rets_.size(),
583                   " arguments. Expected: ", args_.size()));
584             }
585             args_.clear();
586             using std::swap;
587             swap(args_, rets_);
588             EvalCond();
589           });
590     }
591 
Finish(Status s)592     void Finish(Status s) {
593       if (s.ok()) {
594         s = SetOutputs(kernel_, ctx_, args_);
595       }
596       ctx_->SetStatus(s);
597       done_();
598       delete this;
599     }
600   };
601 
DoComputeSync(OpKernelContext * ctx)602   Status DoComputeSync(OpKernelContext* ctx) {
603     FHandle cond_handle;
604     FHandle body_handle;
605     TF_RETURN_IF_ERROR(GetHandles(ctx, &cond_handle, &body_handle));
606     auto lib = ctx->function_library();
607     FunctionLibraryRuntime::Options opts;
608     SetRunOptions(ctx, &opts, false /* always_collect_stats */);
609 
610     // Pre-allocate argument and return value vectors for the cond and body
611     // functions.
612     std::vector<Tensor> args;
613     const int num_loop_vars = ctx->num_inputs();
614     DataTypeVector loop_var_types(num_loop_vars);
615     GetArgsFromContext(ctx, &args, &loop_var_types);
616     std::vector<Tensor> cond_rets;
617     cond_rets.reserve(1);
618     std::vector<Tensor> body_rets;
619     body_rets.reserve(num_loop_vars);
620 
621     // Implement the logic of the while loop as a single C++ do-while loop that
622     // executes the cond and body functions synchronously.
623     do {
624       // Evaluate the cond function on the current loop variables.
625       {
626         profiler::TraceMe trace_me("WhileOp-EvalCond");
627         TF_RETURN_IF_ERROR(lib->RunSync(opts, cond_handle, args, &cond_rets));
628       }
629       if (cond_rets.size() != 1) {
630         return errors::InvalidArgument(
631             "Expected a single scalar return value from WhileOp cond, got ",
632             cond_rets.size(), " tensors.");
633       }
634 
635       // If the cond function evaluates to false, we are done: output the
636       // current loop variables.
637       bool cond_result;
638       TF_RETURN_IF_ERROR(
639           CondResultToBool(ctx, opts, cond_rets[0], &cond_result));
640       if (!cond_result) {
641         return SetOutputs(this, ctx, args);
642       }
643 
644       // Evaluate the body function on the current loop variables, to get an
645       // updated vector of loop variables.
646       {
647         profiler::TraceMe trace_me("WhileOp-StartBody");
648         body_rets.resize(num_loop_vars);
649         BodyFuncCallFrame call_frame(&args, &body_rets, loop_var_types);
650         TF_RETURN_IF_ERROR(lib->RunSync(opts, body_handle, &call_frame));
651       }
652       std::swap(body_rets, args);
653       body_rets.clear();
654     } while (true);
655   }
656 
GetHandles(OpKernelContext * ctx,FHandle * cond_handle,FHandle * body_handle)657   Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle,
658                     FHandle* body_handle) {
659     // TODO(b/37549631): Because this op has `SetIsStateful()` in its
660     // op registration, this kernel may be shared by multiple
661     // subgraphs, which have different associated
662     // `FunctionLibraryRuntime` objects and hence different `FHandle`
663     // namespaces. We currently work around this by caching the map
664     // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
665     // functions this op uses.
666     auto lib = ctx->function_library();
667     if (lib == nullptr) return errors::Internal("No function library");
668     *cond_handle = kInvalidHandle;
669     *body_handle = kInvalidHandle;
670     {
671       tf_shared_lock l(mu_);
672       const auto iter = handles_.find(lib);
673       if (TF_PREDICT_TRUE(iter != handles_.end())) {
674         *cond_handle = iter->second.first;
675         *body_handle = iter->second.second;
676       }
677     }
678     if (TF_PREDICT_FALSE(*cond_handle == kInvalidHandle)) {
679       mutex_lock l(mu_);
680       const auto iter = handles_.find(lib);
681       if (TF_PREDICT_TRUE(iter != handles_.end())) {
682         *cond_handle = iter->second.first;
683         *body_handle = iter->second.second;
684       } else {
685         TF_RETURN_IF_ERROR(Instantiate(ctx, cond_func_, cond_handle));
686         TF_RETURN_IF_ERROR(Instantiate(ctx, body_func_, body_handle));
687         handles_[lib] = {*cond_handle, *body_handle};
688       }
689     }
690     return OkStatus();
691   }
692 };
693 // TODO(drpng): remove these.
694 REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp);
695 REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_DEFAULT), WhileOp);
696 
697 REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp);
698 REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_DEFAULT), WhileOp);
699 
700 REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_CPU), WhileOp);
701 REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_DEFAULT), WhileOp);
702 
703 class ToBoolOp : public OpKernel {
704  public:
ToBoolOp(OpKernelConstruction * ctx)705   explicit ToBoolOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)706   void Compute(OpKernelContext* ctx) override {
707     bool b;
708     OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &b));
709     Tensor* out;
710     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
711     out->scalar<bool>()() = b;
712   }
713 };
714 
715 REGISTER_KERNEL_BUILDER(Name("ToBool").Device(DEVICE_CPU), ToBoolOp);
716 
GetScalar(OpKernelContext * ctx,int index,int32 * value,const char * label)717 Status GetScalar(OpKernelContext* ctx, int index, int32* value,
718                  const char* label) {
719   Tensor t = ctx->input(index);
720   if (!TensorShapeUtils::IsScalar(t.shape())) {
721     return errors::InvalidArgument(label, " must be a scalar, but ",
722                                    t.shape().DebugString());
723   }
724   *value = t.scalar<int32>()();
725   return OkStatus();
726 }
727 
728 class ForOp : public AsyncOpKernel {
729  public:
ForOp(OpKernelConstruction * ctx)730   explicit ForOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
731     auto lib = ctx->function_library();
732     OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
733     const NameAttrList* func;
734     OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &func));
735     OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &body_handle_));
736   }
737 
~ForOp()738   ~ForOp() override {}
739 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)740   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
741     (new State(this, ctx, done))->Start();
742   }
743 
744  private:
745   FHandle body_handle_;
746 
747   class State {
748    public:
State(ForOp * kernel,OpKernelContext * ctx,DoneCallback done)749     State(ForOp* kernel, OpKernelContext* ctx, DoneCallback done)
750         : kernel_(kernel),
751           ctx_(ctx),
752           done_(std::move(done)),
753           lib_(CHECK_NOTNULL(ctx_->function_library())),
754           args_(1 + ctx_->num_inputs() - 3) {
755       args_[0] = Tensor(DT_INT32, {});
756       iter_ = &args_[0].scalar<int32>()();
757 
758       const int32_t num_loop_inputs = ctx_->num_inputs() - 3;
759       rets_.reserve(num_loop_inputs);
760       for (int i = 0; i < num_loop_inputs; ++i) {
761         rets_.push_back(ctx_->input(3 + i));
762       }
763     }
764 
~State()765     ~State() {}
766 
Start()767     void Start() {
768       Status s = StartLoop();
769       if (!s.ok()) Finish(s);
770     }
771 
772    private:
773     ForOp* const kernel_;
774     OpKernelContext* const ctx_;
775     const DoneCallback done_;
776     FunctionLibraryRuntime* const lib_;
777     FunctionLibraryRuntime::Options opts_;
778     TensorVec args_;
779     TensorVec rets_;
780 
781     int32* iter_;  // points to args_[0].
782     int32 limit_;
783     int32 delta_;
784 
785     // If an error e is returned, caller must call Finish(e).
786     // If OK is returned, the async loop execution has been started.
StartLoop()787     Status StartLoop() {
788       SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
789 
790       TF_RETURN_IF_ERROR(GetScalar(ctx_, 0, iter_, "start"));
791       TF_RETURN_IF_ERROR(GetScalar(ctx_, 1, &limit_, "limit"));
792       TF_RETURN_IF_ERROR(GetScalar(ctx_, 2, &delta_, "delta"));
793 
794       if ((delta_ > 0 && *iter_ <= limit_) ||
795           (delta_ < 0 && *iter_ >= limit_) ||
796           (delta_ == 0 && *iter_ == limit_)) {
797         RunNext();
798         return OkStatus();
799       } else {
800         return errors::InvalidArgument("Invalid start/limit/delta: ", *iter_,
801                                        " ", limit_, " ", delta_);
802       }
803     }
804 
RunNext()805     void RunNext() {
806       bool done_loop;
807       if (delta_ > 0) {
808         done_loop = *iter_ >= limit_;
809       } else {
810         done_loop = *iter_ <= limit_;
811       }
812       if (done_loop) {
813         Finish(OkStatus());
814         return;
815       }
816 
817       if (rets_.size() >= args_.size()) {
818         Finish(errors::InvalidArgument(
819             "For loop body returned ", rets_.size(),
820             " arguments. Expected: ", args_.size() - 1));
821         return;
822       }
823       for (int i = 0; i < rets_.size(); ++i) {
824         args_[1 + i] = std::move(rets_[i]);
825       }
826       rets_.clear();
827       profiler::TraceMe trace_me("ForOp");
828       lib_->Run(opts_, kernel_->body_handle_, args_, &rets_,
829                 [this](const Status& s) {
830                   if (s.ok()) {
831                     *iter_ += delta_;
832                     RunNext();
833                   } else {
834                     Finish(s);
835                   }
836                 });
837     }
838 
Finish(Status s)839     void Finish(Status s) {
840       if (s.ok()) {
841         s = SetOutputs(kernel_, ctx_, rets_);
842       }
843       ctx_->SetStatus(s);
844       done_();
845       delete this;
846     }
847   };
848 };
849 
850 REGISTER_KERNEL_BUILDER(Name("For").Device(DEVICE_CPU), ForOp);
851 REGISTER_KERNEL_BUILDER(Name("For")
852                             .Device(DEVICE_DEFAULT)
853                             .HostMemory("start")
854                             .HostMemory("limit")
855                             .HostMemory("delta"),
856                         ForOp);
857 
858 // FakeParamOp allocates a tensor with a shape conforming to the expected
859 // output. This is necessary if the value will be stored in a while_loop's
860 // TensorList. The output is otherwise not expected to be consumed by anything
861 // else.
862 class FakeParamOp : public OpKernel {
863  public:
FakeParamOp(OpKernelConstruction * context)864   explicit FakeParamOp(OpKernelConstruction* context) : OpKernel(context) {
865     DataType dtype;
866     OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype));
867 
868     // Set shape to the specified shape, setting unknown dimensions to empty.
869     // If the specified shape is unknown, leave as an empty shape.
870     TensorShape shape;
871     PartialTensorShape partial_shape;
872     OP_REQUIRES_OK(context, context->GetAttr("shape", &partial_shape));
873     if (!partial_shape.unknown_rank()) {
874       for (int64_t d : partial_shape.dim_sizes()) {
875         shape.AddDim(d == -1 ? 0 : d);
876       }
877     }
878 
879     // Create a tensor that we can repeatedly return to save memory.
880     // TODO(b/119612758): add optimization to prevent sending this across
881     // devices on each Compute() call.
882     OP_REQUIRES_OK(context, context->allocate_temp(dtype, shape, &value_));
883   }
884 
Compute(OpKernelContext * context)885   void Compute(OpKernelContext* context) override {
886     context->set_output(0, value_);
887   }
888 
889  private:
890   Tensor value_;
891 };
892 
893 REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);
894 REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_DEFAULT), FakeParamOp);
895 
896 // DeviceIndexOP returns the current device index.
897 class DeviceIndexOp : public OpKernel {
898  public:
DeviceIndexOp(OpKernelConstruction * ctx)899   explicit DeviceIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
900     OP_REQUIRES_OK(ctx, ctx->GetAttr("device_names", &device_names_));
901   }
902 
Compute(OpKernelContext * ctx)903   void Compute(OpKernelContext* ctx) override {
904     Tensor* device_name_t;
905     OP_REQUIRES_OK(ctx,
906                    ctx->allocate_output(0, TensorShape({}), &device_name_t));
907     DeviceNameUtils::ParsedName parsed_name;
908     int index = device_names_.size();
909     if (DeviceNameUtils::ParseFullName(ctx->device()->name(), &parsed_name) &&
910         parsed_name.has_type) {
911       auto it = absl::c_find(device_names_, parsed_name.type);
912       if (it != device_names_.end()) {
913         index = it - device_names_.begin();
914       }
915     }
916     device_name_t->scalar<int32>()() = index;
917   }
918 
919  private:
920   std::vector<string> device_names_;
921 };
922 
923 REGISTER_KERNEL_BUILDER(Name("DeviceIndex").Device(DEVICE_CPU), DeviceIndexOp);
924 REGISTER_KERNEL_BUILDER(
925     Name("DeviceIndex").Device(DEVICE_DEFAULT).HostMemory("index"),
926     DeviceIndexOp);
927 
928 }  // namespace
929 }  // namespace tensorflow
930