1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/function.h"
17
18 #include <atomic>
19 #include <functional>
20 #include <utility>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/numbers.h"
24 #include "absl/strings/str_split.h"
25 #include "tensorflow/cc/ops/array_ops_internal.h"
26 #include "tensorflow/cc/ops/function_ops.h"
27 #include "tensorflow/cc/ops/functional_ops.h"
28 #include "tensorflow/cc/ops/sendrecv_ops.h"
29 #include "tensorflow/cc/ops/standard_ops.h"
30 #include "tensorflow/core/common_runtime/constant_folding.h"
31 #include "tensorflow/core/common_runtime/device.h"
32 #include "tensorflow/core/common_runtime/device_factory.h"
33 #include "tensorflow/core/common_runtime/executor.h"
34 #include "tensorflow/core/common_runtime/executor_factory.h"
35 #include "tensorflow/core/common_runtime/function_testlib.h"
36 #include "tensorflow/core/common_runtime/graph_constructor.h"
37 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
38 #include "tensorflow/core/common_runtime/step_stats_collector.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/function_testlib.h"
41 #include "tensorflow/core/framework/metrics.h"
42 #include "tensorflow/core/framework/op.h"
43 #include "tensorflow/core/framework/op_kernel.h"
44 #include "tensorflow/core/framework/op_requires.h"
45 #include "tensorflow/core/framework/tensor_testutil.h"
46 #include "tensorflow/core/framework/versions.pb.h"
47 #include "tensorflow/core/lib/core/notification.h"
48 #include "tensorflow/core/lib/core/status.h"
49 #include "tensorflow/core/lib/core/status_test_util.h"
50 #include "tensorflow/core/lib/core/threadpool.h"
51 #include "tensorflow/core/lib/strings/str_util.h"
52 #include "tensorflow/core/platform/errors.h"
53 #include "tensorflow/core/platform/test.h"
54 #include "tensorflow/core/platform/threadpool_interface.h"
55 #include "tensorflow/core/protobuf/error_codes.pb.h"
56 #include "tensorflow/core/public/session_options.h"
57 #include "tensorflow/core/public/version.h"
58 #include "tensorflow/core/util/equal_graph_def.h"
59
60 namespace tensorflow {
61 namespace {
62
63 using FDH = ::tensorflow::FunctionDefHelper;
64
65 using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
66
GetOpSig(const string & op,const OpDef ** sig)67 Status GetOpSig(const string& op, const OpDef** sig) {
68 return OpRegistry::Global()->LookUpOpDef(op, sig);
69 }
70
HasError(const Status & s,const error::Code code,StringPiece substr)71 void HasError(const Status& s, const error::Code code, StringPiece substr) {
72 EXPECT_EQ(s.code(), code) << s;
73 EXPECT_TRUE(absl::StrContains(s.error_message(), substr))
74 << s << ", expected substring " << substr;
75 }
76
77 class FunctionTest : public ::testing::Test {
78 protected:
FunctionTest()79 FunctionTest()
80 : device_(DeviceFactory::NewDevice("CPU", {},
81 "/job:localhost/replica:0/task:0")) {}
82
Create(const FunctionDef & fdef,test::function::Attrs attrs)83 void Create(const FunctionDef& fdef, test::function::Attrs attrs) {
84 exec_ = nullptr;
85 InstantiationResult result;
86 TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result));
87
88 arg_types_ = result.arg_types;
89 ret_types_ = result.ret_types;
90
91 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
92 GraphConstructorOptions opts;
93 opts.allow_internal_ops = true;
94 opts.expect_device_spec = false;
95 TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get()));
96
97 const int version = g->versions().producer();
98 LocalExecutorParams params;
99 params.device = device_.get();
100 params.create_kernel =
101 [this, version](const std::shared_ptr<const NodeProperties>& props,
102 OpKernel** kernel) {
103 return CreateNonCachedKernel(device_.get(), nullptr, props, version,
104 kernel);
105 };
106 params.delete_kernel = [](OpKernel* kernel) {
107 DeleteNonCachedKernel(kernel);
108 };
109 Executor* exec;
110 TF_CHECK_OK(NewLocalExecutor(params, *g, &exec));
111 exec_.reset(exec);
112 }
113
Run(const std::vector<Tensor> & args,std::vector<Tensor * > rets)114 void Run(const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
115 FunctionCallFrame frame(arg_types_, ret_types_);
116 TF_CHECK_OK(frame.SetArgs(args));
117 Executor::Args exec_args;
118 exec_args.call_frame = &frame;
119 exec_args.runner = test::function::FunctionTestSchedClosure;
120 TF_CHECK_OK(exec_->Run(exec_args));
121 std::vector<Tensor> computed;
122 TF_CHECK_OK(frame.GetRetvals(&computed));
123 CHECK_EQ(computed.size(), rets.size());
124 for (int i = 0; i < rets.size(); ++i) {
125 *(rets[i]) = computed[i];
126 }
127 }
128
129 std::unique_ptr<Device> device_;
130 std::unique_ptr<Executor> exec_;
131 DataTypeVector arg_types_;
132 DataTypeVector ret_types_;
133 };
134
TEST_F(FunctionTest,XTimesTwo)135 TEST_F(FunctionTest, XTimesTwo) {
136 Create(test::function::XTimesTwo(), {{"T", DT_FLOAT}});
137 auto x = test::AsTensor<float>({1, 2, 3, 4});
138 Tensor y;
139 Run({x}, {&y});
140 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
141 }
142
TEST_F(FunctionTest,WXPlusB)143 TEST_F(FunctionTest, WXPlusB) {
144 Create(test::function::WXPlusB(), {{"T", DT_FLOAT}});
145 auto w = test::AsTensor<float>({1., 2., 3., 4.}, {2, 2});
146 auto x = test::AsTensor<float>({1., 3., 2., 4.}, {2, 2});
147 auto b = test::AsTensor<float>({0.5, 2.5}, {2});
148 Tensor y;
149 Run({w, x, b}, {&y});
150 test::ExpectTensorEqual<float>(
151 y, test::AsTensor<float>({5.5, 13.5, 11.5, 27.5}, {2, 2}));
152 }
153
154 class FunctionLibraryRuntimeTest : public ::testing::Test {
155 protected:
Init(const std::vector<FunctionDef> & flib)156 void Init(const std::vector<FunctionDef>& flib) {
157 SessionOptions options;
158 auto* device_count = options.config.mutable_device_count();
159 device_count->insert({"CPU", 3});
160 std::vector<std::unique_ptr<Device>> devices;
161 TF_CHECK_OK(DeviceFactory::AddDevices(
162 options, "/job:localhost/replica:0/task:0", &devices));
163
164 FunctionDefLibrary proto;
165 for (const auto& fdef : flib) *(proto.add_function()) = fdef;
166 lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
167 OptimizerOptions opts;
168 device_mgr_ = std::make_unique<StaticDeviceMgr>(std::move(devices));
169 pflr_.reset(new ProcessFunctionLibraryRuntime(
170 device_mgr_.get(), Env::Default(), &options.config,
171 TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, /*thread_pool=*/nullptr,
172 /*parent=*/nullptr, /*session_metadata=*/nullptr,
173 Rendezvous::Factory{
174 [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) {
175 *r = new IntraProcessRendezvous(device_mgr);
176 return OkStatus();
177 }}));
178 flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
179 flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1");
180 flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2");
181 fdef_lib_ = lib_def_->ToProto();
182 }
183
Run(FunctionLibraryRuntime * flr,FunctionLibraryRuntime::Handle handle,FunctionLibraryRuntime::Options opts,const std::vector<Tensor> & args,std::vector<Tensor * > rets)184 Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
185 FunctionLibraryRuntime::Options opts,
186 const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
187 std::function<void(std::function<void()>)> runner =
188 [](std::function<void()> fn) {
189 test::function::FunctionTestSchedClosure(fn);
190 };
191 opts.runner = &runner;
192 Notification done;
193 std::vector<Tensor> out;
194 Status status;
195 flr->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
196 status = s;
197 done.Notify();
198 });
199 done.WaitForNotification();
200 if (!status.ok()) {
201 return status;
202 }
203 CHECK_EQ(rets.size(), out.size());
204 for (size_t i = 0; i < rets.size(); ++i) {
205 *rets[i] = out[i];
206 }
207 return OkStatus();
208 }
209
Instantiate(FunctionLibraryRuntime * flr,const string & name,test::function::Attrs attrs,FunctionLibraryRuntime::Handle * handle)210 Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
211 test::function::Attrs attrs,
212 FunctionLibraryRuntime::Handle* handle) {
213 return flr->Instantiate(name, attrs, handle);
214 }
215
Instantiate(FunctionLibraryRuntime * flr,const string & name,test::function::Attrs attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)216 Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
217 test::function::Attrs attrs,
218 const FunctionLibraryRuntime::InstantiateOptions& options,
219 FunctionLibraryRuntime::Handle* handle) {
220 return flr->Instantiate(name, attrs, options, handle);
221 }
222
InstantiateAndRun(FunctionLibraryRuntime * flr,const string & name,test::function::Attrs attrs,const std::vector<Tensor> & args,std::vector<Tensor * > rets)223 Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name,
224 test::function::Attrs attrs,
225 const std::vector<Tensor>& args,
226 std::vector<Tensor*> rets) {
227 return InstantiateAndRun(flr, name, attrs,
228 FunctionLibraryRuntime::InstantiateOptions(), args,
229 std::move(rets));
230 }
231
InstantiateAndRun(FunctionLibraryRuntime * flr,const string & name,test::function::Attrs attrs,const FunctionLibraryRuntime::InstantiateOptions & options,const std::vector<Tensor> & args,std::vector<Tensor * > rets)232 Status InstantiateAndRun(
233 FunctionLibraryRuntime* flr, const string& name,
234 test::function::Attrs attrs,
235 const FunctionLibraryRuntime::InstantiateOptions& options,
236 const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
237 FunctionLibraryRuntime::Handle handle;
238 Status status = flr->Instantiate(name, attrs, options, &handle);
239 if (!status.ok()) {
240 return status;
241 }
242 FunctionLibraryRuntime::Options opts;
243 status = Run(flr, handle, opts, args, rets);
244 if (!status.ok()) return status;
245
246 // Release the handle and try running again. It should not succeed.
247 status = flr->ReleaseHandle(handle);
248 if (!status.ok()) return status;
249
250 Status status2 = Run(flr, handle, opts, args, std::move(rets));
251 EXPECT_TRUE(errors::IsNotFound(status2))
252 << "Actual status: " << status2.ToString();
253 EXPECT_TRUE(absl::StrContains(status2.error_message(), "Handle"));
254 EXPECT_TRUE(absl::StrContains(status2.error_message(), "not found"));
255
256 return status;
257 }
258
Run(FunctionLibraryRuntime * flr,FunctionLibraryRuntime::Handle handle,FunctionLibraryRuntime::Options opts,CallFrameInterface * frame)259 Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
260 FunctionLibraryRuntime::Options opts, CallFrameInterface* frame) {
261 std::function<void(std::function<void()>)> runner =
262 [](std::function<void()> fn) {
263 test::function::FunctionTestSchedClosure(fn);
264 };
265 opts.runner = &runner;
266 Notification done;
267 Status status;
268 flr->Run(opts, handle, frame, [&status, &done](const Status& s) {
269 status = s;
270 done.Notify();
271 });
272 done.WaitForNotification();
273 if (!status.ok()) {
274 return status;
275 }
276
277 return OkStatus();
278 }
279
InstantiateAndRunViaCallFrameInterface(FunctionLibraryRuntime * flr,const string & name,test::function::Attrs attrs,const std::vector<Tensor> & args,std::vector<Tensor * > rets)280 Status InstantiateAndRunViaCallFrameInterface(FunctionLibraryRuntime* flr,
281 const string& name,
282 test::function::Attrs attrs,
283 const std::vector<Tensor>& args,
284 std::vector<Tensor*> rets) {
285 FunctionLibraryRuntime::Handle handle;
286 Status status = flr->Instantiate(name, attrs, &handle);
287 if (!status.ok()) {
288 return status;
289 }
290 const FunctionBody* fbody = flr->GetFunctionBody(handle);
291 FunctionCallFrame frame(fbody->arg_types, fbody->ret_types);
292 TF_RETURN_IF_ERROR(frame.SetArgs(args));
293
294 FunctionLibraryRuntime::Options opts;
295 status = Run(flr, handle, opts, &frame);
296 if (!status.ok()) return status;
297
298 std::vector<Tensor> retvals;
299 TF_RETURN_IF_ERROR(frame.GetRetvals(&retvals));
300 CHECK_EQ(rets.size(), retvals.size());
301 for (size_t i = 0; i < rets.size(); ++i) {
302 *rets[i] = retvals[i];
303 }
304
305 // Release the handle and try running again. It should not succeed.
306 status = flr->ReleaseHandle(handle);
307 if (!status.ok()) return status;
308
309 Status status2 = Run(flr, handle, opts, args, std::move(rets));
310 EXPECT_TRUE(errors::IsNotFound(status2));
311 EXPECT_TRUE(absl::StrContains(status2.error_message(), "Handle"));
312 EXPECT_TRUE(absl::StrContains(status2.error_message(), "not found"));
313
314 return status;
315 }
316
GetFuncBody(FunctionLibraryRuntime * flr,const string & name,test::function::Attrs attrs)317 std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
318 const string& name,
319 test::function::Attrs attrs) {
320 FunctionLibraryRuntime::Handle handle;
321 Status status = flr->Instantiate(name, attrs, &handle);
322 if (!status.ok()) {
323 LOG(ERROR) << status;
324 return nullptr;
325 }
326 const FunctionBody* fbody = flr->GetFunctionBody(handle);
327 CHECK_NOTNULL(fbody);
328 std::unique_ptr<Graph> ret(new Graph(lib_def_.get()));
329 CopyGraph(*fbody->graph, ret.get());
330 return ret;
331 }
332
GetGradBody(FunctionLibraryRuntime * flr,const string & func,test::function::Attrs attrs)333 std::unique_ptr<Graph> GetGradBody(FunctionLibraryRuntime* flr,
334 const string& func,
335 test::function::Attrs attrs) {
336 FunctionLibraryRuntime::Handle handle;
337 Status status = flr->Instantiate(func, attrs, &handle);
338 if (!status.ok()) {
339 LOG(ERROR) << status;
340 return nullptr;
341 }
342 const FunctionBody* fbody = flr->GetFunctionBody(handle);
343 CHECK_NOTNULL(fbody);
344 std::unique_ptr<FunctionBody> gbody(SymbolicGradient(*fbody));
345 CHECK_NOTNULL(gbody);
346 std::unique_ptr<Graph> ret(new Graph(lib_def_.get()));
347 CopyGraph(*gbody->graph, ret.get());
348 return ret;
349 }
350
351 FunctionLibraryRuntime* flr0_;
352 FunctionLibraryRuntime* flr1_;
353 FunctionLibraryRuntime* flr2_;
354 std::unique_ptr<DeviceMgr> device_mgr_;
355 std::unique_ptr<FunctionLibraryDefinition> lib_def_;
356 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
357 FunctionDefLibrary fdef_lib_;
358 };
359
TEST_F(FunctionLibraryRuntimeTest,IsStateful)360 TEST_F(FunctionLibraryRuntimeTest, IsStateful) {
361 Init({});
362 EXPECT_TRUE(flr0_->IsStateful("Variable"));
363 EXPECT_TRUE(flr0_->IsStateful("VariableV2"));
364 EXPECT_FALSE(flr0_->IsStateful("Matmul"));
365 }
366
TEST_F(FunctionLibraryRuntimeTest,XTimesTwo)367 TEST_F(FunctionLibraryRuntimeTest, XTimesTwo) {
368 Init({test::function::XTimesTwo()});
369 auto x = test::AsTensor<float>({1, 2, 3, 4});
370 Tensor y;
371 TF_CHECK_OK(
372 InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y}));
373 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
374 TF_CHECK_OK(InstantiateAndRunViaCallFrameInterface(
375 flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y}));
376 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
377 }
378
TEST_F(FunctionLibraryRuntimeTest,InstantiationStackTraceCopying)379 TEST_F(FunctionLibraryRuntimeTest, InstantiationStackTraceCopying) {
380 class DummyStackTrace : public AbstractStackTrace {
381 absl::Span<StackFrame const> ToFrames() const override { return {}; }
382
383 std::string ToString(const TracePrintingOptions& opts) const override {
384 return "DummyStackTrace";
385 }
386
387 StackFrame LastUserFrame() const override { return StackFrame{}; }
388 };
389
390 FunctionDef func = test::function::XTimesTwo();
391 Init({});
392
393 StackTracesMap stack_traces;
394 stack_traces["two"] = std::make_shared<DummyStackTrace>();
395
396 TF_CHECK_OK(lib_def_->AddFunctionDef(func, stack_traces));
397
398 FunctionLibraryRuntime::Handle handle;
399 TF_CHECK_OK(Instantiate(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {}, &handle));
400
401 const FunctionBody* func_body = flr0_->GetFunctionBody(handle);
402 for (const Node* node : func_body->graph->nodes()) {
403 if (node->name() == "two") {
404 EXPECT_EQ(node->GetStackTrace()->ToString({}), "DummyStackTrace");
405 }
406 }
407 TF_CHECK_OK(flr0_->ReleaseHandle(handle));
408 }
409
TEST_F(FunctionLibraryRuntimeTest,XTimesTwo_MultiDeviceBacked)410 TEST_F(FunctionLibraryRuntimeTest, XTimesTwo_MultiDeviceBacked) {
411 Init({test::function::XTimesTwo()});
412 auto x = test::AsTensor<float>({1, 2, 3, 4});
413 Tensor y;
414
415 FunctionLibraryRuntime::InstantiateOptions options;
416 options.is_multi_device_function = true;
417
418 TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
419 {x}, {&y}));
420 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
421 }
422
423 class ConsumeArgumentCallFrame : public CallFrameInterface {
424 public:
ConsumeArgumentCallFrame(Tensor * arg,Tensor * retval)425 ConsumeArgumentCallFrame(Tensor* arg, Tensor* retval)
426 : arg_(arg), retval_(retval) {}
427
num_args() const428 size_t num_args() const override { return 1; }
num_retvals() const429 size_t num_retvals() const override { return 1; }
430
GetArg(int index,const Tensor ** val)431 Status GetArg(int index, const Tensor** val) override {
432 LOG(FATAL) << "Should not be called.";
433 }
434
CanConsumeArg(int index) const435 bool CanConsumeArg(int index) const override { return index == 0; }
436
ConsumeArg(int index,Tensor * val)437 void ConsumeArg(int index, Tensor* val) override { *val = std::move(*arg_); }
438
SetRetval(int index,const Tensor & val)439 Status SetRetval(int index, const Tensor& val) override {
440 CHECK_EQ(index, 0);
441 *retval_ = val;
442 return OkStatus();
443 }
444
445 private:
446 Tensor* const arg_;
447 Tensor* const retval_;
448 };
449
TEST_F(FunctionLibraryRuntimeTest,XTimesTwo_ConsumeArgument_DefaultExecutor)450 TEST_F(FunctionLibraryRuntimeTest, XTimesTwo_ConsumeArgument_DefaultExecutor) {
451 Init({test::function::XTimesTwo()});
452 auto default_executor = metrics::TestDelta("flr_executor", "default");
453 auto single_threaded = metrics::TestDelta("flr_executor", "single_threaded");
454 FunctionLibraryRuntime::Handle handle;
455 TF_CHECK_OK(flr0_->Instantiate(
456 "XTimesTwo", test::function::Attrs({{"T", DT_FLOAT}}), &handle));
457
458 auto x = test::AsTensor<float>({1, 2, 3, 4});
459 float* x_base_ptr = &x.flat<float>()(0);
460 Tensor y;
461 ConsumeArgumentCallFrame frame(&x, &y);
462
463 FunctionLibraryRuntime::Options opts;
464 TF_CHECK_OK(Run(flr0_, handle, opts, &frame));
465
466 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
467
468 // Expect that the buffer for `x` has been forwarded to and used as the buffer
469 // for `y`.
470 float* y_base_ptr = &y.flat<float>()(0);
471 EXPECT_EQ(x_base_ptr, y_base_ptr);
472 EXPECT_FALSE(x.IsInitialized());
473
474 TF_CHECK_OK(flr0_->ReleaseHandle(handle));
475 EXPECT_GT(default_executor.Get(), 0);
476 EXPECT_EQ(single_threaded.Get(), 0);
477 }
478
TEST_F(FunctionLibraryRuntimeTest,XTimesTwo_ConsumeArgument_SingleThreadedExecutor)479 TEST_F(FunctionLibraryRuntimeTest,
480 XTimesTwo_ConsumeArgument_SingleThreadedExecutor) {
481 Init({test::function::XTimesTwo()});
482 auto default_executor = metrics::TestDelta("flr_executor", "default");
483 auto single_threaded = metrics::TestDelta("flr_executor", "single_threaded");
484 FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
485 instantiate_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
486 FunctionLibraryRuntime::Handle handle;
487 TF_CHECK_OK(flr0_->Instantiate("XTimesTwo",
488 test::function::Attrs({{"T", DT_FLOAT}}),
489 instantiate_opts, &handle));
490
491 auto x = test::AsTensor<float>({1, 2, 3, 4});
492 float* x_base_ptr = &x.flat<float>()(0);
493 Tensor y;
494 ConsumeArgumentCallFrame frame(&x, &y);
495
496 FunctionLibraryRuntime::Options opts;
497 TF_CHECK_OK(Run(flr0_, handle, opts, &frame));
498
499 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
500
501 // Expect that the buffer for `x` has been forwarded to and used as the buffer
502 // for `y`.
503 float* y_base_ptr = &y.flat<float>()(0);
504 EXPECT_EQ(x_base_ptr, y_base_ptr);
505 EXPECT_FALSE(x.IsInitialized());
506
507 TF_CHECK_OK(flr0_->ReleaseHandle(handle));
508 EXPECT_EQ(default_executor.Get(), 0);
509 EXPECT_GT(single_threaded.Get(), 0);
510 }
511
TEST_F(FunctionLibraryRuntimeTest,XTimesN)512 TEST_F(FunctionLibraryRuntimeTest, XTimesN) {
513 Init({test::function::XTimesTwo(), test::function::XTimesFour(),
514 test::function::XTimes16()});
515 auto x = test::AsTensor<float>({1, 2, 3, 4});
516 Tensor y;
517 TF_CHECK_OK(
518 InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y}));
519 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
520 TF_CHECK_OK(
521 InstantiateAndRun(flr0_, "XTimesFour", {{"T", DT_FLOAT}}, {x}, {&y}));
522 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
523 TF_CHECK_OK(
524 InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}));
525 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64}));
526 }
527
TEST_F(FunctionLibraryRuntimeTest,XTimesNInLibDef)528 TEST_F(FunctionLibraryRuntimeTest, XTimesNInLibDef) {
529 Init({});
530 FunctionDefLibrary proto;
531 *proto.add_function() = test::function::XTimesTwo();
532 *proto.add_function() = test::function::XTimesFour();
533 *proto.add_function() = test::function::XTimes16();
534 std::unique_ptr<FunctionLibraryDefinition> lib_def(
535 new FunctionLibraryDefinition(OpRegistry::Global(), proto));
536
537 FunctionLibraryRuntime::InstantiateOptions options;
538 options.lib_def = lib_def.get();
539
540 auto x = test::AsTensor<float>({1, 2, 3, 4});
541 Tensor y;
542
543 // Ensure that the function is not installed in the base library.
544 HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
545 {} /* options */, {x}, {&y}),
546 error::NOT_FOUND, "Function XTimesTwo is not defined.");
547
548 TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
549 {x}, {&y}));
550 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
551 TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesFour", {{"T", DT_FLOAT}}, options,
552 {x}, {&y}));
553 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
554 TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, options,
555 {x}, {&y}));
556 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64}));
557
558 // Ensure that the function is still not installed in the base library.
559 HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
560 {} /* options */, {x}, {&y}),
561 error::NOT_FOUND, "Function XTimesTwo is not defined.");
562 }
563
TEST_F(FunctionLibraryRuntimeTest,XTimesNInLibDefAndDelayedInstantiation)564 TEST_F(FunctionLibraryRuntimeTest, XTimesNInLibDefAndDelayedInstantiation) {
565 using FDH = ::tensorflow::FunctionDefHelper;
566
567 Init({});
568
569 // Call XTimesFour via PartitionedCall which delays functions instantiation
570 // to the first call to Compute/ComputeAsync.
571 FunctionDef my_xt4 = FunctionDefHelper::Create(
572 "MyXTimesFour", {"x:float"}, {"z:float"}, {},
573 {{{"x_times_four"},
574 "PartitionedCall",
575 {"x"},
576 {{"Tin", DataTypeSlice({DT_FLOAT})},
577 {"Tout", DataTypeSlice({DT_FLOAT})},
578 {"f", FDH::FunctionRef("XTimesFour", {{"T", DT_FLOAT}})}}}},
579 /* Mapping between function returns and function node outputs. */
580 {{"z", "x_times_four:output:0"}});
581
582 FunctionDefLibrary lib;
583 *lib.add_function() = test::function::XTimesTwo();
584 *lib.add_function() = test::function::XTimesFour();
585 *lib.add_function() = my_xt4;
586 std::unique_ptr<FunctionLibraryDefinition> lib_def(
587 new FunctionLibraryDefinition(OpRegistry::Global(), lib));
588
589 FunctionLibraryRuntime::InstantiateOptions options;
590 options.lib_def = lib_def.get();
591
592 auto x = test::AsTensor<float>({1, 2, 3, 4});
593 Tensor y;
594
595 // When we instantiate with `options` we should get x*4.
596 TF_CHECK_OK(InstantiateAndRun(flr0_, "MyXTimesFour", {}, options, {x}, {&y}));
597 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
598
599 // Create options that override XTimesFour body with XTimesTwo body.
600 FunctionDef xt4_override = test::function::XTimesTwo();
601 xt4_override.mutable_signature()->set_name("XTimesFour");
602 FunctionDefLibrary lib_override;
603 *lib_override.add_function() = xt4_override;
604 *lib_override.add_function() = my_xt4;
605 std::unique_ptr<FunctionLibraryDefinition> lib_def_override(
606 new FunctionLibraryDefinition(OpRegistry::Global(), lib_override));
607 options.lib_def = lib_def_override.get();
608
609 // When we instantiate with `options` we should get x*2.
610 TF_CHECK_OK(InstantiateAndRun(flr0_, "MyXTimesFour", {}, options, {x}, {&y}));
611 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
612 }
613
TEST_F(FunctionLibraryRuntimeTest,StateHandle)614 TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
615 auto T = DT_INT32;
616
617 // The expected sequence of outputs from this function is [6, 4, 0, 1, ...].
618 FunctionDef stateful_func = FDH::Define(
619 // Name
620 "RandomUniformWrapper",
621 // Args
622 {},
623 // Return values
624 {"y: int32"},
625 // Attrs
626 {},
627 // Nodes
628 {FDH::Const<int32>("shape", gtl::ArraySlice<int32>({1})),
629 FDH::Const<int32>("minval", 0),
630 FDH::Const<int32>("maxval", 10),
631 // A stateful node.
632 {{"y"},
633 "RandomUniformInt",
634 {"shape", "minval", "maxval"},
635 {{"seed", 37}, {"seed2", 48}, {"Tout", T}, {"T", T}}}});
636 Init({stateful_func});
637
638 FunctionLibraryRuntime::Handle handle;
639 TF_CHECK_OK(Instantiate(flr0_, "RandomUniformWrapper", {}, &handle));
640
641 FunctionLibraryRuntime::Options opts;
642 Tensor y;
643 {
644 // Simple case: instantiating with no state_handle.
645 for (int32_t expected : {6, 4}) {
646 TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&y}));
647 test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
648 }
649 }
650
651 {
652 // Instantiating again with no state_handle should yield the same handle and
653 // the continuation of the same sequence.
654 FunctionLibraryRuntime::Handle handle_non_isolated;
655 TF_CHECK_OK(
656 Instantiate(flr0_, "RandomUniformWrapper", {}, &handle_non_isolated));
657 EXPECT_EQ(handle, handle_non_isolated);
658 for (int32_t expected : {0, 1}) {
659 TF_CHECK_OK(Run(flr0_, handle_non_isolated, opts, {}, {&y}));
660 test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
661 }
662 }
663
664 {
665 // Instantiating with a given state handle will create new state and yield
666 // the original sequence.
667 FunctionLibraryRuntime::InstantiateOptions options;
668 FunctionLibraryRuntime::Handle handle_isolated;
669 options.state_handle = "handle_1";
670 TF_CHECK_OK(Instantiate(flr0_, "RandomUniformWrapper", {}, options,
671 &handle_isolated));
672 EXPECT_NE(handle, handle_isolated);
673 for (int32_t expected : {6, 4, 0, 1}) {
674 TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
675 test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
676 }
677 }
678
679 {
680 // Instantiating with a different given state handle will create new state
681 // and yield the original sequence.
682 FunctionLibraryRuntime::InstantiateOptions options;
683 FunctionLibraryRuntime::Handle handle_isolated;
684 options.state_handle = "handle_2";
685 TF_CHECK_OK(Instantiate(flr0_, "RandomUniformWrapper", {}, options,
686 &handle_isolated));
687 EXPECT_NE(handle, handle_isolated);
688 for (int32_t expected : {6, 4, 0, 1}) {
689 TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
690 test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
691 }
692 }
693
694 {
695 // Reinstantiating after releasing a handle will yield the original sequence
696 // multiple times.
697 FunctionLibraryRuntime::InstantiateOptions options;
698 FunctionLibraryRuntime::Handle handle_isolated;
699 options.state_handle = "handle_3";
700
701 for (int i = 0; i < 2; ++i) {
702 TF_CHECK_OK(Instantiate(flr0_, "RandomUniformWrapper", {}, options,
703 &handle_isolated));
704 EXPECT_NE(handle, handle_isolated);
705 for (int32_t expected : {6, 4, 0, 1}) {
706 TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
707 test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
708 }
709 TF_CHECK_OK(flr0_->ReleaseHandle(handle_isolated));
710 }
711 }
712 }
713
714 namespace {
715 class DummyExecutorRegistrar {
716 public:
DummyExecutorRegistrar()717 DummyExecutorRegistrar() {
718 ExecutorFactory::Register("DUMMY", new Factory());
719 }
720
721 private:
722 class Factory : public ExecutorFactory {
NewExecutor(const LocalExecutorParams & params,const Graph & graph,std::unique_ptr<Executor> * out_executor)723 Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
724 std::unique_ptr<Executor>* out_executor) override {
725 return errors::Internal("This is a dummy.");
726 }
727 };
728 };
729 static DummyExecutorRegistrar registrar;
730 } // namespace
731
TEST_F(FunctionLibraryRuntimeTest,ExecutorFactory)732 TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) {
733 Init({test::function::XTimesTwo()});
734
735 auto x = test::AsTensor<float>({1, 2, 3, 4});
736 Tensor y;
737
738 // Test that the default executor works.
739 {
740 FunctionLibraryRuntime::InstantiateOptions options;
741 options.executor_type = "";
742 TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
743 options, {x}, {&y}));
744 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
745 }
746
747 // Test the explicit registration for the default executor.
748 {
749 FunctionLibraryRuntime::InstantiateOptions options;
750 options.executor_type = "DEFAULT";
751 TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
752 options, {x}, {&y}));
753 test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
754 }
755
756 // Test that a non-default executor factory can be invoked.
757 {
758 FunctionLibraryRuntime::InstantiateOptions options;
759 options.executor_type = "DUMMY";
760 HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
761 {x}, {&y}),
762 error::INTERNAL, "This is a dummy.");
763 }
764
765 // Test that a non-default executor factory can be invoked via an attr.
766 {
767 FunctionLibraryRuntime::InstantiateOptions options;
768 HasError(InstantiateAndRun(flr0_, "XTimesTwo",
769 {{"T", DT_FLOAT}, {"_executor", "DUMMY"}},
770 options, {x}, {&y}),
771 error::INTERNAL, "This is a dummy.");
772 }
773
774 // Test that a non-default executor factory specified via an
775 // `InstantiateOptions` supersedes the attr when both are present.
776 {
777 FunctionLibraryRuntime::InstantiateOptions options;
778 options.executor_type = "DUMMY";
779 HasError(
780 InstantiateAndRun(flr0_, "XTimesTwo",
781 {{"T", DT_FLOAT}, {"_executor", "UNKNOWN_EXECUTOR"}},
782 options, {x}, {&y}),
783 error::INTERNAL, "This is a dummy.");
784 }
785
786 // Test that non-existent executor types trigger an error.
787 {
788 FunctionLibraryRuntime::InstantiateOptions options;
789 options.executor_type = "UNKNOWN_EXECUTOR";
790 HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
791 {x}, {&y}),
792 error::NOT_FOUND,
793 "No executor factory registered for the given executor "
794 "type: UNKNOWN_EXECUTOR");
795 }
796 {
797 FunctionLibraryRuntime::InstantiateOptions options;
798 HasError(
799 InstantiateAndRun(flr0_, "XTimesTwo",
800 {{"T", DT_FLOAT}, {"_executor", "UNKNOWN_EXECUTOR"}},
801 options, {x}, {&y}),
802 error::NOT_FOUND,
803 "No executor factory registered for the given executor "
804 "type: UNKNOWN_EXECUTOR");
805 }
806 }
807
TEST_F(FunctionLibraryRuntimeTest,ExpandInlineFunctions)808 TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
809 Init({test::function::XTimesTwo(), test::function::XTimesFour(),
810 test::function::XTimes16()});
811 std::unique_ptr<Graph> g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}});
812 ASSERT_TRUE(g != nullptr);
813
814 {
815 Scope s = Scope::NewRootScope();
816 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
817 auto arg = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
818 auto a = test::function::Call(&s, "x4", "XTimesFour", {arg});
819 auto b = test::function::Call(&s, "y", "XTimesFour", {a});
820 auto ret = ops::_Retval(s.WithOpName("y_RetVal"), b, 0);
821 GraphDef expected;
822 TF_ASSERT_OK(s.ToGraphDef(&expected));
823
824 GraphDef actual;
825 g->ToGraphDef(&actual);
826 TF_EXPECT_GRAPH_EQ(expected, actual);
827 }
828
829 ExpandInlineFunctions(flr0_, g.get());
830 {
831 Scope s = Scope::NewRootScope();
832 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
833 auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
834 auto func0 = ops::Identity(s.WithOpName("Func/x4/input/_0"), x);
835 auto x4_x2 = test::function::Call(&s, "x4/x2", "XTimesTwo", {func0});
836 auto x4_y = test::function::Call(&s, "x4/y", "XTimesTwo", {x4_x2});
837 auto func1 = ops::Identity(s.WithOpName("Func/x4/output/_1"), x4_y);
838 auto func2 = ops::Identity(s.WithOpName("Func/y/input/_2"), func1);
839 auto y_x2 = test::function::Call(&s, "y/x2", "XTimesTwo", {func2});
840 auto y_y = test::function::Call(&s, "y/y", "XTimesTwo", {y_x2});
841 auto func3 = ops::Identity(s.WithOpName("Func/y/output/_3"), y_y);
842 auto ret = ops::_Retval(s.WithOpName("y_RetVal"), func3, 0);
843 GraphDef expected;
844 TF_ASSERT_OK(s.ToGraphDef(&expected));
845
846 GraphDef actual;
847 g->ToGraphDef(&actual);
848 TF_EXPECT_GRAPH_EQ(expected, actual);
849 }
850
851 ExpandInlineFunctions(flr0_, g.get());
852 GraphDef e2;
853 {
854 Scope s = Scope::NewRootScope();
855 auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
856 auto x4_x2_two = ops::Const<int64_t>(s.WithOpName("x4/x2/two"), int64_t{2});
857 auto x4_y_two = ops::Const<int64_t>(s.WithOpName("x4/y/two"), int64_t{2});
858 auto y_x2_two = ops::Const<int64_t>(s.WithOpName("y/x2/two"), int64_t{2});
859 auto y_y_two = ops::Const<int64_t>(s.WithOpName("y/y/two"), int64_t{2});
860 auto x4_x2_scale =
861 ops::Cast(s.WithOpName("x4/x2/scale"), x4_x2_two, DT_FLOAT);
862 auto x4_y_scale = ops::Cast(s.WithOpName("x4/y/scale"), x4_y_two, DT_FLOAT);
863 auto y_x2_scale = ops::Cast(s.WithOpName("y/x2/scale"), y_x2_two, DT_FLOAT);
864 auto y_y_scale = ops::Cast(s.WithOpName("y/y/scale"), y_y_two, DT_FLOAT);
865 auto func0 = ops::Identity(s.WithOpName("Func/x4/input/_0"), x);
866 auto func4 = ops::Identity(s.WithOpName("Func/x4/x2/input/_4"), func0);
867 auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), func4, x4_x2_scale);
868 auto func5 = ops::Identity(s.WithOpName("Func/x4/x2/output/_5"), x4_x2_y);
869 auto func6 = ops::Identity(s.WithOpName("Func/x4/y/input/_6"), func5);
870 auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), func6, x4_y_scale);
871 auto func7 = ops::Identity(s.WithOpName("Func/x4/y/output/_7"), x4_y_y);
872 auto func1 = ops::Identity(s.WithOpName("Func/x4/output/_1"), func7);
873 auto func2 = ops::Identity(s.WithOpName("Func/y/input/_2"), func1);
874 auto func8 = ops::Identity(s.WithOpName("Func/y/x2/input/_8"), func2);
875 auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), func8, y_x2_scale);
876 auto func9 = ops::Identity(s.WithOpName("Func/y/x2/output/_9"), y_x2_y);
877 auto func10 = ops::Identity(s.WithOpName("Func/y/y/input/_10"), func9);
878 auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), func10, y_y_scale);
879 auto func11 = ops::Identity(s.WithOpName("Func/y/y/output/_11"), y_y_y);
880 auto func3 = ops::Identity(s.WithOpName("Func/y/output/_3"), func11);
881 auto ret = ops::_Retval(s.WithOpName("y_RetVal"), func3, 0);
882 TF_ASSERT_OK(s.ToGraphDef(&e2));
883
884 GraphDef actual;
885 g->ToGraphDef(&actual);
886 TF_EXPECT_GRAPH_EQ(e2, actual);
887 }
888
889 // No further inlining.
890 ExpandInlineFunctions(flr0_, g.get());
891 {
892 GraphDef actual;
893 g->ToGraphDef(&actual);
894 TF_EXPECT_GRAPH_EQ(e2, actual);
895 }
896
897 // Get rid of redundant Identity nodes.
898 RemoveIdentityNodes(g.get());
899 {
900 Scope s = Scope::NewRootScope();
901 auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
902 auto x4_x2_two = ops::Const<int64_t>(s.WithOpName("x4/x2/two"), int64_t{2});
903 auto x4_y_two = ops::Const<int64_t>(s.WithOpName("x4/y/two"), int64_t{2});
904 auto y_x2_two = ops::Const<int64_t>(s.WithOpName("y/x2/two"), int64_t{2});
905 auto y_y_two = ops::Const<int64_t>(s.WithOpName("y/y/two"), int64_t{2});
906 auto x4_x2_scale =
907 ops::Cast(s.WithOpName("x4/x2/scale"), x4_x2_two, DT_FLOAT);
908 auto x4_y_scale = ops::Cast(s.WithOpName("x4/y/scale"), x4_y_two, DT_FLOAT);
909 auto y_x2_scale = ops::Cast(s.WithOpName("y/x2/scale"), y_x2_two, DT_FLOAT);
910 auto y_y_scale = ops::Cast(s.WithOpName("y/y/scale"), y_y_two, DT_FLOAT);
911 auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
912 auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), x4_x2_y, x4_y_scale);
913 auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), x4_y_y, y_x2_scale);
914 auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), y_x2_y, y_y_scale);
915 auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y_y_y, 0);
916 GraphDef expected;
917 TF_ASSERT_OK(s.ToGraphDef(&expected));
918
919 GraphDef actual;
920 g->ToGraphDef(&actual);
921 TF_EXPECT_GRAPH_EQ(expected, actual);
922 }
923 }
924
925 // Verifies that control dependencies on the caller are added as control
926 // dependencies on any function calls created by inlining.
TEST_F(FunctionLibraryRuntimeTest,ExpandInlineFunctionsWithInputControlEdges)927 TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithInputControlEdges) {
928 Init({test::function::XTimesTwo(), test::function::XTimesFour()});
929
930 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
931 {
932 Scope s = Scope::NewRootScope();
933 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
934 auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
935 auto c = ops::NoOp(s.WithOpName("c"));
936 auto b = test::function::Call(&s, "b", "XTimesFour", {a});
937 s.graph()->AddControlEdge(c.operation.node(), b.node());
938 auto ret = ops::_Retval(s.WithOpName("b_RetVal"), b, 0);
939 TF_ASSERT_OK(s.ToGraph(g.get()));
940 }
941
942 ExpandInlineFunctions(flr0_, g.get());
943 {
944 Scope s = Scope::NewRootScope();
945 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
946 auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
947 auto c = ops::NoOp(s.WithOpName("c"));
948 auto func0 = ops::NoOp(s.WithOpName("Func/b/input_control_node/_0")
949 .WithControlDependencies({c}));
950 auto func1 = ops::Identity(
951 s.WithOpName("Func/b/input/_1").WithControlDependencies({func0}), a);
952 auto b_x2 = test::function::Call(&s, "b/x2", "XTimesTwo", {func1});
953 s.graph()->AddControlEdge(func0.operation.node(), b_x2.node());
954 auto b_y = test::function::Call(&s, "b/y", "XTimesTwo", {b_x2});
955 s.graph()->AddControlEdge(func0.operation.node(), b_y.node());
956 auto func2 = ops::Identity(s.WithOpName("Func/b/output/_2"), b_y);
957 auto ret = ops::_Retval(s.WithOpName("b_RetVal"), func2, 0);
958 GraphDef expected;
959 TF_ASSERT_OK(s.ToGraphDef(&expected));
960
961 GraphDef actual;
962 g->ToGraphDef(&actual);
963 TF_EXPECT_GRAPH_EQ(expected, actual);
964 }
965
966 ExpandInlineFunctions(flr0_, g.get());
967 {
968 Scope s = Scope::NewRootScope();
969 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
970 auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
971 auto c = ops::NoOp(s.WithOpName("c"));
972 auto func0 = ops::NoOp(s.WithOpName("Func/b/input_control_node/_0")
973 .WithControlDependencies({c}));
974 auto func1 = ops::Identity(
975 s.WithOpName("Func/b/input/_1").WithControlDependencies({func0}), a);
976
977 auto func3 = ops::NoOp(s.WithOpName("Func/b/x2/input_control_node/_3")
978 .WithControlDependencies({func0}));
979 auto func4 = ops::Identity(
980 s.WithOpName("Func/b/x2/input/_4").WithControlDependencies({func3}),
981 func1);
982 auto b_x2_two = ops::Const(
983 s.WithOpName("b/x2/two").WithControlDependencies({func3}), int64_t{2});
984 auto b_x2_scale = ops::Cast(s.WithOpName("b/x2/scale"), b_x2_two, DT_FLOAT);
985 auto b_x2_y = ops::Mul(s.WithOpName("b/x2/y"), func4, b_x2_scale);
986 auto func5 = ops::Identity(s.WithOpName("Func/b/x2/output/_5"), b_x2_y);
987
988 auto func6 = ops::NoOp(s.WithOpName("Func/b/y/input_control_node/_6")
989 .WithControlDependencies({func0}));
990 auto func7 = ops::Identity(
991 s.WithOpName("Func/b/y/input/_7").WithControlDependencies({func6}),
992 func5);
993 auto b_y_two = ops::Const(
994 s.WithOpName("b/y/two").WithControlDependencies({func6}), int64_t{2});
995 auto b_y_scale = ops::Cast(s.WithOpName("b/y/scale"), b_y_two, DT_FLOAT);
996 auto b_y_y = ops::Mul(s.WithOpName("b/y/y"), func7, b_y_scale);
997 auto func8 = ops::Identity(s.WithOpName("Func/b/y/output/_8"), b_y_y);
998
999 auto func2 = ops::Identity(s.WithOpName("Func/b/output/_2"), func8);
1000 auto ret = ops::_Retval(s.WithOpName("b_RetVal"), func2, 0);
1001
1002 GraphDef expected;
1003 TF_ASSERT_OK(s.ToGraphDef(&expected));
1004
1005 GraphDef actual;
1006 g->ToGraphDef(&actual);
1007 TF_EXPECT_GRAPH_EQ(expected, actual);
1008 }
1009 }
1010
TEST_F(FunctionLibraryRuntimeTest,ExpandInlineFunctionsWithOutputControlEdges)1011 TEST_F(FunctionLibraryRuntimeTest,
1012 ExpandInlineFunctionsWithOutputControlEdges) {
1013 using test::function::NDef;
1014
1015 // `add` node is not required to compute regular output `o`, but it must
1016 // execute because it is in `control_ret`.
1017 const FunctionDef func =
1018 FDH::Create("AddAndMul", {"i: float"}, {"o: float"}, {},
1019 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
1020 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
1021 /*ret_def=*/{{"o", "ret:z:0"}},
1022 /*control_ret_def=*/{{"must_execute", "add"}});
1023
1024 Init({func});
1025
1026 // Construct a graph for the function call:
1027 //
1028 // a = Arg[dtype=DT_FLOAT]
1029 // b = AddAndMul(a)
1030 // c = NoOp(^b)
1031 // ret = RetVal(b, ^c)
1032 const auto init_graph = [this](std::unique_ptr<Graph>* g) -> void {
1033 *g = std::make_unique<Graph>(OpRegistry::Global());
1034
1035 Scope s = Scope::NewRootScope();
1036 TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
1037 auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
1038 auto b = test::function::Call(&s, "b", "AddAndMul", {a});
1039 auto c = ops::NoOp(s.WithOpName("c"));
1040 auto ret = ops::_Retval(s.WithOpName("ret"), b, 0);
1041 s.graph()->AddControlEdge(b.node(), c.operation.node());
1042 s.graph()->AddControlEdge(c.operation.node(), ret.operation.node());
1043 TF_ASSERT_OK(s.ToGraph(g->get()));
1044 };
1045
1046 std::unique_ptr<Graph> g;
1047 ExpandInlineFunctionsOptions opts;
1048
1049 const string input_node = "Func/b/input/_0";
1050 const string output_node = "Func/b/output/_1";
1051 const string output_control_node = "Func/b/output_control_node/_2";
1052
1053 // Use data outputs as output control source.
1054 opts.native_options.output_control_src = OutputControlSrc::kDataOutputs;
1055
1056 init_graph(&g);
1057 ExpandInlineFunctions(flr0_, g.get(), opts);
1058 {
1059 GraphDef expected = test::function::GDef(
1060 {NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}),
1061 NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}),
1062 NDef("b/add", "Add", {input_node, input_node}, {{"T", DT_FLOAT}}),
1063 NDef("b/ret", "Mul", {input_node, input_node}, {{"T", DT_FLOAT}}),
1064 NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}}),
1065 NDef(output_control_node, "NoOp", {"^Func/b/output/_1"}, {}),
1066 NDef("c", "NoOp", {"^" + output_control_node}, {}),
1067 NDef("ret", "_Retval", {output_node, "^c"},
1068 {{"T", DT_FLOAT}, {"index", 0}})},
1069 {func});
1070
1071 GraphDef actual;
1072 g->ToGraphDef(&actual);
1073 TF_EXPECT_GRAPH_EQ(expected, actual);
1074 }
1075
1076 // Use control outputs as output control source.
1077 opts.native_options.output_control_src = OutputControlSrc::kControlOutputs;
1078
1079 init_graph(&g);
1080 ExpandInlineFunctions(flr0_, g.get(), opts);
1081 {
1082 GraphDef expected = test::function::GDef(
1083 {NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}),
1084 NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}),
1085 NDef("b/add", "Add", {input_node, input_node}, {{"T", DT_FLOAT}}),
1086 NDef("b/ret", "Mul", {input_node, input_node}, {{"T", DT_FLOAT}}),
1087 NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}}),
1088 NDef(output_control_node, "NoOp", {"^b/add"}, {}),
1089 NDef("c", "NoOp", {"^" + output_control_node}, {}),
1090 NDef("ret", "_Retval", {output_node, "^c"},
1091 {{"T", DT_FLOAT}, {"index", 0}})},
1092 {func});
1093
1094 GraphDef actual;
1095 g->ToGraphDef(&actual);
1096 TF_EXPECT_GRAPH_EQ(expected, actual);
1097 }
1098 }
1099
TEST_F(FunctionLibraryRuntimeTest,ExpandInlineFunctionsAndKeepCallerNode)1100 TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndKeepCallerNode) {
1101 using test::function::NDef;
1102 using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
1103
1104 const FunctionDef func =
1105 FDH::Create("AddAndMul", {"i: float"}, {"o: float"}, {},
1106 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
1107 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
1108 /*ret_def=*/{{"o", "ret:z:0"}},
1109 /*control_ret_def=*/{{"must_execute", "add"}});
1110 Init({func});
1111
1112 // Construct a graph:
1113 // a = Arg[dtype=DT_FLOAT]
1114 // b = FunctionWithControlOutputs(a)
1115 auto construct_graph = [this](std::unique_ptr<Graph>* g) -> Status {
1116 Scope s = Scope::NewRootScope();
1117 TF_RETURN_IF_ERROR(s.graph()->AddFunctionLibrary(fdef_lib_));
1118 auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
1119 auto b = test::function::Call(&s, "b", "AddAndMul", {a});
1120 TF_RETURN_IF_ERROR(s.ToGraph(g->get()));
1121 return OkStatus();
1122 };
1123
1124 const string input_node = "Func/b/input/_0";
1125 const string output_node = "Func/b/output/_1";
1126 const string output_control_node = "Func/b/output_control_node/_2";
1127
1128 // Construct expected graph after function inlining.
1129 auto expected_graph = [&](const NodeDef& caller) -> GraphDef {
1130 return test::function::GDef(
1131 {
1132 NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}),
1133 NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}),
1134 NDef("b/add", "Add", {input_node, input_node}, {{"T", DT_FLOAT}}),
1135 NDef("b/ret", "Mul", {input_node, input_node}, {{"T", DT_FLOAT}}),
1136 NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}}),
1137 NDef(output_control_node, "NoOp", {"^b/add"}, {}),
1138 caller, // Keep node in a graph with the same name as caller node.
1139 },
1140 {func});
1141 };
1142
1143 ExpandInlineFunctionsOptions opts;
1144 opts.native_options.output_control_src = OutputControlSrc::kControlOutputs;
1145
1146 // Keep inlined function call node fetchable.
1147 {
1148 opts.native_options.keep_caller_node = KeepCallerNode::kFetchable;
1149
1150 std::unique_ptr<Graph> g = std::make_unique<Graph>(OpRegistry::Global());
1151 TF_ASSERT_OK(construct_graph(&g));
1152
1153 ExpandInlineFunctions(flr0_, g.get(), opts);
1154 GraphDef expected =
1155 expected_graph(/*caller=*/
1156 NDef("b", "IdentityN",
1157 {output_node, "^" + output_control_node},
1158 {{"T", DataTypeSlice{DT_FLOAT}}}));
1159
1160 GraphDef actual;
1161 g->ToGraphDef(&actual);
1162 TF_EXPECT_GRAPH_EQ(expected, actual);
1163 }
1164
1165 // Keep inlined function call node targetable.
1166 {
1167 opts.native_options.keep_caller_node = KeepCallerNode::kTargetable;
1168
1169 std::unique_ptr<Graph> g = std::make_unique<Graph>(OpRegistry::Global());
1170 TF_ASSERT_OK(construct_graph(&g));
1171
1172 ExpandInlineFunctions(flr0_, g.get(), opts);
1173 GraphDef expected =
1174 expected_graph(/*caller=*/
1175 NDef("b", "NoOp", {"^" + output_control_node}, {}));
1176
1177 GraphDef actual;
1178 g->ToGraphDef(&actual);
1179 TF_EXPECT_GRAPH_EQ(expected, actual);
1180 }
1181 }
1182
TEST_F(FunctionLibraryRuntimeTest,ExpandInlineFunctionsAndPlaceInlinedNodes)1183 TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndPlaceInlinedNodes) {
1184 using test::function::NDef;
1185 using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
1186
1187 const string arg_device = "/job:arg/replica:0/task:0/device:GPU";
1188 const string call_device = "/job:call/replica:0/task:1/device:GPU";
1189 const string body_device = "/job:body/replica:0/task:1/device:CPU";
1190
1191 const FunctionDef func = FDH::Create(
1192 "AddFunc", {"i: float"}, {"o: float"}, {},
1193 {{{"ret"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}, {}, body_device}},
1194 /*ret_def=*/{{"o", "ret:z:0"}});
1195 Init({func});
1196
1197 // Construct a graph:
1198 // a = Arg[dtype=DT_FLOAT, _device=arg_device]
1199 // b = AddFunc[_device=call_device](a)
1200 auto construct_graph = [&](std::unique_ptr<Graph>* g) -> Status {
1201 Scope s = Scope::NewRootScope();
1202 TF_RETURN_IF_ERROR(s.graph()->AddFunctionLibrary(fdef_lib_));
1203 auto a = ops::_Arg(s.WithOpName("a").WithDevice(arg_device), DT_FLOAT, 0);
1204 auto b = test::function::Call(&s, "b", "AddFunc", {a});
1205 TF_RETURN_IF_ERROR(s.ToGraph(g->get()));
1206 for (Node* node : (*g)->op_nodes()) {
1207 if (node->name() == "b") node->set_requested_device(call_device);
1208 }
1209 return OkStatus();
1210 };
1211
1212 const string input_node = "Func/b/input/_0";
1213 const string output_node = "Func/b/output/_1";
1214 const string output_control_node = "Func/b/output_control_node/_2";
1215
1216 // Construct expected graph after function inlining.
1217 auto expected_graph = [&](const std::vector<string>& placed) -> GraphDef {
1218 return test::function::GDef(
1219 {
1220 NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}, placed[0]),
1221 NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}, placed[1]),
1222 NDef("b/ret", "Add", {input_node, input_node}, {{"T", DT_FLOAT}},
1223 placed[2]),
1224 NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}},
1225 placed[3]),
1226 NDef(output_control_node, "NoOp", {"^" + output_node}, {},
1227 placed[4]),
1228 },
1229 {func});
1230 };
1231
1232 ExpandInlineFunctionsOptions opts;
1233 opts.native_options.keep_caller_node = KeepCallerNode::kDoNotKeep;
1234
1235 // Place only input nodes to match input device.
1236 {
1237 opts.native_options.inlined_function_body_placer =
1238 InlinedFunctionBodyPlacer::Default();
1239
1240 auto g = std::make_unique<Graph>(OpRegistry::Global());
1241 TF_ASSERT_OK(construct_graph(&g));
1242
1243 ExpandInlineFunctions(flr0_, g.get(), opts);
1244 GraphDef expected = expected_graph({/*a*/ arg_device, //
1245 /*input*/ arg_device, //
1246 /*body*/ body_device, //
1247 /*output*/ "", //
1248 /*control_output*/ ""} //
1249 );
1250
1251 GraphDef actual;
1252 g->ToGraphDef(&actual);
1253 TF_EXPECT_GRAPH_EQ(expected, actual);
1254 }
1255
1256 // Place all nodes on the call node device.
1257 {
1258 opts.native_options.inlined_function_body_placer =
1259 InlinedFunctionBodyPlacer::SingleDevice();
1260
1261 auto g = std::make_unique<Graph>(OpRegistry::Global());
1262 TF_ASSERT_OK(construct_graph(&g));
1263
1264 ExpandInlineFunctions(flr0_, g.get(), opts);
1265 GraphDef expected = expected_graph({/*a*/ arg_device, //
1266 /*input*/ call_device, //
1267 /*body*/ call_device, //
1268 /*output*/ call_device, //
1269 /*control_output*/ call_device} //
1270 );
1271
1272 GraphDef actual;
1273 g->ToGraphDef(&actual);
1274 TF_EXPECT_GRAPH_EQ(expected, actual);
1275 }
1276
1277 // Multi device function placement.
1278 {
1279 opts.native_options.inlined_function_body_placer =
1280 InlinedFunctionBodyPlacer::MultiDevice();
1281
1282 auto g = std::make_unique<Graph>(OpRegistry::Global());
1283 TF_ASSERT_OK(construct_graph(&g));
1284
1285 const string merged_device = "/job:body/replica:0/task:1/device:CPU:*";
1286
1287 ExpandInlineFunctions(flr0_, g.get(), opts);
1288 GraphDef expected = expected_graph({/*a*/ arg_device, //
1289 /*input*/ arg_device, //
1290 /*body*/ merged_device, //
1291 /*output*/ "", //
1292 /*control_output*/ call_device} //
1293 );
1294
1295 GraphDef actual;
1296 g->ToGraphDef(&actual);
1297 TF_EXPECT_GRAPH_EQ(expected, actual);
1298 }
1299 }
1300
TEST_F(FunctionLibraryRuntimeTest,PruneBody)1301 TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
1302 auto T = DT_INT32;
1303 FunctionDef stateful_func = FDH::Define(
1304 // Name
1305 "SquareAndAddOneWithStatefulNodes",
1306 // Args
1307 {"x: int32", "y: float32"},
1308 // Return values
1309 {"z: int32"},
1310 // Attrs
1311 {},
1312 // Nodes
1313 {// a = Square<T>(x)
1314 {{"a"}, "Square", {"x"}, {{"T", T}}},
1315 // 1
1316 FDH::Const("o", 1),
1317 // A bunch of extra arithmetic that y doesn't depend on
1318 {{"x1"}, "Add", {"o", "o"}, {{"T", T}}},
1319 {{"x2"}, "Mul", {"a", "x1"}, {{"T", T}}},
1320 {{"x3"}, "Mul", {"x1", "x2"}, {{"T", T}}},
1321 FDH::Const<int32>("shape", {1, 2}),
1322 // A stateful node.
1323 {{"keep_me"},
1324 "RandomUniform",
1325 {"shape"},
1326 {{"T", T}, {"dtype", DT_FLOAT}}},
1327 // z = Add<T>(a, o)
1328 {{"z"}, "Add", {"a", "o"}, {{"T", T}}}});
1329 Init({stateful_func});
1330
1331 auto x = test::AsTensor<int32>({1, 2, 3, 4});
1332 auto y = test::AsTensor<float>({1.0, 2.0, 3.0, 4.0});
1333 Tensor z;
1334
1335 FunctionLibraryRuntime::Handle handle;
1336 TF_CHECK_OK(
1337 Instantiate(flr0_, "SquareAndAddOneWithStatefulNodes", {}, &handle));
1338
1339 StepStats stats;
1340 StepStatsCollector stats_collector(&stats);
1341 FunctionLibraryRuntime::Options opts;
1342 opts.stats_collector = &stats_collector;
1343 TF_CHECK_OK(Run(flr0_, handle, opts, {x, y}, {&z}));
1344 TF_CHECK_OK(flr0_->ReleaseHandle(handle));
1345
1346 TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {},
1347 {x, y}, {&z}));
1348 test::ExpectTensorEqual<int>(z, test::AsTensor<int32>({2, 5, 10, 17}));
1349
1350 stats_collector.FinalizeAndSwap(&stats);
1351
1352 // Note that we do not expect the nodes named "y", "x1", "x2", or "x3" to
1353 // execute.
1354 std::set<string> expected_node_names(
1355 {"_SOURCE", "shape", "x", "o", "a", "keep_me", "z", "z_RetVal"});
1356 std::set<string> executed_node_names;
1357 for (const auto& node_stats : stats.dev_stats()[0].node_stats()) {
1358 executed_node_names.insert(node_stats.node_name());
1359 }
1360 EXPECT_EQ(expected_node_names, executed_node_names);
1361 }
1362
TEST_F(FunctionLibraryRuntimeTest,DoNotPruneControlOutputsFromBody)1363 TEST_F(FunctionLibraryRuntimeTest, DoNotPruneControlOutputsFromBody) {
1364 // `add` node is not required to compute regular output `o`, but it must
1365 // execute because it is in `control_ret`.
1366 const FunctionDef func =
1367 FDH::Create("FunctionWithControlOutputs", {"i: float"}, {"o: float"}, {},
1368 {
1369 {{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
1370 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}},
1371 },
1372 /*ret_def=*/{{"o", "ret:z:0"}},
1373 /*control_ret_def=*/{{"must_execute", "add"}});
1374
1375 Init({func});
1376
1377 auto x = test::AsTensor<float>({1.25});
1378 Tensor z;
1379
1380 FunctionLibraryRuntime::Handle handle;
1381 TF_CHECK_OK(Instantiate(flr1_, "FunctionWithControlOutputs", {}, &handle));
1382
1383 StepStats stats;
1384 StepStatsCollector stats_collector(&stats);
1385 FunctionLibraryRuntime::Options opts;
1386 opts.stats_collector = &stats_collector;
1387 TF_CHECK_OK(Run(flr1_, handle, opts, {x}, {&z}));
1388 TF_CHECK_OK(flr1_->ReleaseHandle(handle));
1389
1390 TF_CHECK_OK(
1391 InstantiateAndRun(flr1_, "FunctionWithControlOutputs", {}, {x}, {&z}));
1392 test::ExpectTensorEqual<float>(z, test::AsTensor<float>({1.25 * 1.25}));
1393
1394 stats_collector.FinalizeAndSwap(&stats);
1395
1396 std::set<string> expected_node_names(
1397 {"_SOURCE", "i", "add", "ret", "o_RetVal"});
1398 std::set<string> executed_node_names;
1399 for (const auto& node_stats : stats.dev_stats()[0].node_stats()) {
1400 executed_node_names.insert(node_stats.node_name());
1401 }
1402 EXPECT_EQ(expected_node_names, executed_node_names);
1403 }
1404
TEST_F(FunctionLibraryRuntimeTest,OptimizeGraph)1405 TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
1406 Init({test::function::XTimesTwo(), test::function::XTimesFour(),
1407 test::function::XTimes16()});
1408 std::unique_ptr<Graph> g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}});
1409 ASSERT_TRUE(g != nullptr);
1410 ExpandInlineFunctions(flr0_, g.get());
1411 OptimizeGraph(flr0_, &g);
1412 {
1413 Scope s = Scope::NewRootScope();
1414 auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1415 auto x4_x2_scale = ops::Const<float>(
1416 s.WithOpName("x4/x2/scale/_12__cf__0")
1417 .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
1418 2.0f);
1419 auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
1420 auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), x4_x2_y, x4_x2_scale);
1421 auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), x4_y_y, x4_x2_scale);
1422 auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), y_x2_y, x4_x2_scale);
1423 auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y_y_y, 0);
1424 GraphDef expected;
1425 TF_ASSERT_OK(s.ToGraphDef(&expected));
1426
1427 GraphDef actual;
1428 g->ToGraphDef(&actual);
1429 TF_EXPECT_GRAPH_EQ(expected, actual);
1430 }
1431 }
1432
TEST_F(FunctionLibraryRuntimeTest,ManySwapsNodeDef)1433 TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) {
1434 auto func = FDH::Create( // Creates a FunctionDef using NodeDefs
1435 // Name
1436 "ManySwapsNodeDef",
1437 // Input
1438 {"x: float", "y: float"},
1439 // Output
1440 {"o: float"},
1441 // Attr
1442 {},
1443 // Nodes
1444 {{{"a"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}},
1445 {{"b"}, "Swap", {"a:o0", "a:o1"}, {{"T", DT_FLOAT}}},
1446 {{"c"}, "Swap", {"b:o0", "b:o1"}, {{"T", DT_FLOAT}}},
1447 {{"d"}, "Swap", {"c:o0", "c:o1"}, {{"T", DT_FLOAT}}},
1448 {{"e"}, "Swap", {"d:o0", "d:o1"}, {{"T", DT_FLOAT}}},
1449 {{"f"}, "Swap", {"e:o0", "e:o1"}, {{"T", DT_FLOAT}}},
1450 {{"g"}, "Identity", {"f:o0"}, {{"T", DT_FLOAT}}}},
1451 // Return
1452 {{"o", "g:output"}});
1453 Init({test::function::Swap(), func});
1454 std::unique_ptr<Graph> g = GetFuncBody(flr0_, "ManySwapsNodeDef", {});
1455 ASSERT_TRUE(g != nullptr);
1456 OptimizeGraph(flr0_, &g);
1457 const char* e0 = R"P(
1458 (n2:float, n3:float) -> (n2:float) {
1459 }
1460 )P";
1461 EXPECT_EQ(e0, DebugString(g.get()));
1462 }
1463
TEST_F(FunctionLibraryRuntimeTest,ControlDeps)1464 TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
1465 auto func = FDH::Create(
1466 // Name
1467 "ManySwapsFirst",
1468 // Args
1469 {"x: float", "y: float"},
1470 // Return values
1471 {"o: float"},
1472 // attr def
1473 {},
1474 // Nodes
1475 //
1476 // o = x*x + y*y. Furthermore, The 1st swap depends on x2, and
1477 // y2 depends on the 2nd swap. The 2nd swap has data dependency
1478 // on the 1st swap. The optimization should maintain the control
1479 // dependencies.
1480 {{{"a0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}, {"x2"}},
1481 {{"a1"}, "Swap", {"a0:o0:0", "a0:o1:0"}, {{"T", DT_FLOAT}}},
1482 {{"x2"}, "Mul", {"x", "x"}, {{"T", DT_FLOAT}}},
1483 {{"y2"}, "Mul", {"y", "y"}, {{"T", DT_FLOAT}}, {"a1"}},
1484 {{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}},
1485 {{"o", "o:z:0"}});
1486 Init({test::function::Swap(), func});
1487 std::unique_ptr<Graph> g = GetFuncBody(flr0_, "ManySwapsFirst", {});
1488 ASSERT_TRUE(g != nullptr);
1489 OptimizeGraph(flr0_, &g);
1490
1491 // NOTE: We can remove func0, func1, func2, func9 with a control edge
1492 // n8->n5. But we don't have a pass doing that.
1493 {
1494 Scope s = Scope::NewRootScope();
1495 auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1496 auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
1497 auto x2 = ops::Mul(s.WithOpName("x2"), x, x);
1498 auto func0 = ops::NoOp(s.WithOpName("Func/a0/input_control_node/_0")
1499 .WithControlDependencies(x2));
1500 auto func1 = ops::Identity(
1501 s.WithOpName("Func/a0/input/_1").WithControlDependencies({func0}), x);
1502 auto func2 = ops::Identity(
1503 s.WithOpName("Func/a0/input/_2").WithControlDependencies({func0}), y);
1504 auto func9 = ops::NoOp(
1505 s.WithOpName("Func/a1/output_control_node/_9")
1506 .WithControlDependencies({func1.output.op(), func2.output.op()}));
1507 auto y2 =
1508 ops::Mul(s.WithOpName("y2").WithControlDependencies({func9}), y, y);
1509 auto o = ops::Add(s.WithOpName("o"), x2, y2);
1510 auto ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0);
1511 GraphDef expected;
1512 TF_ASSERT_OK(s.ToGraphDef(&expected));
1513
1514 GraphDef actual;
1515 g->ToGraphDef(&actual);
1516 TF_EXPECT_GRAPH_EQ(expected, actual);
1517 }
1518 }
1519
TEST_F(FunctionLibraryRuntimeTest,Error_NotFound)1520 TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) {
1521 Init({test::function::XTimesTwo(), test::function::XTimesFour()});
1522 auto x = test::AsTensor<float>({1, 2, 3, 4});
1523 Tensor y;
1524 HasError(InstantiateAndRun(flr0_, "Foo", {{"T", DT_FLOAT}}, {x}, {&y}),
1525 error::NOT_FOUND, "Function Foo is not defined.");
1526 }
1527
TEST_F(FunctionLibraryRuntimeTest,Error_InstantiationError)1528 TEST_F(FunctionLibraryRuntimeTest, Error_InstantiationError) {
1529 auto bad_x_times_two = FDH::Define(
1530 // Name
1531 "XTimesTwo",
1532 // Args
1533 {"x: T"},
1534 // Return values
1535 {"y: T"},
1536 // Attr def
1537 {"T: {float, double, int32, int64}"},
1538 // Nodes
1539 {
1540 {{"y"}, "Add", {"x", "x"}, {{"no_T", "$T"}}},
1541 });
1542 Init({bad_x_times_two, test::function::XTimesFour(),
1543 test::function::XTimes16()});
1544
1545 // Instantiating "XTimesTwo" should fail.
1546 FunctionLibraryRuntime::Handle handle;
1547 HasError(flr0_->Instantiate(
1548 "XTimesTwo", test::function::Attrs({{"T", DT_FLOAT}}), &handle),
1549 error::NOT_FOUND, "type attr not found");
1550
1551 // But XTimesFour and XTimes16 instantiation should succeed. Only
1552 // when they run, they fail because XTimesTwo is bad.
1553 TF_CHECK_OK(flr0_->Instantiate(
1554 "XTimesFour", test::function::Attrs({{"T", DT_FLOAT}}), &handle));
1555 TF_CHECK_OK(flr0_->Instantiate(
1556 "XTimes16", test::function::Attrs({{"T", DT_FLOAT}}), &handle));
1557
1558 auto x = test::AsTensor<float>({1, 2, 3, 4});
1559 Tensor y;
1560 HasError(InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}),
1561 error::NOT_FOUND, "type attr not found");
1562 }
1563
TEST_F(FunctionLibraryRuntimeTest,Error_BadControlFlow)1564 TEST_F(FunctionLibraryRuntimeTest, Error_BadControlFlow) {
1565 Init({test::function::InvalidControlFlow()});
1566 auto x = test::AsTensor<int32>({0});
1567 DCHECK_EQ(x.dtype(), DT_INT32);
1568 Tensor y;
1569 HasError(InstantiateAndRun(flr0_, "InvalidControlFlow", {}, {x}, {&y}),
1570 error::INVALID_ARGUMENT,
1571 "{{node add}} has inputs from different frames. The input"
1572 " {{node enter}} is in frame 'while'. The input {{node i}} is in"
1573 " frame ''.");
1574 }
1575
TEST_F(FunctionLibraryRuntimeTest,Gradient_XTimesTwo)1576 TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
1577 Init({test::function::XTimesTwo(), test::function::XTimesFour(),
1578 test::function::XTimes16()});
1579 std::unique_ptr<Graph> f = GetFuncBody(flr0_, "XTimesTwo", {{"T", DT_FLOAT}});
1580 {
1581 Scope s = Scope::NewRootScope();
1582 auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1583 auto two = ops::Const(s.WithOpName("two"), int64_t{2});
1584 auto scale = ops::Cast(s.WithOpName("scale"), two, DT_FLOAT);
1585 auto y = ops::Mul(s.WithOpName("y"), x, scale);
1586 auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
1587 GraphDef expected;
1588 TF_ASSERT_OK(s.ToGraphDef(&expected));
1589
1590 GraphDef actual;
1591 f->ToGraphDef(&actual);
1592 TF_EXPECT_GRAPH_EQ(expected, actual);
1593 }
1594
1595 std::unique_ptr<Graph> g = GetGradBody(flr0_, "XTimesTwo", {{"T", DT_FLOAT}});
1596
1597 {
1598 Scope s = Scope::NewRootScope();
1599 auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1600 auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
1601 auto two = ops::Const(s.WithOpName("two"), int64_t{2});
1602 auto scale = ops::Cast(s.WithOpName("scale"), two, DT_FLOAT);
1603 auto y = ops::Mul(s.WithOpName("y"), x, scale);
1604 NameAttrList fn0;
1605 fn0.set_name("Mul");
1606 (*fn0.mutable_attr())["T"].set_type(DT_FLOAT);
1607 auto func1 = ops::SymbolicGradient(
1608 s.WithOpName("Func/_1"), std::initializer_list<Input>{x, scale, func0},
1609 {DT_FLOAT, DT_FLOAT}, fn0);
1610 NameAttrList fn1;
1611 fn1.set_name("Cast");
1612 (*fn1.mutable_attr())["SrcT"].set_type(DT_INT64);
1613 (*fn1.mutable_attr())["DstT"].set_type(DT_FLOAT);
1614 (*fn1.mutable_attr())["Truncate"].set_b(false);
1615 auto func2 = ops::SymbolicGradient(
1616 s.WithOpName("Func/_2"),
1617 std::initializer_list<Input>{two, func1.output[1]}, {DT_INT64}, fn1);
1618 auto func3 = ops::_Retval(s.WithOpName("Func/_3"), func1[0], 0);
1619 GraphDef expected;
1620 TF_ASSERT_OK(s.ToGraphDef(&expected));
1621
1622 GraphDef actual;
1623 g->ToGraphDef(&actual);
1624 TF_EXPECT_GRAPH_EQ(expected, actual);
1625 }
1626
1627 OptimizeGraph(flr0_, &g);
1628 {
1629 Scope s = Scope::NewRootScope();
1630 auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1631 auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
1632 auto scale = ops::Const(
1633 s.WithOpName("scale/_6__cf__1")
1634 .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
1635 2.0f);
1636 auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
1637 auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
1638 auto const0 = ops::Const(
1639 s.WithOpName("Func/_1/sy/_5__cf__0")
1640 .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
1641 0, {0});
1642 auto func1_rx = ops::internal::BroadcastGradientArgs(
1643 s.WithOpName("Func/_1/rx"), func1_sx, const0);
1644 auto func1_sum_gx =
1645 ops::Sum(s.WithOpName("Func/_1/sum_gx"), func1_gx, func1_rx.r0);
1646 auto func1_dx =
1647 ops::Reshape(s.WithOpName("Func/_1/dx"), func1_sum_gx, func1_sx);
1648 auto func2 = ops::_Retval(s.WithOpName("Func/_3"), func1_dx, 0);
1649 GraphDef expected;
1650 TF_ASSERT_OK(s.ToGraphDef(&expected));
1651
1652 GraphDef actual;
1653 g->ToGraphDef(&actual);
1654 TF_EXPECT_GRAPH_EQ(expected, actual);
1655 }
1656 }
1657
TEST_F(FunctionLibraryRuntimeTest,Gradient_Select)1658 TEST_F(FunctionLibraryRuntimeTest, Gradient_Select) {
1659 FunctionDef my_select = FunctionDefHelper::Create(
1660 "MySelect",
1661 // Args
1662 {"condition: bool", "t: float32", "e: float32"},
1663 // Return values
1664 {"z: float32"},
1665 // Attrs
1666 {},
1667 // Nodes
1668 {
1669 {{"select0"}, "Select", {"condition", "t", "e"}, {{"T", DT_FLOAT}}},
1670 {{"select1"}, "Select", {"condition", "t", "e"}, {{"T", DT_FLOAT}}},
1671 {{"add"},
1672 "Add",
1673 {"select0:output", "select1:output"},
1674 {{"T", DT_FLOAT}}},
1675 },
1676 // Output mapping
1677 {{"z", "add:z"}});
1678 FunctionDef select_grad = FunctionDefHelper::Create(
1679 "MySelectGrad",
1680 // Args
1681 {"condition: bool", "t:float32", "e: float32", "dz: float32"},
1682 // Return values
1683 {"dt: float32"},
1684 // Attrs
1685 {},
1686 // Nodes
1687 {{
1688 {"grad"},
1689 "SymbolicGradient",
1690 {"condition", "t", "e", "dz"},
1691 {
1692 {"f", FunctionDefHelper::FunctionRef("MySelect")},
1693 {"Tin", DataTypeSlice({DT_BOOL, DT_FLOAT, DT_FLOAT, DT_FLOAT})},
1694 {"Tout", DataTypeSlice({DT_BOOL, DT_FLOAT, DT_FLOAT})},
1695 },
1696 }},
1697 // Output mapping
1698 {{"dt", "grad:output:1"}});
1699 Init({my_select, select_grad});
1700
1701 auto condition = test::AsTensor<bool>({false});
1702 auto t = test::AsTensor<float>({13.0});
1703 auto e = test::AsTensor<float>({15.0});
1704 auto dz = test::AsTensor<float>({1.0});
1705 Tensor y;
1706 TF_EXPECT_OK(InstantiateAndRun(flr0_, "MySelectGrad", {},
1707 {condition, t, e, dz}, {&y}));
1708 }
1709
TEST_F(FunctionLibraryRuntimeTest,Gradient_Add)1710 TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) {
1711 Init({});
1712 auto T = DT_FLOAT;
1713 std::unique_ptr<Graph> g = GetFuncBody(
1714 flr0_, "SymbolicGradient", {{"f", FDH::FunctionRef("Add", {{"T", T}})}});
1715 {
1716 Scope s = Scope::NewRootScope();
1717 auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1718 auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
1719 auto dz = ops::_Arg(s.WithOpName("dz"), DT_FLOAT, 2);
1720 auto gx = ops::Identity(s.WithOpName("gx"), dz);
1721 auto gy = ops::Identity(s.WithOpName("gy"), dz);
1722 auto sx = ops::Shape(s.WithOpName("sx"), x);
1723 auto sy = ops::Shape(s.WithOpName("sy"), y);
1724 auto rx = ops::internal::BroadcastGradientArgs(s.WithOpName("rx"), sx, sy);
1725 auto sum_gx = ops::Sum(s.WithOpName("sum_gx"), gx, rx.r0);
1726 auto sum_gy = ops::Sum(s.WithOpName("sum_gy"), gy, rx.r1);
1727 auto dx = ops::Reshape(s.WithOpName("dx"), sum_gx, sx);
1728 auto dy = ops::Reshape(s.WithOpName("dy"), sum_gy, sy);
1729 auto dx_ret = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
1730 auto dy_ret = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
1731 GraphDef expected;
1732 TF_ASSERT_OK(s.ToGraphDef(&expected));
1733
1734 GraphDef actual;
1735 g->ToGraphDef(&actual);
1736 TF_EXPECT_GRAPH_EQ(expected, actual);
1737 }
1738 }
1739
TEST_F(FunctionLibraryRuntimeTest,Gradient_Mul)1740 TEST_F(FunctionLibraryRuntimeTest, Gradient_Mul) {
1741 Init({});
1742 auto T = DT_FLOAT;
1743 std::unique_ptr<Graph> g = GetFuncBody(
1744 flr0_, "SymbolicGradient", {{"f", FDH::FunctionRef("Mul", {{"T", T}})}});
1745 {
1746 Scope s = Scope::NewRootScope();
1747 auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1748 auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
1749 auto dz = ops::_Arg(s.WithOpName("dz"), DT_FLOAT, 2);
1750 auto gx = ops::Mul(s.WithOpName("gx"), dz, y);
1751 auto sx = ops::Shape(s.WithOpName("sx"), x);
1752 auto gy = ops::Mul(s.WithOpName("gy"), x, dz);
1753 auto sy = ops::Shape(s.WithOpName("sy"), y);
1754 auto rx = ops::internal::BroadcastGradientArgs(s.WithOpName("rx"), sx, sy);
1755 auto sum_gx = ops::Sum(s.WithOpName("sum_gx"), gx, rx.r0);
1756 auto sum_gy = ops::Sum(s.WithOpName("sum_gy"), gy, rx.r1);
1757 auto dx = ops::Reshape(s.WithOpName("dx"), sum_gx, sx);
1758 auto dy = ops::Reshape(s.WithOpName("dy"), sum_gy, sy);
1759 auto dx_ret = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
1760 auto dy_ret = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
1761 GraphDef expected;
1762 TF_ASSERT_OK(s.ToGraphDef(&expected));
1763
1764 GraphDef actual;
1765 g->ToGraphDef(&actual);
1766 TF_EXPECT_GRAPH_EQ(expected, actual);
1767 }
1768 }
1769
TEST_F(FunctionLibraryRuntimeTest,Gradient_AddSum)1770 TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
1771 // Sum(Add(x, y))
1772 auto T = DT_FLOAT;
1773 auto test = FDH::Define("Test", {"x:float", "y:float"}, {"l:float"}, {},
1774 {
1775 {{"z"}, "Add", {"x", "y"}, {{"T", T}}},
1776 FDH::Const("zero", 0),
1777 FDH::Const("one", 1),
1778 {{"r"}, "Rank", {"z"}, {{"T", T}}},
1779 {{"indices"}, "Range", {"zero", "r", "one"}},
1780 {{"l"}, "Sum", {"z", "indices"}, {{"T", T}}},
1781 });
1782
1783 // TestGrad = Test'(x, y)
1784 auto grad = FDH::Define("TestGrad", {"x:float", "y:float"},
1785 {"dx:float", "dy:float"}, {},
1786 {FDH::Const<float>("dz", 1),
1787 {{"grad0", "grad1"},
1788 "SymbolicGradient",
1789 {"x", "y", "dz"},
1790 {
1791 {"f", FDH::FunctionRef("Test")},
1792 {"Tin", DataTypeSlice{T, T, T}},
1793 {"Tout", DataTypeSlice{T, T}},
1794 }},
1795 {{"dx"}, "Identity", {"grad0"}, {{"T", DT_FLOAT}}},
1796 {{"dy"}, "Identity", {"grad1"}, {{"T", DT_FLOAT}}}});
1797
1798 Init({test, grad});
1799
1800 std::unique_ptr<Graph> g = GetFuncBody(flr0_, "TestGrad", {});
1801 ASSERT_TRUE(g != nullptr);
1802 {
1803 Scope s = Scope::NewRootScope();
1804 auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1805 auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
1806 auto dz = ops::Const(s.WithOpName("dz"), 1.0f);
1807 NameAttrList fn;
1808 fn.set_name("Test");
1809 auto grad0 = ops::SymbolicGradient(s.WithOpName("grad0"),
1810 std::initializer_list<Input>{x, y, dz},
1811 {DT_FLOAT, DT_FLOAT}, fn);
1812 auto dx = ops::Identity(s.WithOpName("dx"), grad0[0]);
1813 auto dy = ops::Identity(s.WithOpName("dy"), grad0[1]);
1814 auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
1815 auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
1816 GraphDef expected;
1817 TF_ASSERT_OK(s.ToGraphDef(&expected));
1818
1819 GraphDef actual;
1820 g->ToGraphDef(&actual);
1821 TF_EXPECT_GRAPH_EQ(expected, actual);
1822 }
1823
1824 ExpandInlineFunctions(flr0_, g.get());
1825 {
1826 Scope s = Scope::NewRootScope();
1827 auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1828 auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
1829 auto dz = ops::Const(s.WithOpName("dz"), 1.0f);
1830 auto grad0_zero = ops::Const(s.WithOpName("grad0/zero"), 0);
1831 auto grad0_one = ops::Const(s.WithOpName("grad0/one"), 1);
1832 auto func0 = ops::Identity(s.WithOpName("Func/grad0/input/_0"), x);
1833 auto func1 = ops::Identity(s.WithOpName("Func/grad0/input/_1"), y);
1834 auto func2 = ops::Identity(s.WithOpName("Func/grad0/input/_2"), dz);
1835 auto grad0_z = ops::Add(s.WithOpName("grad0/z"), func0, func1);
1836 auto grad0_r = ops::Rank(s.WithOpName("grad0/r"), grad0_z);
1837 auto grad0_indices = ops::Range(s.WithOpName("grad0/indices"), grad0_zero,
1838 grad0_r, grad0_one);
1839 auto grad0_l = ops::Sum(s.WithOpName("grad0/l"), grad0_z, grad0_indices);
1840
1841 NameAttrList sum;
1842 sum.set_name("Sum");
1843 (*sum.mutable_attr())["T"].set_type(DT_FLOAT);
1844 (*sum.mutable_attr())["Tidx"].set_type(DT_INT32);
1845 (*sum.mutable_attr())["keep_dims"].set_b(false);
1846 auto grad0_func1 = ops::SymbolicGradient(
1847 s.WithOpName("grad0/Func/_1"),
1848 std::initializer_list<Input>{grad0_z, grad0_indices, func2},
1849 {DT_FLOAT, DT_INT32}, sum);
1850
1851 auto grad0_func2 =
1852 ops::ZerosLike(s.WithOpName("grad0/Func/_2"), grad0_zero);
1853 auto grad0_func3 = ops::ZerosLike(s.WithOpName("grad0/Func/_3"), grad0_r);
1854 auto grad0_func4 = ops::ZerosLike(s.WithOpName("grad0/Func/_4"), grad0_one);
1855
1856 NameAttrList add;
1857 add.set_name("Add");
1858 (*add.mutable_attr())["T"].set_type(DT_FLOAT);
1859 auto grad0_func5 = ops::SymbolicGradient(
1860 s.WithOpName("grad0/Func/_5"),
1861 std::initializer_list<Input>{func0, func1, grad0_func1[0]},
1862 {DT_FLOAT, DT_FLOAT}, add);
1863
1864 auto func3 =
1865 ops::Identity(s.WithOpName("Func/grad0/output/_3"), grad0_func5[0]);
1866 auto func4 =
1867 ops::Identity(s.WithOpName("Func/grad0/output/_4"), grad0_func5[1]);
1868 auto dx = ops::Identity(s.WithOpName("dx"), func3);
1869 auto dy = ops::Identity(s.WithOpName("dy"), func4);
1870 auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
1871 auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
1872
1873 GraphDef expected;
1874 TF_ASSERT_OK(s.ToGraphDef(&expected));
1875
1876 GraphDef actual;
1877 g->ToGraphDef(&actual);
1878 TF_EXPECT_GRAPH_EQ(expected, actual);
1879 }
1880
1881 OptimizeGraph(flr0_, &g);
1882 {
1883 Scope s = Scope::NewRootScope();
1884 auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
1885 auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
1886 auto dz = ops::Const(s.WithOpName("dz"), 1.0f);
1887 auto grad0_zero = ops::Const(s.WithOpName("grad0/zero"), 0);
1888 auto grad0_one = ops::Const(s.WithOpName("grad0/one"), 1);
1889 auto grad0_z = ops::Add(s.WithOpName("grad0/z"), x, y);
1890 auto grad0_r = ops::Rank(s.WithOpName("grad0/r"), grad0_z);
1891 auto grad0_indices = ops::Range(s.WithOpName("grad0/indices"), grad0_zero,
1892 grad0_r, grad0_one);
1893 auto i_shape =
1894 ops::Shape(s.WithOpName("grad0/Func/_1/i_shape"), grad0_indices);
1895 auto stitch_val = ops::Fill(s.WithOpName("grad0/Func/_1/stitch_val1"),
1896 i_shape, grad0_one);
1897 auto x_shape = ops::Shape(s.WithOpName("grad0/Func/_1/x_shape"), grad0_z);
1898 auto y_shape = ops::DynamicStitch(
1899 s.WithOpName("grad0/Func/_1/y_shape"),
1900 std::initializer_list<Input>{grad0_indices, grad0_indices},
1901 std::initializer_list<Input>{x_shape, stitch_val});
1902 auto dy_reshaped =
1903 ops::Reshape(s.WithOpName("grad0/Func/_1/dy_reshaped"), dz, y_shape);
1904 auto tile_scaling =
1905 ops::Div(s.WithOpName("grad0/Func/_1/tile_scaling"), x_shape, y_shape);
1906 auto func1_dx =
1907 ops::Tile(s.WithOpName("grad0/Func/_1/dx"), dy_reshaped, tile_scaling);
1908
1909 auto sx = ops::Shape(s.WithOpName("grad0/Func/_3/sx"), x);
1910 auto sy = ops::Shape(s.WithOpName("grad0/Func/_3/sy"), y);
1911 auto rx = ops::internal::BroadcastGradientArgs(
1912 s.WithOpName("grad0/Func/_3/rx"), sx, sy);
1913 auto sum_gx =
1914 ops::Sum(s.WithOpName("grad0/Func/_3/sum_gx"), func1_dx, rx.r0);
1915 auto sum_gy =
1916 ops::Sum(s.WithOpName("grad0/Func/_3/sum_gy"), func1_dx, rx.r1);
1917 auto dx = ops::Reshape(s.WithOpName("grad0/Func/_3/dx"), sum_gx, sx);
1918 auto dy = ops::Reshape(s.WithOpName("grad0/Func/_3/dy"), sum_gy, sy);
1919
1920 auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
1921 auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
1922
1923 GraphDef expected;
1924 TF_ASSERT_OK(s.ToGraphDef(&expected));
1925
1926 GraphDef actual;
1927 g->ToGraphDef(&actual);
1928 // The optimizer is non-deterministic, so we only check that the number of
1929 // nodes is not greater than expected.
1930 EXPECT_LE(actual.node_size(), expected.node_size());
1931 }
1932 }
1933
TEST_F(FunctionLibraryRuntimeTest,CrossDevice)1934 TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
1935 Init({test::function::FindDevice()});
1936 FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
1937 instantiate_opts.target = "/device:CPU:1";
1938 FunctionLibraryRuntime::Handle handle;
1939 TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {}, instantiate_opts, &handle));
1940
1941 Tensor y;
1942 FunctionLibraryRuntime::Options opts;
1943 PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
1944 opts.rendezvous = &rendezvous;
1945 opts.source_device = "/device:CPU:1";
1946 // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
1947 TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}));
1948 test::ExpectTensorEqual<tstring>(
1949 y,
1950 test::AsTensor<tstring>({"/job:localhost/replica:0/task:0/device:CPU:1"},
1951 TensorShape({})));
1952 opts.remote_execution = true;
1953 opts.source_device = "/job:localhost/replica:0/task:0/cpu:2";
1954 TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}));
1955 test::ExpectTensorEqual<tstring>(
1956 y,
1957 test::AsTensor<tstring>({"/job:localhost/replica:0/task:0/device:CPU:1"},
1958 TensorShape({})));
1959 }
1960
1961 class AreAllKernelsInlineOp : public OpKernel {
1962 public:
1963 using OpKernel::OpKernel;
1964
Compute(OpKernelContext * ctx)1965 void Compute(OpKernelContext* ctx) override {
1966 Tensor* output;
1967 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
1968 output->scalar<bool>()() = ctx->run_all_kernels_inline();
1969 }
1970 };
1971
1972 REGISTER_OP("AreAllKernelsInline").Output("result : bool").SetIsStateful();
1973 REGISTER_KERNEL_BUILDER(Name("AreAllKernelsInline").Device(DEVICE_CPU),
1974 AreAllKernelsInlineOp);
1975
TEST_F(FunctionLibraryRuntimeTest,RunAllKernelsInline)1976 TEST_F(FunctionLibraryRuntimeTest, RunAllKernelsInline) {
1977 // Create a function "F" that includes an AreAllKernelsInline op, and a
1978 // function "G" that calls "F".
1979 auto f = FDH::Create(
1980 // Name
1981 "F",
1982 // Args
1983 {},
1984 // Return values
1985 {"ret: bool"},
1986 // Attrs
1987 {},
1988 // Nodes
1989 {// y = AreAllKernelsInline()
1990 {{"y"}, "AreAllKernelsInline", {}, {}}},
1991 {{"ret", "y:result:0"}});
1992
1993 auto g = FDH::Create(
1994 // Name
1995 "G",
1996 // Args
1997 {},
1998 // Return values
1999 {"ret: bool"},
2000 // Attrs
2001 {},
2002 // Nodes
2003 {// y = F()
2004 {{"y"}, "F", {}, {}}},
2005 {{"ret", "y:ret:0"}});
2006
2007 Init({f, g});
2008 FunctionLibraryRuntime::Handle handle;
2009 TF_CHECK_OK(Instantiate(flr0_, "G", {}, &handle));
2010
2011 // Test that the `run_all_kernels_inline` flag is inherited by the kernel
2012 // running inside the called function.
2013 for (bool inline_option : {false, true}) {
2014 FunctionLibraryRuntime::Options opts;
2015 opts.run_all_kernels_inline = inline_option;
2016 Tensor result;
2017 TF_ASSERT_OK(Run(flr0_, handle, opts, {}, {&result}));
2018 EXPECT_EQ(result.scalar<bool>()(), inline_option);
2019 }
2020 }
2021
2022 class UserIntraOpThreadPoolOp : public OpKernel {
2023 public:
2024 using OpKernel::OpKernel;
2025
2026 class DummyThreadPool : public thread::ThreadPoolInterface {
2027 public:
Schedule(std::function<void ()> fn)2028 void Schedule(std::function<void()> fn) override { fn(); }
NumThreads() const2029 int NumThreads() const override { return 1; }
CurrentThreadId() const2030 int CurrentThreadId() const override { return -1; }
2031 };
2032
dummy_thread_pool()2033 static DummyThreadPool& dummy_thread_pool() {
2034 static DummyThreadPool& thread_pool = *new DummyThreadPool();
2035 return thread_pool;
2036 }
2037
Compute(OpKernelContext * ctx)2038 void Compute(OpKernelContext* ctx) override {
2039 Tensor* result;
2040 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &result));
2041 result->scalar<bool>()() =
2042 ctx->device()
2043 ->tensorflow_cpu_worker_threads()
2044 ->workers->AsEigenThreadPool() == &dummy_thread_pool();
2045 }
2046 };
2047
2048 REGISTER_OP("UserIntraOpThreadPool").Output("result: bool").SetIsStateful();
2049 REGISTER_KERNEL_BUILDER(Name("UserIntraOpThreadPool").Device(DEVICE_CPU),
2050 UserIntraOpThreadPoolOp);
2051
TEST_F(FunctionLibraryRuntimeTest,RunUserIntraOpThreadPool)2052 TEST_F(FunctionLibraryRuntimeTest, RunUserIntraOpThreadPool) {
2053 // Create a function "F" that includes an AreAllKernelsInline op, and a
2054 // function "G" that calls "F".
2055 auto f = FDH::Create(
2056 // Name
2057 "F",
2058 // Args
2059 {},
2060 // Return values
2061 {"ret: bool"},
2062 // Attrs
2063 {},
2064 // Nodes
2065 {// y = UserIntraOpThreadPool()
2066 {{"y"}, "UserIntraOpThreadPool", {}, {}}},
2067 {{"ret", "y:result:0"}});
2068
2069 Init({f});
2070 FunctionLibraryRuntime::Handle handle;
2071 TF_CHECK_OK(Instantiate(flr0_, "F", {}, &handle));
2072
2073 FunctionLibraryRuntime::Options opts;
2074 opts.user_intra_op_threadpool = &UserIntraOpThreadPoolOp::dummy_thread_pool();
2075
2076 Tensor result;
2077 TF_ASSERT_OK(Run(flr0_, handle, opts, {}, {&result}));
2078 EXPECT_TRUE(result.scalar<bool>()());
2079 }
2080
2081 namespace {
2082
DoNothing(Graph * g)2083 bool DoNothing(Graph* g) { return false; }
2084
Optimize(const std::function<bool (Graph * g)> & pass,const FunctionDef & fdef)2085 GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
2086 const FunctionDef& fdef) {
2087 InstantiationResult result;
2088 TF_CHECK_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
2089 std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
2090 GraphConstructorOptions opts;
2091 opts.allow_internal_ops = true;
2092 opts.expect_device_spec = false;
2093 TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get()));
2094 pass(g.get());
2095 std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global()));
2096 CopyGraph(*g, g1.get());
2097 g = nullptr;
2098 GraphDef gdef;
2099 g1->ToGraphDef(&gdef);
2100 return gdef;
2101 }
2102
2103 } // end namespace
2104
TEST(OptimizationTest,RemoveDeadNodes)2105 TEST(OptimizationTest, RemoveDeadNodes) {
2106 auto T = DT_INT32;
2107 auto func = FDH::Define(
2108 // Name
2109 "F",
2110 // Args
2111 {"x: int32"},
2112 // Return values
2113 {"y: int32"},
2114 // Attrs
2115 {},
2116 // Nodes
2117 {// a = Square<T>(x)
2118 {{"a"}, "Square", {"x"}, {{"T", T}}},
2119 // 1
2120 FDH::Const("o", 1),
2121 // A bunch of extra arithmetic that y doesn't depend on
2122 {{"x1"}, "Add", {"o", "o"}, {{"T", T}}},
2123 {{"x2"}, "Mul", {"a", "x1"}, {{"T", T}}},
2124 {{"x3"}, "Mul", {"x1", "x2"}, {{"T", T}}},
2125 // A stateful node.
2126 {{"keep_me"}, "RandomUniform", {"o"}, {{"T", T}, {"dtype", DT_FLOAT}}},
2127 // y = Add<T>(a, o)
2128 {{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
2129
2130 GraphDef expected;
2131 {
2132 Scope s = Scope::DisabledShapeInferenceScope();
2133 auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
2134 auto o = ops::Const(s.WithOpName("o"), 1);
2135 auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT);
2136 auto x1 = ops::Add(s.WithOpName("x1"), o, o);
2137 auto a = ops::Square(s.WithOpName("a"), x);
2138 auto y = ops::Add(s.WithOpName("y"), a, o);
2139 auto x2 = ops::Mul(s.WithOpName("x2"), a, x1);
2140 auto x3 = ops::Mul(s.WithOpName("x3"), x1, x2);
2141 auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
2142 TF_ASSERT_OK(s.ToGraphDef(&expected));
2143 }
2144 TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
2145
2146 // TODO(zhifengc): Comes up another test case.
2147 TF_EXPECT_GRAPH_EQ(expected, Optimize(::tensorflow::RemoveDeadNodes, func));
2148 }
2149
TEST(OptimizationTest,RemoveIdentityNodes_Ref)2150 TEST(OptimizationTest, RemoveIdentityNodes_Ref) {
2151 auto T = DT_FLOAT;
2152 auto func = FDH::Define(
2153 // Name
2154 "F",
2155 // Args
2156 {},
2157 // Return values
2158 {"ret: float"},
2159 // Attrs
2160 {},
2161 // Nodes
2162 {// variable
2163 {{"v"}, "VariableV2", {}, {{"dtype", T}, {"shape", TensorShape({})}}},
2164 // read the variable. Shouldn't be removed.
2165 {{"v_read"}, "Identity", {"v"}, {{"T", T}}},
2166 // returns v + v
2167 {{"ret"}, "Add", {"v_read", "v_read"}, {{"T", T}}}});
2168
2169 GraphDef expected;
2170 {
2171 Scope s = Scope::NewRootScope();
2172 auto v = ops::Variable(s.WithOpName("v"), PartialTensorShape({}), DT_FLOAT);
2173 auto v_read = ops::Identity(s.WithOpName("v_read"), v);
2174 auto ret = ops::Add(s.WithOpName("ret"), v_read, v_read);
2175 auto ret_retval = ops::_Retval(s.WithOpName("ret_RetVal"), ret, 0);
2176 TF_ASSERT_OK(s.ToGraphDef(&expected));
2177 }
2178 TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
2179 TF_EXPECT_GRAPH_EQ(expected,
2180 Optimize(::tensorflow::RemoveIdentityNodes, func));
2181 }
2182
TEST(OptimizationTest,RemoveIdentityNodes)2183 TEST(OptimizationTest, RemoveIdentityNodes) {
2184 auto T = DT_INT32;
2185 auto func = FDH::Define(
2186 // Name
2187 "F",
2188 // Args
2189 {"x: int32"},
2190 // Return values
2191 {"y: int32"},
2192 // Attrs
2193 {},
2194 // Nodes
2195 {// a = Square<T>(x)
2196 {{"a"}, "Square", {"x"}, {{"T", T}}},
2197 // 1
2198 FDH::Const("o", 1),
2199 // A bunch of extra arithmetic that y doesn't depend on
2200 {{"x1"}, "Identity", {"a"}, {{"T", T}}},
2201 {{"x2"}, "Identity", {"x1"}, {{"T", T}}},
2202 {{"x3"}, "Identity", {"x2"}, {{"T", T}}},
2203 // A stateful node.
2204 {{"keep_me"},
2205 "RandomUniform",
2206 {"o"},
2207 {{"T", T}, {"dtype", DT_FLOAT}},
2208 {"x3"}},
2209 // y = Add<T>(a, o)
2210 {{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
2211
2212 {
2213 Scope s = Scope::DisabledShapeInferenceScope();
2214 auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
2215 auto o = ops::Const(s.WithOpName("o"), 1);
2216 auto a = ops::Square(s.WithOpName("a"), x);
2217 auto y = ops::Add(s.WithOpName("y"), a, o);
2218 auto x1 = ops::Identity(s.WithOpName("x1"), a);
2219 auto x2 = ops::Identity(s.WithOpName("x2"), x1);
2220 auto x3 = ops::Identity(s.WithOpName("x3"), x2);
2221 auto keep_me = ops::RandomUniform(
2222 s.WithOpName("keep_me").WithControlDependencies(x3), {o}, DT_FLOAT);
2223 auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
2224 GraphDef expected;
2225 TF_ASSERT_OK(s.ToGraphDef(&expected));
2226 TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
2227 }
2228
2229 {
2230 Scope s = Scope::DisabledShapeInferenceScope();
2231 auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
2232 auto o = ops::Const(s.WithOpName("o"), 1);
2233 auto a = ops::Square(s.WithOpName("a"), x);
2234 auto y = ops::Add(s.WithOpName("y"), a, o);
2235 auto keep_me = ops::RandomUniform(
2236 s.WithOpName("keep_me").WithControlDependencies(a), {o}, DT_FLOAT);
2237 auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
2238 GraphDef expected;
2239 TF_ASSERT_OK(s.ToGraphDef(&expected));
2240 TF_EXPECT_GRAPH_EQ(expected,
2241 Optimize(::tensorflow::RemoveIdentityNodes, func));
2242 }
2243 }
2244
TEST(OptimizationTest,RemoveListArrayConverter)2245 TEST(OptimizationTest, RemoveListArrayConverter) {
2246 auto func = FDH::Create(
2247 // Name
2248 "Test",
2249 // Args
2250 {"i: float"},
2251 // Return signature
2252 {"o: float"},
2253 // Attrs
2254 {},
2255 // Nodes
2256 {FDH::Const("zero", 0),
2257 {{"s"},
2258 "Split",
2259 {"zero:output:0", "i"},
2260 {{"num_split", 4}, {"T", DT_FLOAT}}},
2261 {{"a"},
2262 "_ArrayToList",
2263 {"s:output"},
2264 {{"N", 4},
2265 {"T", DT_FLOAT},
2266 {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}},
2267 {{"l"}, "Mul", {"a:output:0", "a:output:1"}, {{"T", DT_FLOAT}}},
2268 {{"r"}, "Mul", {"a:output:2", "a:output:3"}, {{"T", DT_FLOAT}}},
2269 {{"x"},
2270 "_ListToArray",
2271 {"l:z", "r:z"},
2272 {{"N", 2},
2273 {"T", DT_FLOAT},
2274 {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
2275 {{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}},
2276 // Return values
2277 {{"o", "o:sum"}});
2278
2279 {
2280 Scope scope = Scope::DisabledShapeInferenceScope();
2281 auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
2282 auto zero = ops::Const(scope.WithOpName("zero"), 0);
2283 auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
2284 auto a = ops::_ArrayToList(scope.WithOpName("a"), s.output,
2285 {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT});
2286 auto r = ops::Mul(scope.WithOpName("r"), a[2], a[3]);
2287 auto l = ops::Mul(scope.WithOpName("l"), a[0], a[1]);
2288 auto x = ops::_ListToArray(scope.WithOpName("x"),
2289 std::initializer_list<Input>{l, r}, DT_FLOAT, 2);
2290 auto o = ops::AddN(scope.WithOpName("o"), x.output);
2291 auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0);
2292 GraphDef expected;
2293 TF_ASSERT_OK(scope.ToGraphDef(&expected));
2294 TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
2295 }
2296
2297 {
2298 Scope scope = Scope::NewRootScope();
2299 auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
2300 auto zero = ops::Const(scope.WithOpName("zero"), 0);
2301 auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
2302 auto func_0 = ops::Identity(scope.WithOpName("Func/a/input/_0"), s[0]);
2303 auto func_1 = ops::Identity(scope.WithOpName("Func/a/input/_1"), s[1]);
2304 auto func_2 = ops::Identity(scope.WithOpName("Func/a/input/_2"), s[2]);
2305 auto func_3 = ops::Identity(scope.WithOpName("Func/a/input/_3"), s[3]);
2306 auto r = ops::Mul(scope.WithOpName("r"), func_2, func_3);
2307 auto l = ops::Mul(scope.WithOpName("l"), func_0, func_1);
2308 auto func_4 = ops::Identity(scope.WithOpName("Func/x/input/_4"), l);
2309 auto func_5 = ops::Identity(scope.WithOpName("Func/x/input/_5"), r);
2310 auto o = ops::AddN(scope.WithOpName("o"),
2311 std::initializer_list<Input>{func_4, func_5});
2312 auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0);
2313 GraphDef expected;
2314 TF_ASSERT_OK(scope.ToGraphDef(&expected));
2315 TF_EXPECT_GRAPH_EQ(expected, Optimize(RemoveListArrayConverter, func));
2316 }
2317
2318 {
2319 Scope scope = Scope::NewRootScope();
2320 auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
2321 auto zero = ops::Const(scope.WithOpName("zero"), 0);
2322 auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
2323 auto r = ops::Mul(scope.WithOpName("r"), s[2], s[3]);
2324 auto l = ops::Mul(scope.WithOpName("l"), s[0], s[1]);
2325 auto o =
2326 ops::AddN(scope.WithOpName("o"), std::initializer_list<Input>{l, r});
2327 auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0);
2328 GraphDef expected;
2329 TF_ASSERT_OK(scope.ToGraphDef(&expected));
2330
2331 auto remove_listarray_and_identity = [](Graph* g) {
2332 return RemoveListArrayConverter(g) && RemoveIdentityNodes(g);
2333 };
2334 TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
2335 }
2336 }
2337
TEST(OptimizationTest,RemoveListArrayConverter_WithControlDeps)2338 TEST(OptimizationTest, RemoveListArrayConverter_WithControlDeps) {
2339 auto func = FDH::Create(
2340 // Name
2341 "Test",
2342 // Args
2343 {"i: float"},
2344 // Return values
2345 {"o: float"},
2346 // Attrs
2347 {},
2348 // Nodes
2349 {FDH::Const("dummy", 0),
2350 {{"x"},
2351 "_ListToArray",
2352 {"i", "i"},
2353 {{"N", 2}, {"T", DT_FLOAT}, {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}},
2354 // Control dep
2355 {"dummy"}},
2356 {{"o"},
2357 "AddN",
2358 {"x:output"},
2359 {{"N", 2}, {"T", DT_FLOAT}},
2360 // Control dep
2361 {"x"}}},
2362 {{"o", "o:sum"}});
2363
2364 {
2365 Scope s = Scope::DisabledShapeInferenceScope();
2366 auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
2367 auto dummy = ops::Const(s.WithOpName("dummy"), 0);
2368 auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy),
2369 std::initializer_list<Input>{i, i}, DT_FLOAT, 2);
2370 auto o =
2371 ops::AddN(s.WithOpName("o").WithControlDependencies({x.output[0].op()}),
2372 x.output);
2373 auto o_ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0);
2374 GraphDef expected;
2375 TF_ASSERT_OK(s.ToGraphDef(&expected));
2376 TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
2377 }
2378
2379 GraphDef expected;
2380 {
2381 Scope s = Scope::NewRootScope();
2382 auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
2383 auto dummy = ops::Const(s.WithOpName("dummy"), 0);
2384 auto func_2 = ops::NoOp(s.WithOpName("Func/x/input_control_node/_2")
2385 .WithControlDependencies(dummy));
2386 auto func_0 = ops::Identity(
2387 s.WithOpName("Func/x/input/_0").WithControlDependencies({func_2}), i);
2388 auto func_1 = ops::Identity(
2389 s.WithOpName("Func/x/input/_1").WithControlDependencies({func_2}), i);
2390 auto func_3 = ops::NoOp(
2391 s.WithOpName("Func/x/output_control_node/_3")
2392 .WithControlDependencies({func_0.output.op(), func_1.output.op()}));
2393 auto o = ops::AddN(s.WithOpName("o").WithControlDependencies({func_3}),
2394 std::initializer_list<Input>{func_0, func_1});
2395 auto o_ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0);
2396 TF_ASSERT_OK(s.ToGraphDef(&expected));
2397 }
2398 TF_EXPECT_GRAPH_EQ(expected, Optimize(RemoveListArrayConverter, func));
2399
2400 auto remove_listarray_and_identity = [](Graph* g) {
2401 return RemoveListArrayConverter(g) && RemoveIdentityNodes(g);
2402 };
2403 // NOTE: We are not removing Identity nodes with any control
2404 // dependencies yet.
2405 TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
2406 }
2407
2408 } // namespace
2409 } // namespace tensorflow
2410