1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
17 #include "tensorflow/compiler/xla/tests/test_macros.h"
18 #include "tensorflow/core/lib/core/status_test_util.h"
19
20 namespace xla {
21 namespace {
22
23 // Tests that ensure outfeed instructions that are contained in nested
24 // computations in non-root positions are executed.
25
26 class OutfeedInNestedComputationTest : public LocalClientTestBase {};
27
XLA_TEST_F(OutfeedInNestedComputationTest,OutfeedInWhile)28 XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) {
29 XlaBuilder b(TestName());
30
31 Shape state_tuple_array_shape = ShapeUtil::MakeShape(xla::S32, {10, 5});
32 Shape int_shape = ShapeUtil::MakeShape(xla::S32, {});
33 Shape state_tuple_shape =
34 ShapeUtil::MakeTupleShape({int_shape, state_tuple_array_shape});
35 Shape xfeed_shape = ShapeUtil::MakeShape(xla::S32, {2});
36
37 XlaOp some_buffer = Broadcast(ConstantR0<int32_t>(&b, 0), {10, 5});
38 XlaOp num_iter = Infeed(&b, int_shape);
39 XlaOp init_tuple = Tuple(&b, {num_iter, some_buffer});
40
41 TF_ASSERT_OK_AND_ASSIGN(XlaComputation loop_cond, [&] {
42 // Condition: iteration variable > 0
43 XlaBuilder cond_builder("loop_condition");
44 XlaOp state_tuple = Parameter(&cond_builder, 0, state_tuple_shape, "state");
45 XlaOp loop_counter = GetTupleElement(state_tuple, 0);
46 Outfeed(loop_counter, int_shape, "");
47 Gt(loop_counter, ConstantR0<int32_t>(&cond_builder, 0));
48 return cond_builder.Build();
49 }());
50
51 TF_ASSERT_OK_AND_ASSIGN(XlaComputation loop_body, [&] {
52 XlaBuilder body_builder("loop_body");
53 XlaOp state_tuple = Parameter(&body_builder, 0, state_tuple_shape, "state");
54 XlaOp loop_counter = GetTupleElement(state_tuple, 0);
55 XlaOp buffer_inside = GetTupleElement(state_tuple, 1);
56
57 // Read some stuff from Infeed.
58 XlaOp some_input = Infeed(&body_builder, xfeed_shape);
59 XlaOp sum = Add(some_input, Broadcast(loop_counter, {2}));
60 Outfeed(sum, xfeed_shape, "");
61
62 XlaOp iter_left = Sub(loop_counter, ConstantR0<int32_t>(&body_builder, 1));
63
64 Tuple(&body_builder, {iter_left, buffer_inside});
65 return body_builder.Build();
66 }());
67
68 // Build loop.
69 XlaOp result_tuple = While(loop_cond, loop_body, init_tuple);
70 GetTupleElement(result_tuple, 0);
71 TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());
72
73 Literal comp_result;
74 std::unique_ptr<tensorflow::Thread> thread(
75 tensorflow::Env::Default()->StartThread(
76 tensorflow::ThreadOptions(), "execute_thread", [&] {
77 comp_result =
78 local_client_->ExecuteAndTransfer(computation, {}).value();
79 }));
80
81 VLOG(1) << "Transferring trip count to computation";
82 // Transfer number of iterations to Infeed.
83 TF_ASSERT_OK(
84 local_client_->TransferToInfeed(LiteralUtil::CreateR0<int32_t>(1)));
85
86 // Pick up value from outfeed
87 {
88 VLOG(1) << "Reading from condition outfeed";
89 TF_ASSERT_OK_AND_ASSIGN(Literal r,
90 local_client_->TransferFromOutfeed(&int_shape));
91 EXPECT_EQ(r.Get<int32_t>({}), 1);
92 }
93
94 VLOG(1) << "Writing data to infeed";
95 // Transfer some stuff to Infeed for use inside of loop.
96 TF_ASSERT_OK(local_client_->TransferToInfeed(
97 LiteralUtil::CreateR1<int32_t>({10, 20})));
98
99 // Pick up value from outfeed
100 {
101 VLOG(1) << "Reading from body outfeed";
102 TF_ASSERT_OK_AND_ASSIGN(Literal r,
103 local_client_->TransferFromOutfeed(&xfeed_shape));
104 EXPECT_EQ(r.Get<int32_t>({0}), 11);
105 EXPECT_EQ(r.Get<int32_t>({1}), 21);
106 }
107
108 {
109 VLOG(1) << "Reading from condition outfeed";
110 TF_ASSERT_OK_AND_ASSIGN(Literal r,
111 local_client_->TransferFromOutfeed(&int_shape));
112 EXPECT_EQ(r.Get<int32_t>({}), 0);
113 }
114
115 // Joins the thread
116 thread.reset();
117
118 EXPECT_EQ(comp_result.Get<int32_t>({}), 0);
119 }
120
XLA_TEST_F(OutfeedInNestedComputationTest,OutfeedInConditional)121 XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
122 XlaBuilder b(TestName());
123
124 Shape condition_shape = ShapeUtil::MakeShape(xla::PRED, {});
125 Shape result_shape = ShapeUtil::MakeShape(xla::PRED, {});
126
127 TF_ASSERT_OK_AND_ASSIGN(XlaComputation true_computation, [&] {
128 XlaBuilder inner_builder("true_computation");
129 XlaOp param = Parameter(&inner_builder, 0, result_shape, "param");
130 Outfeed(param, result_shape, "");
131 Or(param, param);
132 return inner_builder.Build();
133 }());
134
135 TF_ASSERT_OK_AND_ASSIGN(XlaComputation false_computation, [&] {
136 XlaBuilder inner_builder("false_computation");
137 Parameter(&inner_builder, 0, result_shape, "param");
138 return inner_builder.Build();
139 }());
140
141 XlaOp pred = Infeed(&b, condition_shape);
142 Conditional(/*predicate=*/pred, /*true_operand=*/pred,
143 /*true_computation=*/true_computation, /*false_operand=*/pred,
144 /*false_computation=*/false_computation);
145
146 TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());
147
148 Literal comp_result;
149 std::unique_ptr<tensorflow::Thread> thread(
150 tensorflow::Env::Default()->StartThread(
151 tensorflow::ThreadOptions(), "execute_thread", [&] {
152 comp_result =
153 local_client_->ExecuteAndTransfer(computation, {}).value();
154 }));
155
156 TF_ASSERT_OK(
157 local_client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
158
159 TF_ASSERT_OK_AND_ASSIGN(Literal r,
160 local_client_->TransferFromOutfeed(&result_shape));
161
162 EXPECT_EQ(r.Get<bool>({}), true);
163
164 // Join the thread
165 thread.reset();
166 }
167
168 } // namespace
169 } // namespace xla
170