xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/function_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/function.h"
17 
18 #include <atomic>
19 #include <functional>
20 #include <utility>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/numbers.h"
24 #include "absl/strings/str_split.h"
25 #include "tensorflow/cc/ops/array_ops_internal.h"
26 #include "tensorflow/cc/ops/function_ops.h"
27 #include "tensorflow/cc/ops/functional_ops.h"
28 #include "tensorflow/cc/ops/sendrecv_ops.h"
29 #include "tensorflow/cc/ops/standard_ops.h"
30 #include "tensorflow/core/common_runtime/constant_folding.h"
31 #include "tensorflow/core/common_runtime/device.h"
32 #include "tensorflow/core/common_runtime/device_factory.h"
33 #include "tensorflow/core/common_runtime/executor.h"
34 #include "tensorflow/core/common_runtime/executor_factory.h"
35 #include "tensorflow/core/common_runtime/function_testlib.h"
36 #include "tensorflow/core/common_runtime/graph_constructor.h"
37 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
38 #include "tensorflow/core/common_runtime/step_stats_collector.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/function_testlib.h"
41 #include "tensorflow/core/framework/metrics.h"
42 #include "tensorflow/core/framework/op.h"
43 #include "tensorflow/core/framework/op_kernel.h"
44 #include "tensorflow/core/framework/op_requires.h"
45 #include "tensorflow/core/framework/tensor_testutil.h"
46 #include "tensorflow/core/framework/versions.pb.h"
47 #include "tensorflow/core/lib/core/notification.h"
48 #include "tensorflow/core/lib/core/status.h"
49 #include "tensorflow/core/lib/core/status_test_util.h"
50 #include "tensorflow/core/lib/core/threadpool.h"
51 #include "tensorflow/core/lib/strings/str_util.h"
52 #include "tensorflow/core/platform/errors.h"
53 #include "tensorflow/core/platform/test.h"
54 #include "tensorflow/core/platform/threadpool_interface.h"
55 #include "tensorflow/core/protobuf/error_codes.pb.h"
56 #include "tensorflow/core/public/session_options.h"
57 #include "tensorflow/core/public/version.h"
58 #include "tensorflow/core/util/equal_graph_def.h"
59 
60 namespace tensorflow {
61 namespace {
62 
63 using FDH = ::tensorflow::FunctionDefHelper;
64 
65 using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
66 
GetOpSig(const string & op,const OpDef ** sig)67 Status GetOpSig(const string& op, const OpDef** sig) {
68   return OpRegistry::Global()->LookUpOpDef(op, sig);
69 }
70 
HasError(const Status & s,const error::Code code,StringPiece substr)71 void HasError(const Status& s, const error::Code code, StringPiece substr) {
72   EXPECT_EQ(s.code(), code) << s;
73   EXPECT_TRUE(absl::StrContains(s.error_message(), substr))
74       << s << ", expected substring " << substr;
75 }
76 
77 class FunctionTest : public ::testing::Test {
78  protected:
FunctionTest()79   FunctionTest()
80       : device_(DeviceFactory::NewDevice("CPU", {},
81                                          "/job:localhost/replica:0/task:0")) {}
82 
Create(const FunctionDef & fdef,test::function::Attrs attrs)83   void Create(const FunctionDef& fdef, test::function::Attrs attrs) {
84     exec_ = nullptr;
85     InstantiationResult result;
86     TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result));
87 
88     arg_types_ = result.arg_types;
89     ret_types_ = result.ret_types;
90 
91     std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
92     GraphConstructorOptions opts;
93     opts.allow_internal_ops = true;
94     opts.expect_device_spec = false;
95     TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get()));
96 
97     const int version = g->versions().producer();
98     LocalExecutorParams params;
99     params.device = device_.get();
100     params.create_kernel =
101         [this, version](const std::shared_ptr<const NodeProperties>& props,
102                         OpKernel** kernel) {
103           return CreateNonCachedKernel(device_.get(), nullptr, props, version,
104                                        kernel);
105         };
106     params.delete_kernel = [](OpKernel* kernel) {
107       DeleteNonCachedKernel(kernel);
108     };
109     Executor* exec;
110     TF_CHECK_OK(NewLocalExecutor(params, *g, &exec));
111     exec_.reset(exec);
112   }
113 
Run(const std::vector<Tensor> & args,std::vector<Tensor * > rets)114   void Run(const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
115     FunctionCallFrame frame(arg_types_, ret_types_);
116     TF_CHECK_OK(frame.SetArgs(args));
117     Executor::Args exec_args;
118     exec_args.call_frame = &frame;
119     exec_args.runner = test::function::FunctionTestSchedClosure;
120     TF_CHECK_OK(exec_->Run(exec_args));
121     std::vector<Tensor> computed;
122     TF_CHECK_OK(frame.GetRetvals(&computed));
123     CHECK_EQ(computed.size(), rets.size());
124     for (int i = 0; i < rets.size(); ++i) {
125       *(rets[i]) = computed[i];
126     }
127   }
128 
129   std::unique_ptr<Device> device_;
130   std::unique_ptr<Executor> exec_;
131   DataTypeVector arg_types_;
132   DataTypeVector ret_types_;
133 };
134 
TEST_F(FunctionTest,XTimesTwo)135 TEST_F(FunctionTest, XTimesTwo) {
136   Create(test::function::XTimesTwo(), {{"T", DT_FLOAT}});
137   auto x = test::AsTensor<float>({1, 2, 3, 4});
138   Tensor y;
139   Run({x}, {&y});
140   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
141 }
142 
TEST_F(FunctionTest,WXPlusB)143 TEST_F(FunctionTest, WXPlusB) {
144   Create(test::function::WXPlusB(), {{"T", DT_FLOAT}});
145   auto w = test::AsTensor<float>({1., 2., 3., 4.}, {2, 2});
146   auto x = test::AsTensor<float>({1., 3., 2., 4.}, {2, 2});
147   auto b = test::AsTensor<float>({0.5, 2.5}, {2});
148   Tensor y;
149   Run({w, x, b}, {&y});
150   test::ExpectTensorEqual<float>(
151       y, test::AsTensor<float>({5.5, 13.5, 11.5, 27.5}, {2, 2}));
152 }
153 
154 class FunctionLibraryRuntimeTest : public ::testing::Test {
155  protected:
Init(const std::vector<FunctionDef> & flib)156   void Init(const std::vector<FunctionDef>& flib) {
157     SessionOptions options;
158     auto* device_count = options.config.mutable_device_count();
159     device_count->insert({"CPU", 3});
160     std::vector<std::unique_ptr<Device>> devices;
161     TF_CHECK_OK(DeviceFactory::AddDevices(
162         options, "/job:localhost/replica:0/task:0", &devices));
163 
164     FunctionDefLibrary proto;
165     for (const auto& fdef : flib) *(proto.add_function()) = fdef;
166     lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
167     OptimizerOptions opts;
168     device_mgr_ = std::make_unique<StaticDeviceMgr>(std::move(devices));
169     pflr_.reset(new ProcessFunctionLibraryRuntime(
170         device_mgr_.get(), Env::Default(), &options.config,
171         TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, /*thread_pool=*/nullptr,
172         /*parent=*/nullptr, /*session_metadata=*/nullptr,
173         Rendezvous::Factory{
174             [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) {
175               *r = new IntraProcessRendezvous(device_mgr);
176               return OkStatus();
177             }}));
178     flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
179     flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1");
180     flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2");
181     fdef_lib_ = lib_def_->ToProto();
182   }
183 
Run(FunctionLibraryRuntime * flr,FunctionLibraryRuntime::Handle handle,FunctionLibraryRuntime::Options opts,const std::vector<Tensor> & args,std::vector<Tensor * > rets)184   Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
185              FunctionLibraryRuntime::Options opts,
186              const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
187     std::function<void(std::function<void()>)> runner =
188         [](std::function<void()> fn) {
189           test::function::FunctionTestSchedClosure(fn);
190         };
191     opts.runner = &runner;
192     Notification done;
193     std::vector<Tensor> out;
194     Status status;
195     flr->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
196       status = s;
197       done.Notify();
198     });
199     done.WaitForNotification();
200     if (!status.ok()) {
201       return status;
202     }
203     CHECK_EQ(rets.size(), out.size());
204     for (size_t i = 0; i < rets.size(); ++i) {
205       *rets[i] = out[i];
206     }
207     return OkStatus();
208   }
209 
Instantiate(FunctionLibraryRuntime * flr,const string & name,test::function::Attrs attrs,FunctionLibraryRuntime::Handle * handle)210   Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
211                      test::function::Attrs attrs,
212                      FunctionLibraryRuntime::Handle* handle) {
213     return flr->Instantiate(name, attrs, handle);
214   }
215 
Instantiate(FunctionLibraryRuntime * flr,const string & name,test::function::Attrs attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)216   Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
217                      test::function::Attrs attrs,
218                      const FunctionLibraryRuntime::InstantiateOptions& options,
219                      FunctionLibraryRuntime::Handle* handle) {
220     return flr->Instantiate(name, attrs, options, handle);
221   }
222 
InstantiateAndRun(FunctionLibraryRuntime * flr,const string & name,test::function::Attrs attrs,const std::vector<Tensor> & args,std::vector<Tensor * > rets)223   Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name,
224                            test::function::Attrs attrs,
225                            const std::vector<Tensor>& args,
226                            std::vector<Tensor*> rets) {
227     return InstantiateAndRun(flr, name, attrs,
228                              FunctionLibraryRuntime::InstantiateOptions(), args,
229                              std::move(rets));
230   }
231 
InstantiateAndRun(FunctionLibraryRuntime * flr,const string & name,test::function::Attrs attrs,const FunctionLibraryRuntime::InstantiateOptions & options,const std::vector<Tensor> & args,std::vector<Tensor * > rets)232   Status InstantiateAndRun(
233       FunctionLibraryRuntime* flr, const string& name,
234       test::function::Attrs attrs,
235       const FunctionLibraryRuntime::InstantiateOptions& options,
236       const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
237     FunctionLibraryRuntime::Handle handle;
238     Status status = flr->Instantiate(name, attrs, options, &handle);
239     if (!status.ok()) {
240       return status;
241     }
242     FunctionLibraryRuntime::Options opts;
243     status = Run(flr, handle, opts, args, rets);
244     if (!status.ok()) return status;
245 
246     // Release the handle and try running again. It should not succeed.
247     status = flr->ReleaseHandle(handle);
248     if (!status.ok()) return status;
249 
250     Status status2 = Run(flr, handle, opts, args, std::move(rets));
251     EXPECT_TRUE(errors::IsNotFound(status2))
252         << "Actual status: " << status2.ToString();
253     EXPECT_TRUE(absl::StrContains(status2.error_message(), "Handle"));
254     EXPECT_TRUE(absl::StrContains(status2.error_message(), "not found"));
255 
256     return status;
257   }
258 
Run(FunctionLibraryRuntime * flr,FunctionLibraryRuntime::Handle handle,FunctionLibraryRuntime::Options opts,CallFrameInterface * frame)259   Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
260              FunctionLibraryRuntime::Options opts, CallFrameInterface* frame) {
261     std::function<void(std::function<void()>)> runner =
262         [](std::function<void()> fn) {
263           test::function::FunctionTestSchedClosure(fn);
264         };
265     opts.runner = &runner;
266     Notification done;
267     Status status;
268     flr->Run(opts, handle, frame, [&status, &done](const Status& s) {
269       status = s;
270       done.Notify();
271     });
272     done.WaitForNotification();
273     if (!status.ok()) {
274       return status;
275     }
276 
277     return OkStatus();
278   }
279 
InstantiateAndRunViaCallFrameInterface(FunctionLibraryRuntime * flr,const string & name,test::function::Attrs attrs,const std::vector<Tensor> & args,std::vector<Tensor * > rets)280   Status InstantiateAndRunViaCallFrameInterface(FunctionLibraryRuntime* flr,
281                                                 const string& name,
282                                                 test::function::Attrs attrs,
283                                                 const std::vector<Tensor>& args,
284                                                 std::vector<Tensor*> rets) {
285     FunctionLibraryRuntime::Handle handle;
286     Status status = flr->Instantiate(name, attrs, &handle);
287     if (!status.ok()) {
288       return status;
289     }
290     const FunctionBody* fbody = flr->GetFunctionBody(handle);
291     FunctionCallFrame frame(fbody->arg_types, fbody->ret_types);
292     TF_RETURN_IF_ERROR(frame.SetArgs(args));
293 
294     FunctionLibraryRuntime::Options opts;
295     status = Run(flr, handle, opts, &frame);
296     if (!status.ok()) return status;
297 
298     std::vector<Tensor> retvals;
299     TF_RETURN_IF_ERROR(frame.GetRetvals(&retvals));
300     CHECK_EQ(rets.size(), retvals.size());
301     for (size_t i = 0; i < rets.size(); ++i) {
302       *rets[i] = retvals[i];
303     }
304 
305     // Release the handle and try running again. It should not succeed.
306     status = flr->ReleaseHandle(handle);
307     if (!status.ok()) return status;
308 
309     Status status2 = Run(flr, handle, opts, args, std::move(rets));
310     EXPECT_TRUE(errors::IsNotFound(status2));
311     EXPECT_TRUE(absl::StrContains(status2.error_message(), "Handle"));
312     EXPECT_TRUE(absl::StrContains(status2.error_message(), "not found"));
313 
314     return status;
315   }
316 
GetFuncBody(FunctionLibraryRuntime * flr,const string & name,test::function::Attrs attrs)317   std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
318                                      const string& name,
319                                      test::function::Attrs attrs) {
320     FunctionLibraryRuntime::Handle handle;
321     Status status = flr->Instantiate(name, attrs, &handle);
322     if (!status.ok()) {
323       LOG(ERROR) << status;
324       return nullptr;
325     }
326     const FunctionBody* fbody = flr->GetFunctionBody(handle);
327     CHECK_NOTNULL(fbody);
328     std::unique_ptr<Graph> ret(new Graph(lib_def_.get()));
329     CopyGraph(*fbody->graph, ret.get());
330     return ret;
331   }
332 
GetGradBody(FunctionLibraryRuntime * flr,const string & func,test::function::Attrs attrs)333   std::unique_ptr<Graph> GetGradBody(FunctionLibraryRuntime* flr,
334                                      const string& func,
335                                      test::function::Attrs attrs) {
336     FunctionLibraryRuntime::Handle handle;
337     Status status = flr->Instantiate(func, attrs, &handle);
338     if (!status.ok()) {
339       LOG(ERROR) << status;
340       return nullptr;
341     }
342     const FunctionBody* fbody = flr->GetFunctionBody(handle);
343     CHECK_NOTNULL(fbody);
344     std::unique_ptr<FunctionBody> gbody(SymbolicGradient(*fbody));
345     CHECK_NOTNULL(gbody);
346     std::unique_ptr<Graph> ret(new Graph(lib_def_.get()));
347     CopyGraph(*gbody->graph, ret.get());
348     return ret;
349   }
350 
351   FunctionLibraryRuntime* flr0_;
352   FunctionLibraryRuntime* flr1_;
353   FunctionLibraryRuntime* flr2_;
354   std::unique_ptr<DeviceMgr> device_mgr_;
355   std::unique_ptr<FunctionLibraryDefinition> lib_def_;
356   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
357   FunctionDefLibrary fdef_lib_;
358 };
359 
TEST_F(FunctionLibraryRuntimeTest,IsStateful)360 TEST_F(FunctionLibraryRuntimeTest, IsStateful) {
361   Init({});
362   EXPECT_TRUE(flr0_->IsStateful("Variable"));
363   EXPECT_TRUE(flr0_->IsStateful("VariableV2"));
364   EXPECT_FALSE(flr0_->IsStateful("Matmul"));
365 }
366 
TEST_F(FunctionLibraryRuntimeTest,XTimesTwo)367 TEST_F(FunctionLibraryRuntimeTest, XTimesTwo) {
368   Init({test::function::XTimesTwo()});
369   auto x = test::AsTensor<float>({1, 2, 3, 4});
370   Tensor y;
371   TF_CHECK_OK(
372       InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y}));
373   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
374   TF_CHECK_OK(InstantiateAndRunViaCallFrameInterface(
375       flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y}));
376   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
377 }
378 
TEST_F(FunctionLibraryRuntimeTest,InstantiationStackTraceCopying)379 TEST_F(FunctionLibraryRuntimeTest, InstantiationStackTraceCopying) {
380   class DummyStackTrace : public AbstractStackTrace {
381     absl::Span<StackFrame const> ToFrames() const override { return {}; }
382 
383     std::string ToString(const TracePrintingOptions& opts) const override {
384       return "DummyStackTrace";
385     }
386 
387     StackFrame LastUserFrame() const override { return StackFrame{}; }
388   };
389 
390   FunctionDef func = test::function::XTimesTwo();
391   Init({});
392 
393   StackTracesMap stack_traces;
394   stack_traces["two"] = std::make_shared<DummyStackTrace>();
395 
396   TF_CHECK_OK(lib_def_->AddFunctionDef(func, stack_traces));
397 
398   FunctionLibraryRuntime::Handle handle;
399   TF_CHECK_OK(Instantiate(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {}, &handle));
400 
401   const FunctionBody* func_body = flr0_->GetFunctionBody(handle);
402   for (const Node* node : func_body->graph->nodes()) {
403     if (node->name() == "two") {
404       EXPECT_EQ(node->GetStackTrace()->ToString({}), "DummyStackTrace");
405     }
406   }
407   TF_CHECK_OK(flr0_->ReleaseHandle(handle));
408 }
409 
TEST_F(FunctionLibraryRuntimeTest,XTimesTwo_MultiDeviceBacked)410 TEST_F(FunctionLibraryRuntimeTest, XTimesTwo_MultiDeviceBacked) {
411   Init({test::function::XTimesTwo()});
412   auto x = test::AsTensor<float>({1, 2, 3, 4});
413   Tensor y;
414 
415   FunctionLibraryRuntime::InstantiateOptions options;
416   options.is_multi_device_function = true;
417 
418   TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
419                                 {x}, {&y}));
420   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
421 }
422 
423 class ConsumeArgumentCallFrame : public CallFrameInterface {
424  public:
ConsumeArgumentCallFrame(Tensor * arg,Tensor * retval)425   ConsumeArgumentCallFrame(Tensor* arg, Tensor* retval)
426       : arg_(arg), retval_(retval) {}
427 
num_args() const428   size_t num_args() const override { return 1; }
num_retvals() const429   size_t num_retvals() const override { return 1; }
430 
GetArg(int index,const Tensor ** val)431   Status GetArg(int index, const Tensor** val) override {
432     LOG(FATAL) << "Should not be called.";
433   }
434 
CanConsumeArg(int index) const435   bool CanConsumeArg(int index) const override { return index == 0; }
436 
ConsumeArg(int index,Tensor * val)437   void ConsumeArg(int index, Tensor* val) override { *val = std::move(*arg_); }
438 
SetRetval(int index,const Tensor & val)439   Status SetRetval(int index, const Tensor& val) override {
440     CHECK_EQ(index, 0);
441     *retval_ = val;
442     return OkStatus();
443   }
444 
445  private:
446   Tensor* const arg_;
447   Tensor* const retval_;
448 };
449 
TEST_F(FunctionLibraryRuntimeTest,XTimesTwo_ConsumeArgument_DefaultExecutor)450 TEST_F(FunctionLibraryRuntimeTest, XTimesTwo_ConsumeArgument_DefaultExecutor) {
451   Init({test::function::XTimesTwo()});
452   auto default_executor = metrics::TestDelta("flr_executor", "default");
453   auto single_threaded = metrics::TestDelta("flr_executor", "single_threaded");
454   FunctionLibraryRuntime::Handle handle;
455   TF_CHECK_OK(flr0_->Instantiate(
456       "XTimesTwo", test::function::Attrs({{"T", DT_FLOAT}}), &handle));
457 
458   auto x = test::AsTensor<float>({1, 2, 3, 4});
459   float* x_base_ptr = &x.flat<float>()(0);
460   Tensor y;
461   ConsumeArgumentCallFrame frame(&x, &y);
462 
463   FunctionLibraryRuntime::Options opts;
464   TF_CHECK_OK(Run(flr0_, handle, opts, &frame));
465 
466   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
467 
468   // Expect that the buffer for `x` has been forwarded to and used as the buffer
469   // for `y`.
470   float* y_base_ptr = &y.flat<float>()(0);
471   EXPECT_EQ(x_base_ptr, y_base_ptr);
472   EXPECT_FALSE(x.IsInitialized());
473 
474   TF_CHECK_OK(flr0_->ReleaseHandle(handle));
475   EXPECT_GT(default_executor.Get(), 0);
476   EXPECT_EQ(single_threaded.Get(), 0);
477 }
478 
TEST_F(FunctionLibraryRuntimeTest,XTimesTwo_ConsumeArgument_SingleThreadedExecutor)479 TEST_F(FunctionLibraryRuntimeTest,
480        XTimesTwo_ConsumeArgument_SingleThreadedExecutor) {
481   Init({test::function::XTimesTwo()});
482   auto default_executor = metrics::TestDelta("flr_executor", "default");
483   auto single_threaded = metrics::TestDelta("flr_executor", "single_threaded");
484   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
485   instantiate_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
486   FunctionLibraryRuntime::Handle handle;
487   TF_CHECK_OK(flr0_->Instantiate("XTimesTwo",
488                                  test::function::Attrs({{"T", DT_FLOAT}}),
489                                  instantiate_opts, &handle));
490 
491   auto x = test::AsTensor<float>({1, 2, 3, 4});
492   float* x_base_ptr = &x.flat<float>()(0);
493   Tensor y;
494   ConsumeArgumentCallFrame frame(&x, &y);
495 
496   FunctionLibraryRuntime::Options opts;
497   TF_CHECK_OK(Run(flr0_, handle, opts, &frame));
498 
499   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
500 
501   // Expect that the buffer for `x` has been forwarded to and used as the buffer
502   // for `y`.
503   float* y_base_ptr = &y.flat<float>()(0);
504   EXPECT_EQ(x_base_ptr, y_base_ptr);
505   EXPECT_FALSE(x.IsInitialized());
506 
507   TF_CHECK_OK(flr0_->ReleaseHandle(handle));
508   EXPECT_EQ(default_executor.Get(), 0);
509   EXPECT_GT(single_threaded.Get(), 0);
510 }
511 
TEST_F(FunctionLibraryRuntimeTest,XTimesN)512 TEST_F(FunctionLibraryRuntimeTest, XTimesN) {
513   Init({test::function::XTimesTwo(), test::function::XTimesFour(),
514         test::function::XTimes16()});
515   auto x = test::AsTensor<float>({1, 2, 3, 4});
516   Tensor y;
517   TF_CHECK_OK(
518       InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y}));
519   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
520   TF_CHECK_OK(
521       InstantiateAndRun(flr0_, "XTimesFour", {{"T", DT_FLOAT}}, {x}, {&y}));
522   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
523   TF_CHECK_OK(
524       InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}));
525   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64}));
526 }
527 
TEST_F(FunctionLibraryRuntimeTest,XTimesNInLibDef)528 TEST_F(FunctionLibraryRuntimeTest, XTimesNInLibDef) {
529   Init({});
530   FunctionDefLibrary proto;
531   *proto.add_function() = test::function::XTimesTwo();
532   *proto.add_function() = test::function::XTimesFour();
533   *proto.add_function() = test::function::XTimes16();
534   std::unique_ptr<FunctionLibraryDefinition> lib_def(
535       new FunctionLibraryDefinition(OpRegistry::Global(), proto));
536 
537   FunctionLibraryRuntime::InstantiateOptions options;
538   options.lib_def = lib_def.get();
539 
540   auto x = test::AsTensor<float>({1, 2, 3, 4});
541   Tensor y;
542 
543   // Ensure that the function is not installed in the base library.
544   HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
545                              {} /* options */, {x}, {&y}),
546            error::NOT_FOUND, "Function XTimesTwo is not defined.");
547 
548   TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
549                                 {x}, {&y}));
550   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
551   TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesFour", {{"T", DT_FLOAT}}, options,
552                                 {x}, {&y}));
553   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
554   TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, options,
555                                 {x}, {&y}));
556   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64}));
557 
558   // Ensure that the function is still not installed in the base library.
559   HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
560                              {} /* options */, {x}, {&y}),
561            error::NOT_FOUND, "Function XTimesTwo is not defined.");
562 }
563 
TEST_F(FunctionLibraryRuntimeTest,XTimesNInLibDefAndDelayedInstantiation)564 TEST_F(FunctionLibraryRuntimeTest, XTimesNInLibDefAndDelayedInstantiation) {
565   using FDH = ::tensorflow::FunctionDefHelper;
566 
567   Init({});
568 
569   // Call XTimesFour via PartitionedCall which delays functions instantiation
570   // to the first call to Compute/ComputeAsync.
571   FunctionDef my_xt4 = FunctionDefHelper::Create(
572       "MyXTimesFour", {"x:float"}, {"z:float"}, {},
573       {{{"x_times_four"},
574         "PartitionedCall",
575         {"x"},
576         {{"Tin", DataTypeSlice({DT_FLOAT})},
577          {"Tout", DataTypeSlice({DT_FLOAT})},
578          {"f", FDH::FunctionRef("XTimesFour", {{"T", DT_FLOAT}})}}}},
579       /* Mapping between function returns and function node outputs. */
580       {{"z", "x_times_four:output:0"}});
581 
582   FunctionDefLibrary lib;
583   *lib.add_function() = test::function::XTimesTwo();
584   *lib.add_function() = test::function::XTimesFour();
585   *lib.add_function() = my_xt4;
586   std::unique_ptr<FunctionLibraryDefinition> lib_def(
587       new FunctionLibraryDefinition(OpRegistry::Global(), lib));
588 
589   FunctionLibraryRuntime::InstantiateOptions options;
590   options.lib_def = lib_def.get();
591 
592   auto x = test::AsTensor<float>({1, 2, 3, 4});
593   Tensor y;
594 
595   // When we instantiate with `options` we should get x*4.
596   TF_CHECK_OK(InstantiateAndRun(flr0_, "MyXTimesFour", {}, options, {x}, {&y}));
597   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
598 
599   // Create options that override XTimesFour body with XTimesTwo body.
600   FunctionDef xt4_override = test::function::XTimesTwo();
601   xt4_override.mutable_signature()->set_name("XTimesFour");
602   FunctionDefLibrary lib_override;
603   *lib_override.add_function() = xt4_override;
604   *lib_override.add_function() = my_xt4;
605   std::unique_ptr<FunctionLibraryDefinition> lib_def_override(
606       new FunctionLibraryDefinition(OpRegistry::Global(), lib_override));
607   options.lib_def = lib_def_override.get();
608 
609   // When we instantiate with `options` we should get x*2.
610   TF_CHECK_OK(InstantiateAndRun(flr0_, "MyXTimesFour", {}, options, {x}, {&y}));
611   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
612 }
613 
TEST_F(FunctionLibraryRuntimeTest,StateHandle)614 TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
615   auto T = DT_INT32;
616 
617   // The expected sequence of outputs from this function is [6, 4, 0, 1, ...].
618   FunctionDef stateful_func = FDH::Define(
619       // Name
620       "RandomUniformWrapper",
621       // Args
622       {},
623       // Return values
624       {"y: int32"},
625       // Attrs
626       {},
627       // Nodes
628       {FDH::Const<int32>("shape", gtl::ArraySlice<int32>({1})),
629        FDH::Const<int32>("minval", 0),
630        FDH::Const<int32>("maxval", 10),
631        // A stateful node.
632        {{"y"},
633         "RandomUniformInt",
634         {"shape", "minval", "maxval"},
635         {{"seed", 37}, {"seed2", 48}, {"Tout", T}, {"T", T}}}});
636   Init({stateful_func});
637 
638   FunctionLibraryRuntime::Handle handle;
639   TF_CHECK_OK(Instantiate(flr0_, "RandomUniformWrapper", {}, &handle));
640 
641   FunctionLibraryRuntime::Options opts;
642   Tensor y;
643   {
644     // Simple case: instantiating with no state_handle.
645     for (int32_t expected : {6, 4}) {
646       TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&y}));
647       test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
648     }
649   }
650 
651   {
652     // Instantiating again with no state_handle should yield the same handle and
653     // the continuation of the same sequence.
654     FunctionLibraryRuntime::Handle handle_non_isolated;
655     TF_CHECK_OK(
656         Instantiate(flr0_, "RandomUniformWrapper", {}, &handle_non_isolated));
657     EXPECT_EQ(handle, handle_non_isolated);
658     for (int32_t expected : {0, 1}) {
659       TF_CHECK_OK(Run(flr0_, handle_non_isolated, opts, {}, {&y}));
660       test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
661     }
662   }
663 
664   {
665     // Instantiating with a given state handle will create new state and yield
666     // the original sequence.
667     FunctionLibraryRuntime::InstantiateOptions options;
668     FunctionLibraryRuntime::Handle handle_isolated;
669     options.state_handle = "handle_1";
670     TF_CHECK_OK(Instantiate(flr0_, "RandomUniformWrapper", {}, options,
671                             &handle_isolated));
672     EXPECT_NE(handle, handle_isolated);
673     for (int32_t expected : {6, 4, 0, 1}) {
674       TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
675       test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
676     }
677   }
678 
679   {
680     // Instantiating with a different given state handle will create new state
681     // and yield the original sequence.
682     FunctionLibraryRuntime::InstantiateOptions options;
683     FunctionLibraryRuntime::Handle handle_isolated;
684     options.state_handle = "handle_2";
685     TF_CHECK_OK(Instantiate(flr0_, "RandomUniformWrapper", {}, options,
686                             &handle_isolated));
687     EXPECT_NE(handle, handle_isolated);
688     for (int32_t expected : {6, 4, 0, 1}) {
689       TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
690       test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
691     }
692   }
693 
694   {
695     // Reinstantiating after releasing a handle will yield the original sequence
696     // multiple times.
697     FunctionLibraryRuntime::InstantiateOptions options;
698     FunctionLibraryRuntime::Handle handle_isolated;
699     options.state_handle = "handle_3";
700 
701     for (int i = 0; i < 2; ++i) {
702       TF_CHECK_OK(Instantiate(flr0_, "RandomUniformWrapper", {}, options,
703                               &handle_isolated));
704       EXPECT_NE(handle, handle_isolated);
705       for (int32_t expected : {6, 4, 0, 1}) {
706         TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
707         test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
708       }
709       TF_CHECK_OK(flr0_->ReleaseHandle(handle_isolated));
710     }
711   }
712 }
713 
714 namespace {
715 class DummyExecutorRegistrar {
716  public:
DummyExecutorRegistrar()717   DummyExecutorRegistrar() {
718     ExecutorFactory::Register("DUMMY", new Factory());
719   }
720 
721  private:
722   class Factory : public ExecutorFactory {
NewExecutor(const LocalExecutorParams & params,const Graph & graph,std::unique_ptr<Executor> * out_executor)723     Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
724                        std::unique_ptr<Executor>* out_executor) override {
725       return errors::Internal("This is a dummy.");
726     }
727   };
728 };
729 static DummyExecutorRegistrar registrar;
730 }  // namespace
731 
TEST_F(FunctionLibraryRuntimeTest,ExecutorFactory)732 TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) {
733   Init({test::function::XTimesTwo()});
734 
735   auto x = test::AsTensor<float>({1, 2, 3, 4});
736   Tensor y;
737 
738   // Test that the default executor works.
739   {
740     FunctionLibraryRuntime::InstantiateOptions options;
741     options.executor_type = "";
742     TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
743                                   options, {x}, {&y}));
744     test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
745   }
746 
747   // Test the explicit registration for the default executor.
748   {
749     FunctionLibraryRuntime::InstantiateOptions options;
750     options.executor_type = "DEFAULT";
751     TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
752                                   options, {x}, {&y}));
753     test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
754   }
755 
756   // Test that a non-default executor factory can be invoked.
757   {
758     FunctionLibraryRuntime::InstantiateOptions options;
759     options.executor_type = "DUMMY";
760     HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
761                                {x}, {&y}),
762              error::INTERNAL, "This is a dummy.");
763   }
764 
765   // Test that a non-default executor factory can be invoked via an attr.
766   {
767     FunctionLibraryRuntime::InstantiateOptions options;
768     HasError(InstantiateAndRun(flr0_, "XTimesTwo",
769                                {{"T", DT_FLOAT}, {"_executor", "DUMMY"}},
770                                options, {x}, {&y}),
771              error::INTERNAL, "This is a dummy.");
772   }
773 
774   // Test that a non-default executor factory specified via an
775   // `InstantiateOptions` supersedes the attr when both are present.
776   {
777     FunctionLibraryRuntime::InstantiateOptions options;
778     options.executor_type = "DUMMY";
779     HasError(
780         InstantiateAndRun(flr0_, "XTimesTwo",
781                           {{"T", DT_FLOAT}, {"_executor", "UNKNOWN_EXECUTOR"}},
782                           options, {x}, {&y}),
783         error::INTERNAL, "This is a dummy.");
784   }
785 
786   // Test that non-existent executor types trigger an error.
787   {
788     FunctionLibraryRuntime::InstantiateOptions options;
789     options.executor_type = "UNKNOWN_EXECUTOR";
790     HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
791                                {x}, {&y}),
792              error::NOT_FOUND,
793              "No executor factory registered for the given executor "
794              "type: UNKNOWN_EXECUTOR");
795   }
796   {
797     FunctionLibraryRuntime::InstantiateOptions options;
798     HasError(
799         InstantiateAndRun(flr0_, "XTimesTwo",
800                           {{"T", DT_FLOAT}, {"_executor", "UNKNOWN_EXECUTOR"}},
801                           options, {x}, {&y}),
802         error::NOT_FOUND,
803         "No executor factory registered for the given executor "
804         "type: UNKNOWN_EXECUTOR");
805   }
806 }
807 
TEST_F(FunctionLibraryRuntimeTest,ExpandInlineFunctions)808 TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
809   Init({test::function::XTimesTwo(), test::function::XTimesFour(),
810         test::function::XTimes16()});
811   std::unique_ptr<Graph> g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}});
812   ASSERT_TRUE(g != nullptr);
813 
814   {
815     Scope s = Scope::NewRootScope();
816     TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
817     auto arg = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
818     auto a = test::function::Call(&s, "x4", "XTimesFour", {arg});
819     auto b = test::function::Call(&s, "y", "XTimesFour", {a});
820     auto ret = ops::_Retval(s.WithOpName("y_RetVal"), b, 0);
821     GraphDef expected;
822     TF_ASSERT_OK(s.ToGraphDef(&expected));
823 
824     GraphDef actual;
825     g->ToGraphDef(&actual);
826     TF_EXPECT_GRAPH_EQ(expected, actual);
827   }
828 
829   ExpandInlineFunctions(flr0_, g.get());
830   {
831     Scope s = Scope::NewRootScope();
832     TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
833     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
834     auto func0 = ops::Identity(s.WithOpName("Func/x4/input/_0"), x);
835     auto x4_x2 = test::function::Call(&s, "x4/x2", "XTimesTwo", {func0});
836     auto x4_y = test::function::Call(&s, "x4/y", "XTimesTwo", {x4_x2});
837     auto func1 = ops::Identity(s.WithOpName("Func/x4/output/_1"), x4_y);
838     auto func2 = ops::Identity(s.WithOpName("Func/y/input/_2"), func1);
839     auto y_x2 = test::function::Call(&s, "y/x2", "XTimesTwo", {func2});
840     auto y_y = test::function::Call(&s, "y/y", "XTimesTwo", {y_x2});
841     auto func3 = ops::Identity(s.WithOpName("Func/y/output/_3"), y_y);
842     auto ret = ops::_Retval(s.WithOpName("y_RetVal"), func3, 0);
843     GraphDef expected;
844     TF_ASSERT_OK(s.ToGraphDef(&expected));
845 
846     GraphDef actual;
847     g->ToGraphDef(&actual);
848     TF_EXPECT_GRAPH_EQ(expected, actual);
849   }
850 
851   ExpandInlineFunctions(flr0_, g.get());
852   GraphDef e2;
853   {
854     Scope s = Scope::NewRootScope();
855     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
856     auto x4_x2_two = ops::Const<int64_t>(s.WithOpName("x4/x2/two"), int64_t{2});
857     auto x4_y_two = ops::Const<int64_t>(s.WithOpName("x4/y/two"), int64_t{2});
858     auto y_x2_two = ops::Const<int64_t>(s.WithOpName("y/x2/two"), int64_t{2});
859     auto y_y_two = ops::Const<int64_t>(s.WithOpName("y/y/two"), int64_t{2});
860     auto x4_x2_scale =
861         ops::Cast(s.WithOpName("x4/x2/scale"), x4_x2_two, DT_FLOAT);
862     auto x4_y_scale = ops::Cast(s.WithOpName("x4/y/scale"), x4_y_two, DT_FLOAT);
863     auto y_x2_scale = ops::Cast(s.WithOpName("y/x2/scale"), y_x2_two, DT_FLOAT);
864     auto y_y_scale = ops::Cast(s.WithOpName("y/y/scale"), y_y_two, DT_FLOAT);
865     auto func0 = ops::Identity(s.WithOpName("Func/x4/input/_0"), x);
866     auto func4 = ops::Identity(s.WithOpName("Func/x4/x2/input/_4"), func0);
867     auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), func4, x4_x2_scale);
868     auto func5 = ops::Identity(s.WithOpName("Func/x4/x2/output/_5"), x4_x2_y);
869     auto func6 = ops::Identity(s.WithOpName("Func/x4/y/input/_6"), func5);
870     auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), func6, x4_y_scale);
871     auto func7 = ops::Identity(s.WithOpName("Func/x4/y/output/_7"), x4_y_y);
872     auto func1 = ops::Identity(s.WithOpName("Func/x4/output/_1"), func7);
873     auto func2 = ops::Identity(s.WithOpName("Func/y/input/_2"), func1);
874     auto func8 = ops::Identity(s.WithOpName("Func/y/x2/input/_8"), func2);
875     auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), func8, y_x2_scale);
876     auto func9 = ops::Identity(s.WithOpName("Func/y/x2/output/_9"), y_x2_y);
877     auto func10 = ops::Identity(s.WithOpName("Func/y/y/input/_10"), func9);
878     auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), func10, y_y_scale);
879     auto func11 = ops::Identity(s.WithOpName("Func/y/y/output/_11"), y_y_y);
880     auto func3 = ops::Identity(s.WithOpName("Func/y/output/_3"), func11);
881     auto ret = ops::_Retval(s.WithOpName("y_RetVal"), func3, 0);
882     TF_ASSERT_OK(s.ToGraphDef(&e2));
883 
884     GraphDef actual;
885     g->ToGraphDef(&actual);
886     TF_EXPECT_GRAPH_EQ(e2, actual);
887   }
888 
889   // No further inlining.
890   ExpandInlineFunctions(flr0_, g.get());
891   {
892     GraphDef actual;
893     g->ToGraphDef(&actual);
894     TF_EXPECT_GRAPH_EQ(e2, actual);
895   }
896 
897   // Get rid of redundant Identity nodes.
898   RemoveIdentityNodes(g.get());
899   {
900     Scope s = Scope::NewRootScope();
901     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
902     auto x4_x2_two = ops::Const<int64_t>(s.WithOpName("x4/x2/two"), int64_t{2});
903     auto x4_y_two = ops::Const<int64_t>(s.WithOpName("x4/y/two"), int64_t{2});
904     auto y_x2_two = ops::Const<int64_t>(s.WithOpName("y/x2/two"), int64_t{2});
905     auto y_y_two = ops::Const<int64_t>(s.WithOpName("y/y/two"), int64_t{2});
906     auto x4_x2_scale =
907         ops::Cast(s.WithOpName("x4/x2/scale"), x4_x2_two, DT_FLOAT);
908     auto x4_y_scale = ops::Cast(s.WithOpName("x4/y/scale"), x4_y_two, DT_FLOAT);
909     auto y_x2_scale = ops::Cast(s.WithOpName("y/x2/scale"), y_x2_two, DT_FLOAT);
910     auto y_y_scale = ops::Cast(s.WithOpName("y/y/scale"), y_y_two, DT_FLOAT);
911     auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
912     auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), x4_x2_y, x4_y_scale);
913     auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), x4_y_y, y_x2_scale);
914     auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), y_x2_y, y_y_scale);
915     auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y_y_y, 0);
916     GraphDef expected;
917     TF_ASSERT_OK(s.ToGraphDef(&expected));
918 
919     GraphDef actual;
920     g->ToGraphDef(&actual);
921     TF_EXPECT_GRAPH_EQ(expected, actual);
922   }
923 }
924 
925 // Verifies that control dependencies on the caller are added as control
926 // dependencies on any function calls created by inlining.
TEST_F(FunctionLibraryRuntimeTest,ExpandInlineFunctionsWithInputControlEdges)927 TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithInputControlEdges) {
928   Init({test::function::XTimesTwo(), test::function::XTimesFour()});
929 
930   std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
931   {
932     Scope s = Scope::NewRootScope();
933     TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
934     auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
935     auto c = ops::NoOp(s.WithOpName("c"));
936     auto b = test::function::Call(&s, "b", "XTimesFour", {a});
937     s.graph()->AddControlEdge(c.operation.node(), b.node());
938     auto ret = ops::_Retval(s.WithOpName("b_RetVal"), b, 0);
939     TF_ASSERT_OK(s.ToGraph(g.get()));
940   }
941 
942   ExpandInlineFunctions(flr0_, g.get());
943   {
944     Scope s = Scope::NewRootScope();
945     TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
946     auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
947     auto c = ops::NoOp(s.WithOpName("c"));
948     auto func0 = ops::NoOp(s.WithOpName("Func/b/input_control_node/_0")
949                                .WithControlDependencies({c}));
950     auto func1 = ops::Identity(
951         s.WithOpName("Func/b/input/_1").WithControlDependencies({func0}), a);
952     auto b_x2 = test::function::Call(&s, "b/x2", "XTimesTwo", {func1});
953     s.graph()->AddControlEdge(func0.operation.node(), b_x2.node());
954     auto b_y = test::function::Call(&s, "b/y", "XTimesTwo", {b_x2});
955     s.graph()->AddControlEdge(func0.operation.node(), b_y.node());
956     auto func2 = ops::Identity(s.WithOpName("Func/b/output/_2"), b_y);
957     auto ret = ops::_Retval(s.WithOpName("b_RetVal"), func2, 0);
958     GraphDef expected;
959     TF_ASSERT_OK(s.ToGraphDef(&expected));
960 
961     GraphDef actual;
962     g->ToGraphDef(&actual);
963     TF_EXPECT_GRAPH_EQ(expected, actual);
964   }
965 
966   ExpandInlineFunctions(flr0_, g.get());
967   {
968     Scope s = Scope::NewRootScope();
969     TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
970     auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
971     auto c = ops::NoOp(s.WithOpName("c"));
972     auto func0 = ops::NoOp(s.WithOpName("Func/b/input_control_node/_0")
973                                .WithControlDependencies({c}));
974     auto func1 = ops::Identity(
975         s.WithOpName("Func/b/input/_1").WithControlDependencies({func0}), a);
976 
977     auto func3 = ops::NoOp(s.WithOpName("Func/b/x2/input_control_node/_3")
978                                .WithControlDependencies({func0}));
979     auto func4 = ops::Identity(
980         s.WithOpName("Func/b/x2/input/_4").WithControlDependencies({func3}),
981         func1);
982     auto b_x2_two = ops::Const(
983         s.WithOpName("b/x2/two").WithControlDependencies({func3}), int64_t{2});
984     auto b_x2_scale = ops::Cast(s.WithOpName("b/x2/scale"), b_x2_two, DT_FLOAT);
985     auto b_x2_y = ops::Mul(s.WithOpName("b/x2/y"), func4, b_x2_scale);
986     auto func5 = ops::Identity(s.WithOpName("Func/b/x2/output/_5"), b_x2_y);
987 
988     auto func6 = ops::NoOp(s.WithOpName("Func/b/y/input_control_node/_6")
989                                .WithControlDependencies({func0}));
990     auto func7 = ops::Identity(
991         s.WithOpName("Func/b/y/input/_7").WithControlDependencies({func6}),
992         func5);
993     auto b_y_two = ops::Const(
994         s.WithOpName("b/y/two").WithControlDependencies({func6}), int64_t{2});
995     auto b_y_scale = ops::Cast(s.WithOpName("b/y/scale"), b_y_two, DT_FLOAT);
996     auto b_y_y = ops::Mul(s.WithOpName("b/y/y"), func7, b_y_scale);
997     auto func8 = ops::Identity(s.WithOpName("Func/b/y/output/_8"), b_y_y);
998 
999     auto func2 = ops::Identity(s.WithOpName("Func/b/output/_2"), func8);
1000     auto ret = ops::_Retval(s.WithOpName("b_RetVal"), func2, 0);
1001 
1002     GraphDef expected;
1003     TF_ASSERT_OK(s.ToGraphDef(&expected));
1004 
1005     GraphDef actual;
1006     g->ToGraphDef(&actual);
1007     TF_EXPECT_GRAPH_EQ(expected, actual);
1008   }
1009 }
1010 
TEST_F(FunctionLibraryRuntimeTest,ExpandInlineFunctionsWithOutputControlEdges)1011 TEST_F(FunctionLibraryRuntimeTest,
1012        ExpandInlineFunctionsWithOutputControlEdges) {
1013   using test::function::NDef;
1014 
1015   // `add` node is not required to compute regular output `o`, but it must
1016   // execute because it is in `control_ret`.
1017   const FunctionDef func =
1018       FDH::Create("AddAndMul", {"i: float"}, {"o: float"}, {},
1019                   {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
1020                    {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
1021                   /*ret_def=*/{{"o", "ret:z:0"}},
1022                   /*control_ret_def=*/{{"must_execute", "add"}});
1023 
1024   Init({func});
1025 
1026   // Construct a graph for the function call:
1027   //
1028   //   a = Arg[dtype=DT_FLOAT]
1029   //   b = AddAndMul(a)
1030   //   c = NoOp(^b)
1031   //   ret = RetVal(b, ^c)
1032   const auto init_graph = [this](std::unique_ptr<Graph>* g) -> void {
1033     *g = std::make_unique<Graph>(OpRegistry::Global());
1034 
1035     Scope s = Scope::NewRootScope();
1036     TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
1037     auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
1038     auto b = test::function::Call(&s, "b", "AddAndMul", {a});
1039     auto c = ops::NoOp(s.WithOpName("c"));
1040     auto ret = ops::_Retval(s.WithOpName("ret"), b, 0);
1041     s.graph()->AddControlEdge(b.node(), c.operation.node());
1042     s.graph()->AddControlEdge(c.operation.node(), ret.operation.node());
1043     TF_ASSERT_OK(s.ToGraph(g->get()));
1044   };
1045 
1046   std::unique_ptr<Graph> g;
1047   ExpandInlineFunctionsOptions opts;
1048 
1049   const string input_node = "Func/b/input/_0";
1050   const string output_node = "Func/b/output/_1";
1051   const string output_control_node = "Func/b/output_control_node/_2";
1052 
1053   // Use data outputs as output control source.
1054   opts.native_options.output_control_src = OutputControlSrc::kDataOutputs;
1055 
1056   init_graph(&g);
1057   ExpandInlineFunctions(flr0_, g.get(), opts);
1058   {
1059     GraphDef expected = test::function::GDef(
1060         {NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}),
1061          NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}),
1062          NDef("b/add", "Add", {input_node, input_node}, {{"T", DT_FLOAT}}),
1063          NDef("b/ret", "Mul", {input_node, input_node}, {{"T", DT_FLOAT}}),
1064          NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}}),
1065          NDef(output_control_node, "NoOp", {"^Func/b/output/_1"}, {}),
1066          NDef("c", "NoOp", {"^" + output_control_node}, {}),
1067          NDef("ret", "_Retval", {output_node, "^c"},
1068               {{"T", DT_FLOAT}, {"index", 0}})},
1069         {func});
1070 
1071     GraphDef actual;
1072     g->ToGraphDef(&actual);
1073     TF_EXPECT_GRAPH_EQ(expected, actual);
1074   }
1075 
1076   // Use control outputs as output control source.
1077   opts.native_options.output_control_src = OutputControlSrc::kControlOutputs;
1078 
1079   init_graph(&g);
1080   ExpandInlineFunctions(flr0_, g.get(), opts);
1081   {
1082     GraphDef expected = test::function::GDef(
1083         {NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}),
1084          NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}),
1085          NDef("b/add", "Add", {input_node, input_node}, {{"T", DT_FLOAT}}),
1086          NDef("b/ret", "Mul", {input_node, input_node}, {{"T", DT_FLOAT}}),
1087          NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}}),
1088          NDef(output_control_node, "NoOp", {"^b/add"}, {}),
1089          NDef("c", "NoOp", {"^" + output_control_node}, {}),
1090          NDef("ret", "_Retval", {output_node, "^c"},
1091               {{"T", DT_FLOAT}, {"index", 0}})},
1092         {func});
1093 
1094     GraphDef actual;
1095     g->ToGraphDef(&actual);
1096     TF_EXPECT_GRAPH_EQ(expected, actual);
1097   }
1098 }
1099 
TEST_F(FunctionLibraryRuntimeTest,ExpandInlineFunctionsAndKeepCallerNode)1100 TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndKeepCallerNode) {
1101   using test::function::NDef;
1102   using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
1103 
1104   const FunctionDef func =
1105       FDH::Create("AddAndMul", {"i: float"}, {"o: float"}, {},
1106                   {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
1107                    {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
1108                   /*ret_def=*/{{"o", "ret:z:0"}},
1109                   /*control_ret_def=*/{{"must_execute", "add"}});
1110   Init({func});
1111 
1112   // Construct a graph:
1113   //   a = Arg[dtype=DT_FLOAT]
1114   //   b = FunctionWithControlOutputs(a)
1115   auto construct_graph = [this](std::unique_ptr<Graph>* g) -> Status {
1116     Scope s = Scope::NewRootScope();
1117     TF_RETURN_IF_ERROR(s.graph()->AddFunctionLibrary(fdef_lib_));
1118     auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
1119     auto b = test::function::Call(&s, "b", "AddAndMul", {a});
1120     TF_RETURN_IF_ERROR(s.ToGraph(g->get()));
1121     return OkStatus();
1122   };
1123 
1124   const string input_node = "Func/b/input/_0";
1125   const string output_node = "Func/b/output/_1";
1126   const string output_control_node = "Func/b/output_control_node/_2";
1127 
1128   // Construct expected graph after function inlining.
1129   auto expected_graph = [&](const NodeDef& caller) -> GraphDef {
1130     return test::function::GDef(
1131         {
1132             NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}),
1133             NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}),
1134             NDef("b/add", "Add", {input_node, input_node}, {{"T", DT_FLOAT}}),
1135             NDef("b/ret", "Mul", {input_node, input_node}, {{"T", DT_FLOAT}}),
1136             NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}}),
1137             NDef(output_control_node, "NoOp", {"^b/add"}, {}),
1138             caller,  // Keep node in a graph with the same name as caller node.
1139         },
1140         {func});
1141   };
1142 
1143   ExpandInlineFunctionsOptions opts;
1144   opts.native_options.output_control_src = OutputControlSrc::kControlOutputs;
1145 
1146   // Keep inlined function call node fetchable.
1147   {
1148     opts.native_options.keep_caller_node = KeepCallerNode::kFetchable;
1149 
1150     std::unique_ptr<Graph> g = std::make_unique<Graph>(OpRegistry::Global());
1151     TF_ASSERT_OK(construct_graph(&g));
1152 
1153     ExpandInlineFunctions(flr0_, g.get(), opts);
1154     GraphDef expected =
1155         expected_graph(/*caller=*/
1156                        NDef("b", "IdentityN",
1157                             {output_node, "^" + output_control_node},
1158                             {{"T", DataTypeSlice{DT_FLOAT}}}));
1159 
1160     GraphDef actual;
1161     g->ToGraphDef(&actual);
1162     TF_EXPECT_GRAPH_EQ(expected, actual);
1163   }
1164 
1165   // Keep inlined function call node targetable.
1166   {
1167     opts.native_options.keep_caller_node = KeepCallerNode::kTargetable;
1168 
1169     std::unique_ptr<Graph> g = std::make_unique<Graph>(OpRegistry::Global());
1170     TF_ASSERT_OK(construct_graph(&g));
1171 
1172     ExpandInlineFunctions(flr0_, g.get(), opts);
1173     GraphDef expected =
1174         expected_graph(/*caller=*/
1175                        NDef("b", "NoOp", {"^" + output_control_node}, {}));
1176 
1177     GraphDef actual;
1178     g->ToGraphDef(&actual);
1179     TF_EXPECT_GRAPH_EQ(expected, actual);
1180   }
1181 }
1182 
TEST_F(FunctionLibraryRuntimeTest,ExpandInlineFunctionsAndPlaceInlinedNodes)1183 TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndPlaceInlinedNodes) {
1184   using test::function::NDef;
1185   using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
1186 
1187   const string arg_device = "/job:arg/replica:0/task:0/device:GPU";
1188   const string call_device = "/job:call/replica:0/task:1/device:GPU";
1189   const string body_device = "/job:body/replica:0/task:1/device:CPU";
1190 
1191   const FunctionDef func = FDH::Create(
1192       "AddFunc", {"i: float"}, {"o: float"}, {},
1193       {{{"ret"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}, {}, body_device}},
1194       /*ret_def=*/{{"o", "ret:z:0"}});
1195   Init({func});
1196 
1197   // Construct a graph:
1198   //   a = Arg[dtype=DT_FLOAT, _device=arg_device]
1199   //   b = AddFunc[_device=call_device](a)
1200   auto construct_graph = [&](std::unique_ptr<Graph>* g) -> Status {
1201     Scope s = Scope::NewRootScope();
1202     TF_RETURN_IF_ERROR(s.graph()->AddFunctionLibrary(fdef_lib_));
1203     auto a = ops::_Arg(s.WithOpName("a").WithDevice(arg_device), DT_FLOAT, 0);
1204     auto b = test::function::Call(&s, "b", "AddFunc", {a});
1205     TF_RETURN_IF_ERROR(s.ToGraph(g->get()));
1206     for (Node* node : (*g)->op_nodes()) {
1207       if (node->name() == "b") node->set_requested_device(call_device);
1208     }
1209     return OkStatus();
1210   };
1211 
1212   const string input_node = "Func/b/input/_0";
1213   const string output_node = "Func/b/output/_1";
1214   const string output_control_node = "Func/b/output_control_node/_2";
1215 
1216   // Construct expected graph after function inlining.
1217   auto expected_graph = [&](const std::vector<string>& placed) -> GraphDef {
1218     return test::function::GDef(
1219         {
1220             NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}, placed[0]),
1221             NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}, placed[1]),
1222             NDef("b/ret", "Add", {input_node, input_node}, {{"T", DT_FLOAT}},
1223                  placed[2]),
1224             NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}},
1225                  placed[3]),
1226             NDef(output_control_node, "NoOp", {"^" + output_node}, {},
1227                  placed[4]),
1228         },
1229         {func});
1230   };
1231 
1232   ExpandInlineFunctionsOptions opts;
1233   opts.native_options.keep_caller_node = KeepCallerNode::kDoNotKeep;
1234 
1235   // Place only input nodes to match input device.
1236   {
1237     opts.native_options.inlined_function_body_placer =
1238         InlinedFunctionBodyPlacer::Default();
1239 
1240     auto g = std::make_unique<Graph>(OpRegistry::Global());
1241     TF_ASSERT_OK(construct_graph(&g));
1242 
1243     ExpandInlineFunctions(flr0_, g.get(), opts);
1244     GraphDef expected = expected_graph({/*a*/ arg_device,       //
1245                                         /*input*/ arg_device,   //
1246                                         /*body*/ body_device,   //
1247                                         /*output*/ "",          //
1248                                         /*control_output*/ ""}  //
1249     );
1250 
1251     GraphDef actual;
1252     g->ToGraphDef(&actual);
1253     TF_EXPECT_GRAPH_EQ(expected, actual);
1254   }
1255 
1256   // Place all nodes on the call node device.
1257   {
1258     opts.native_options.inlined_function_body_placer =
1259         InlinedFunctionBodyPlacer::SingleDevice();
1260 
1261     auto g = std::make_unique<Graph>(OpRegistry::Global());
1262     TF_ASSERT_OK(construct_graph(&g));
1263 
1264     ExpandInlineFunctions(flr0_, g.get(), opts);
1265     GraphDef expected = expected_graph({/*a*/ arg_device,                //
1266                                         /*input*/ call_device,           //
1267                                         /*body*/ call_device,            //
1268                                         /*output*/ call_device,          //
1269                                         /*control_output*/ call_device}  //
1270     );
1271 
1272     GraphDef actual;
1273     g->ToGraphDef(&actual);
1274     TF_EXPECT_GRAPH_EQ(expected, actual);
1275   }
1276 
1277   // Multi device function placement.
1278   {
1279     opts.native_options.inlined_function_body_placer =
1280         InlinedFunctionBodyPlacer::MultiDevice();
1281 
1282     auto g = std::make_unique<Graph>(OpRegistry::Global());
1283     TF_ASSERT_OK(construct_graph(&g));
1284 
1285     const string merged_device = "/job:body/replica:0/task:1/device:CPU:*";
1286 
1287     ExpandInlineFunctions(flr0_, g.get(), opts);
1288     GraphDef expected = expected_graph({/*a*/ arg_device,                //
1289                                         /*input*/ arg_device,            //
1290                                         /*body*/ merged_device,          //
1291                                         /*output*/ "",                   //
1292                                         /*control_output*/ call_device}  //
1293     );
1294 
1295     GraphDef actual;
1296     g->ToGraphDef(&actual);
1297     TF_EXPECT_GRAPH_EQ(expected, actual);
1298   }
1299 }
1300 
TEST_F(FunctionLibraryRuntimeTest,PruneBody)1301 TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
1302   auto T = DT_INT32;
1303   FunctionDef stateful_func = FDH::Define(
1304       // Name
1305       "SquareAndAddOneWithStatefulNodes",
1306       // Args
1307       {"x: int32", "y: float32"},
1308       // Return values
1309       {"z: int32"},
1310       // Attrs
1311       {},
1312       // Nodes
1313       {// a = Square<T>(x)
1314        {{"a"}, "Square", {"x"}, {{"T", T}}},
1315        // 1
1316        FDH::Const("o", 1),
1317        // A bunch of extra arithmetic that y doesn't depend on
1318        {{"x1"}, "Add", {"o", "o"}, {{"T", T}}},
1319        {{"x2"}, "Mul", {"a", "x1"}, {{"T", T}}},
1320        {{"x3"}, "Mul", {"x1", "x2"}, {{"T", T}}},
1321        FDH::Const<int32>("shape", {1, 2}),
1322        // A stateful node.
1323        {{"keep_me"},
1324         "RandomUniform",
1325         {"shape"},
1326         {{"T", T}, {"dtype", DT_FLOAT}}},
1327        // z = Add<T>(a, o)
1328        {{"z"}, "Add", {"a", "o"}, {{"T", T}}}});
1329   Init({stateful_func});
1330 
1331   auto x = test::AsTensor<int32>({1, 2, 3, 4});
1332   auto y = test::AsTensor<float>({1.0, 2.0, 3.0, 4.0});
1333   Tensor z;
1334 
1335   FunctionLibraryRuntime::Handle handle;
1336   TF_CHECK_OK(
1337       Instantiate(flr0_, "SquareAndAddOneWithStatefulNodes", {}, &handle));
1338 
1339   StepStats stats;
1340   StepStatsCollector stats_collector(&stats);
1341   FunctionLibraryRuntime::Options opts;
1342   opts.stats_collector = &stats_collector;
1343   TF_CHECK_OK(Run(flr0_, handle, opts, {x, y}, {&z}));
1344   TF_CHECK_OK(flr0_->ReleaseHandle(handle));
1345 
1346   TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {},
1347                                 {x, y}, {&z}));
1348   test::ExpectTensorEqual<int>(z, test::AsTensor<int32>({2, 5, 10, 17}));
1349 
1350   stats_collector.FinalizeAndSwap(&stats);
1351 
1352   // Note that we do not expect the nodes named "y", "x1", "x2", or "x3" to
1353   // execute.
1354   std::set<string> expected_node_names(
1355       {"_SOURCE", "shape", "x", "o", "a", "keep_me", "z", "z_RetVal"});
1356   std::set<string> executed_node_names;
1357   for (const auto& node_stats : stats.dev_stats()[0].node_stats()) {
1358     executed_node_names.insert(node_stats.node_name());
1359   }
1360   EXPECT_EQ(expected_node_names, executed_node_names);
1361 }
1362 
TEST_F(FunctionLibraryRuntimeTest,DoNotPruneControlOutputsFromBody)1363 TEST_F(FunctionLibraryRuntimeTest, DoNotPruneControlOutputsFromBody) {
1364   // `add` node is not required to compute regular output `o`, but it must
1365   // execute because it is in `control_ret`.
1366   const FunctionDef func =
1367       FDH::Create("FunctionWithControlOutputs", {"i: float"}, {"o: float"}, {},
1368                   {
1369                       {{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
1370                       {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}},
1371                   },
1372                   /*ret_def=*/{{"o", "ret:z:0"}},
1373                   /*control_ret_def=*/{{"must_execute", "add"}});
1374 
1375   Init({func});
1376 
1377   auto x = test::AsTensor<float>({1.25});
1378   Tensor z;
1379 
1380   FunctionLibraryRuntime::Handle handle;
1381   TF_CHECK_OK(Instantiate(flr1_, "FunctionWithControlOutputs", {}, &handle));
1382 
1383   StepStats stats;
1384   StepStatsCollector stats_collector(&stats);
1385   FunctionLibraryRuntime::Options opts;
1386   opts.stats_collector = &stats_collector;
1387   TF_CHECK_OK(Run(flr1_, handle, opts, {x}, {&z}));
1388   TF_CHECK_OK(flr1_->ReleaseHandle(handle));
1389 
1390   TF_CHECK_OK(
1391       InstantiateAndRun(flr1_, "FunctionWithControlOutputs", {}, {x}, {&z}));
1392   test::ExpectTensorEqual<float>(z, test::AsTensor<float>({1.25 * 1.25}));
1393 
1394   stats_collector.FinalizeAndSwap(&stats);
1395 
1396   std::set<string> expected_node_names(
1397       {"_SOURCE", "i", "add", "ret", "o_RetVal"});
1398   std::set<string> executed_node_names;
1399   for (const auto& node_stats : stats.dev_stats()[0].node_stats()) {
1400     executed_node_names.insert(node_stats.node_name());
1401   }
1402   EXPECT_EQ(expected_node_names, executed_node_names);
1403 }
1404 
TEST_F(FunctionLibraryRuntimeTest,OptimizeGraph)1405 TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
1406   Init({test::function::XTimesTwo(), test::function::XTimesFour(),
1407         test::function::XTimes16()});
1408   std::unique_ptr<Graph> g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}});
1409   ASSERT_TRUE(g != nullptr);
1410   ExpandInlineFunctions(flr0_, g.get());
1411   OptimizeGraph(flr0_, &g);
1412   {
1413     Scope s = Scope::NewRootScope();
1414     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1415     auto x4_x2_scale = ops::Const<float>(
1416         s.WithOpName("x4/x2/scale/_12__cf__0")
1417             .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
1418         2.0f);
1419     auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
1420     auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), x4_x2_y, x4_x2_scale);
1421     auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), x4_y_y, x4_x2_scale);
1422     auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), y_x2_y, x4_x2_scale);
1423     auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y_y_y, 0);
1424     GraphDef expected;
1425     TF_ASSERT_OK(s.ToGraphDef(&expected));
1426 
1427     GraphDef actual;
1428     g->ToGraphDef(&actual);
1429     TF_EXPECT_GRAPH_EQ(expected, actual);
1430   }
1431 }
1432 
TEST_F(FunctionLibraryRuntimeTest,ManySwapsNodeDef)1433 TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) {
1434   auto func = FDH::Create(  // Creates a FunctionDef using NodeDefs
1435                             // Name
1436       "ManySwapsNodeDef",
1437       // Input
1438       {"x: float", "y: float"},
1439       // Output
1440       {"o: float"},
1441       // Attr
1442       {},
1443       // Nodes
1444       {{{"a"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}},
1445        {{"b"}, "Swap", {"a:o0", "a:o1"}, {{"T", DT_FLOAT}}},
1446        {{"c"}, "Swap", {"b:o0", "b:o1"}, {{"T", DT_FLOAT}}},
1447        {{"d"}, "Swap", {"c:o0", "c:o1"}, {{"T", DT_FLOAT}}},
1448        {{"e"}, "Swap", {"d:o0", "d:o1"}, {{"T", DT_FLOAT}}},
1449        {{"f"}, "Swap", {"e:o0", "e:o1"}, {{"T", DT_FLOAT}}},
1450        {{"g"}, "Identity", {"f:o0"}, {{"T", DT_FLOAT}}}},
1451       // Return
1452       {{"o", "g:output"}});
1453   Init({test::function::Swap(), func});
1454   std::unique_ptr<Graph> g = GetFuncBody(flr0_, "ManySwapsNodeDef", {});
1455   ASSERT_TRUE(g != nullptr);
1456   OptimizeGraph(flr0_, &g);
1457   const char* e0 = R"P(
1458 (n2:float, n3:float) -> (n2:float) {
1459 }
1460 )P";
1461   EXPECT_EQ(e0, DebugString(g.get()));
1462 }
1463 
TEST_F(FunctionLibraryRuntimeTest,ControlDeps)1464 TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
1465   auto func = FDH::Create(
1466       // Name
1467       "ManySwapsFirst",
1468       // Args
1469       {"x: float", "y: float"},
1470       // Return values
1471       {"o: float"},
1472       // attr def
1473       {},
1474       // Nodes
1475       //
1476       // o = x*x + y*y.  Furthermore, The 1st swap depends on x2, and
1477       // y2 depends on the 2nd swap.  The 2nd swap has data dependency
1478       // on the 1st swap. The optimization should maintain the control
1479       // dependencies.
1480       {{{"a0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}, {"x2"}},
1481        {{"a1"}, "Swap", {"a0:o0:0", "a0:o1:0"}, {{"T", DT_FLOAT}}},
1482        {{"x2"}, "Mul", {"x", "x"}, {{"T", DT_FLOAT}}},
1483        {{"y2"}, "Mul", {"y", "y"}, {{"T", DT_FLOAT}}, {"a1"}},
1484        {{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}},
1485       {{"o", "o:z:0"}});
1486   Init({test::function::Swap(), func});
1487   std::unique_ptr<Graph> g = GetFuncBody(flr0_, "ManySwapsFirst", {});
1488   ASSERT_TRUE(g != nullptr);
1489   OptimizeGraph(flr0_, &g);
1490 
1491   // NOTE: We can remove func0, func1, func2, func9 with a control edge
1492   // n8->n5. But we don't have a pass doing that.
1493   {
1494     Scope s = Scope::NewRootScope();
1495     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1496     auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
1497     auto x2 = ops::Mul(s.WithOpName("x2"), x, x);
1498     auto func0 = ops::NoOp(s.WithOpName("Func/a0/input_control_node/_0")
1499                                .WithControlDependencies(x2));
1500     auto func1 = ops::Identity(
1501         s.WithOpName("Func/a0/input/_1").WithControlDependencies({func0}), x);
1502     auto func2 = ops::Identity(
1503         s.WithOpName("Func/a0/input/_2").WithControlDependencies({func0}), y);
1504     auto func9 = ops::NoOp(
1505         s.WithOpName("Func/a1/output_control_node/_9")
1506             .WithControlDependencies({func1.output.op(), func2.output.op()}));
1507     auto y2 =
1508         ops::Mul(s.WithOpName("y2").WithControlDependencies({func9}), y, y);
1509     auto o = ops::Add(s.WithOpName("o"), x2, y2);
1510     auto ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0);
1511     GraphDef expected;
1512     TF_ASSERT_OK(s.ToGraphDef(&expected));
1513 
1514     GraphDef actual;
1515     g->ToGraphDef(&actual);
1516     TF_EXPECT_GRAPH_EQ(expected, actual);
1517   }
1518 }
1519 
TEST_F(FunctionLibraryRuntimeTest,Error_NotFound)1520 TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) {
1521   Init({test::function::XTimesTwo(), test::function::XTimesFour()});
1522   auto x = test::AsTensor<float>({1, 2, 3, 4});
1523   Tensor y;
1524   HasError(InstantiateAndRun(flr0_, "Foo", {{"T", DT_FLOAT}}, {x}, {&y}),
1525            error::NOT_FOUND, "Function Foo is not defined.");
1526 }
1527 
TEST_F(FunctionLibraryRuntimeTest,Error_InstantiationError)1528 TEST_F(FunctionLibraryRuntimeTest, Error_InstantiationError) {
1529   auto bad_x_times_two = FDH::Define(
1530       // Name
1531       "XTimesTwo",
1532       // Args
1533       {"x: T"},
1534       // Return values
1535       {"y: T"},
1536       // Attr def
1537       {"T: {float, double, int32, int64}"},
1538       // Nodes
1539       {
1540           {{"y"}, "Add", {"x", "x"}, {{"no_T", "$T"}}},
1541       });
1542   Init({bad_x_times_two, test::function::XTimesFour(),
1543         test::function::XTimes16()});
1544 
1545   // Instantiating "XTimesTwo" should fail.
1546   FunctionLibraryRuntime::Handle handle;
1547   HasError(flr0_->Instantiate(
1548                "XTimesTwo", test::function::Attrs({{"T", DT_FLOAT}}), &handle),
1549            error::NOT_FOUND, "type attr not found");
1550 
1551   // But XTimesFour and XTimes16 instantiation should succeed. Only
1552   // when they run, they fail because XTimesTwo is bad.
1553   TF_CHECK_OK(flr0_->Instantiate(
1554       "XTimesFour", test::function::Attrs({{"T", DT_FLOAT}}), &handle));
1555   TF_CHECK_OK(flr0_->Instantiate(
1556       "XTimes16", test::function::Attrs({{"T", DT_FLOAT}}), &handle));
1557 
1558   auto x = test::AsTensor<float>({1, 2, 3, 4});
1559   Tensor y;
1560   HasError(InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}),
1561            error::NOT_FOUND, "type attr not found");
1562 }
1563 
TEST_F(FunctionLibraryRuntimeTest,Error_BadControlFlow)1564 TEST_F(FunctionLibraryRuntimeTest, Error_BadControlFlow) {
1565   Init({test::function::InvalidControlFlow()});
1566   auto x = test::AsTensor<int32>({0});
1567   DCHECK_EQ(x.dtype(), DT_INT32);
1568   Tensor y;
1569   HasError(InstantiateAndRun(flr0_, "InvalidControlFlow", {}, {x}, {&y}),
1570            error::INVALID_ARGUMENT,
1571            "{{node add}} has inputs from different frames. The input"
1572            " {{node enter}} is in frame 'while'. The input {{node i}} is in"
1573            " frame ''.");
1574 }
1575 
TEST_F(FunctionLibraryRuntimeTest,Gradient_XTimesTwo)1576 TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
1577   Init({test::function::XTimesTwo(), test::function::XTimesFour(),
1578         test::function::XTimes16()});
1579   std::unique_ptr<Graph> f = GetFuncBody(flr0_, "XTimesTwo", {{"T", DT_FLOAT}});
1580   {
1581     Scope s = Scope::NewRootScope();
1582     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1583     auto two = ops::Const(s.WithOpName("two"), int64_t{2});
1584     auto scale = ops::Cast(s.WithOpName("scale"), two, DT_FLOAT);
1585     auto y = ops::Mul(s.WithOpName("y"), x, scale);
1586     auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
1587     GraphDef expected;
1588     TF_ASSERT_OK(s.ToGraphDef(&expected));
1589 
1590     GraphDef actual;
1591     f->ToGraphDef(&actual);
1592     TF_EXPECT_GRAPH_EQ(expected, actual);
1593   }
1594 
1595   std::unique_ptr<Graph> g = GetGradBody(flr0_, "XTimesTwo", {{"T", DT_FLOAT}});
1596 
1597   {
1598     Scope s = Scope::NewRootScope();
1599     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1600     auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
1601     auto two = ops::Const(s.WithOpName("two"), int64_t{2});
1602     auto scale = ops::Cast(s.WithOpName("scale"), two, DT_FLOAT);
1603     auto y = ops::Mul(s.WithOpName("y"), x, scale);
1604     NameAttrList fn0;
1605     fn0.set_name("Mul");
1606     (*fn0.mutable_attr())["T"].set_type(DT_FLOAT);
1607     auto func1 = ops::SymbolicGradient(
1608         s.WithOpName("Func/_1"), std::initializer_list<Input>{x, scale, func0},
1609         {DT_FLOAT, DT_FLOAT}, fn0);
1610     NameAttrList fn1;
1611     fn1.set_name("Cast");
1612     (*fn1.mutable_attr())["SrcT"].set_type(DT_INT64);
1613     (*fn1.mutable_attr())["DstT"].set_type(DT_FLOAT);
1614     (*fn1.mutable_attr())["Truncate"].set_b(false);
1615     auto func2 = ops::SymbolicGradient(
1616         s.WithOpName("Func/_2"),
1617         std::initializer_list<Input>{two, func1.output[1]}, {DT_INT64}, fn1);
1618     auto func3 = ops::_Retval(s.WithOpName("Func/_3"), func1[0], 0);
1619     GraphDef expected;
1620     TF_ASSERT_OK(s.ToGraphDef(&expected));
1621 
1622     GraphDef actual;
1623     g->ToGraphDef(&actual);
1624     TF_EXPECT_GRAPH_EQ(expected, actual);
1625   }
1626 
1627   OptimizeGraph(flr0_, &g);
1628   {
1629     Scope s = Scope::NewRootScope();
1630     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1631     auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
1632     auto scale = ops::Const(
1633         s.WithOpName("scale/_6__cf__1")
1634             .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
1635         2.0f);
1636     auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
1637     auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
1638     auto const0 = ops::Const(
1639         s.WithOpName("Func/_1/sy/_5__cf__0")
1640             .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
1641         0, {0});
1642     auto func1_rx = ops::internal::BroadcastGradientArgs(
1643         s.WithOpName("Func/_1/rx"), func1_sx, const0);
1644     auto func1_sum_gx =
1645         ops::Sum(s.WithOpName("Func/_1/sum_gx"), func1_gx, func1_rx.r0);
1646     auto func1_dx =
1647         ops::Reshape(s.WithOpName("Func/_1/dx"), func1_sum_gx, func1_sx);
1648     auto func2 = ops::_Retval(s.WithOpName("Func/_3"), func1_dx, 0);
1649     GraphDef expected;
1650     TF_ASSERT_OK(s.ToGraphDef(&expected));
1651 
1652     GraphDef actual;
1653     g->ToGraphDef(&actual);
1654     TF_EXPECT_GRAPH_EQ(expected, actual);
1655   }
1656 }
1657 
TEST_F(FunctionLibraryRuntimeTest,Gradient_Select)1658 TEST_F(FunctionLibraryRuntimeTest, Gradient_Select) {
1659   FunctionDef my_select = FunctionDefHelper::Create(
1660       "MySelect",
1661       // Args
1662       {"condition: bool", "t: float32", "e: float32"},
1663       // Return values
1664       {"z: float32"},
1665       // Attrs
1666       {},
1667       // Nodes
1668       {
1669           {{"select0"}, "Select", {"condition", "t", "e"}, {{"T", DT_FLOAT}}},
1670           {{"select1"}, "Select", {"condition", "t", "e"}, {{"T", DT_FLOAT}}},
1671           {{"add"},
1672            "Add",
1673            {"select0:output", "select1:output"},
1674            {{"T", DT_FLOAT}}},
1675       },
1676       // Output mapping
1677       {{"z", "add:z"}});
1678   FunctionDef select_grad = FunctionDefHelper::Create(
1679       "MySelectGrad",
1680       // Args
1681       {"condition: bool", "t:float32", "e: float32", "dz: float32"},
1682       // Return values
1683       {"dt: float32"},
1684       // Attrs
1685       {},
1686       // Nodes
1687       {{
1688           {"grad"},
1689           "SymbolicGradient",
1690           {"condition", "t", "e", "dz"},
1691           {
1692               {"f", FunctionDefHelper::FunctionRef("MySelect")},
1693               {"Tin", DataTypeSlice({DT_BOOL, DT_FLOAT, DT_FLOAT, DT_FLOAT})},
1694               {"Tout", DataTypeSlice({DT_BOOL, DT_FLOAT, DT_FLOAT})},
1695           },
1696       }},
1697       // Output mapping
1698       {{"dt", "grad:output:1"}});
1699   Init({my_select, select_grad});
1700 
1701   auto condition = test::AsTensor<bool>({false});
1702   auto t = test::AsTensor<float>({13.0});
1703   auto e = test::AsTensor<float>({15.0});
1704   auto dz = test::AsTensor<float>({1.0});
1705   Tensor y;
1706   TF_EXPECT_OK(InstantiateAndRun(flr0_, "MySelectGrad", {},
1707                                  {condition, t, e, dz}, {&y}));
1708 }
1709 
TEST_F(FunctionLibraryRuntimeTest,Gradient_Add)1710 TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) {
1711   Init({});
1712   auto T = DT_FLOAT;
1713   std::unique_ptr<Graph> g = GetFuncBody(
1714       flr0_, "SymbolicGradient", {{"f", FDH::FunctionRef("Add", {{"T", T}})}});
1715   {
1716     Scope s = Scope::NewRootScope();
1717     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1718     auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
1719     auto dz = ops::_Arg(s.WithOpName("dz"), DT_FLOAT, 2);
1720     auto gx = ops::Identity(s.WithOpName("gx"), dz);
1721     auto gy = ops::Identity(s.WithOpName("gy"), dz);
1722     auto sx = ops::Shape(s.WithOpName("sx"), x);
1723     auto sy = ops::Shape(s.WithOpName("sy"), y);
1724     auto rx = ops::internal::BroadcastGradientArgs(s.WithOpName("rx"), sx, sy);
1725     auto sum_gx = ops::Sum(s.WithOpName("sum_gx"), gx, rx.r0);
1726     auto sum_gy = ops::Sum(s.WithOpName("sum_gy"), gy, rx.r1);
1727     auto dx = ops::Reshape(s.WithOpName("dx"), sum_gx, sx);
1728     auto dy = ops::Reshape(s.WithOpName("dy"), sum_gy, sy);
1729     auto dx_ret = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
1730     auto dy_ret = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
1731     GraphDef expected;
1732     TF_ASSERT_OK(s.ToGraphDef(&expected));
1733 
1734     GraphDef actual;
1735     g->ToGraphDef(&actual);
1736     TF_EXPECT_GRAPH_EQ(expected, actual);
1737   }
1738 }
1739 
TEST_F(FunctionLibraryRuntimeTest,Gradient_Mul)1740 TEST_F(FunctionLibraryRuntimeTest, Gradient_Mul) {
1741   Init({});
1742   auto T = DT_FLOAT;
1743   std::unique_ptr<Graph> g = GetFuncBody(
1744       flr0_, "SymbolicGradient", {{"f", FDH::FunctionRef("Mul", {{"T", T}})}});
1745   {
1746     Scope s = Scope::NewRootScope();
1747     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1748     auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
1749     auto dz = ops::_Arg(s.WithOpName("dz"), DT_FLOAT, 2);
1750     auto gx = ops::Mul(s.WithOpName("gx"), dz, y);
1751     auto sx = ops::Shape(s.WithOpName("sx"), x);
1752     auto gy = ops::Mul(s.WithOpName("gy"), x, dz);
1753     auto sy = ops::Shape(s.WithOpName("sy"), y);
1754     auto rx = ops::internal::BroadcastGradientArgs(s.WithOpName("rx"), sx, sy);
1755     auto sum_gx = ops::Sum(s.WithOpName("sum_gx"), gx, rx.r0);
1756     auto sum_gy = ops::Sum(s.WithOpName("sum_gy"), gy, rx.r1);
1757     auto dx = ops::Reshape(s.WithOpName("dx"), sum_gx, sx);
1758     auto dy = ops::Reshape(s.WithOpName("dy"), sum_gy, sy);
1759     auto dx_ret = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
1760     auto dy_ret = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
1761     GraphDef expected;
1762     TF_ASSERT_OK(s.ToGraphDef(&expected));
1763 
1764     GraphDef actual;
1765     g->ToGraphDef(&actual);
1766     TF_EXPECT_GRAPH_EQ(expected, actual);
1767   }
1768 }
1769 
TEST_F(FunctionLibraryRuntimeTest,Gradient_AddSum)1770 TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
1771   // Sum(Add(x, y))
1772   auto T = DT_FLOAT;
1773   auto test = FDH::Define("Test", {"x:float", "y:float"}, {"l:float"}, {},
1774                           {
1775                               {{"z"}, "Add", {"x", "y"}, {{"T", T}}},
1776                               FDH::Const("zero", 0),
1777                               FDH::Const("one", 1),
1778                               {{"r"}, "Rank", {"z"}, {{"T", T}}},
1779                               {{"indices"}, "Range", {"zero", "r", "one"}},
1780                               {{"l"}, "Sum", {"z", "indices"}, {{"T", T}}},
1781                           });
1782 
1783   // TestGrad = Test'(x, y)
1784   auto grad = FDH::Define("TestGrad", {"x:float", "y:float"},
1785                           {"dx:float", "dy:float"}, {},
1786                           {FDH::Const<float>("dz", 1),
1787                            {{"grad0", "grad1"},
1788                             "SymbolicGradient",
1789                             {"x", "y", "dz"},
1790                             {
1791                                 {"f", FDH::FunctionRef("Test")},
1792                                 {"Tin", DataTypeSlice{T, T, T}},
1793                                 {"Tout", DataTypeSlice{T, T}},
1794                             }},
1795                            {{"dx"}, "Identity", {"grad0"}, {{"T", DT_FLOAT}}},
1796                            {{"dy"}, "Identity", {"grad1"}, {{"T", DT_FLOAT}}}});
1797 
1798   Init({test, grad});
1799 
1800   std::unique_ptr<Graph> g = GetFuncBody(flr0_, "TestGrad", {});
1801   ASSERT_TRUE(g != nullptr);
1802   {
1803     Scope s = Scope::NewRootScope();
1804     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1805     auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
1806     auto dz = ops::Const(s.WithOpName("dz"), 1.0f);
1807     NameAttrList fn;
1808     fn.set_name("Test");
1809     auto grad0 = ops::SymbolicGradient(s.WithOpName("grad0"),
1810                                        std::initializer_list<Input>{x, y, dz},
1811                                        {DT_FLOAT, DT_FLOAT}, fn);
1812     auto dx = ops::Identity(s.WithOpName("dx"), grad0[0]);
1813     auto dy = ops::Identity(s.WithOpName("dy"), grad0[1]);
1814     auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
1815     auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
1816     GraphDef expected;
1817     TF_ASSERT_OK(s.ToGraphDef(&expected));
1818 
1819     GraphDef actual;
1820     g->ToGraphDef(&actual);
1821     TF_EXPECT_GRAPH_EQ(expected, actual);
1822   }
1823 
1824   ExpandInlineFunctions(flr0_, g.get());
1825   {
1826     Scope s = Scope::NewRootScope();
1827     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1828     auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
1829     auto dz = ops::Const(s.WithOpName("dz"), 1.0f);
1830     auto grad0_zero = ops::Const(s.WithOpName("grad0/zero"), 0);
1831     auto grad0_one = ops::Const(s.WithOpName("grad0/one"), 1);
1832     auto func0 = ops::Identity(s.WithOpName("Func/grad0/input/_0"), x);
1833     auto func1 = ops::Identity(s.WithOpName("Func/grad0/input/_1"), y);
1834     auto func2 = ops::Identity(s.WithOpName("Func/grad0/input/_2"), dz);
1835     auto grad0_z = ops::Add(s.WithOpName("grad0/z"), func0, func1);
1836     auto grad0_r = ops::Rank(s.WithOpName("grad0/r"), grad0_z);
1837     auto grad0_indices = ops::Range(s.WithOpName("grad0/indices"), grad0_zero,
1838                                     grad0_r, grad0_one);
1839     auto grad0_l = ops::Sum(s.WithOpName("grad0/l"), grad0_z, grad0_indices);
1840 
1841     NameAttrList sum;
1842     sum.set_name("Sum");
1843     (*sum.mutable_attr())["T"].set_type(DT_FLOAT);
1844     (*sum.mutable_attr())["Tidx"].set_type(DT_INT32);
1845     (*sum.mutable_attr())["keep_dims"].set_b(false);
1846     auto grad0_func1 = ops::SymbolicGradient(
1847         s.WithOpName("grad0/Func/_1"),
1848         std::initializer_list<Input>{grad0_z, grad0_indices, func2},
1849         {DT_FLOAT, DT_INT32}, sum);
1850 
1851     auto grad0_func2 =
1852         ops::ZerosLike(s.WithOpName("grad0/Func/_2"), grad0_zero);
1853     auto grad0_func3 = ops::ZerosLike(s.WithOpName("grad0/Func/_3"), grad0_r);
1854     auto grad0_func4 = ops::ZerosLike(s.WithOpName("grad0/Func/_4"), grad0_one);
1855 
1856     NameAttrList add;
1857     add.set_name("Add");
1858     (*add.mutable_attr())["T"].set_type(DT_FLOAT);
1859     auto grad0_func5 = ops::SymbolicGradient(
1860         s.WithOpName("grad0/Func/_5"),
1861         std::initializer_list<Input>{func0, func1, grad0_func1[0]},
1862         {DT_FLOAT, DT_FLOAT}, add);
1863 
1864     auto func3 =
1865         ops::Identity(s.WithOpName("Func/grad0/output/_3"), grad0_func5[0]);
1866     auto func4 =
1867         ops::Identity(s.WithOpName("Func/grad0/output/_4"), grad0_func5[1]);
1868     auto dx = ops::Identity(s.WithOpName("dx"), func3);
1869     auto dy = ops::Identity(s.WithOpName("dy"), func4);
1870     auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
1871     auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
1872 
1873     GraphDef expected;
1874     TF_ASSERT_OK(s.ToGraphDef(&expected));
1875 
1876     GraphDef actual;
1877     g->ToGraphDef(&actual);
1878     TF_EXPECT_GRAPH_EQ(expected, actual);
1879   }
1880 
1881   OptimizeGraph(flr0_, &g);
1882   {
1883     Scope s = Scope::NewRootScope();
1884     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1885     auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
1886     auto dz = ops::Const(s.WithOpName("dz"), 1.0f);
1887     auto grad0_zero = ops::Const(s.WithOpName("grad0/zero"), 0);
1888     auto grad0_one = ops::Const(s.WithOpName("grad0/one"), 1);
1889     auto grad0_z = ops::Add(s.WithOpName("grad0/z"), x, y);
1890     auto grad0_r = ops::Rank(s.WithOpName("grad0/r"), grad0_z);
1891     auto grad0_indices = ops::Range(s.WithOpName("grad0/indices"), grad0_zero,
1892                                     grad0_r, grad0_one);
1893     auto i_shape =
1894         ops::Shape(s.WithOpName("grad0/Func/_1/i_shape"), grad0_indices);
1895     auto stitch_val = ops::Fill(s.WithOpName("grad0/Func/_1/stitch_val1"),
1896                                 i_shape, grad0_one);
1897     auto x_shape = ops::Shape(s.WithOpName("grad0/Func/_1/x_shape"), grad0_z);
1898     auto y_shape = ops::DynamicStitch(
1899         s.WithOpName("grad0/Func/_1/y_shape"),
1900         std::initializer_list<Input>{grad0_indices, grad0_indices},
1901         std::initializer_list<Input>{x_shape, stitch_val});
1902     auto dy_reshaped =
1903         ops::Reshape(s.WithOpName("grad0/Func/_1/dy_reshaped"), dz, y_shape);
1904     auto tile_scaling =
1905         ops::Div(s.WithOpName("grad0/Func/_1/tile_scaling"), x_shape, y_shape);
1906     auto func1_dx =
1907         ops::Tile(s.WithOpName("grad0/Func/_1/dx"), dy_reshaped, tile_scaling);
1908 
1909     auto sx = ops::Shape(s.WithOpName("grad0/Func/_3/sx"), x);
1910     auto sy = ops::Shape(s.WithOpName("grad0/Func/_3/sy"), y);
1911     auto rx = ops::internal::BroadcastGradientArgs(
1912         s.WithOpName("grad0/Func/_3/rx"), sx, sy);
1913     auto sum_gx =
1914         ops::Sum(s.WithOpName("grad0/Func/_3/sum_gx"), func1_dx, rx.r0);
1915     auto sum_gy =
1916         ops::Sum(s.WithOpName("grad0/Func/_3/sum_gy"), func1_dx, rx.r1);
1917     auto dx = ops::Reshape(s.WithOpName("grad0/Func/_3/dx"), sum_gx, sx);
1918     auto dy = ops::Reshape(s.WithOpName("grad0/Func/_3/dy"), sum_gy, sy);
1919 
1920     auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
1921     auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
1922 
1923     GraphDef expected;
1924     TF_ASSERT_OK(s.ToGraphDef(&expected));
1925 
1926     GraphDef actual;
1927     g->ToGraphDef(&actual);
1928     // The optimizer is non-deterministic, so we only check that the number of
1929     // nodes is not greater than expected.
1930     EXPECT_LE(actual.node_size(), expected.node_size());
1931   }
1932 }
1933 
TEST_F(FunctionLibraryRuntimeTest,CrossDevice)1934 TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
1935   Init({test::function::FindDevice()});
1936   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
1937   instantiate_opts.target = "/device:CPU:1";
1938   FunctionLibraryRuntime::Handle handle;
1939   TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {}, instantiate_opts, &handle));
1940 
1941   Tensor y;
1942   FunctionLibraryRuntime::Options opts;
1943   PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
1944   opts.rendezvous = &rendezvous;
1945   opts.source_device = "/device:CPU:1";
1946   // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
1947   TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}));
1948   test::ExpectTensorEqual<tstring>(
1949       y,
1950       test::AsTensor<tstring>({"/job:localhost/replica:0/task:0/device:CPU:1"},
1951                               TensorShape({})));
1952   opts.remote_execution = true;
1953   opts.source_device = "/job:localhost/replica:0/task:0/cpu:2";
1954   TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}));
1955   test::ExpectTensorEqual<tstring>(
1956       y,
1957       test::AsTensor<tstring>({"/job:localhost/replica:0/task:0/device:CPU:1"},
1958                               TensorShape({})));
1959 }
1960 
1961 class AreAllKernelsInlineOp : public OpKernel {
1962  public:
1963   using OpKernel::OpKernel;
1964 
Compute(OpKernelContext * ctx)1965   void Compute(OpKernelContext* ctx) override {
1966     Tensor* output;
1967     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
1968     output->scalar<bool>()() = ctx->run_all_kernels_inline();
1969   }
1970 };
1971 
1972 REGISTER_OP("AreAllKernelsInline").Output("result : bool").SetIsStateful();
1973 REGISTER_KERNEL_BUILDER(Name("AreAllKernelsInline").Device(DEVICE_CPU),
1974                         AreAllKernelsInlineOp);
1975 
TEST_F(FunctionLibraryRuntimeTest,RunAllKernelsInline)1976 TEST_F(FunctionLibraryRuntimeTest, RunAllKernelsInline) {
1977   // Create a function "F" that includes an AreAllKernelsInline op, and a
1978   // function "G" that calls "F".
1979   auto f = FDH::Create(
1980       // Name
1981       "F",
1982       // Args
1983       {},
1984       // Return values
1985       {"ret: bool"},
1986       // Attrs
1987       {},
1988       // Nodes
1989       {// y = AreAllKernelsInline()
1990        {{"y"}, "AreAllKernelsInline", {}, {}}},
1991       {{"ret", "y:result:0"}});
1992 
1993   auto g = FDH::Create(
1994       // Name
1995       "G",
1996       // Args
1997       {},
1998       // Return values
1999       {"ret: bool"},
2000       // Attrs
2001       {},
2002       // Nodes
2003       {// y = F()
2004        {{"y"}, "F", {}, {}}},
2005       {{"ret", "y:ret:0"}});
2006 
2007   Init({f, g});
2008   FunctionLibraryRuntime::Handle handle;
2009   TF_CHECK_OK(Instantiate(flr0_, "G", {}, &handle));
2010 
2011   // Test that the `run_all_kernels_inline` flag is inherited by the kernel
2012   // running inside the called function.
2013   for (bool inline_option : {false, true}) {
2014     FunctionLibraryRuntime::Options opts;
2015     opts.run_all_kernels_inline = inline_option;
2016     Tensor result;
2017     TF_ASSERT_OK(Run(flr0_, handle, opts, {}, {&result}));
2018     EXPECT_EQ(result.scalar<bool>()(), inline_option);
2019   }
2020 }
2021 
2022 class UserIntraOpThreadPoolOp : public OpKernel {
2023  public:
2024   using OpKernel::OpKernel;
2025 
2026   class DummyThreadPool : public thread::ThreadPoolInterface {
2027    public:
Schedule(std::function<void ()> fn)2028     void Schedule(std::function<void()> fn) override { fn(); }
NumThreads() const2029     int NumThreads() const override { return 1; }
CurrentThreadId() const2030     int CurrentThreadId() const override { return -1; }
2031   };
2032 
dummy_thread_pool()2033   static DummyThreadPool& dummy_thread_pool() {
2034     static DummyThreadPool& thread_pool = *new DummyThreadPool();
2035     return thread_pool;
2036   }
2037 
Compute(OpKernelContext * ctx)2038   void Compute(OpKernelContext* ctx) override {
2039     Tensor* result;
2040     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &result));
2041     result->scalar<bool>()() =
2042         ctx->device()
2043             ->tensorflow_cpu_worker_threads()
2044             ->workers->AsEigenThreadPool() == &dummy_thread_pool();
2045   }
2046 };
2047 
2048 REGISTER_OP("UserIntraOpThreadPool").Output("result: bool").SetIsStateful();
2049 REGISTER_KERNEL_BUILDER(Name("UserIntraOpThreadPool").Device(DEVICE_CPU),
2050                         UserIntraOpThreadPoolOp);
2051 
TEST_F(FunctionLibraryRuntimeTest,RunUserIntraOpThreadPool)2052 TEST_F(FunctionLibraryRuntimeTest, RunUserIntraOpThreadPool) {
2053   // Create a function "F" that includes an AreAllKernelsInline op, and a
2054   // function "G" that calls "F".
2055   auto f = FDH::Create(
2056       // Name
2057       "F",
2058       // Args
2059       {},
2060       // Return values
2061       {"ret: bool"},
2062       // Attrs
2063       {},
2064       // Nodes
2065       {// y = UserIntraOpThreadPool()
2066        {{"y"}, "UserIntraOpThreadPool", {}, {}}},
2067       {{"ret", "y:result:0"}});
2068 
2069   Init({f});
2070   FunctionLibraryRuntime::Handle handle;
2071   TF_CHECK_OK(Instantiate(flr0_, "F", {}, &handle));
2072 
2073   FunctionLibraryRuntime::Options opts;
2074   opts.user_intra_op_threadpool = &UserIntraOpThreadPoolOp::dummy_thread_pool();
2075 
2076   Tensor result;
2077   TF_ASSERT_OK(Run(flr0_, handle, opts, {}, {&result}));
2078   EXPECT_TRUE(result.scalar<bool>()());
2079 }
2080 
2081 namespace {
2082 
DoNothing(Graph * g)2083 bool DoNothing(Graph* g) { return false; }
2084 
Optimize(const std::function<bool (Graph * g)> & pass,const FunctionDef & fdef)2085 GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
2086                   const FunctionDef& fdef) {
2087   InstantiationResult result;
2088   TF_CHECK_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
2089   std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
2090   GraphConstructorOptions opts;
2091   opts.allow_internal_ops = true;
2092   opts.expect_device_spec = false;
2093   TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get()));
2094   pass(g.get());
2095   std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global()));
2096   CopyGraph(*g, g1.get());
2097   g = nullptr;
2098   GraphDef gdef;
2099   g1->ToGraphDef(&gdef);
2100   return gdef;
2101 }
2102 
2103 }  // end namespace
2104 
TEST(OptimizationTest,RemoveDeadNodes)2105 TEST(OptimizationTest, RemoveDeadNodes) {
2106   auto T = DT_INT32;
2107   auto func = FDH::Define(
2108       // Name
2109       "F",
2110       // Args
2111       {"x: int32"},
2112       // Return values
2113       {"y: int32"},
2114       // Attrs
2115       {},
2116       // Nodes
2117       {// a = Square<T>(x)
2118        {{"a"}, "Square", {"x"}, {{"T", T}}},
2119        // 1
2120        FDH::Const("o", 1),
2121        // A bunch of extra arithmetic that y doesn't depend on
2122        {{"x1"}, "Add", {"o", "o"}, {{"T", T}}},
2123        {{"x2"}, "Mul", {"a", "x1"}, {{"T", T}}},
2124        {{"x3"}, "Mul", {"x1", "x2"}, {{"T", T}}},
2125        // A stateful node.
2126        {{"keep_me"}, "RandomUniform", {"o"}, {{"T", T}, {"dtype", DT_FLOAT}}},
2127        // y = Add<T>(a, o)
2128        {{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
2129 
2130   GraphDef expected;
2131   {
2132     Scope s = Scope::DisabledShapeInferenceScope();
2133     auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
2134     auto o = ops::Const(s.WithOpName("o"), 1);
2135     auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT);
2136     auto x1 = ops::Add(s.WithOpName("x1"), o, o);
2137     auto a = ops::Square(s.WithOpName("a"), x);
2138     auto y = ops::Add(s.WithOpName("y"), a, o);
2139     auto x2 = ops::Mul(s.WithOpName("x2"), a, x1);
2140     auto x3 = ops::Mul(s.WithOpName("x3"), x1, x2);
2141     auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
2142     TF_ASSERT_OK(s.ToGraphDef(&expected));
2143   }
2144   TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
2145 
2146   // TODO(zhifengc): Comes up another test case.
2147   TF_EXPECT_GRAPH_EQ(expected, Optimize(::tensorflow::RemoveDeadNodes, func));
2148 }
2149 
TEST(OptimizationTest,RemoveIdentityNodes_Ref)2150 TEST(OptimizationTest, RemoveIdentityNodes_Ref) {
2151   auto T = DT_FLOAT;
2152   auto func = FDH::Define(
2153       // Name
2154       "F",
2155       // Args
2156       {},
2157       // Return values
2158       {"ret: float"},
2159       // Attrs
2160       {},
2161       // Nodes
2162       {// variable
2163        {{"v"}, "VariableV2", {}, {{"dtype", T}, {"shape", TensorShape({})}}},
2164        // read the variable. Shouldn't be removed.
2165        {{"v_read"}, "Identity", {"v"}, {{"T", T}}},
2166        // returns v + v
2167        {{"ret"}, "Add", {"v_read", "v_read"}, {{"T", T}}}});
2168 
2169   GraphDef expected;
2170   {
2171     Scope s = Scope::NewRootScope();
2172     auto v = ops::Variable(s.WithOpName("v"), PartialTensorShape({}), DT_FLOAT);
2173     auto v_read = ops::Identity(s.WithOpName("v_read"), v);
2174     auto ret = ops::Add(s.WithOpName("ret"), v_read, v_read);
2175     auto ret_retval = ops::_Retval(s.WithOpName("ret_RetVal"), ret, 0);
2176     TF_ASSERT_OK(s.ToGraphDef(&expected));
2177   }
2178   TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
2179   TF_EXPECT_GRAPH_EQ(expected,
2180                      Optimize(::tensorflow::RemoveIdentityNodes, func));
2181 }
2182 
TEST(OptimizationTest,RemoveIdentityNodes)2183 TEST(OptimizationTest, RemoveIdentityNodes) {
2184   auto T = DT_INT32;
2185   auto func = FDH::Define(
2186       // Name
2187       "F",
2188       // Args
2189       {"x: int32"},
2190       // Return values
2191       {"y: int32"},
2192       // Attrs
2193       {},
2194       // Nodes
2195       {// a = Square<T>(x)
2196        {{"a"}, "Square", {"x"}, {{"T", T}}},
2197        // 1
2198        FDH::Const("o", 1),
2199        // A bunch of extra arithmetic that y doesn't depend on
2200        {{"x1"}, "Identity", {"a"}, {{"T", T}}},
2201        {{"x2"}, "Identity", {"x1"}, {{"T", T}}},
2202        {{"x3"}, "Identity", {"x2"}, {{"T", T}}},
2203        // A stateful node.
2204        {{"keep_me"},
2205         "RandomUniform",
2206         {"o"},
2207         {{"T", T}, {"dtype", DT_FLOAT}},
2208         {"x3"}},
2209        // y = Add<T>(a, o)
2210        {{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
2211 
2212   {
2213     Scope s = Scope::DisabledShapeInferenceScope();
2214     auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
2215     auto o = ops::Const(s.WithOpName("o"), 1);
2216     auto a = ops::Square(s.WithOpName("a"), x);
2217     auto y = ops::Add(s.WithOpName("y"), a, o);
2218     auto x1 = ops::Identity(s.WithOpName("x1"), a);
2219     auto x2 = ops::Identity(s.WithOpName("x2"), x1);
2220     auto x3 = ops::Identity(s.WithOpName("x3"), x2);
2221     auto keep_me = ops::RandomUniform(
2222         s.WithOpName("keep_me").WithControlDependencies(x3), {o}, DT_FLOAT);
2223     auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
2224     GraphDef expected;
2225     TF_ASSERT_OK(s.ToGraphDef(&expected));
2226     TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
2227   }
2228 
2229   {
2230     Scope s = Scope::DisabledShapeInferenceScope();
2231     auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
2232     auto o = ops::Const(s.WithOpName("o"), 1);
2233     auto a = ops::Square(s.WithOpName("a"), x);
2234     auto y = ops::Add(s.WithOpName("y"), a, o);
2235     auto keep_me = ops::RandomUniform(
2236         s.WithOpName("keep_me").WithControlDependencies(a), {o}, DT_FLOAT);
2237     auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
2238     GraphDef expected;
2239     TF_ASSERT_OK(s.ToGraphDef(&expected));
2240     TF_EXPECT_GRAPH_EQ(expected,
2241                        Optimize(::tensorflow::RemoveIdentityNodes, func));
2242   }
2243 }
2244 
TEST(OptimizationTest,RemoveListArrayConverter)2245 TEST(OptimizationTest, RemoveListArrayConverter) {
2246   auto func = FDH::Create(
2247       // Name
2248       "Test",
2249       // Args
2250       {"i: float"},
2251       // Return signature
2252       {"o: float"},
2253       // Attrs
2254       {},
2255       // Nodes
2256       {FDH::Const("zero", 0),
2257        {{"s"},
2258         "Split",
2259         {"zero:output:0", "i"},
2260         {{"num_split", 4}, {"T", DT_FLOAT}}},
2261        {{"a"},
2262         "_ArrayToList",
2263         {"s:output"},
2264         {{"N", 4},
2265          {"T", DT_FLOAT},
2266          {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}},
2267        {{"l"}, "Mul", {"a:output:0", "a:output:1"}, {{"T", DT_FLOAT}}},
2268        {{"r"}, "Mul", {"a:output:2", "a:output:3"}, {{"T", DT_FLOAT}}},
2269        {{"x"},
2270         "_ListToArray",
2271         {"l:z", "r:z"},
2272         {{"N", 2},
2273          {"T", DT_FLOAT},
2274          {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
2275        {{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}},
2276       // Return values
2277       {{"o", "o:sum"}});
2278 
2279   {
2280     Scope scope = Scope::DisabledShapeInferenceScope();
2281     auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
2282     auto zero = ops::Const(scope.WithOpName("zero"), 0);
2283     auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
2284     auto a = ops::_ArrayToList(scope.WithOpName("a"), s.output,
2285                                {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT});
2286     auto r = ops::Mul(scope.WithOpName("r"), a[2], a[3]);
2287     auto l = ops::Mul(scope.WithOpName("l"), a[0], a[1]);
2288     auto x = ops::_ListToArray(scope.WithOpName("x"),
2289                                std::initializer_list<Input>{l, r}, DT_FLOAT, 2);
2290     auto o = ops::AddN(scope.WithOpName("o"), x.output);
2291     auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0);
2292     GraphDef expected;
2293     TF_ASSERT_OK(scope.ToGraphDef(&expected));
2294     TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
2295   }
2296 
2297   {
2298     Scope scope = Scope::NewRootScope();
2299     auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
2300     auto zero = ops::Const(scope.WithOpName("zero"), 0);
2301     auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
2302     auto func_0 = ops::Identity(scope.WithOpName("Func/a/input/_0"), s[0]);
2303     auto func_1 = ops::Identity(scope.WithOpName("Func/a/input/_1"), s[1]);
2304     auto func_2 = ops::Identity(scope.WithOpName("Func/a/input/_2"), s[2]);
2305     auto func_3 = ops::Identity(scope.WithOpName("Func/a/input/_3"), s[3]);
2306     auto r = ops::Mul(scope.WithOpName("r"), func_2, func_3);
2307     auto l = ops::Mul(scope.WithOpName("l"), func_0, func_1);
2308     auto func_4 = ops::Identity(scope.WithOpName("Func/x/input/_4"), l);
2309     auto func_5 = ops::Identity(scope.WithOpName("Func/x/input/_5"), r);
2310     auto o = ops::AddN(scope.WithOpName("o"),
2311                        std::initializer_list<Input>{func_4, func_5});
2312     auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0);
2313     GraphDef expected;
2314     TF_ASSERT_OK(scope.ToGraphDef(&expected));
2315     TF_EXPECT_GRAPH_EQ(expected, Optimize(RemoveListArrayConverter, func));
2316   }
2317 
2318   {
2319     Scope scope = Scope::NewRootScope();
2320     auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
2321     auto zero = ops::Const(scope.WithOpName("zero"), 0);
2322     auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
2323     auto r = ops::Mul(scope.WithOpName("r"), s[2], s[3]);
2324     auto l = ops::Mul(scope.WithOpName("l"), s[0], s[1]);
2325     auto o =
2326         ops::AddN(scope.WithOpName("o"), std::initializer_list<Input>{l, r});
2327     auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0);
2328     GraphDef expected;
2329     TF_ASSERT_OK(scope.ToGraphDef(&expected));
2330 
2331     auto remove_listarray_and_identity = [](Graph* g) {
2332       return RemoveListArrayConverter(g) && RemoveIdentityNodes(g);
2333     };
2334     TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
2335   }
2336 }
2337 
TEST(OptimizationTest,RemoveListArrayConverter_WithControlDeps)2338 TEST(OptimizationTest, RemoveListArrayConverter_WithControlDeps) {
2339   auto func = FDH::Create(
2340       // Name
2341       "Test",
2342       // Args
2343       {"i: float"},
2344       // Return values
2345       {"o: float"},
2346       // Attrs
2347       {},
2348       // Nodes
2349       {FDH::Const("dummy", 0),
2350        {{"x"},
2351         "_ListToArray",
2352         {"i", "i"},
2353         {{"N", 2}, {"T", DT_FLOAT}, {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}},
2354         // Control dep
2355         {"dummy"}},
2356        {{"o"},
2357         "AddN",
2358         {"x:output"},
2359         {{"N", 2}, {"T", DT_FLOAT}},
2360         // Control dep
2361         {"x"}}},
2362       {{"o", "o:sum"}});
2363 
2364   {
2365     Scope s = Scope::DisabledShapeInferenceScope();
2366     auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
2367     auto dummy = ops::Const(s.WithOpName("dummy"), 0);
2368     auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy),
2369                                std::initializer_list<Input>{i, i}, DT_FLOAT, 2);
2370     auto o =
2371         ops::AddN(s.WithOpName("o").WithControlDependencies({x.output[0].op()}),
2372                   x.output);
2373     auto o_ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0);
2374     GraphDef expected;
2375     TF_ASSERT_OK(s.ToGraphDef(&expected));
2376     TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
2377   }
2378 
2379   GraphDef expected;
2380   {
2381     Scope s = Scope::NewRootScope();
2382     auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
2383     auto dummy = ops::Const(s.WithOpName("dummy"), 0);
2384     auto func_2 = ops::NoOp(s.WithOpName("Func/x/input_control_node/_2")
2385                                 .WithControlDependencies(dummy));
2386     auto func_0 = ops::Identity(
2387         s.WithOpName("Func/x/input/_0").WithControlDependencies({func_2}), i);
2388     auto func_1 = ops::Identity(
2389         s.WithOpName("Func/x/input/_1").WithControlDependencies({func_2}), i);
2390     auto func_3 = ops::NoOp(
2391         s.WithOpName("Func/x/output_control_node/_3")
2392             .WithControlDependencies({func_0.output.op(), func_1.output.op()}));
2393     auto o = ops::AddN(s.WithOpName("o").WithControlDependencies({func_3}),
2394                        std::initializer_list<Input>{func_0, func_1});
2395     auto o_ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0);
2396     TF_ASSERT_OK(s.ToGraphDef(&expected));
2397   }
2398   TF_EXPECT_GRAPH_EQ(expected, Optimize(RemoveListArrayConverter, func));
2399 
2400   auto remove_listarray_and_identity = [](Graph* g) {
2401     return RemoveListArrayConverter(g) && RemoveIdentityNodes(g);
2402   };
2403   // NOTE: We are not removing Identity nodes with any control
2404   // dependencies yet.
2405   TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
2406 }
2407 
2408 }  // namespace
2409 }  // namespace tensorflow
2410