1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
17
18 #include "tensorflow/compiler/xla/service/hlo.pb.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24
25 namespace xla {
26
27 namespace {
28
29 namespace op = ::xla::testing::opcode_matchers;
30 using ::testing::Property;
31 using ::testing::StrEq;
32
33 class HloModuleGroupTest : public HloTestBase {
34 protected:
35 HloModuleGroupTest() = default;
36 };
37
TEST_F(HloModuleGroupTest,SingleModule)38 TEST_F(HloModuleGroupTest, SingleModule) {
39 const std::string text = R"(
40 HloModule simple_module
41
42 ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
43 %x = f32[] parameter(0)
44 %y = f32[] parameter(1)
45 ROOT %add = f32[] add(%x, %y)
46 }
47 )";
48 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
49 ParseAndReturnVerifiedModule(text));
50 HloModuleGroup group(std::move(module));
51
52 EXPECT_EQ(group.modules().size(), 1);
53 EXPECT_THAT(
54 group.module(0).entry_computation()->instructions(),
55 ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
56
57 TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
58 HloModuleGroup::CreateFromProto(
59 group.ToProto(), {group.module(0).config()}));
60 EXPECT_EQ(group_copy.modules().size(), 1);
61 EXPECT_THAT(
62 group_copy.module(0).entry_computation()->instructions(),
63 ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
64
65 std::vector<std::unique_ptr<HloModule>> modules = group.ConsumeModules();
66 EXPECT_EQ(modules.size(), 1);
67 EXPECT_EQ(group.modules().size(), 0);
68 }
69
TEST_F(HloModuleGroupTest,MultipleModules)70 TEST_F(HloModuleGroupTest, MultipleModules) {
71 const std::string text_0 = R"(
72 HloModule module0
73
74 ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
75 %x = f32[] parameter(0)
76 %y = f32[] parameter(1)
77 ROOT %add = f32[] add(%x, %y)
78 }
79 )";
80 const std::string text_1 = R"(
81 HloModule module1
82
83 ENTRY %entry (a: f32[]) -> f32[] {
84 ROOT %a = f32[] parameter(0)
85 }
86 )";
87 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
88 ParseAndReturnVerifiedModule(text_0));
89 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
90 ParseAndReturnVerifiedModule(text_1));
91 std::vector<std::unique_ptr<HloModule>> modules;
92 modules.push_back(std::move(module_0));
93 modules.push_back(std::move(module_1));
94 HloModuleGroup group(TestName(), absl::MakeSpan(modules));
95 EXPECT_EQ(group.modules().size(), 2);
96 EXPECT_THAT(
97 group.module(0).entry_computation()->instructions(),
98 ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
99 EXPECT_THAT(group.module(1).entry_computation()->instructions(),
100 ::testing::ElementsAre(op::Parameter()));
101
102 TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
103 HloModuleGroup::CreateFromProto(
104 group.ToProto(), {group.module(0).config(),
105 group.module(1).config()}));
106 EXPECT_EQ(group_copy.modules().size(), 2);
107 }
108
TEST_F(HloModuleGroupTest,BuildModuleGroupByPushBack)109 TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) {
110 const std::string text_0 = R"(
111 HloModule module0
112
113 ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
114 %x = f32[] parameter(0)
115 %y = f32[] parameter(1)
116 ROOT %add = f32[] add(%x, %y)
117 }
118 )";
119 const std::string text_1 = R"(
120 HloModule module1
121
122 ENTRY %entry (a: f32[]) -> f32[] {
123 ROOT %a = f32[] parameter(0)
124 }
125 )";
126 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
127 ParseAndReturnVerifiedModule(text_0));
128 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
129 ParseAndReturnVerifiedModule(text_1));
130 HloModuleGroup group(TestName());
131 group.push_back(std::move(module_0));
132 group.push_back(std::move(module_1));
133
134 EXPECT_EQ(group.modules().size(), 2);
135 EXPECT_THAT(
136 group.module(0).entry_computation()->instructions(),
137 ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
138 EXPECT_THAT(group.module(1).entry_computation()->instructions(),
139 ::testing::ElementsAre(op::Parameter()));
140 }
141
142 // Tests that the order of companion instructions in the companion set doesn't
143 // change across runs.
TEST_F(HloModuleGroupTest,ModuleGroupCompanionOrder)144 TEST_F(HloModuleGroupTest, ModuleGroupCompanionOrder) {
145 // A simple while loop template for core i sending to core i+1.
146 constexpr char text[] = R"(
147 HloModule module_%d
148
149 while_cond {
150 param = s32[] parameter(0)
151 ROOT p = pred[] constant(true)
152 }
153
154 while_body {
155 param = s32[] parameter(0)
156 token.s = token[] after-all()
157 token.r = token[] after-all()
158 send = (s32[], u32[], token[]) send(param, token.s), channel_id=%d
159 send-done = token[] send-done(send), channel_id=%d
160 recv = (s32[], u32[], token[]) recv(token.r), channel_id=%d
161 recv-done = (s32[], token[]) recv-done(recv), channel_id=%d
162 ROOT data = s32[] get-tuple-element(recv-done), index=0
163 }
164
165 ENTRY entry {
166 while_init = s32[] constant(1)
167 ROOT while = s32[] while(while_init), condition=while_cond, body=while_body
168 }
169 )";
170
171 // Try creating the module and the metadata kTrialCount times and check the
172 // companion instructions remain in the same order.
173 const int64_t kTrialCount = 5;
174 const int64_t kDeviceCount = 10;
175 std::vector<int64_t> companion_order;
176
177 for (int64_t t = 0; t < kTrialCount; ++t) {
178 HloModuleGroup group(TestName());
179 for (int64_t i = 0; i < kDeviceCount; ++i) {
180 const int64_t send_channel = i;
181 const int64_t recv_channel = i == 0 ? kDeviceCount - 1 : i - 1;
182 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
183 ParseAndReturnVerifiedModule(absl::StrFormat(
184 text, i, send_channel, send_channel,
185 recv_channel, recv_channel)));
186 group.push_back(std::move(module));
187 }
188 ASSERT_EQ(group.modules().size(), kDeviceCount);
189
190 TF_ASSERT_OK_AND_ASSIGN(auto metadata,
191 HloModuleGroupMetadata::Build(group.modules()));
192 ASSERT_EQ(metadata->companion_sets().size(), 1);
193
194 std::vector<int64_t> module_ids;
195 const auto& companion_sets = *metadata->companion_sets()[0];
196 module_ids.reserve(companion_sets.size());
197 for (HloInstruction* companion : companion_sets) {
198 module_ids.push_back(metadata->GetModuleId(companion->GetModule()));
199 }
200
201 if (t == 0) {
202 companion_order = module_ids;
203 } else {
204 EXPECT_TRUE(absl::c_equal(companion_order, module_ids));
205 }
206 }
207 }
208
209 // Test that metadata is transferred when a module is replaced.
TEST_F(HloModuleGroupTest,ReplaceModuleMetadata)210 TEST_F(HloModuleGroupTest, ReplaceModuleMetadata) {
211 auto old_module = CreateNewVerifiedModule();
212 int old_module_id = old_module->unique_id();
213 old_module->metadata()->RecordPassStart();
214 TF_EXPECT_OK(old_module->metadata()->set_current_pass_name("fake pass"));
215
216 HloModuleGroup group(std::move(old_module));
217 EXPECT_EQ(group.module(0).metadata()->proto().module_group_name(),
218 group.name());
219
220 auto new_module = CreateNewVerifiedModule();
221 group.ReplaceModule(0, std::move(new_module));
222
223 EXPECT_NE(group.module(0).unique_id(), old_module_id);
224 const HloModuleMetadataProto& module_metadata =
225 group.module(0).metadata()->proto();
226 EXPECT_EQ(module_metadata.canonical_module_id(), old_module_id);
227
228 const HloPassMetadata& pass_metadata =
229 *module_metadata.pass_metadata().rbegin();
230 EXPECT_THAT(pass_metadata,
231 Property(&HloPassMetadata::pass_name, StrEq("fake pass")));
232 }
233
234 } // namespace
235
236 } // namespace xla
237