xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/captured_function.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/captured_function.h"
16 
17 #include <utility>
18 
19 #include "absl/time/clock.h"
20 #include "tensorflow/core/common_runtime/function.h"
21 #include "tensorflow/core/common_runtime/step_stats_collector.h"
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/data/stats_utils.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/function_handle_cache.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/stats_aggregator.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/optional.h"
32 #include "tensorflow/core/lib/random/random.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/notification.h"
36 #include "tensorflow/core/profiler/lib/traceme.h"
37 
38 #if !defined(IS_MOBILE_PLATFORM)
39 #include "tensorflow/core/grappler/grappler_item.h"
40 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
41 #endif  // !IS_MOBILE_PLATFORM
42 
43 namespace tensorflow {
44 namespace data {
45 namespace {
46 
47 constexpr char kAllowSmallFunctionOptimizations[] =
48     "allow_small_function_optimizations";
49 
50 // Simplistic implementation of the `StepStatsCollectorInterface` that only
51 // cares about collecting the CPU time needed to execute a captured function.
52 class SimpleStepStatsCollector : public StepStatsCollectorInterface {
53  public:
IncrementProcessingTime(int64_t delta)54   void IncrementProcessingTime(int64_t delta) {
55     mutex_lock l(mu_);
56     processing_time_ += delta;
57   }
58 
CreateNodeExecStats(const NodeDef * node)59   NodeExecStatsInterface* CreateNodeExecStats(const NodeDef* node) override {
60     return new SimpleNodeExecStats(this);
61   }
62 
ReportAllocsOnResourceExhausted(const string & err)63   string ReportAllocsOnResourceExhausted(const string& err) override {
64     return "";
65   }
66 
processing_time()67   int64_t processing_time() {
68     tf_shared_lock l(mu_);
69     return processing_time_;
70   }
71 
72  private:
73   class SimpleNodeExecStats : public NodeExecStatsInterface {
74    public:
SimpleNodeExecStats(SimpleStepStatsCollector * step_stats_collector)75     explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
76         : step_stats_collector_(step_stats_collector) {}
77 
Done(const string & device)78     void Done(const string& device) override {
79       step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
80                                                      start_time_ns_);
81       delete this;
82     }
83 
RecordExecutorStarted()84     void RecordExecutorStarted() override {
85       start_time_ns_ = absl::GetCurrentTimeNanos();
86     }
87 
RecordComputeStarted()88     void RecordComputeStarted() override {}
89 
RecordComputeEnded()90     void RecordComputeEnded() override {}
91 
RecordExecutorEnded()92     void RecordExecutorEnded() override {
93       end_time_ns_ = absl::GetCurrentTimeNanos();
94     }
95 
TrackAllocations() const96     bool TrackAllocations() const override { return false; }
97 
SetMemory(OpKernelContext * ctx)98     void SetMemory(OpKernelContext* ctx) override {}
99 
SetOutput(int slot,const Tensor * tensor)100     void SetOutput(int slot, const Tensor* tensor) override {}
101 
SetScheduled(int64_t nanos)102     void SetScheduled(int64_t nanos) override {}
103 
104    private:
105     int64_t start_time_ns_ = 0;
106     int64_t end_time_ns_ = 0;
107     SimpleStepStatsCollector* step_stats_collector_;  // Not owned.
108   };
109 
110   mutex mu_;
111   int64_t processing_time_ TF_GUARDED_BY(mu_) = 0;
112 };
113 
GetCapturedInput(const CapturedFunction * const func,int index,const Tensor ** out)114 Status GetCapturedInput(const CapturedFunction* const func, int index,
115                         const Tensor** out) {
116   if (TF_PREDICT_FALSE(index >= func->captured_inputs().size())) {
117     return errors::OutOfRange(
118         "Out of range access to captured inputs for function ",
119         func->func().name(), ". Index: ", index,
120         ". Num captured inputs: ", func->captured_inputs().size());
121   }
122   *out = &func->captured_inputs()[index];
123   return OkStatus();
124 }
125 
RunShortCircuit(const ShortCircuitInfo & info,const std::vector<Tensor> & args,const CapturedFunction * const func,std::vector<Tensor> * rets)126 Status RunShortCircuit(const ShortCircuitInfo& info,
127                        const std::vector<Tensor>& args,
128                        const CapturedFunction* const func,
129                        std::vector<Tensor>* rets) {
130   VLOG(3) << "Running function " << func->func().name() << " short circuit";
131   const int num_args = args.size();
132   rets->reserve(info.indices.size());
133   for (size_t i = 0; i < info.indices.size(); ++i) {
134     if (info.indices[i] < num_args) {
135       rets->push_back(args[info.indices[i]]);
136     } else {
137       const Tensor* captured_input;
138       TF_RETURN_IF_ERROR(
139           GetCapturedInput(func, info.indices[i] - num_args, &captured_input));
140       rets->push_back(*captured_input);
141     }
142   }
143   return OkStatus();
144 }
145 
RunShortCircuit(const ShortCircuitInfo & info,std::vector<Tensor> && args,const CapturedFunction * const func,std::vector<Tensor> * rets)146 Status RunShortCircuit(const ShortCircuitInfo& info, std::vector<Tensor>&& args,
147                        const CapturedFunction* const func,
148                        std::vector<Tensor>* rets) {
149   VLOG(3) << "Running function " << func->func().name() << " short circuit";
150   const int num_args = args.size();
151   rets->reserve(info.indices.size());
152   for (size_t i = 0; i < info.indices.size(); ++i) {
153     if (info.indices[i] < num_args) {
154       if (info.can_move[i]) {
155         rets->push_back(std::move(args[info.indices[i]]));
156       } else {
157         rets->push_back(args[info.indices[i]]);
158       }
159     } else {
160       const Tensor* captured_input;
161       TF_RETURN_IF_ERROR(
162           GetCapturedInput(func, info.indices[i] - num_args, &captured_input));
163       rets->push_back(*captured_input);
164     }
165   }
166   return OkStatus();
167 }
168 
CreateShortCircuitInfo(OpKernelConstruction * ctx,const NameAttrList & func,ShortCircuitInfo * info)169 Status CreateShortCircuitInfo(OpKernelConstruction* ctx,
170                               const NameAttrList& func,
171                               ShortCircuitInfo* info) {
172   auto& indices = info->indices;
173 
174   FunctionLibraryRuntime::Handle fn_handle;
175   TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
176       func.name(), AttrSlice(&func.attr()), &fn_handle));
177   auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
178     Status s = ctx->function_library()->ReleaseHandle(fn_handle);
179     if (!s.ok()) {
180       LOG(WARNING) << "Failed to release handle: " << s.error_message();
181     }
182   });
183 
184   // If the function contains any stateful operations, we conservatively execute
185   // the entire function.
186   if (ctx->function_library()->IsStateful(func.name())) {
187     return OkStatus();
188   }
189 
190   const FunctionBody* fn_body =
191       ctx->function_library()->GetFunctionBody(fn_handle);
192   indices.resize(fn_body->ret_nodes.size());
193 
194   for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
195     Node* ret_node = fn_body->ret_nodes[i];
196     Node* ret_input_node;
197     TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
198 
199     while (ret_input_node->def().op() == "Identity") {
200       TF_RETURN_IF_ERROR(ret_input_node->input_node(0, &ret_input_node));
201     }
202 
203     if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
204       TF_RETURN_IF_ERROR(
205           GetNodeAttr(ret_input_node->def(), "index", &(indices[i])));
206     } else {
207       indices.clear();
208       break;
209     }
210   }
211 
212   // Compute the `can_move` vector.
213   if (!indices.empty()) {
214     auto& can_move = info->can_move;
215     std::map<int, int> last_use;
216     for (size_t i = 0; i < indices.size(); ++i) {
217       last_use[indices[i]] = i;
218     }
219     can_move.resize(indices.size());
220     for (int i = 0, end = indices.size(); i < end; ++i) {
221       can_move[i] = last_use[indices[i]] == i;
222     }
223   }
224 
225   return OkStatus();
226 }
227 
CreateFunctionLibraryDefinition(const FunctionLibraryDefinition * lib_def,const string & func_name,std::unique_ptr<FunctionLibraryDefinition> * result)228 Status CreateFunctionLibraryDefinition(
229     const FunctionLibraryDefinition* lib_def, const string& func_name,
230     std::unique_ptr<FunctionLibraryDefinition>* result) {
231   DCHECK(lib_def != nullptr);
232   const FunctionDef* fdef = lib_def->Find(func_name);
233   if (TF_PREDICT_FALSE(fdef == nullptr)) {
234     return errors::FailedPrecondition(strings::StrCat(
235         "Could not find required function definition ", func_name));
236   }
237   *result = std::make_unique<FunctionLibraryDefinition>(
238       lib_def->ReachableDefinitions(*fdef));
239   return (*result)->CopyFunctionDefFrom(func_name, *lib_def);
240 }
241 
LookupFunction(const FunctionLibraryDefinition & lib_def,const string & name,const FunctionDef ** fdef)242 Status LookupFunction(const FunctionLibraryDefinition& lib_def,
243                       const string& name, const FunctionDef** fdef) {
244   *fdef = lib_def.Find(name);
245   if (*fdef == nullptr) {
246     return errors::InvalidArgument(
247         "Failed to find function ", name,
248         " in function library: ", lib_def.ToProto().DebugString());
249   }
250   return OkStatus();
251 }
252 
253 class CallFrameBase : public CallFrameInterface {
254  public:
CallFrameBase(DataTypeSlice ret_types)255   explicit CallFrameBase(DataTypeSlice ret_types)
256       : ret_types_(ret_types), retvals_(ret_types.size()) {}
257 
258   // Caller methods.
ConsumeRetvals(std::vector<Tensor> * retvals)259   Status ConsumeRetvals(std::vector<Tensor>* retvals) {
260     retvals->reserve(retvals_.size());
261     int i = 0;
262     for (auto&& val : retvals_) {
263       if (!val) {
264         return errors::Internal("No return value for index ", i, ".");
265       }
266       retvals->emplace_back(std::move(val.value()));
267       ++i;
268     }
269     return OkStatus();
270   }
271 
num_retvals() const272   size_t num_retvals() const override { return retvals_.size(); }
273 
274   // Callee methods.
SetRetval(int index,const Tensor & val)275   Status SetRetval(int index, const Tensor& val) override {
276     const int retvals_size = retvals_.size();
277     if (index < retvals_size && val.dtype() == ret_types_[index] &&
278         !retvals_[index]) {
279       retvals_[index] = val;
280       return OkStatus();
281     } else if (index >= retvals_size) {
282       return errors::InvalidArgument("Return value ", index,
283                                      " is out of range.");
284     } else if (val.dtype() != ret_types_[index]) {
285       return errors::InvalidArgument("Expected type ",
286                                      DataTypeString(ret_types_[index]),
287                                      " for return value ", index, " but got ",
288                                      DataTypeString(val.dtype()), ".");
289     } else {
290       return errors::Internal("Attempted to set return value ", index,
291                               " more than once.");
292     }
293   }
294 
295  private:
296   DataTypeSlice ret_types_;
297   std::vector<gtl::optional<Tensor>> retvals_;
298   TF_DISALLOW_COPY_AND_ASSIGN(CallFrameBase);
299 };
300 
301 class OwnedArgsCallFrame : public CallFrameBase {
302  public:
OwnedArgsCallFrame(std::vector<Tensor> && args,const std::vector<Tensor> * captured_inputs,DataTypeSlice ret_types)303   OwnedArgsCallFrame(std::vector<Tensor>&& args,
304                      const std::vector<Tensor>* captured_inputs,
305                      DataTypeSlice ret_types)
306       : CallFrameBase(ret_types),
307         args_(std::move(args)),
308         captured_inputs_(captured_inputs) {}
309 
num_args() const310   size_t num_args() const override {
311     return args_.size() + captured_inputs_->size();
312   }
313 
314   // Callee methods.
GetArg(int index,const Tensor ** val)315   Status GetArg(int index, const Tensor** val) override {
316     const int args_size = args_.size();
317     const int captured_inputs_size = captured_inputs_->size();
318     if (index < args_size) {
319       *val = &args_[index];
320       return OkStatus();
321     } else if (index < args_size + captured_inputs_size) {
322       *val = &(*captured_inputs_)[index - args_.size()];
323       return OkStatus();
324     } else {
325       return errors::InvalidArgument("Argument ", index, " is out of range.");
326     }
327   }
328 
329   // Since we own the argument tensors in `args_`, we can implement
330   // `ConsumeArg()` for those arguments.
ConsumeArg(int index,Tensor * val)331   void ConsumeArg(int index, Tensor* val) override {
332     DCHECK_GE(index, 0);
333     DCHECK_LT(index, args_.size());
334     *val = std::move(args_[index]);
335   }
CanConsumeArg(int index) const336   bool CanConsumeArg(int index) const override {
337     return index >= 0 && index < static_cast<int>(args_.size());
338   }
339 
340  private:
341   std::vector<Tensor> args_;
342   const std::vector<Tensor>* const captured_inputs_;  // Not owned.
343 };
344 
345 class BorrowedArgsCallFrame : public CallFrameBase {
346  public:
BorrowedArgsCallFrame(const std::vector<Tensor> & args,const std::vector<Tensor> * captured_inputs,DataTypeSlice ret_types)347   BorrowedArgsCallFrame(const std::vector<Tensor>& args,
348                         const std::vector<Tensor>* captured_inputs,
349                         DataTypeSlice ret_types)
350       : CallFrameBase(ret_types),
351         args_(args),
352         captured_inputs_(captured_inputs) {}
353 
num_args() const354   size_t num_args() const override {
355     return args_.size() + captured_inputs_->size();
356   }
357 
358   // Callee methods.
GetArg(int index,const Tensor ** val)359   Status GetArg(int index, const Tensor** val) override {
360     const int args_size = args_.size();
361     const int captured_inputs_size = captured_inputs_->size();
362     if (index < args_size) {
363       *val = &args_[index];
364       return OkStatus();
365     } else if (index < args_size + captured_inputs_size) {
366       *val = &(*captured_inputs_)[index - args_size];
367       return OkStatus();
368     } else {
369       return errors::InvalidArgument("Argument ", index, " is out of range.");
370     }
371   }
372 
373  private:
374   const std::vector<Tensor>& args_;                   // Not owned.
375   const std::vector<Tensor>* const captured_inputs_;  // Not owned.
376 };
377 
378 }  // namespace
379 
MakeIteratorFromInputElement(IteratorContext * ctx,const IteratorBase * parent,const std::vector<Tensor> & input_element,int64_t thread_index,const InstantiatedCapturedFunction & inst_captured_func,StringPiece prefix,std::unique_ptr<IteratorBase> * out_iterator)380 Status MakeIteratorFromInputElement(
381     IteratorContext* ctx, const IteratorBase* parent,
382     const std::vector<Tensor>& input_element, int64_t thread_index,
383     const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
384     std::unique_ptr<IteratorBase>* out_iterator) {
385   return MakeIteratorFromInputElement(ctx, parent, input_element, thread_index,
386                                       inst_captured_func, prefix, out_iterator,
387                                       /*node=*/nullptr);
388 }
389 
MakeIteratorFromInputElement(IteratorContext * ctx,const IteratorBase * parent,const std::vector<Tensor> & input_element,int64_t thread_index,const InstantiatedCapturedFunction & inst_captured_func,StringPiece prefix,std::unique_ptr<IteratorBase> * out_iterator,const std::shared_ptr<model::Node> & node)390 Status MakeIteratorFromInputElement(
391     IteratorContext* ctx, const IteratorBase* parent,
392     const std::vector<Tensor>& input_element, int64_t thread_index,
393     const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
394     std::unique_ptr<IteratorBase>* out_iterator,
395     const std::shared_ptr<model::Node>& node) {
396   std::vector<Tensor> return_values;
397 
398   TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(
399       ctx, input_element, &return_values, node));
400 
401   if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
402         TensorShapeUtils::IsScalar(return_values[0].shape()))) {
403     return errors::InvalidArgument(
404         "Function must return a single scalar of dtype DT_VARIANT.");
405   }
406 
407   // Retrieve the dataset that was created in `f`.
408   DatasetBase* returned_dataset;
409   TF_RETURN_IF_ERROR(
410       GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
411 
412   // Create an iterator for the dataset that was returned by `f`.
413   std::string iterator_prefix = strings::StrCat(prefix, "[", thread_index, "]");
414 
415   return returned_dataset->MakeIterator(MakeNestedIteratorContext(ctx), parent,
416                                         iterator_prefix, out_iterator);
417 }
418 
MakeNestedIteratorContext(IteratorContext * ctx)419 IteratorContext MakeNestedIteratorContext(IteratorContext* ctx) {
420   // Strip out any split providers so that they don't apply to sub-iterators.
421   if (ctx->split_providers().empty()) {
422     return *ctx;
423   }
424   IteratorContext::Params params(ctx);
425   params.split_providers.clear();
426   return IteratorContext(std::move(params));
427 }
428 
429 /* static */
Create(OpKernelConstruction * ctx,const string & func_name,Params params,std::shared_ptr<FunctionMetadata> * out_metadata)430 Status FunctionMetadata::Create(
431     OpKernelConstruction* ctx, const string& func_name, Params params,
432     std::shared_ptr<FunctionMetadata>* out_metadata) {
433   NameAttrList func;
434   TF_RETURN_IF_ERROR(ctx->GetAttr(func_name, &func));
435   return Create(ctx, std::move(func), params, out_metadata);
436 }
437 
Create(OpKernelConstruction * ctx,NameAttrList && func,Params params,std::shared_ptr<FunctionMetadata> * out_metadata)438 Status FunctionMetadata::Create(
439     OpKernelConstruction* ctx, NameAttrList&& func, Params params,
440     std::shared_ptr<FunctionMetadata>* out_metadata) {
441   out_metadata->reset(new FunctionMetadata(std::move(func), params));
442   TF_RETURN_IF_ERROR(CreateFunctionLibraryDefinition(
443       ctx->function_library()->GetFunctionLibraryDefinition(),
444       (*out_metadata)->func_.name(), &(*out_metadata)->lib_def_));
445   TF_RETURN_IF_ERROR(CreateShortCircuitInfo(
446       ctx, (*out_metadata)->func_, &(*out_metadata)->short_circuit_info_));
447   const FunctionDef* fdef;
448   TF_RETURN_IF_ERROR(LookupFunction(*(*out_metadata)->lib_def(),
449                                     (*out_metadata)->func().name(), &fdef));
450 
451   auto attr = fdef->attr().find(FunctionLibraryDefinition::kIntsOnDeviceAttr);
452   if (attr != fdef->attr().end() && attr->second.b()) {
453     VLOG(1) << "Disabling multi-device execution for a function that uses the "
454             << FunctionLibraryDefinition::kIntsOnDeviceAttr << " attribute.";
455     (*out_metadata)->use_multi_device_function_ = false;
456     return OkStatus();
457   }
458   auto validate_arg = [](const OpDef::ArgDef& arg) {
459     if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
460       VLOG(1) << "Disabling multi-device execution for a function with "
461               << "a vector argument " << arg.name() << ".";
462       return false;
463     }
464     return true;
465   };
466   for (const auto& arg : fdef->signature().input_arg()) {
467     if (!validate_arg(arg)) {
468       (*out_metadata)->use_multi_device_function_ = false;
469       return OkStatus();
470     }
471   }
472   for (const auto& arg : fdef->signature().output_arg()) {
473     if (!validate_arg(arg)) {
474       (*out_metadata)->use_multi_device_function_ = false;
475       return OkStatus();
476     }
477   }
478   return OkStatus();
479 }
480 
481 /* static */
Create(OpKernelContext * ctx,std::shared_ptr<const FunctionMetadata> metadata,const string & argument_name,std::unique_ptr<CapturedFunction> * out_function)482 Status CapturedFunction::Create(
483     OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata,
484     const string& argument_name,
485     std::unique_ptr<CapturedFunction>* out_function) {
486   OpInputList inputs;
487   TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs));
488   std::vector<Tensor> captured_inputs(inputs.begin(), inputs.end());
489   return Create(ctx, std::move(metadata), std::move(captured_inputs),
490                 out_function);
491 }
492 
493 /* static */
Create(OpKernelContext * ctx,std::shared_ptr<const FunctionMetadata> metadata,std::vector<Tensor> && captured_inputs,std::unique_ptr<CapturedFunction> * out_function)494 Status CapturedFunction::Create(
495     OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata,
496     std::vector<Tensor>&& captured_inputs,
497     std::unique_ptr<CapturedFunction>* out_function) {
498   *out_function = absl::WrapUnique(
499       new CapturedFunction(std::move(metadata), std::move(captured_inputs)));
500   return OkStatus();
501 }
502 
AddToGraph(SerializationContext * ctx,DatasetBase::DatasetGraphDefBuilder * b,std::vector<Node * > * other_arguments,DataTypeVector * other_arguments_types) const503 Status CapturedFunction::AddToGraph(
504     SerializationContext* ctx, DatasetBase::DatasetGraphDefBuilder* b,
505     std::vector<Node*>* other_arguments,
506     DataTypeVector* other_arguments_types) const {
507   other_arguments->reserve(captured_inputs_.size());
508   other_arguments_types->reserve(captured_inputs_.size());
509   for (const Tensor& t : captured_inputs_) {
510     Node* node;
511     if (!ctx->is_graph_rewrite()) {
512       TF_RETURN_IF_ERROR(b->AddDatasetOrTensor(ctx, t, &node));
513     } else {
514       TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
515       DCHECK_NE(ctx->input_list(), nullptr);
516       ctx->input_list()->emplace_back(node->name(), t);
517     }
518     other_arguments->emplace_back(node);
519     other_arguments_types->emplace_back(t.dtype());
520   }
521   TF_RETURN_IF_ERROR(
522       b->AddFunction(ctx, metadata_->func().name(), *metadata_->lib_def()));
523   return OkStatus();
524 }
525 
Instantiate(IteratorContext * ctx,std::unique_ptr<InstantiatedCapturedFunction> * instantiated_captured_function)526 Status CapturedFunction::Instantiate(
527     IteratorContext* ctx, std::unique_ptr<InstantiatedCapturedFunction>*
528                               instantiated_captured_function) {
529   return CapturedFunction::Instantiate(InstantiateCapturedFunctionParams(ctx),
530                                        instantiated_captured_function);
531 }
532 
533 // TODO(b/190831948): Check whether the function creates a resource and if so,
534 // produce a warning.
Instantiate(InstantiateCapturedFunctionParams params,std::unique_ptr<InstantiatedCapturedFunction> * instantiated_captured_function)535 Status CapturedFunction::Instantiate(
536     InstantiateCapturedFunctionParams params,
537     std::unique_ptr<InstantiatedCapturedFunction>*
538         instantiated_captured_function) {
539   // The context's runtime will be used for all subsequent calls.
540   FunctionLibraryRuntime* lib = params.flr;
541   FunctionLibraryRuntime::InstantiateOptions inst_opts;
542   inst_opts.lib_def = metadata_->lib_def();
543   inst_opts.create_kernels_eagerly = true;
544   inst_opts.default_device_to_target = metadata_->use_default_device();
545   inst_opts.config_proto =
546       lib->config_proto() ? *lib->config_proto() : ConfigProto();
547   if (GetExperiments().contains(kAllowSmallFunctionOptimizations)) {
548     inst_opts.allow_small_function_optimizations = true;
549   } else {
550     if (!metadata_->use_inter_op_parallelism()) {
551       inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
552     }
553   }
554   inst_opts.is_multi_device_function = metadata_->use_multi_device_function();
555   if (!params.function_handle_cache) {
556     // If the caller does not provide a cache, we use the FLR cache.
557     inst_opts.use_function_cache = true;
558   }
559 
560   // We infer the target device from the function library runtime.
561   DCHECK(lib->device() != nullptr);
562   inst_opts.target = lib->device()->name();
563 
564   // Maps from a CompositeDevice name to underlying physical device names.
565   absl::flat_hash_map<string, std::vector<string>> composite_devices;
566 
567   if (inst_opts.is_multi_device_function) {
568     // Compute devices of non-captured inputs.
569     //
570     // We infer the number of non-captured inputs by subtracting the number
571     // of captured inputs from the number of input arguments and we infer the
572     // input devices from the function library runtime.
573     const FunctionDef* fdef;
574     TF_RETURN_IF_ERROR(
575         LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
576     size_t num_non_captured_inputs =
577         fdef->signature().input_arg_size() - captured_inputs_.size();
578     for (size_t i = 0; i < num_non_captured_inputs; ++i) {
579       inst_opts.input_devices.push_back(inst_opts.target);
580     }
581     // Compute devices of captured inputs.
582     // TODO(jsimsa): Correctly handle tensors on devices other than CPU:0.
583     Device* cpu_device;
584     TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
585     std::unordered_map<int, DtypeAndPartialTensorShape>&
586         input_resource_variable_dtypes_and_shapes =
587             inst_opts.input_resource_dtypes_and_shapes;
588     for (size_t i = 0; i < captured_inputs_.size(); ++i) {
589       const auto& input = captured_inputs_[i];
590       DataType dtype = input.dtype();
591       if (dtype == DT_RESOURCE) {
592         const auto& handles = input.flat<ResourceHandle>();
593         const ResourceHandle& handle0 = handles(0);
594         string composite_device;
595         auto iter = fdef->arg_attr().find(num_non_captured_inputs + i);
596         if (iter != fdef->arg_attr().end()) {
597           auto arg_attr = iter->second.attr().find("_composite_device");
598           if (arg_attr != iter->second.attr().end()) {
599             composite_device = arg_attr->second.s();
600           }
601         }
602         if (!composite_device.empty()) {
603           if (composite_devices.find(composite_device) ==
604               composite_devices.end()) {
605             for (int i = 0; i < handles.size(); ++i) {
606               composite_devices[composite_device].push_back(
607                   handles(i).device());
608             }
609           }
610           inst_opts.input_devices.push_back(composite_device);
611         } else {
612           inst_opts.input_devices.push_back(handle0.device());
613         }
614         const auto& dtypes_and_shapes = handle0.dtypes_and_shapes();
615         // Set dtypes and shapes for resource variable inputs.
616         if (!dtypes_and_shapes.empty()) {
617           input_resource_variable_dtypes_and_shapes[num_non_captured_inputs +
618                                                     i] =
619               dtypes_and_shapes.at(0);
620         }
621       } else if (MTypeFromDType(dtype) == HOST_MEMORY) {
622         inst_opts.input_devices.push_back(cpu_device->name());
623       } else {
624         // Fall back to using the function library runtime device.
625         inst_opts.input_devices.push_back(inst_opts.target);
626       }
627     }
628 
629     for (const auto& it : composite_devices) {
630       inst_opts.composite_devices[it.first] = &it.second;
631     }
632 
633     for (int i = 0, end = fdef->signature().output_arg_size(); i < end; ++i) {
634       inst_opts.output_devices.push_back(inst_opts.target);
635     }
636 
637 #if !defined(IS_MOBILE_PLATFORM)
638     grappler::GrapplerItem::OptimizationOptions optimization_options;
639     optimization_options.allow_pruning_stateful_and_dataset_ops = false;
640     ConfigProto config_proto = inst_opts.config_proto;
641     // Layout optimizations are excluded because they assume that ops without
642     // explicit device assignment will be placed on GPU (if available) but
643     // that's not the case for operations within tf.data functions.
644     config_proto.mutable_graph_options()
645         ->mutable_rewrite_options()
646         ->set_layout_optimizer(RewriterConfig::OFF);
647     // TODO(b/120437209): Re-enable constant folding.
648     config_proto.mutable_graph_options()
649         ->mutable_rewrite_options()
650         ->set_constant_folding(RewriterConfig::OFF);
651     inst_opts.optimize_graph_fn =
652         std::bind(tensorflow::grappler::OptimizeGraph, std::placeholders::_1,
653                   std::placeholders::_2, std::placeholders::_3,
654                   std::placeholders::_4, std::placeholders::_5,
655                   std::move(config_proto), fdef->signature().name(),
656                   std::move(optimization_options), std::placeholders::_6);
657 #endif  // !IS_MOBILE_PLATFORM
658   }
659 
660   FunctionLibraryRuntime::Handle f_handle;
661   if (params.function_handle_cache) {
662     TF_RETURN_IF_ERROR(params.function_handle_cache->Instantiate(
663         metadata_->func().name(), AttrSlice(&metadata_->func().attr()),
664         inst_opts, &f_handle));
665   } else {
666     TF_RETURN_IF_ERROR(lib->Instantiate(metadata_->func().name(),
667                                         AttrSlice(&metadata_->func().attr()),
668                                         inst_opts, &f_handle));
669   }
670 
671   DataTypeVector ret_types;
672   TF_RETURN_IF_ERROR(lib->GetRetTypes(f_handle, &ret_types));
673 
674   bool is_multi_device;
675   TF_RETURN_IF_ERROR(IsMultiDevice(lib, &is_multi_device));
676   *instantiated_captured_function = absl::WrapUnique(
677       new InstantiatedCapturedFunction(lib, f_handle, std::move(ret_types),
678                                        *params.runner, this, is_multi_device));
679   return OkStatus();
680 }
681 
CheckExternalState() const682 Status CapturedFunction::CheckExternalState() const {
683   for (const auto& name : lib_def()->ListFunctionNames()) {
684     TF_RETURN_IF_ERROR(
685         IsFunctionStateful(*lib_def(), *(lib_def()->Find(name))));
686   }
687   return OkStatus();
688 }
689 
CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,std::vector<Tensor> captured_inputs)690 CapturedFunction::CapturedFunction(
691     std::shared_ptr<const FunctionMetadata> metadata,
692     std::vector<Tensor> captured_inputs)
693     : metadata_(std::move(metadata)),
694       captured_inputs_(std::move(captured_inputs)) {}
695 
IsMultiDevice(FunctionLibraryRuntime * flr,bool * is_multi_device) const696 Status CapturedFunction::IsMultiDevice(FunctionLibraryRuntime* flr,
697                                        bool* is_multi_device) const {
698   if (!metadata_->use_multi_device_function()) {
699     *is_multi_device = false;
700     return OkStatus();
701   }
702 
703   const FunctionDef* fdef;
704   TF_RETURN_IF_ERROR(
705       LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
706 
707   Device* current_device = flr->device();
708   DeviceType current_device_type(current_device->device_type());
709   DeviceNameUtils::ParsedName current_device_name;
710   if (!DeviceNameUtils::ParseFullName(current_device->name(),
711                                       &current_device_name)) {
712     return errors::InvalidArgument("Failed to parse device name: ",
713                                    current_device->name());
714   }
715 
716   // Check if any of the captured inputs are placed on a device not compatible
717   // with the current device. For non-captured inputs, we assume they are placed
718   // on the current device.
719   for (const auto& input : captured_inputs_) {
720     DataType dtype = input.dtype();
721     if (dtype == DT_RESOURCE) {
722       const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
723       DeviceNameUtils::ParsedName resource_device_name;
724       if (!DeviceNameUtils::ParseFullName(handle.device(),
725                                           &resource_device_name)) {
726         return errors::InvalidArgument("Failed to parse device name: ",
727                                        handle.device());
728       }
729       if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
730                                                   resource_device_name)) {
731         *is_multi_device = true;
732         return OkStatus();
733       }
734     }
735   }
736 
737   // Check if all ops could be placed on the current device.
738   for (const auto& name : metadata_->lib_def()->ListFunctionNames()) {
739     const FunctionDef* fdef;
740     TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef));
741     for (const auto& node : fdef->node_def()) {
742       // Check if the op has a kernel available for the current device.
743       if (!KernelDefAvailable(current_device_type, node)) {
744         *is_multi_device = true;
745         return OkStatus();
746       }
747       // If the op has a requested device, check if the requested device is
748       // compatible with the current device.
749       if (!node.device().empty()) {
750         DeviceNameUtils::ParsedName node_device_name;
751         if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) {
752           return errors::InvalidArgument("Failed to parse device name: ",
753                                          node.device());
754         }
755         if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
756                                                     node_device_name)) {
757           *is_multi_device = true;
758           return OkStatus();
759         }
760       }
761     }
762   }
763 
764   *is_multi_device = false;
765   return OkStatus();
766 }
767 
InstantiatedCapturedFunction(FunctionLibraryRuntime * lib,FunctionLibraryRuntime::Handle f_handle,DataTypeVector ret_types,std::function<void (std::function<void ()>)> runner,CapturedFunction * captured_func,bool is_multi_device)768 InstantiatedCapturedFunction::InstantiatedCapturedFunction(
769     FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
770     DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
771     CapturedFunction* captured_func, bool is_multi_device)
772     : lib_(lib),
773       f_handle_(f_handle),
774       ret_types_(std::move(ret_types)),
775       captured_runner_(std::move(runner)),
776       captured_func_(captured_func),
777       is_multi_device_(is_multi_device) {}
778 
Run(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets) const779 Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
780                                          std::vector<Tensor>&& args,
781                                          std::vector<Tensor>* rets) const {
782   return Run(ctx, std::move(args), rets, /*node=*/nullptr);
783 }
784 
Run(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets,const std::shared_ptr<model::Node> & node) const785 Status InstantiatedCapturedFunction::Run(
786     IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
787     const std::shared_ptr<model::Node>& node) const {
788   auto& info = captured_func_->short_circuit_info();
789   if (!info.indices.empty()) {
790     return RunShortCircuit(info, std::move(args), captured_func_, rets);
791   }
792 
793   FunctionLibraryRuntime::Options f_opts;
794   ScopedStepContainer step_container(
795       f_opts.step_id, [this](const string& name) {
796         lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
797       });
798   f_opts.step_container = &step_container;
799   f_opts.runner = ctx->runner();
800   f_opts.create_rendezvous = ShouldCreateRendezvous();
801   CancellationManager cancellation_manager(ctx->cancellation_manager());
802   f_opts.cancellation_manager = &cancellation_manager;
803   f_opts.collective_executor = ctx->collective_executor();
804 
805   std::shared_ptr<SimpleStepStatsCollector> stats_collector;
806   if (node || ctx->stats_aggregator()) {
807     stats_collector = std::make_shared<SimpleStepStatsCollector>();
808   }
809   const bool collect_usage = node && ctx->model();
810   f_opts.stats_collector = stats_collector.get();
811 
812   OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
813                            ret_types_);
814   profiler::TraceMe activity(
815       [&] {
816         return profiler::TraceMeEncode("InstantiatedCapturedFunction::Run",
817                                        {{"id", f_opts.step_id}});
818       },
819       profiler::TraceMeLevel::kInfo);
820   if (node) {
821     // Resource usage for function execution is gathered from the executor.
822     // TODO(jsimsa): Factor out common code for Run, RunAsync, and
823     // RunWithBorrowedArguments
824     if (collect_usage) node->record_stop(EnvTime::NowNanos());
825     TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
826     if (ctx->stats_aggregator()) {
827       string prefix_with_func_name = strings::StrCat(
828           node->name(), stats_utils::kDelimiter, captured_func_->func().name());
829       ctx->stats_aggregator()->AddToHistogram(
830           stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
831           {static_cast<float>(stats_collector->processing_time())},
832           node->num_elements());
833     }
834     node->add_processing_time(stats_collector->processing_time());
835     if (collect_usage) node->record_start(EnvTime::NowNanos());
836   } else {
837     TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
838   }
839   return frame.ConsumeRetvals(rets);
840 }
841 
RunWithBorrowedArgs(IteratorContext * ctx,const std::vector<Tensor> & args,std::vector<Tensor> * ret) const842 Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
843     IteratorContext* ctx, const std::vector<Tensor>& args,
844     std::vector<Tensor>* ret) const {
845   return RunWithBorrowedArgs(ctx, args, ret, /*node=*/nullptr);
846 }
847 
RunWithBorrowedArgs(IteratorContext * ctx,const std::vector<Tensor> & args,std::vector<Tensor> * rets,const std::shared_ptr<model::Node> & node) const848 Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
849     IteratorContext* ctx, const std::vector<Tensor>& args,
850     std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const {
851   auto& info = captured_func_->short_circuit_info();
852   if (!info.indices.empty()) {
853     return RunShortCircuit(info, args, captured_func_, rets);
854   }
855 
856   FunctionLibraryRuntime::Options f_opts;
857   ScopedStepContainer step_container(
858       f_opts.step_id, [this](const string& name) {
859         lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
860       });
861   f_opts.step_container = &step_container;
862   f_opts.runner = ctx->runner();
863   f_opts.create_rendezvous = ShouldCreateRendezvous();
864   CancellationManager cancellation_manager(ctx->cancellation_manager());
865   f_opts.cancellation_manager = &cancellation_manager;
866   f_opts.collective_executor = ctx->collective_executor();
867 
868   std::shared_ptr<SimpleStepStatsCollector> stats_collector;
869   if (node || ctx->stats_aggregator()) {
870     stats_collector = std::make_shared<SimpleStepStatsCollector>();
871   }
872   const bool collect_usage = node && ctx->model();
873   f_opts.stats_collector = stats_collector.get();
874 
875   BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
876                               ret_types_);
877   profiler::TraceMe activity(
878       [&] {
879         return profiler::TraceMeEncode(
880             "InstantiatedCapturedFunction::RunWithBorrowedArgs",
881             {{"id", f_opts.step_id}});
882       },
883       profiler::TraceMeLevel::kInfo);
884   if (node) {
885     // Resource usage for function execution is gathered from the executor.
886     if (collect_usage) node->record_stop(EnvTime::NowNanos());
887     TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
888     if (ctx->stats_aggregator()) {
889       string prefix_with_func_name = strings::StrCat(
890           node->name(), stats_utils::kDelimiter, captured_func_->func().name());
891       ctx->stats_aggregator()->AddToHistogram(
892           stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
893           {static_cast<float>(stats_collector->processing_time())},
894           node->num_elements());
895     }
896     node->add_processing_time(stats_collector->processing_time());
897     if (collect_usage) node->record_start(EnvTime::NowNanos());
898   } else {
899     TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
900   }
901   return frame.ConsumeRetvals(rets);
902 }
903 
RunInstantiated(const std::vector<Tensor> & args,std::vector<Tensor> * rets)904 Status InstantiatedCapturedFunction::RunInstantiated(
905     const std::vector<Tensor>& args, std::vector<Tensor>* rets) {
906   auto& info = captured_func_->short_circuit_info();
907   if (!info.indices.empty()) {
908     return RunShortCircuit(info, args, captured_func_, rets);
909   }
910 
911   FunctionLibraryRuntime::Options f_opts;
912   ScopedStepContainer step_container(
913       f_opts.step_id, [this](const string& name) {
914         lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
915       });
916   f_opts.step_container = &step_container;
917   f_opts.runner = &captured_runner_;
918   f_opts.create_rendezvous = ShouldCreateRendezvous();
919   CancellationManager cancellation_manager;
920   f_opts.cancellation_manager = &cancellation_manager;
921 
922   BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
923                               ret_types_);
924   profiler::TraceMe activity(
925       [&] {
926         return profiler::TraceMeEncode(
927             "InstantiatedCapturedFunction::RunInstantiated",
928             {{"id", f_opts.step_id}});
929       },
930       profiler::TraceMeLevel::kInfo);
931   TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
932   return frame.ConsumeRetvals(rets);
933 }
934 
RunAsync(IteratorContext * ctx,std::vector<Tensor> && args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done,const std::shared_ptr<model::Node> & node) const935 void InstantiatedCapturedFunction::RunAsync(
936     IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
937     FunctionLibraryRuntime::DoneCallback done,
938     const std::shared_ptr<model::Node>& node) const {
939   auto& info = captured_func_->short_circuit_info();
940   if (!info.indices.empty()) {
941     // Run the `done` callback on a threadpool thread, because it will
942     // potentially do a non-trivial amount of (e.g. copying) work, and we may
943     // want to run that concurrently with the next invocation.
944     Status s = RunShortCircuit(info, std::move(args), captured_func_, rets);
945     (*ctx->runner())(
946         std::bind([s](FunctionLibraryRuntime::DoneCallback& done) { done(s); },
947                   std::move(done)));
948     return;
949   }
950 
951   // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
952   // be deleted before `done` is called. Take care not to capture `ctx` in any
953   // code that may execute asynchronously in this function.
954   OwnedArgsCallFrame* frame = new OwnedArgsCallFrame(
955       std::move(args), &captured_func_->captured_inputs(), ret_types_);
956 
957   FunctionLibraryRuntime::Options f_opts;
958   ResourceMgr* resource_mgr = lib_->device()->resource_manager();
959   ScopedStepContainer* step_container = new ScopedStepContainer(
960       f_opts.step_id, [resource_mgr](const string& name) {
961         resource_mgr->Cleanup(name).IgnoreError();
962       });
963   f_opts.step_container = step_container;
964   f_opts.runner = ctx->runner();
965   f_opts.create_rendezvous = ShouldCreateRendezvous();
966   auto cancellation_manager =
967       std::make_unique<CancellationManager>(ctx->cancellation_manager());
968   f_opts.cancellation_manager = cancellation_manager.get();
969   f_opts.collective_executor = ctx->collective_executor();
970 
971   std::shared_ptr<SimpleStepStatsCollector> stats_collector;
972   if (node || ctx->stats_aggregator()) {
973     stats_collector = std::make_shared<SimpleStepStatsCollector>();
974   }
975   const bool collect_usage = node && ctx->model();
976   f_opts.stats_collector = stats_collector.get();
977 
978   // Transfer ownership of the cancellation manager to `callback`.
979   CancellationManager* raw_cancellation_manager =
980       cancellation_manager.release();
981   auto callback = std::bind(
982       [this, rets, step_container, raw_cancellation_manager, frame, node,
983        collect_usage](
984           const FunctionLibraryRuntime::DoneCallback& done,
985           IteratorContext* ctx,
986           const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
987           // Begin unbound arguments.
988           Status s) {
989         delete step_container;
990         delete raw_cancellation_manager;
991         if (s.ok()) {
992           s = frame->ConsumeRetvals(rets);
993         }
994         delete frame;
995         if (node) {
996           // TODO(b/129085499) Utilize the `node_name` which would be unique
997           // than the prefix for the function execution time statistics.
998           // prefix_with_func_name would then be node_name + func_name.
999           if (ctx->stats_aggregator()) {
1000             string prefix_with_func_name =
1001                 strings::StrCat(node->name(), stats_utils::kDelimiter,
1002                                 captured_func_->func().name());
1003             ctx->stats_aggregator()->AddToHistogram(
1004                 stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
1005                 {static_cast<float>(stats_collector->processing_time())},
1006                 node->num_elements());
1007           }
1008           node->add_processing_time(stats_collector->processing_time());
1009         }
1010         if (collect_usage) {
1011           node->record_start(EnvTime::NowNanos());
1012         }
1013         done(s);
1014         if (collect_usage) {
1015           node->record_stop(EnvTime::NowNanos());
1016         }
1017       },
1018       std::move(done), ctx, std::move(stats_collector), std::placeholders::_1);
1019 
1020   profiler::TraceMe activity(
1021       [&] {
1022         return profiler::TraceMeEncode("InstantiatedCapturedFunction::RunAsync",
1023                                        {{"id", f_opts.step_id}});
1024       },
1025       profiler::TraceMeLevel::kInfo);
1026   // Stop the usage collection before calling `Run()` because `callback` may
1027   // be executed synchronously, and so the `node->record_start()` call within
1028   // `callback` would violate nesting.
1029   if (collect_usage) node->record_stop(EnvTime::NowNanos());
1030   lib_->Run(f_opts, f_handle_, frame, std::move(callback));
1031   if (collect_usage) node->record_start(EnvTime::NowNanos());
1032 }
1033 
ShouldCreateRendezvous() const1034 bool InstantiatedCapturedFunction::ShouldCreateRendezvous() const {
1035   // Rendezvous should only be created by the FLR for non-CPU single-device
1036   // functions. For multi-device functions the appropriate rendezvous will be
1037   // created by the process FLR.
1038   return lib_->device()->device_type() != DEVICE_CPU && !is_multi_device_;
1039 }
1040 
1041 }  // namespace data
1042 }  // namespace tensorflow
1043