xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_module_group_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo.pb.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 
25 namespace xla {
26 
27 namespace {
28 
29 namespace op = ::xla::testing::opcode_matchers;
30 using ::testing::Property;
31 using ::testing::StrEq;
32 
33 class HloModuleGroupTest : public HloTestBase {
34  protected:
35   HloModuleGroupTest() = default;
36 };
37 
TEST_F(HloModuleGroupTest,SingleModule)38 TEST_F(HloModuleGroupTest, SingleModule) {
39   const std::string text = R"(
40 HloModule simple_module
41 
42 ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
43   %x = f32[] parameter(0)
44   %y = f32[] parameter(1)
45   ROOT %add = f32[] add(%x, %y)
46 }
47 )";
48   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
49                           ParseAndReturnVerifiedModule(text));
50   HloModuleGroup group(std::move(module));
51 
52   EXPECT_EQ(group.modules().size(), 1);
53   EXPECT_THAT(
54       group.module(0).entry_computation()->instructions(),
55       ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
56 
57   TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
58                           HloModuleGroup::CreateFromProto(
59                               group.ToProto(), {group.module(0).config()}));
60   EXPECT_EQ(group_copy.modules().size(), 1);
61   EXPECT_THAT(
62       group_copy.module(0).entry_computation()->instructions(),
63       ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
64 
65   std::vector<std::unique_ptr<HloModule>> modules = group.ConsumeModules();
66   EXPECT_EQ(modules.size(), 1);
67   EXPECT_EQ(group.modules().size(), 0);
68 }
69 
TEST_F(HloModuleGroupTest,MultipleModules)70 TEST_F(HloModuleGroupTest, MultipleModules) {
71   const std::string text_0 = R"(
72 HloModule module0
73 
74 ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
75   %x = f32[] parameter(0)
76   %y = f32[] parameter(1)
77   ROOT %add = f32[] add(%x, %y)
78 }
79 )";
80   const std::string text_1 = R"(
81 HloModule module1
82 
83 ENTRY %entry (a: f32[]) -> f32[] {
84   ROOT %a = f32[] parameter(0)
85 }
86 )";
87   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
88                           ParseAndReturnVerifiedModule(text_0));
89   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
90                           ParseAndReturnVerifiedModule(text_1));
91   std::vector<std::unique_ptr<HloModule>> modules;
92   modules.push_back(std::move(module_0));
93   modules.push_back(std::move(module_1));
94   HloModuleGroup group(TestName(), absl::MakeSpan(modules));
95   EXPECT_EQ(group.modules().size(), 2);
96   EXPECT_THAT(
97       group.module(0).entry_computation()->instructions(),
98       ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
99   EXPECT_THAT(group.module(1).entry_computation()->instructions(),
100               ::testing::ElementsAre(op::Parameter()));
101 
102   TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
103                           HloModuleGroup::CreateFromProto(
104                               group.ToProto(), {group.module(0).config(),
105                                                 group.module(1).config()}));
106   EXPECT_EQ(group_copy.modules().size(), 2);
107 }
108 
TEST_F(HloModuleGroupTest,BuildModuleGroupByPushBack)109 TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) {
110   const std::string text_0 = R"(
111 HloModule module0
112 
113 ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
114   %x = f32[] parameter(0)
115   %y = f32[] parameter(1)
116   ROOT %add = f32[] add(%x, %y)
117 }
118 )";
119   const std::string text_1 = R"(
120 HloModule module1
121 
122 ENTRY %entry (a: f32[]) -> f32[] {
123   ROOT %a = f32[] parameter(0)
124 }
125 )";
126   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
127                           ParseAndReturnVerifiedModule(text_0));
128   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
129                           ParseAndReturnVerifiedModule(text_1));
130   HloModuleGroup group(TestName());
131   group.push_back(std::move(module_0));
132   group.push_back(std::move(module_1));
133 
134   EXPECT_EQ(group.modules().size(), 2);
135   EXPECT_THAT(
136       group.module(0).entry_computation()->instructions(),
137       ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
138   EXPECT_THAT(group.module(1).entry_computation()->instructions(),
139               ::testing::ElementsAre(op::Parameter()));
140 }
141 
142 // Tests that the order of companion instructions in the companion set doesn't
143 // change across runs.
TEST_F(HloModuleGroupTest,ModuleGroupCompanionOrder)144 TEST_F(HloModuleGroupTest, ModuleGroupCompanionOrder) {
145   // A simple while loop template for core i sending to core i+1.
146   constexpr char text[] = R"(
147 HloModule module_%d
148 
149 while_cond {
150   param = s32[] parameter(0)
151   ROOT p = pred[] constant(true)
152 }
153 
154 while_body {
155   param = s32[] parameter(0)
156   token.s = token[] after-all()
157   token.r = token[] after-all()
158   send = (s32[], u32[], token[]) send(param, token.s), channel_id=%d
159   send-done = token[] send-done(send), channel_id=%d
160   recv = (s32[], u32[], token[]) recv(token.r), channel_id=%d
161   recv-done = (s32[], token[]) recv-done(recv), channel_id=%d
162   ROOT data = s32[] get-tuple-element(recv-done), index=0
163 }
164 
165 ENTRY entry {
166   while_init = s32[] constant(1)
167   ROOT while = s32[] while(while_init), condition=while_cond, body=while_body
168 }
169 )";
170 
171   // Try creating the module and the metadata kTrialCount times and check the
172   // companion instructions remain in the same order.
173   const int64_t kTrialCount = 5;
174   const int64_t kDeviceCount = 10;
175   std::vector<int64_t> companion_order;
176 
177   for (int64_t t = 0; t < kTrialCount; ++t) {
178     HloModuleGroup group(TestName());
179     for (int64_t i = 0; i < kDeviceCount; ++i) {
180       const int64_t send_channel = i;
181       const int64_t recv_channel = i == 0 ? kDeviceCount - 1 : i - 1;
182       TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
183                               ParseAndReturnVerifiedModule(absl::StrFormat(
184                                   text, i, send_channel, send_channel,
185                                   recv_channel, recv_channel)));
186       group.push_back(std::move(module));
187     }
188     ASSERT_EQ(group.modules().size(), kDeviceCount);
189 
190     TF_ASSERT_OK_AND_ASSIGN(auto metadata,
191                             HloModuleGroupMetadata::Build(group.modules()));
192     ASSERT_EQ(metadata->companion_sets().size(), 1);
193 
194     std::vector<int64_t> module_ids;
195     const auto& companion_sets = *metadata->companion_sets()[0];
196     module_ids.reserve(companion_sets.size());
197     for (HloInstruction* companion : companion_sets) {
198       module_ids.push_back(metadata->GetModuleId(companion->GetModule()));
199     }
200 
201     if (t == 0) {
202       companion_order = module_ids;
203     } else {
204       EXPECT_TRUE(absl::c_equal(companion_order, module_ids));
205     }
206   }
207 }
208 
209 // Test that metadata is transferred when a module is replaced.
TEST_F(HloModuleGroupTest,ReplaceModuleMetadata)210 TEST_F(HloModuleGroupTest, ReplaceModuleMetadata) {
211   auto old_module = CreateNewVerifiedModule();
212   int old_module_id = old_module->unique_id();
213   old_module->metadata()->RecordPassStart();
214   TF_EXPECT_OK(old_module->metadata()->set_current_pass_name("fake pass"));
215 
216   HloModuleGroup group(std::move(old_module));
217   EXPECT_EQ(group.module(0).metadata()->proto().module_group_name(),
218             group.name());
219 
220   auto new_module = CreateNewVerifiedModule();
221   group.ReplaceModule(0, std::move(new_module));
222 
223   EXPECT_NE(group.module(0).unique_id(), old_module_id);
224   const HloModuleMetadataProto& module_metadata =
225       group.module(0).metadata()->proto();
226   EXPECT_EQ(module_metadata.canonical_module_id(), old_module_id);
227 
228   const HloPassMetadata& pass_metadata =
229       *module_metadata.pass_metadata().rbegin();
230   EXPECT_THAT(pass_metadata,
231               Property(&HloPassMetadata::pass_name, StrEq("fake pass")));
232 }
233 
234 }  // namespace
235 
236 }  // namespace xla
237