xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
17 #include "tensorflow/compiler/xla/tests/test_macros.h"
18 #include "tensorflow/core/lib/core/status_test_util.h"
19 
20 namespace xla {
21 namespace {
22 
23 // Tests that ensure outfeed instructions that are contained in nested
24 // computations in non-root positions are executed.
25 
26 class OutfeedInNestedComputationTest : public LocalClientTestBase {};
27 
XLA_TEST_F(OutfeedInNestedComputationTest,OutfeedInWhile)28 XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) {
29   XlaBuilder b(TestName());
30 
31   Shape state_tuple_array_shape = ShapeUtil::MakeShape(xla::S32, {10, 5});
32   Shape int_shape = ShapeUtil::MakeShape(xla::S32, {});
33   Shape state_tuple_shape =
34       ShapeUtil::MakeTupleShape({int_shape, state_tuple_array_shape});
35   Shape xfeed_shape = ShapeUtil::MakeShape(xla::S32, {2});
36 
37   XlaOp some_buffer = Broadcast(ConstantR0<int32_t>(&b, 0), {10, 5});
38   XlaOp num_iter = Infeed(&b, int_shape);
39   XlaOp init_tuple = Tuple(&b, {num_iter, some_buffer});
40 
41   TF_ASSERT_OK_AND_ASSIGN(XlaComputation loop_cond, [&] {
42     // Condition: iteration variable > 0
43     XlaBuilder cond_builder("loop_condition");
44     XlaOp state_tuple = Parameter(&cond_builder, 0, state_tuple_shape, "state");
45     XlaOp loop_counter = GetTupleElement(state_tuple, 0);
46     Outfeed(loop_counter, int_shape, "");
47     Gt(loop_counter, ConstantR0<int32_t>(&cond_builder, 0));
48     return cond_builder.Build();
49   }());
50 
51   TF_ASSERT_OK_AND_ASSIGN(XlaComputation loop_body, [&] {
52     XlaBuilder body_builder("loop_body");
53     XlaOp state_tuple = Parameter(&body_builder, 0, state_tuple_shape, "state");
54     XlaOp loop_counter = GetTupleElement(state_tuple, 0);
55     XlaOp buffer_inside = GetTupleElement(state_tuple, 1);
56 
57     // Read some stuff from Infeed.
58     XlaOp some_input = Infeed(&body_builder, xfeed_shape);
59     XlaOp sum = Add(some_input, Broadcast(loop_counter, {2}));
60     Outfeed(sum, xfeed_shape, "");
61 
62     XlaOp iter_left = Sub(loop_counter, ConstantR0<int32_t>(&body_builder, 1));
63 
64     Tuple(&body_builder, {iter_left, buffer_inside});
65     return body_builder.Build();
66   }());
67 
68   // Build loop.
69   XlaOp result_tuple = While(loop_cond, loop_body, init_tuple);
70   GetTupleElement(result_tuple, 0);
71   TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());
72 
73   Literal comp_result;
74   std::unique_ptr<tensorflow::Thread> thread(
75       tensorflow::Env::Default()->StartThread(
76           tensorflow::ThreadOptions(), "execute_thread", [&] {
77             comp_result =
78                 local_client_->ExecuteAndTransfer(computation, {}).value();
79           }));
80 
81   VLOG(1) << "Transferring trip count to computation";
82   // Transfer number of iterations to Infeed.
83   TF_ASSERT_OK(
84       local_client_->TransferToInfeed(LiteralUtil::CreateR0<int32_t>(1)));
85 
86   // Pick up value from outfeed
87   {
88     VLOG(1) << "Reading from condition outfeed";
89     TF_ASSERT_OK_AND_ASSIGN(Literal r,
90                             local_client_->TransferFromOutfeed(&int_shape));
91     EXPECT_EQ(r.Get<int32_t>({}), 1);
92   }
93 
94   VLOG(1) << "Writing data to infeed";
95   // Transfer some stuff to Infeed for use inside of loop.
96   TF_ASSERT_OK(local_client_->TransferToInfeed(
97       LiteralUtil::CreateR1<int32_t>({10, 20})));
98 
99   // Pick up value from outfeed
100   {
101     VLOG(1) << "Reading from body outfeed";
102     TF_ASSERT_OK_AND_ASSIGN(Literal r,
103                             local_client_->TransferFromOutfeed(&xfeed_shape));
104     EXPECT_EQ(r.Get<int32_t>({0}), 11);
105     EXPECT_EQ(r.Get<int32_t>({1}), 21);
106   }
107 
108   {
109     VLOG(1) << "Reading from condition outfeed";
110     TF_ASSERT_OK_AND_ASSIGN(Literal r,
111                             local_client_->TransferFromOutfeed(&int_shape));
112     EXPECT_EQ(r.Get<int32_t>({}), 0);
113   }
114 
115   // Joins the thread
116   thread.reset();
117 
118   EXPECT_EQ(comp_result.Get<int32_t>({}), 0);
119 }
120 
XLA_TEST_F(OutfeedInNestedComputationTest,OutfeedInConditional)121 XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
122   XlaBuilder b(TestName());
123 
124   Shape condition_shape = ShapeUtil::MakeShape(xla::PRED, {});
125   Shape result_shape = ShapeUtil::MakeShape(xla::PRED, {});
126 
127   TF_ASSERT_OK_AND_ASSIGN(XlaComputation true_computation, [&] {
128     XlaBuilder inner_builder("true_computation");
129     XlaOp param = Parameter(&inner_builder, 0, result_shape, "param");
130     Outfeed(param, result_shape, "");
131     Or(param, param);
132     return inner_builder.Build();
133   }());
134 
135   TF_ASSERT_OK_AND_ASSIGN(XlaComputation false_computation, [&] {
136     XlaBuilder inner_builder("false_computation");
137     Parameter(&inner_builder, 0, result_shape, "param");
138     return inner_builder.Build();
139   }());
140 
141   XlaOp pred = Infeed(&b, condition_shape);
142   Conditional(/*predicate=*/pred, /*true_operand=*/pred,
143               /*true_computation=*/true_computation, /*false_operand=*/pred,
144               /*false_computation=*/false_computation);
145 
146   TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());
147 
148   Literal comp_result;
149   std::unique_ptr<tensorflow::Thread> thread(
150       tensorflow::Env::Default()->StartThread(
151           tensorflow::ThreadOptions(), "execute_thread", [&] {
152             comp_result =
153                 local_client_->ExecuteAndTransfer(computation, {}).value();
154           }));
155 
156   TF_ASSERT_OK(
157       local_client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
158 
159   TF_ASSERT_OK_AND_ASSIGN(Literal r,
160                           local_client_->TransferFromOutfeed(&result_shape));
161 
162   EXPECT_EQ(r.Get<bool>({}), true);
163 
164   // Join the thread
165   thread.reset();
166 }
167 
168 }  // namespace
169 }  // namespace xla
170