1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
17
18 #include <functional>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <utility>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_module_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_parser.h"
31 #include "tensorflow/compiler/xla/service/platform_util.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/tests/filecheck.h"
35 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
36 #include "tensorflow/compiler/xla/tests/test_utils.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/core/lib/core/status_test_util.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/test.h"
41
42 namespace xla {
43
44 namespace {
45
46 using absl::string_view;
47 using std::optional;
48
49 constexpr char kInterpreter[] = "interpreter";
50
ProgramShapesEqual(const ProgramShape & lhs,const ProgramShape & rhs)51 bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) {
52 if (lhs.parameters_size() != rhs.parameters_size()) {
53 return false;
54 }
55 for (int i = 0; i < lhs.parameters_size(); i++) {
56 if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) {
57 return false;
58 }
59 }
60 return ShapeUtil::Equal(lhs.result(), rhs.result());
61 }
62
GetProgramShapeWithLayout(const HloModule & module)63 ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
64 ProgramShape program_shape;
65 const auto* entry = module.entry_computation();
66 for (const auto* param : entry->parameter_instructions()) {
67 *program_shape.add_parameters() = param->shape();
68 *program_shape.add_parameter_names() = param->name();
69 }
70 *program_shape.mutable_result() = entry->root_instruction()->shape();
71 return program_shape;
72 }
73
74 } // namespace
75
HloTestBase(bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,HloPredicate instruction_can_change_layout_func)76 HloTestBase::HloTestBase(bool verifier_layout_sensitive,
77 bool allow_mixed_precision_in_hlo_verifier,
78 HloPredicate instruction_can_change_layout_func)
79 : HloTestBase(GetTestPlatform(), GetReferencePlatform(),
80 verifier_layout_sensitive,
81 allow_mixed_precision_in_hlo_verifier,
82 instruction_can_change_layout_func) {}
83
HloTestBase(se::Platform * test_platform,se::Platform * reference_platform,bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,HloPredicate instruction_can_change_layout_func)84 HloTestBase::HloTestBase(se::Platform* test_platform,
85 se::Platform* reference_platform,
86 bool verifier_layout_sensitive,
87 bool allow_mixed_precision_in_hlo_verifier,
88 HloPredicate instruction_can_change_layout_func)
89 : test_runner_(test_platform),
90 reference_runner_(reference_platform),
91 verifier_layout_sensitive_(verifier_layout_sensitive),
92 allow_mixed_precision_in_hlo_verifier_(
93 allow_mixed_precision_in_hlo_verifier) {
94 hlo_verifier_ = std::make_unique<HloVerifier>(
95 /*layout_sensitive=*/verifier_layout_sensitive,
96 /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
97 instruction_can_change_layout_func);
98 }
99
GetReferencePlatform()100 /*static*/ se::Platform* HloTestBase::GetReferencePlatform() {
101 auto result = PlatformUtil::GetPlatform(kInterpreter);
102 TF_CHECK_OK(result.status()) << "could not get interpreter platform";
103 return result.ValueOrDie();
104 }
105
GetTestPlatform()106 /*static*/ se::Platform* HloTestBase::GetTestPlatform() {
107 auto result = PlatformUtil::GetDefaultPlatform();
108 TF_CHECK_OK(result.status()) << "could not get test platform";
109 return result.ValueOrDie();
110 }
111
CreateNewUnverifiedModule(const std::string & name)112 std::unique_ptr<HloModule> HloTestBase::CreateNewUnverifiedModule(
113 const std::string& name) {
114 return std::make_unique<HloModule>(name, GetModuleConfigForTest());
115 }
116
CreateNewVerifiedModule(const std::string & name,int64_t replica_count)117 std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
118 const std::string& name, int64_t replica_count) {
119 return std::make_unique<VerifiedHloModule>(
120 name, GetModuleConfigForTest(replica_count), verifier_layout_sensitive_,
121 allow_mixed_precision_in_hlo_verifier_,
122 backend().compiler()->ShapeSizeBytesFunction());
123 }
124
125 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,int64_t replica_count,int64_t num_partitions)126 HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
127 int64_t replica_count,
128 int64_t num_partitions) {
129 TF_ASSIGN_OR_RETURN(
130 auto module,
131 ParseAndReturnVerifiedModule(
132 hlo_text, GetModuleConfigForTest(replica_count, num_partitions)));
133 UpdateEntryComputationLayout(module.get());
134 return module;
135 }
136
137 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,const HloModuleConfig & config)138 HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
139 const HloModuleConfig& config) {
140 auto module = std::make_unique<VerifiedHloModule>(
141 TestName(), config, verifier_layout_sensitive_,
142 allow_mixed_precision_in_hlo_verifier_,
143 backend().compiler()->ShapeSizeBytesFunction());
144 TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
145 UpdateEntryComputationLayout(module.get());
146 return std::move(module);
147 }
148
AddEntryComputationAndUpdateEntryComputationLayout(HloModule * module,std::unique_ptr<HloComputation> computation)149 HloComputation* HloTestBase::AddEntryComputationAndUpdateEntryComputationLayout(
150 HloModule* module, std::unique_ptr<HloComputation> computation) {
151 auto comp = module->AddEntryComputation(std::move(computation));
152 UpdateEntryComputationLayout(module);
153 return comp;
154 }
155
UpdateEntryComputationLayout(HloModule * module)156 void HloTestBase::UpdateEntryComputationLayout(HloModule* module) {
157 xla::UpdateEntryComputationLayout(
158 module, test_runner_.device_shape_representation_fn());
159 }
160
161 /* static */
RunHloPass(HloPassInterface * hlo_pass,HloModule * module)162 StatusOr<bool> HloTestBase::RunHloPass(HloPassInterface* hlo_pass,
163 HloModule* module) {
164 const std::string module_str_before_run =
165 module->ToProto().ShortDebugString();
166 const auto status_or = hlo_pass->Run(module);
167 if (status_or.status().ok()) {
168 const std::string module_str_after_run =
169 module->ToProto().ShortDebugString();
170 const bool passChangedHlo = status_or.ValueOrDie();
171 if (passChangedHlo) {
172 // Check that the proto actually changed.
173 EXPECT_NE(module_str_after_run, module_str_before_run);
174 } else {
175 // Check that the proto remains same.
176 EXPECT_EQ(module_str_after_run, module_str_before_run);
177 }
178 }
179 return status_or;
180 }
181
182 /* static */
DefaultPrecisionConfig(int operands)183 PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) {
184 PrecisionConfig precision_config;
185 precision_config.mutable_operand_precision()->Resize(
186 operands, PrecisionConfig::DEFAULT);
187 return precision_config;
188 }
189
SetAotFastMathDebugOptions(DebugOptions * options)190 void HloTestBase::SetAotFastMathDebugOptions(DebugOptions* options) {
191 options->set_xla_cpu_enable_fast_math(true);
192 options->set_xla_gpu_enable_fast_min_max(true);
193 options->set_xla_cpu_enable_fast_min_max(true);
194 options->set_xla_cpu_fast_math_honor_nans(false);
195 options->set_xla_cpu_fast_math_honor_infs(false);
196 options->set_xla_cpu_fast_math_honor_functions(false);
197 options->set_xla_cpu_fast_math_honor_division(false);
198 }
199
GetDebugOptionsForTest()200 DebugOptions HloTestBase::GetDebugOptionsForTest() {
201 auto debug_options = GetDebugOptionsFromFlags();
202 // TODO(b/38354253): Change tests to use Parameters instead of Constants.
203 debug_options.add_xla_disable_hlo_passes("constant_folding");
204 debug_options.set_xla_gpu_max_kernel_unroll_factor(1);
205 debug_options.set_xla_hlo_evaluator_use_fast_path(true);
206 return debug_options;
207 }
208
RunAndFilecheckHloRewrite(absl::string_view hlo,HloPassInterface && hlo_pass,std::optional<absl::string_view> expected,std::function<void (HloModule *)> after_pass_checks)209 void HloTestBase::RunAndFilecheckHloRewrite(
210 absl::string_view hlo, HloPassInterface&& hlo_pass,
211 std::optional<absl::string_view> expected,
212 std::function<void(HloModule*)> after_pass_checks) {
213 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
214 ParseAndReturnVerifiedModule(hlo));
215 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&hlo_pass, module.get()));
216 EXPECT_EQ(changed, expected.has_value());
217 if (changed) {
218 TF_ASSERT_OK_AND_ASSIGN(
219 bool filecheck_matches,
220 RunFileCheck(
221 module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
222 *expected));
223 EXPECT_TRUE(filecheck_matches);
224 if (after_pass_checks) {
225 after_pass_checks(module.get());
226 }
227 }
228 }
229
Execute(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)230 StatusOr<Literal> HloTestBase::Execute(std::unique_ptr<HloModule> module,
231 absl::Span<Literal* const> arguments) {
232 return test_runner_.Execute(std::move(module), arguments);
233 }
234
ExecuteNoHloPasses(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)235 Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
236 absl::Span<Literal* const> arguments) {
237 return test_runner_
238 .Execute(std::move(module), arguments,
239 /*run_hlo_passes=*/false)
240 .ValueOrDie();
241 }
242
ExecuteAndTransfer(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)243 Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
244 absl::Span<Literal* const> arguments) {
245 return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
246 }
247
ExecuteReplicated(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,int64_t num_replicas,bool use_threads,bool run_hlo_passes)248 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
249 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
250 int64_t num_replicas, bool use_threads, bool run_hlo_passes) {
251 HloRunner::ReplicatedExecuteOptions options;
252 options.num_replicas = num_replicas;
253 options.run_hlo_passes = run_hlo_passes;
254 options.use_threads = use_threads;
255 for (auto argument : arguments) {
256 options.arguments.push_back(argument);
257 }
258 return test_runner_.ExecuteReplicated(std::move(module), options);
259 }
260
ExecuteReplicated(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,int64_t num_replicas,DeviceAssignment * device_assignment,bool run_hlo_passes,bool use_threads)261 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
262 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
263 int64_t num_replicas, DeviceAssignment* device_assignment,
264 bool run_hlo_passes, bool use_threads) {
265 HloRunner::ReplicatedExecuteOptions options;
266 options.num_replicas = num_replicas;
267 options.run_hlo_passes = run_hlo_passes;
268 options.use_threads = use_threads;
269 for (auto argument : arguments) {
270 options.arguments.push_back(argument);
271 }
272 return test_runner_.ExecuteReplicated(std::move(module), options,
273 device_assignment);
274 }
275
ExecuteReplicated(std::function<Executable * (int64_t)> executable_provider,std::function<int64_t (int64_t)> argument_count_provider,std::function<const Literal * (int64_t,int64_t)> argument_provider,int64_t num_replicas,bool run_hlo_passes,DeviceAssignment * device_assignment)276 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
277 std::function<Executable*(int64_t)> executable_provider,
278 std::function<int64_t(int64_t)> argument_count_provider,
279 std::function<const Literal*(int64_t, int64_t)> argument_provider,
280 int64_t num_replicas, bool run_hlo_passes,
281 DeviceAssignment* device_assignment) {
282 HloRunner::ReplicatedExecuteOptions options;
283 options.num_replicas = num_replicas;
284 options.run_hlo_passes = run_hlo_passes;
285 options.use_threads = true;
286 return test_runner_.ExecuteReplicated(
287 executable_provider, argument_count_provider, argument_provider, options,
288 device_assignment);
289 }
290
MakeReferenceModule(const HloModule & test_module,const std::function<void (HloModule *)> & reference_preprocessor)291 StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
292 const HloModule& test_module,
293 const std::function<void(HloModule*)>& reference_preprocessor) {
294 std::unique_ptr<HloModule> reference_module = test_module.Clone();
295 const auto& program_shape = GetProgramShapeWithLayout(test_module);
296
297 if (reference_preprocessor != nullptr) {
298 reference_preprocessor(reference_module.get());
299 if (!ProgramShapesEqual(program_shape,
300 GetProgramShapeWithLayout(*reference_module))) {
301 return InvalidArgument(
302 "reference preprocessor must not modify the program shape");
303 }
304 }
305 TF_RETURN_IF_ERROR(hlo_verifier_->Run(reference_module.get()).status());
306 return std::move(reference_module);
307 }
308
RunAndCompareInternal(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,bool run_hlo_passes,const std::function<void (HloModule *)> & reference_preprocessor)309 StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
310 std::unique_ptr<HloModule> module,
311 const absl::Span<Literal* const> arguments,
312 const optional<ErrorSpec>& error, bool run_hlo_passes,
313 const std::function<void(HloModule*)>& reference_preprocessor) {
314 TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status());
315 TF_ASSIGN_OR_RETURN(auto reference_module,
316 MakeReferenceModule(*module, reference_preprocessor));
317
318 // Execute on two backends.
319 TF_ASSIGN_OR_RETURN(
320 auto test,
321 test_runner_.Execute(std::move(module), arguments, run_hlo_passes));
322 TF_ASSIGN_OR_RETURN(auto reference,
323 reference_runner_.Execute(std::move(reference_module),
324 arguments, run_hlo_passes));
325 if (reference.IsAll(0)) {
326 LOG(WARNING) << "Reference value is only zeros.";
327 }
328
329 return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test,
330 error);
331 }
332
RunAndCompare(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)333 ::testing::AssertionResult HloTestBase::RunAndCompare(
334 std::unique_ptr<HloModule> module,
335 const absl::Span<Literal* const> arguments,
336 const optional<ErrorSpec>& error,
337 const std::function<void(HloModule*)>& reference_preprocessor) {
338 auto result =
339 RunAndCompareInternal(std::move(module), arguments, error,
340 /*run_hlo_passes=*/true, reference_preprocessor);
341 if (!result.ok()) {
342 return ::testing::AssertionFailure() << result.status();
343 }
344 return result.ValueOrDie();
345 }
346
RunAndCompareNoHloPasses(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)347 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
348 std::unique_ptr<HloModule> module,
349 const absl::Span<Literal* const> arguments,
350 const optional<ErrorSpec>& error,
351 const std::function<void(HloModule*)>& reference_preprocessor) {
352 auto result =
353 RunAndCompareInternal(std::move(module), arguments, error,
354 /*run_hlo_passes=*/false, reference_preprocessor);
355 if (!result.ok()) {
356 return ::testing::AssertionFailure() << result.status();
357 }
358 return result.ValueOrDie();
359 }
360
RunAndCompare(std::unique_ptr<HloModule> module,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)361 ::testing::AssertionResult HloTestBase::RunAndCompare(
362 std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
363 const std::function<void(HloModule*)>& reference_preprocessor) {
364 auto fake_arguments = MakeFakeArguments(module.get()).value();
365
366 std::vector<Literal*> fake_argument_ptrs;
367 absl::c_transform(
368 fake_arguments, std::back_inserter(fake_argument_ptrs),
369 [](const Literal& literal) { return const_cast<Literal*>(&literal); });
370
371 return RunAndCompare(std::move(module), fake_argument_ptrs, error,
372 reference_preprocessor);
373 }
374
RunAndCompareNoHloPasses(std::unique_ptr<HloModule> module,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)375 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
376 std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
377 const std::function<void(HloModule*)>& reference_preprocessor) {
378 const auto fake_arguments = MakeFakeArguments(module.get()).value();
379 std::vector<Literal*> fake_argument_ptrs;
380 absl::c_transform(
381 fake_arguments, std::back_inserter(fake_argument_ptrs),
382 [](const Literal& literal) { return const_cast<Literal*>(&literal); });
383
384 return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error,
385 reference_preprocessor);
386 }
387
Run(std::unique_ptr<HloModule> module,bool run_hlo_passes)388 ::testing::AssertionResult HloTestBase::Run(std::unique_ptr<HloModule> module,
389 bool run_hlo_passes) {
390 const auto fake_arguments = MakeFakeArguments(module.get()).value();
391 const auto change = hlo_verifier_->Run(module.get());
392 if (!change.ok()) {
393 return ::testing::AssertionFailure() << change.status();
394 }
395
396 const auto output =
397 test_runner_.Execute(std::move(module), fake_arguments, run_hlo_passes);
398 return output.ok()
399 ? ::testing::AssertionSuccess()
400 : ::testing::AssertionFailure() << output.status().error_message();
401 }
402
RunAndCompare(string_view hlo_string,const std::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)403 ::testing::AssertionResult HloTestBase::RunAndCompare(
404 string_view hlo_string, const std::optional<ErrorSpec>& error,
405 const std::function<void(HloModule*)>& reference_preprocessor) {
406 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
407 if (!module_or_status.ok()) {
408 return ::testing::AssertionFailure()
409 << "Error while parsing HLO text format: "
410 << module_or_status.status().ToString();
411 }
412 return RunAndCompare(std::move(module_or_status).value(), error,
413 reference_preprocessor);
414 }
415
416 StatusOr<::testing::AssertionResult>
RunAndCompareTwoModulesInternal(std::unique_ptr<HloModule> module_0,std::unique_ptr<HloModule> module_1,const absl::Span<Literal * const> arguments,const std::optional<ErrorSpec> & error,bool run_hlo_passes)417 HloTestBase::RunAndCompareTwoModulesInternal(
418 std::unique_ptr<HloModule> module_0, std::unique_ptr<HloModule> module_1,
419 const absl::Span<Literal* const> arguments,
420 const std::optional<ErrorSpec>& error, bool run_hlo_passes) {
421 TF_RETURN_IF_ERROR(hlo_verifier_->Run(module_0.get()).status());
422 TF_RETURN_IF_ERROR(hlo_verifier_->Run(module_1.get()).status());
423
424 // Execute the two modules.
425 TF_ASSIGN_OR_RETURN(
426 auto test_0,
427 test_runner_.Execute(std::move(module_0), arguments, run_hlo_passes));
428 TF_ASSIGN_OR_RETURN(
429 auto test_1,
430 test_runner_.Execute(std::move(module_1), arguments, run_hlo_passes));
431
432 return LiteralTestUtil::NearOrEqual(/*expected=*/test_0, /*actual=*/test_1,
433 error);
434 }
435
RunAndCompareTwoModules(std::unique_ptr<HloModule> module_0,std::unique_ptr<HloModule> module_1,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error)436 ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules(
437 std::unique_ptr<HloModule> module_0, std::unique_ptr<HloModule> module_1,
438 const absl::Span<Literal* const> arguments,
439 const optional<ErrorSpec>& error) {
440 auto result = RunAndCompareTwoModulesInternal(
441 std::move(module_0), std::move(module_1), arguments, error,
442 /*run_hlo_passes=*/true);
443 if (!result.ok()) {
444 return ::testing::AssertionFailure() << result.status();
445 }
446 return result.ValueOrDie();
447 }
448
RunAndCompareTwoModules(std::unique_ptr<HloModule> module_0,std::unique_ptr<HloModule> module_1,const optional<ErrorSpec> & error)449 ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules(
450 std::unique_ptr<HloModule> module_0, std::unique_ptr<HloModule> module_1,
451 const optional<ErrorSpec>& error) {
452 const auto params_0 = module_0->entry_computation()->parameter_instructions();
453 const auto params_1 = module_1->entry_computation()->parameter_instructions();
454 for (int i = 0; i < params_0.size(); ++i) {
455 const HloModuleConfig& module_config_0 = module_0->config();
456 const Shape& param_shape_0 =
457 (module_config_0.has_entry_computation_layout() &&
458 module_config_0.entry_computation_layout()
459 .parameter_layout(i)
460 .shape()
461 .is_static())
462 ? module_config_0.entry_computation_layout()
463 .parameter_layout(i)
464 .shape()
465 : params_0[i]->shape();
466
467 const HloModuleConfig& module_config_1 = module_1->config();
468 const Shape& param_shape_1 =
469 (module_config_1.has_entry_computation_layout() &&
470 module_config_1.entry_computation_layout()
471 .parameter_layout(i)
472 .shape()
473 .is_static())
474 ? module_config_1.entry_computation_layout()
475 .parameter_layout(i)
476 .shape()
477 : params_1[i]->shape();
478
479 if (!ShapeUtil::Equal(param_shape_0, param_shape_1)) {
480 return ::testing::AssertionFailure()
481 << "Error : mismatching parameter shapes: "
482 << param_shape_0.ToString() << " Vs. " << param_shape_1.ToString();
483 }
484 }
485
486 auto fake_arguments = MakeFakeArguments(module_0.get()).value();
487
488 std::vector<Literal*> fake_argument_ptrs;
489 absl::c_transform(
490 fake_arguments, std::back_inserter(fake_argument_ptrs),
491 [](const Literal& literal) { return const_cast<Literal*>(&literal); });
492
493 return RunAndCompareTwoModules(std::move(module_0), std::move(module_1),
494 fake_argument_ptrs, error);
495 }
496
RunAndCompareTwoModules(string_view hlo_string_module_0,string_view hlo_string_module_1,const std::optional<ErrorSpec> & error)497 ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules(
498 string_view hlo_string_module_0, string_view hlo_string_module_1,
499 const std::optional<ErrorSpec>& error) {
500 auto module_0_or_status = ParseAndReturnVerifiedModule(hlo_string_module_0);
501 if (!module_0_or_status.ok()) {
502 return ::testing::AssertionFailure()
503 << "Error while parsing HLO text format: "
504 << module_0_or_status.status().ToString();
505 }
506
507 auto module_1_or_status = ParseAndReturnVerifiedModule(hlo_string_module_1);
508 if (!module_1_or_status.ok()) {
509 return ::testing::AssertionFailure()
510 << "Error while parsing HLO text format: "
511 << module_1_or_status.status().ToString();
512 }
513 return RunAndCompareTwoModules(std::move(module_0_or_status).value(),
514 std::move(module_1_or_status).value(), error);
515 }
516
Run(string_view hlo_string,bool run_hlo_passes,ExecutionProfile * profile,const tensorflow::protobuf::Message * backend_config)517 ::testing::AssertionResult HloTestBase::Run(
518 string_view hlo_string, bool run_hlo_passes, ExecutionProfile* profile,
519 const tensorflow::protobuf::Message* backend_config) {
520 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
521 if (!module_or_status.ok()) {
522 return ::testing::AssertionFailure()
523 << "Error while parsing HLO text format: "
524 << module_or_status.status().ToString();
525 }
526
527 std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
528 const auto fake_arguments = MakeFakeArguments(module.get()).value();
529 std::vector<Literal*> fake_argument_ptrs;
530 absl::c_transform(
531 fake_arguments, std::back_inserter(fake_argument_ptrs),
532 [](const Literal& literal) { return const_cast<Literal*>(&literal); });
533
534 if (profile != nullptr) {
535 // We have to enable HLO profiling since otherwise currently the
536 // ExecutionProfile is not correct.
537 //
538 // TODO(b/119432044): Fix collection of the ExecutionProfile
539 // so that this is not necessary.
540 HloModuleConfig config = module->config();
541 DebugOptions debug_options = config.debug_options();
542 debug_options.set_xla_hlo_profile(true);
543 config.set_debug_options(debug_options);
544 module->set_config(config);
545 }
546
547 if (backend_config) {
548 // Set backend configuration if it is given.
549 HloInstruction* instruction =
550 module->entry_computation()->root_instruction();
551 Status s = instruction->set_backend_config(*backend_config);
552 return s.ok() ? ::testing::AssertionSuccess()
553 : ::testing::AssertionFailure() << s.error_message();
554 }
555
556 auto output = test_runner_.Execute(std::move(module), fake_argument_ptrs,
557 /*run_hlo_passes=*/run_hlo_passes,
558 /*profile=*/profile);
559
560 return output.ok()
561 ? ::testing::AssertionSuccess()
562 : ::testing::AssertionFailure() << output.status().error_message();
563 }
564
RunReplicated(string_view hlo_string,bool run_hlo_passes,int64_t num_replicas,const tensorflow::protobuf::Message * backend_config)565 ::testing::AssertionResult HloTestBase::RunReplicated(
566 string_view hlo_string, bool run_hlo_passes, int64_t num_replicas,
567 const tensorflow::protobuf::Message* backend_config) {
568 auto module_or_status =
569 ParseAndReturnVerifiedModule(hlo_string, num_replicas);
570 if (!module_or_status.ok()) {
571 return ::testing::AssertionFailure()
572 << "Error while parsing HLO text format: "
573 << module_or_status.status().ToString();
574 }
575
576 std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
577 const auto fake_arguments = MakeFakeArguments(module.get()).value();
578 std::vector<Literal*> fake_argument_ptrs;
579 absl::c_transform(
580 fake_arguments, std::back_inserter(fake_argument_ptrs),
581 [](const Literal& literal) { return const_cast<Literal*>(&literal); });
582
583 if (backend_config) {
584 // Set backend configuration if it is given.
585 HloInstruction* instruction =
586 module->entry_computation()->root_instruction();
587 Status s = instruction->set_backend_config(*backend_config);
588 return s.ok() ? ::testing::AssertionSuccess()
589 : ::testing::AssertionFailure() << s.error_message();
590 }
591
592 HloRunner::ReplicatedExecuteOptions options;
593 options.num_replicas = num_replicas;
594 options.run_hlo_passes = run_hlo_passes;
595 options.use_threads = true;
596 for (auto argument : fake_argument_ptrs) {
597 options.arguments.push_back(argument);
598 }
599 auto output = test_runner_.ExecuteReplicated(std::move(module), options);
600
601 return output.ok()
602 ? ::testing::AssertionSuccess()
603 : ::testing::AssertionFailure() << output.status().error_message();
604 }
605
RunMultipleTimes(string_view hlo_string,bool run_hlo_passes,std::vector<ExecutionProfile> * profiles,const tensorflow::protobuf::Message * backend_config,bool assert_determinism)606 ::testing::AssertionResult HloTestBase::RunMultipleTimes(
607 string_view hlo_string, bool run_hlo_passes,
608 std::vector<ExecutionProfile>* profiles,
609 const tensorflow::protobuf::Message* backend_config,
610 bool assert_determinism) {
611 int n = profiles->size();
612 std::vector<std::vector<Literal*>> fake_argument_ptrs(n);
613 std::vector<std::vector<Literal>> fake_arguments(n);
614 std::vector<std::unique_ptr<Executable>> executables(n);
615
616 for (int i = 0; i < n; ++i) {
617 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
618 if (!module_or_status.ok()) {
619 return ::testing::AssertionFailure()
620 << "Error while parsing HLO text format: "
621 << module_or_status.status().ToString();
622 }
623 std::unique_ptr<HloModule> module =
624 std::move(module_or_status.ValueOrDie());
625
626 fake_arguments[i] = MakeFakeArguments(module.get()).value();
627
628 if (profiles != nullptr) {
629 // We have to enable HLO profiling since otherwise currently the
630 // ExecutionProfile is not correct.
631 //
632 // TODO(b/119432044): Fix collection of the ExecutionProfile
633 // so that this is not necessary.
634 HloModuleConfig config = module->config();
635 DebugOptions debug_options = config.debug_options();
636 debug_options.set_xla_hlo_profile(true);
637 config.set_debug_options(debug_options);
638 module->set_config(config);
639 }
640
641 if (backend_config) {
642 // Set backend configuration if it is given.
643 HloInstruction* instruction =
644 module->entry_computation()->root_instruction();
645 Status s = instruction->set_backend_config(*backend_config);
646 return s.ok() ? ::testing::AssertionSuccess()
647 : ::testing::AssertionFailure() << s.error_message();
648 }
649
650 auto executable =
651 test_runner_.CreateExecutable(std::move(module), run_hlo_passes);
652 if (!executable.ok()) {
653 return ::testing::AssertionFailure()
654 << executable.status().error_message();
655 }
656 executables[i] = std::move(executable.ValueOrDie());
657 }
658
659 std::optional<Literal> canonical_output;
660 for (int i = 0; i < n; ++i) {
661 StatusOr<Literal> output = test_runner_.ExecuteWithExecutable(
662 executables[i].get(), fake_arguments[i],
663 /*profile=*/&((*profiles)[i]));
664 if (!output.ok()) {
665 return ::testing::AssertionFailure() << output.status().error_message();
666 }
667
668 if (assert_determinism) {
669 if (!canonical_output.has_value()) {
670 canonical_output = std::move(output).value();
671 } else {
672 if (*canonical_output != output.ValueOrDie()) {
673 return ::testing::AssertionFailure()
674 << "Successive runs have returned different results: "
675 << *canonical_output << " vs. " << output.ValueOrDie();
676 }
677 }
678 }
679 }
680
681 return ::testing::AssertionSuccess();
682 }
683
RunAndCompareFromFile(const std::string & filename,const std::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)684 ::testing::AssertionResult HloTestBase::RunAndCompareFromFile(
685 const std::string& filename, const std::optional<ErrorSpec>& error,
686 const std::function<void(HloModule*)>& reference_preprocessor) {
687 auto module_or_status =
688 HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
689 if (!module_or_status.ok()) {
690 return ::testing::AssertionFailure()
691 << "failed reading hlo module from file";
692 }
693 return RunAndCompare(std::move(module_or_status).value(), error,
694 reference_preprocessor);
695 }
696
RunAndCompareNoHloPasses(string_view hlo_string,const std::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)697 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
698 string_view hlo_string, const std::optional<ErrorSpec>& error,
699 const std::function<void(HloModule*)>& reference_preprocessor) {
700 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
701 if (!module_or_status.ok()) {
702 return ::testing::AssertionFailure()
703 << "Error while parsing HLO text format: "
704 << module_or_status.status().ToString();
705 }
706 return RunAndCompareNoHloPasses(std::move(module_or_status).value(), error,
707 reference_preprocessor);
708 }
709
RunAndCompareNoHloPassesFromFile(const std::string & filename,const std::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)710 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile(
711 const std::string& filename, const std::optional<ErrorSpec>& error,
712 const std::function<void(HloModule*)>& reference_preprocessor) {
713 auto module_or_status =
714 HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
715 if (!module_or_status.ok()) {
716 return ::testing::AssertionFailure()
717 << "failed reading hlo module from file";
718 }
719 return RunAndCompareNoHloPasses(std::move(module_or_status).value(), error,
720 reference_preprocessor);
721 }
722
FindComputation(HloModule * module,absl::string_view name)723 HloComputation* HloTestBase::FindComputation(HloModule* module,
724 absl::string_view name) {
725 auto computations = module->computations();
726 auto it = absl::c_find_if(
727 computations, [&](HloComputation* c) { return c->name() == name; });
728 if (it == computations.end()) {
729 return nullptr;
730 }
731 return *it;
732 }
733
FindInstruction(HloModule * module,absl::string_view name)734 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
735 absl::string_view name) {
736 for (const HloComputation* c : module->computations()) {
737 auto instructions = c->instructions();
738 auto it = absl::c_find_if(
739 instructions, [&](HloInstruction* i) { return i->name() == name; });
740 if (it != instructions.end()) {
741 return *it;
742 }
743 }
744 return nullptr;
745 }
746
FindInstruction(HloModule * module,HloOpcode opcode)747 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
748 HloOpcode opcode) {
749 for (const HloComputation* c : module->computations()) {
750 auto instructions = c->instructions();
751 auto it = absl::c_find_if(
752 instructions, [&](HloInstruction* i) { return i->opcode() == opcode; });
753 if (it != instructions.end()) {
754 return *it;
755 }
756 }
757 return nullptr;
758 }
759
backend()760 Backend& HloTestBase::backend() { return test_runner_.backend(); }
761
762 /* static */
TestName()763 std::string HloTestBase::TestName() {
764 return ::testing::UnitTest::GetInstance()->current_test_info()->name();
765 }
766
767 } // namespace xla
768