xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/single_threaded_executor.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/single_threaded_executor.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/core/common_runtime/entry.h"
21 #include "tensorflow/core/common_runtime/executor.h"
22 #include "tensorflow/core/common_runtime/executor_factory.h"
23 #include "tensorflow/core/common_runtime/renamed_device.h"
24 #include "tensorflow/core/graph/algorithm.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/gtl/cleanup.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/macros.h"
30 
31 namespace tensorflow {
32 
ValidateOpIsSafeForSyncExecution(const Node & n,bool allow_control_flow_sync_execution)33 Status ValidateOpIsSafeForSyncExecution(
34     const Node& n, bool allow_control_flow_sync_execution) {
35   for (DataType dt : n.output_types()) {
36     if (IsRefType(dt)) {
37       return errors::Unimplemented(
38           "Single-threaded executor does not support reference-typed "
39           "edges.  But saw type ",
40           DataTypeString(dt), " in outputs of node ", n.name());
41     }
42   }
43   // Executing Switch nodes requires propagating deadness which is
44   // not currently supported in the SingleThreadedExecutor.
45   if (n.IsSwitch()) {
46     return errors::FailedPrecondition(
47         "Single-threaded executor does not support switch op, but saw node ",
48         n.name(),
49         ". Perhaps your graph contains old-style control flow primitives? "
50         "Try using tf.compat.v1.enable_control_flow_v2().");
51   }
52   if (n.IsControlFlow() && !allow_control_flow_sync_execution) {
53     return errors::FailedPrecondition(
54         "Single-threaded executor does not support low level control flow, "
55         " but saw control flow node ",
56         n.name(),
57         ".  Perhaps your graph contains old-style control flow primitives? "
58         "Try using tf.compat.v1.enable_control_flow_v2().");
59   }
60   return OkStatus();
61 }
62 
63 namespace {
64 
65 typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
66 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
67 
68 static const string& kSingleThreadedExecutor =
69     *new string("SINGLE_THREADED_EXECUTOR");
70 
71 class SingleThreadedExecutorImpl : public Executor {
72  public:
SingleThreadedExecutorImpl(const LocalExecutorParams & params)73   explicit SingleThreadedExecutorImpl(const LocalExecutorParams& params)
74       : params_(params) {}
75 
~SingleThreadedExecutorImpl()76   ~SingleThreadedExecutorImpl() override {
77     for (const KernelState& kernel_state : kernels_) {
78       params_.delete_kernel(kernel_state.kernel);
79     }
80     for (const ConstTensorKernelState& kernel_state : const_tensor_kernels_) {
81       params_.delete_kernel(kernel_state.kernel);
82     }
83   }
84 
Initialize(const Graph & graph)85   Status Initialize(const Graph& graph) {
86     // Topologicially sort `graph` to get a sequence of OpKernels.
87     std::vector<Node*> ordered_nodes;
88     ordered_nodes.reserve(graph.num_nodes());
89     GetReversePostOrder(graph, &ordered_nodes);
90     int ordered_nodes_size = ordered_nodes.size();
91     if (ordered_nodes_size != graph.num_nodes()) {
92       return errors::InvalidArgument("Graph had ", graph.num_nodes(),
93                                      " but reverse post-order had ",
94                                      ordered_nodes.size());
95     }
96 
97     // We reserve two less nodes because we do not need to create kernels for
98     // the _SOURCE and _SINK nodes.
99     kernels_.reserve(ordered_nodes.size() - 2);
100     std::vector<Node*> nodes_with_kernels;
101     std::vector<Node*> nodes_with_const_tensor_kernels;
102     nodes_with_kernels.reserve(ordered_nodes.size() - 2);
103 
104     std::map<size_t, Node*> arg_index_to_node_map;
105     absl::flat_hash_map<Node*, size_t> node_to_index_map;
106 
107     // Create the kernel and input-related structures for each node in `graph`.
108     for (Node* n : ordered_nodes) {
109       if (n->IsSource() || n->IsSink()) {
110         continue;
111       }
112       TF_RETURN_IF_ERROR(ValidateOpIsSafeForSyncExecution(
113           *n, params_.allow_control_flow_sync_execution));
114       if (n->IsArg()) {
115         int32_t arg_index;
116         TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &arg_index));
117         if (arg_index < 0) {
118           return errors::InvalidArgument("Invalid argument index ", arg_index,
119                                          " in node ", n->name());
120         }
121         arg_index_to_node_map[arg_index] = n;
122         // We do not create a kernel for Arg nodes, and instead inline the
123         // argument handling directly in the executor code.
124         continue;
125       }
126 
127       OpKernel* kernel;
128       TF_RETURN_IF_ERROR(params_.create_kernel(n->properties(), &kernel));
129 
130       const Tensor* const_tensor;
131       if (n->num_outputs() == 1 && (const_tensor = kernel->const_tensor())) {
132         // Nodes that produce a single constant tensor are handled specially:
133         // we evaluate the tensor once, and propagate it to its consumers as
134         // a `const Tensor*`, to avoid refcount manipulation.
135         const size_t kernel_index = const_tensor_kernels_.size();
136         const_tensor_kernels_.push_back({});
137         nodes_with_const_tensor_kernels.push_back(n);
138         ConstTensorKernelState& kernel_state =
139             const_tensor_kernels_[kernel_index];
140         kernel_state.kernel = kernel;
141         kernel_state.const_tensor = *const_tensor;
142       } else {
143         const size_t kernel_index = kernels_.size();
144         kernels_.push_back({});
145         nodes_with_kernels.push_back(n);
146         KernelState& kernel_state = kernels_[kernel_index];
147         kernel_state.kernel = kernel;
148         kernel_state.num_inputs = n->num_inputs();
149         kernel_state.num_outputs = n->num_outputs();
150         node_to_index_map[n] = kernel_index;
151         if (kernel_index == 0) {
152           kernel_state.input_start_index = 0;
153         } else {
154           const KernelState& previous_kernel_state = kernels_[kernel_index - 1];
155           kernel_state.input_start_index =
156               previous_kernel_state.input_start_index +
157               previous_kernel_state.num_inputs;
158         }
159       }
160     }
161 
162     // Build the mapping from each Arg node output to the input slot for the
163     // corresponding destination node.
164     if (!arg_index_to_node_map.empty()) {
165       const size_t num_args = arg_index_to_node_map.rbegin()->first + 1;
166       arg_output_locations_.resize(num_args);
167       for (const auto& arg_index_node_pair : arg_index_to_node_map) {
168         const size_t arg_index = arg_index_node_pair.first;
169         const Node* arg_node = arg_index_node_pair.second;
170         arg_output_locations_[arg_index].reserve(arg_node->out_edges().size());
171         for (const Edge* e : arg_node->out_edges()) {
172           if (e->src_output() == Graph::kControlSlot) {
173             continue;
174           } else if (e->src_output() != 0) {
175             return errors::Internal("Invalid output index ", e->src_output(),
176                                     " from argument node ", arg_index);
177           }
178           arg_output_locations_[arg_index].push_back(
179               kernels_[node_to_index_map[e->dst()]].input_start_index +
180               e->dst_input());
181         }
182       }
183     }
184 
185     // Build the mapping from each const tensor kernel to the input slot for the
186     // corresponding destination node.
187     for (size_t i = 0; i < const_tensor_kernels_.size(); ++i) {
188       Node* n = nodes_with_const_tensor_kernels[i];
189       ConstTensorKernelState& kernel_state = const_tensor_kernels_[i];
190       for (const Edge* e : n->out_edges()) {
191         if (e->src_output() == Graph::kControlSlot) {
192           continue;
193         } else if (e->src_output() != 0) {
194           return errors::Internal("Invalid output index ", e->src_output(),
195                                   " from node ", n->DebugString());
196         }
197         kernel_state.output_locations.push_back(
198             kernels_[node_to_index_map[e->dst()]].input_start_index +
199             e->dst_input());
200       }
201 
202       bool on_host =
203           kernel_state.kernel->output_memory_types()[0] == HOST_MEMORY;
204       kernel_state.output_alloc_attr.set_on_host(on_host);
205     }
206 
207     // Build the mapping from each node output to the input slot for the
208     // corresponding destination node.
209     for (size_t i = 0; i < kernels_.size(); ++i) {
210       Node* n = nodes_with_kernels[i];
211       KernelState& kernel_state = kernels_[i];
212       kernel_state.output_locations.resize(kernel_state.num_outputs);
213       for (const Edge* e : n->out_edges()) {
214         if (!e->IsControlEdge()) {
215           kernel_state.output_locations[e->src_output()].push_back(
216               kernels_[node_to_index_map[e->dst()]].input_start_index +
217               e->dst_input());
218         }
219       }
220 
221       // Compute allocator attributes for each node output, and corresponding
222       // node input.
223       kernel_state.output_alloc_attrs.resize(kernel_state.num_outputs);
224       AllocatorAttributes* attrs = kernel_state.output_alloc_attrs.data();
225 
226       OpKernel* op_kernel = kernel_state.kernel;
227       for (int out = 0; out < n->num_outputs(); out++) {
228         DCHECK_LT(out, op_kernel->output_memory_types().size());
229         bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
230         if (on_host) {
231           AllocatorAttributes h;
232           h.set_on_host(on_host);
233           attrs[out].Merge(h);
234         }
235       }
236     }
237 
238     if (!kernels_.empty()) {
239       const KernelState& last_kernel_state = kernels_.back();
240       total_num_inputs_ =
241           last_kernel_state.input_start_index + last_kernel_state.num_inputs;
242       input_alloc_attrs_.resize(total_num_inputs_);
243       for (size_t i = 0; i < kernels_.size(); ++i) {
244         for (size_t j = 0; j < kernels_[i].output_locations.size(); ++j) {
245           for (size_t output_location : kernels_[i].output_locations[j]) {
246             input_alloc_attrs_[output_location] =
247                 kernels_[i].output_alloc_attrs[j];
248           }
249         }
250       }
251     } else {
252       total_num_inputs_ = 0;
253     }
254     return OkStatus();
255   }
256 
Run(const Args & args)257   Status Run(const Args& args) override {
258     // The inputs to each kernel are stored contiguously in `inputs`.
259     //
260     // We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to
261     // determine the range of elements in this vector that correspond to
262     // the inputs of `kernels_[i]`.
263     //
264     // This vector has the following layout:
265     //
266     // * Kernel 0, input 0.
267     // * Kernel 0, input 1.
268     // * ...
269     // * Kernel 0, input `kernels_[0].num_inputs - 1`.
270     // * Kernel 1, input 0.
271     // * ...
272     // * Kernel 1, input `kernels_[1].num_inputs - 1`.
273     // * ...
274     // * Kernel `kernels_.size() - 1`, input 0.
275     // * ...
276     // * Kernel `kernels_.size() - 1`, input `kernels_.back().num_inputs - 1`.
277     //
278     // Note that kernels with zero inputs do not correspond to any elements in
279     // this vector.
280     //
281     // We use `ManualConstructor<Tensor>` to avoid the overhead of
282     // default-constructing an invalid `Tensor` for each slot at the beginning
283     // of execution:
284     // * Elements are initialized when the outputs of a kernel execution are
285     //   propagated to the inputs of kernels that depend on them.
286     // * The elements corresponding to the inputs for kernel `i` are destroyed
287     //   after kernel `i` executes.
288     // * In an error case (see below), we use the connectivity information in
289     //   `KernelState::output_locations` to determine which locations have been
290     //   initialized, and manually destroy them.
291     std::vector<Entry> inputs(total_num_inputs_);
292 
293     // TODO(mrry): Can we avoid copying into these vectors? Consider modifying
294     // OpKernelContext to take the TensorValueVec as a pointer into `inputs`.
295     TensorValueVec node_inputs;
296     AllocatorAttributeVec input_alloc_attrs;
297 
298     // Override intra op thread pool if requested.
299     Device* device = params_.device;
300     std::unique_ptr<Device> user_device;
301     if (args.user_intra_op_threadpool != nullptr) {
302       user_device = RenamedDevice::NewRenamedDevice(
303           device->name(), device, /*owns_underlying=*/false,
304           /*isolate_session_state=*/false, args.user_intra_op_threadpool);
305       device = user_device.get();
306     }
307 
308     // Prepare the parameters that will be the same for all kernels.
309     OpKernelContext::Params params;
310     params.step_id = args.step_id;
311     params.device = device;
312     params.log_memory = false;  // TODO(mrry): Too severe?
313     params.rendezvous = args.rendezvous;
314     params.session_state = args.session_state;
315     params.session_metadata = params_.session_metadata;
316     params.tensor_store = args.tensor_store;
317     params.cancellation_manager = args.cancellation_manager;
318     params.call_frame = args.call_frame;
319     params.function_library = params_.function_library;
320     params.resource_manager = device->resource_manager();
321     params.step_container = args.step_container;
322     params.collective_executor = args.collective_executor;
323     params.stack_trace = args.stack_trace;
324     params.slice_reader_cache = nullptr;  // TODO(mrry): Too severe?
325 
326     Args::Runner runner_copy = args.runner;
327     params.runner = &runner_copy;
328     params.run_all_kernels_inline = args.run_all_kernels_inline;
329     params.stats_collector = args.stats_collector;
330     params.executor_type = &kSingleThreadedExecutor;
331 
332     // NOTE(mrry): We are assuming that the graph is loopless and condless.
333     params.frame_iter = FrameAndIter(0, 0);
334     params.is_input_dead = false;
335 
336     device->TryGetDeviceContext(&params.op_device_context).IgnoreError();
337     auto context_cleanup = gtl::MakeCleanup([&params] {
338       if (params.op_device_context != nullptr) {
339         params.op_device_context->Unref();
340       }
341     });
342 
343     // TODO(mrry): Consider implementing forwarding.
344     params.forward_from_array = nullptr;
345 
346     const size_t received_args =
347         args.call_frame ? args.call_frame->num_args() : 0;
348     if (TF_PREDICT_FALSE(arg_output_locations_.size() > received_args)) {
349       return errors::InvalidArgument("Expected ", arg_output_locations_.size(),
350                                      " arguments, but only received ",
351                                      received_args, ".");
352     }
353 
354     // ArgOp is a relatively expensive OpKernel due to the Tensor
355     // allocations that it performs. Therefore we specialize its implementation
356     // and forward arguments directly to the inputs of kernels that consume
357     // them.
358     for (size_t i = 0; i < arg_output_locations_.size(); ++i) {
359       const size_t num_destinations = arg_output_locations_[i].size();
360       if (num_destinations > 0) {
361         if (args.call_frame->CanConsumeArg(i)) {
362           // The first destination input can consume the argument.
363           Entry& first_input = inputs[arg_output_locations_[i][0]];
364           first_input.state = Entry::State::HAS_VALUE;
365           first_input.val.Init();
366           args.call_frame->ConsumeArg(i, first_input.val.get());
367           // All subsequent destination inputs get a shallow copy of the first
368           // destination input.
369           //
370           // NOTE: If we had metadata about which kernels might attempt to
371           // forward their input, we could arrange the kernel order so that
372           // one of those kernels was executed last.
373           for (size_t j = 1; j < num_destinations; ++j) {
374             Entry& input = inputs[arg_output_locations_[i][j]];
375             input.state = Entry::State::HAS_VALUE;
376             input.val.Init(*first_input.val);
377           }
378         } else {
379           const Tensor* arg;
380           TF_RETURN_IF_ERROR(args.call_frame->GetArg(i, &arg));
381           for (size_t j = 0; j < num_destinations; ++j) {
382             Entry& input = inputs[arg_output_locations_[i][j]];
383             // NOTE: We must make at least one shallow copy of the argument
384             // tensor that remains live until all consuming kernels have
385             // executed, to keep the reference count > 1, and inhibit buffer
386             // forwarding. For simplicity, we shallow copy into the input entry
387             // for each consuming kernel.
388             input.state = Entry::State::HAS_VALUE;
389             input.val.Init(*arg);
390           }
391         }
392       }
393     }
394 
395     // Kernels that return a constant value (e.g. ConstOp) are relatively
396     // expensive due to the Tensor allocations that they perform. Therefore we
397     // specialize their implementation and forward their constant value directly
398     // to the inputs of kernels that consume them.
399     for (const ConstTensorKernelState& kernel_state : const_tensor_kernels_) {
400       for (size_t i = 0; i < kernel_state.output_locations.size(); ++i) {
401         Entry& input = inputs[kernel_state.output_locations[i]];
402         input.state = Entry::State::HAS_CONST_TENSOR;
403         input.const_tensor = &kernel_state.const_tensor;
404       }
405     }
406 
407     // Execute the kernels one-at-a-time in topological order.
408     for (size_t i = 0; i < kernels_.size(); ++i) {
409       const KernelState& kernel_state = kernels_[i];
410 
411       // Prepare the per-kernel parameters.
412       const size_t input_start_index = kernel_state.input_start_index;
413       const size_t num_inputs = kernel_state.num_inputs;
414       const size_t num_outputs = kernel_state.num_outputs;
415 
416       node_inputs.clear();
417       node_inputs.resize(num_inputs);
418       input_alloc_attrs.clear();
419       input_alloc_attrs.resize(num_inputs);
420       for (size_t j = 0; j < num_inputs; ++j) {
421         Entry& input = inputs[input_start_index + j];
422         switch (input.state) {
423           case Entry::State::HAS_CONST_TENSOR:
424             // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
425             // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
426             // accessors making dynamic checks that prevent using an immutable
427             // tensor as a mutable tensor.
428             node_inputs[j].tensor = const_cast<Tensor*>(input.const_tensor);
429             break;
430           case Entry::State::HAS_VALUE:
431             node_inputs[j].tensor = input.val.get();
432             break;
433           default:
434             DCHECK(false) << "Input did not have a valid value.";
435         }
436         input_alloc_attrs[j] = input_alloc_attrs_[input_start_index + j];
437       }
438       params.inputs = node_inputs;
439       params.input_alloc_attrs = input_alloc_attrs;
440       params.op_kernel = kernel_state.kernel;
441       params.output_attr_array = kernel_state.output_alloc_attrs.data();
442       OpKernelContext ctx(&params, num_outputs);
443 
444       // Actually execute the kernel.
445       device->Compute(kernel_state.kernel, &ctx);
446       TF_RETURN_IF_ERROR(ctx.status());
447 
448       // Free the inputs to the current kernel.
449       for (size_t j = 0; j < num_inputs; ++j) {
450         inputs[input_start_index + j].ClearVal();
451       }
452 
453       // Forward the outputs of the kernel to the inputs of subsequent kernels.
454       for (size_t j = 0; j < num_outputs; ++j) {
455         TensorValue val = ctx.release_output(j);
456         const size_t num_destinations = kernel_state.output_locations[j].size();
457         if (num_destinations > 0) {
458           // TODO(mrry): Consider flattening the `output_locations` vector
459           // to improve the cache-friendliness of this loop.
460           for (size_t k = 0; k < num_destinations - 1; ++k) {
461             // TODO(mrry): Validate that the types match the expected values or
462             // ensure that the necessary validation has already happened.
463             Entry& input = inputs[kernel_state.output_locations[j][k]];
464             input.state = Entry::State::HAS_VALUE;
465             if (val.tensor != nullptr) {
466               input.val.Init(*val.tensor);
467             } else {
468               input.val.Init(Tensor(kernel_state.kernel->output_type(j)));
469             }
470           }
471           // Move `arg` to the last consumer to avoid the cost of copying it.
472           Entry& input =
473               inputs[kernel_state.output_locations[j][num_destinations - 1]];
474           input.state = Entry::State::HAS_VALUE;
475           if (val.tensor != nullptr) {
476             input.val.Init(std::move(*val.tensor));
477           } else {
478             input.val.Init(Tensor(kernel_state.kernel->output_type(j)));
479           }
480         }
481         delete val.tensor;
482       }
483     }
484     return OkStatus();
485   }
486 
487   // Execute all operations in the calling thread when asynchronous execution
488   // is requested. Callers may expect to perform expensive work in the calling
489   // thread even when the execution itself is single-threaded.
490   //
491   // This also avoid stack-overflow issues with functional control flow.
RunAsync(const Args & args,DoneCallback done)492   void RunAsync(const Args& args, DoneCallback done) override {
493     args.runner([this, args, done]() { done(Run(args)); });
494   }
495 
496  private:
497   const LocalExecutorParams params_;
498 
499   // All following members are read-only after Initialize().
500 
501   // The sum of the number of inputs for each node in the graph. This determines
502   // the length of the flat `inputs` vector. See comment at the beginning of
503   // `RunAsync()` for details.
504   size_t total_num_inputs_;
505 
506   // Represents cached graph structure state for each kernel.
507   struct KernelState {
508     // The kernel object. Not owned.
509     //
510     // This pointer is managed by `params_.create_kernel()` and
511     // `params_.delete_kernel()`.
512     OpKernel* kernel;
513 
514     // These fields determine the range of elements in `inputs` that corresponds
515     // to the inputs of `kernel`.
516     size_t input_start_index;
517     size_t num_inputs;
518 
519     size_t num_outputs;
520 
521     // For the `j`th output of `kernel`, `output_locations[j]` contains the
522     // locations in the flat `inputs` vector to which that output must be
523     // copied. See comment at the beginning of `Run()` for details.
524     std::vector<std::vector<size_t>>
525         output_locations;  // Length = `num_outputs`.
526 
527     // Memory space information for each output of `kernel`.
528     std::vector<AllocatorAttributes>
529         output_alloc_attrs;  // Length = `num_outputs`.
530   };
531   std::vector<KernelState> kernels_;
532 
533   // For the `i`th argument, `arg_output_locations_[i]` contains the locations
534   // in the flat `inputs` vector to which that argument must be copied.
535   std::vector<std::vector<size_t>>
536       arg_output_locations_;  // Length = `num_args`.
537 
538   // Represents cached graph structure state for each kernel that produces
539   // a single constant-valued tensor.
540   struct ConstTensorKernelState {
541     // The kernel object. Not owned.
542     //
543     // This pointer is managed by `params_.create_kernel()` and
544     // `params_.delete_kernel()`.
545     OpKernel* kernel;
546 
547     // The cached value of `kernel->const_tensor()`.
548     //
549     // NOTE: We keep a `Tensor` rather than a `const Tensor*` here in order to
550     // keep the reference count on the underlying buffer above 1. Otherwise, a
551     // kernel could interpret the input as a forwardable tensor, and mutate the
552     // underlying constant tensor.
553     Tensor const_tensor;
554 
555     // For the single output of `kernel`, `output_locations` contains the
556     // locations in the flat `inputs` vector to which that output must be
557     // copied. See comment at the beginning of `Run()` for details.
558     std::vector<size_t> output_locations;  // Length = `num_outputs`.
559 
560     // Memory space information for the single output of `kernel`.
561     AllocatorAttributes output_alloc_attr;
562   };
563   std::vector<ConstTensorKernelState> const_tensor_kernels_;
564 
565   // Memory space information for each input. This information is stored in the
566   // same order as the flat `inputs` vector. See comment at the beginning of
567   // `RunAsync()` for details.
568   std::vector<AllocatorAttributes>
569       input_alloc_attrs_;  // Length = `total_num_inputs_`.
570 };
571 
572 class SingleThreadedExecutorRegistrar {
573  public:
SingleThreadedExecutorRegistrar()574   SingleThreadedExecutorRegistrar() {
575     ExecutorFactory::Register(kSingleThreadedExecutor, new Factory());
576   }
577 
578  private:
579   class Factory : public ExecutorFactory {
NewExecutor(const LocalExecutorParams & params,const Graph & graph,std::unique_ptr<Executor> * out_executor)580     Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
581                        std::unique_ptr<Executor>* out_executor) override {
582       Executor* ret;
583       TF_RETURN_IF_ERROR(NewSingleThreadedExecutor(params, graph, &ret));
584       out_executor->reset(ret);
585       return OkStatus();
586     }
587   };
588 };
589 static SingleThreadedExecutorRegistrar registrar;
590 
591 }  // namespace
592 
NewSingleThreadedExecutor(const LocalExecutorParams & params,const Graph & graph,Executor ** executor)593 Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
594                                  const Graph& graph, Executor** executor) {
595   auto impl = std::make_unique<SingleThreadedExecutorImpl>(params);
596   TF_RETURN_IF_ERROR(impl->Initialize(graph));
597   *executor = impl.release();
598   return OkStatus();
599 }
600 
601 }  // namespace tensorflow
602