xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/auto_parallel_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
17 #include "tensorflow/cc/ops/standard_ops.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/grappler/grappler_item.h"
20 #include "tensorflow/core/grappler/utils.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/platform/test.h"
23 
24 namespace tensorflow {
25 namespace grappler {
26 namespace {
27 
28 class AutoParallelTest : public ::testing::Test {};
29 
TEST_F(AutoParallelTest,SimpleParallel)30 TEST_F(AutoParallelTest, SimpleParallel) {
31   tensorflow::Scope s = tensorflow::Scope::DisabledShapeInferenceScope();
32   Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1});
33   Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
34   Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
35   Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a});
36   Output identity = ops::Identity(s.WithOpName("identity"), {var});
37   Output fifo_queue = ops::FIFOQueue(s.WithOpName("fifo_queue"), {DT_FLOAT});
38   auto dequeue = ops::QueueDequeueMany(s.WithOpName("dequeue"), {fifo_queue},
39                                        {constant_b}, {DT_FLOAT});
40   Output add = ops::AddN(s.WithOpName("add"), {constant_a, dequeue[0]});
41   Output learning_rate = ops::Const(s.WithOpName("learning_rate"), 0.01f, {1});
42   Output apply_gradient = ops::ApplyGradientDescent(
43       s.WithOpName("apply_gradient"), {var}, {learning_rate}, {add});
44 
45   GrapplerItem item;
46   item.init_ops.push_back("assign");
47   item.fetch.push_back("apply_gradient");
48   item.init_ops.push_back("assign");
49   TF_CHECK_OK(s.ToGraphDef(&item.graph));
50 
51   AutoParallel parallel(2);
52   GraphDef output;
53   Status status = parallel.Optimize(nullptr, item, &output);
54   TF_EXPECT_OK(status);
55   EXPECT_EQ(21, output.node_size());
56 
57   const NodeDef& node_assign = output.node(0);
58   EXPECT_EQ("assign", node_assign.name());
59   EXPECT_EQ("AutoParallel-Replica-0/constant_a", node_assign.input(1));
60 
61   const NodeDef& node_constant_b = output.node(1);
62   EXPECT_EQ("constant_b", node_constant_b.name());
63 
64   const NodeDef& node_fifo_queue = output.node(2);
65   EXPECT_EQ("fifo_queue", node_fifo_queue.name());
66 
67   const NodeDef& node_identity = output.node(3);
68   EXPECT_EQ("identity", node_identity.name());
69   EXPECT_EQ("var", node_identity.input(0));
70 
71   const NodeDef& node_var = output.node(4);
72   EXPECT_EQ("var", node_var.name());
73 
74   const NodeDef& node_div_const0 = output.node(5);
75   EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-Const",
76             node_div_const0.name());
77 
78   const NodeDef& node_div0 = output.node(6);
79   EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-apply_gradient",
80             node_div0.name());
81   const NodeDef& node_add0 = output.node(7);
82   EXPECT_EQ("AutoParallel-Replica-0/add", node_add0.name());
83 
84   const NodeDef& node_gradient0 = output.node(8);
85   EXPECT_EQ("AutoParallel-Replica-0/apply_gradient", node_gradient0.name());
86 
87   const NodeDef& node_constant_a0 = output.node(9);
88   EXPECT_EQ("AutoParallel-Replica-0/constant_a", node_constant_a0.name());
89 
90   const NodeDef& node_dequeue0 = output.node(10);
91   EXPECT_EQ("AutoParallel-Replica-0/dequeue", node_dequeue0.name());
92 
93   const NodeDef& node_learning_rate0 = output.node(11);
94   EXPECT_EQ("AutoParallel-Replica-0/learning_rate", node_learning_rate0.name());
95 
96   const NodeDef& node_div_const1 = output.node(12);
97   EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-Const",
98             node_div_const1.name());
99 
100   const NodeDef& node_div1 = output.node(13);
101   EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-apply_gradient",
102             node_div1.name());
103 
104   const NodeDef& node_add1 = output.node(14);
105   EXPECT_EQ("AutoParallel-Replica-1/add", node_add1.name());
106 
107   const NodeDef& node_gradient1 = output.node(15);
108   EXPECT_EQ("AutoParallel-Replica-1/apply_gradient", node_gradient1.name());
109 
110   const NodeDef& node_constant_a1 = output.node(16);
111   EXPECT_EQ("AutoParallel-Replica-1/constant_a", node_constant_a1.name());
112 
113   const NodeDef& node_dequeue1 = output.node(17);
114   EXPECT_EQ("AutoParallel-Replica-1/dequeue", node_dequeue1.name());
115 
116   const NodeDef& node_learning_rate1 = output.node(18);
117   EXPECT_EQ("AutoParallel-Replica-1/learning_rate", node_learning_rate1.name());
118 
119   const NodeDef& node_fetch = output.node(19);
120   EXPECT_EQ("AutoParallel-Control-Fetch", node_fetch.name());
121   EXPECT_EQ("^AutoParallel-Replica-0/apply_gradient", node_fetch.input(0));
122   EXPECT_EQ("^AutoParallel-Replica-1/apply_gradient", node_fetch.input(1));
123 
124   const NodeDef& node_gradient = output.node(20);
125   EXPECT_EQ("apply_gradient", node_gradient.name());
126   EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0));
127 }
128 
TEST_F(AutoParallelTest,SimpleParallelNoDequeue)129 TEST_F(AutoParallelTest, SimpleParallelNoDequeue) {
130   tensorflow::Scope s = tensorflow::Scope::DisabledShapeInferenceScope();
131   Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1});
132   Output constant_c = ops::Const(s.WithOpName("constant_c"), 1.0f, {1});
133   Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
134   Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
135   Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a});
136   Output add = ops::AddN(s.WithOpName("add"), {constant_a, constant_c});
137   Output learning_rate = ops::Const(s.WithOpName("learning_rate"), 0.01f, {1});
138   Output apply_gradient = ops::ApplyGradientDescent(
139       s.WithOpName("apply_gradient"), {var}, {learning_rate}, {add});
140 
141   GrapplerItem item;
142   item.init_ops.push_back("assign");
143   item.fetch.push_back("apply_gradient");
144   item.init_ops.push_back("assign");
145   TF_CHECK_OK(s.ToGraphDef(&item.graph));
146 
147   AutoParallel parallel(2);
148   GraphDef output;
149   Status status = parallel.Optimize(nullptr, item, &output);
150   TF_EXPECT_OK(status);
151 }
152 
153 }  // namespace
154 }  // namespace grappler
155 }  // namespace tensorflow
156