xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/hlo_test_base.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_module_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_parser.h"
31 #include "tensorflow/compiler/xla/service/platform_util.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/tests/filecheck.h"
35 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
36 #include "tensorflow/compiler/xla/tests/test_utils.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/core/lib/core/status_test_util.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/test.h"
41 
42 namespace xla {
43 
44 namespace {
45 
46 using absl::string_view;
47 using std::optional;
48 
49 constexpr char kInterpreter[] = "interpreter";
50 
ProgramShapesEqual(const ProgramShape & lhs,const ProgramShape & rhs)51 bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) {
52   if (lhs.parameters_size() != rhs.parameters_size()) {
53     return false;
54   }
55   for (int i = 0; i < lhs.parameters_size(); i++) {
56     if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) {
57       return false;
58     }
59   }
60   return ShapeUtil::Equal(lhs.result(), rhs.result());
61 }
62 
GetProgramShapeWithLayout(const HloModule & module)63 ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
64   ProgramShape program_shape;
65   const auto* entry = module.entry_computation();
66   for (const auto* param : entry->parameter_instructions()) {
67     *program_shape.add_parameters() = param->shape();
68     *program_shape.add_parameter_names() = param->name();
69   }
70   *program_shape.mutable_result() = entry->root_instruction()->shape();
71   return program_shape;
72 }
73 
74 }  // namespace
75 
HloTestBase(bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,HloPredicate instruction_can_change_layout_func)76 HloTestBase::HloTestBase(bool verifier_layout_sensitive,
77                          bool allow_mixed_precision_in_hlo_verifier,
78                          HloPredicate instruction_can_change_layout_func)
79     : HloTestBase(GetTestPlatform(), GetReferencePlatform(),
80                   verifier_layout_sensitive,
81                   allow_mixed_precision_in_hlo_verifier,
82                   instruction_can_change_layout_func) {}
83 
HloTestBase(se::Platform * test_platform,se::Platform * reference_platform,bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,HloPredicate instruction_can_change_layout_func)84 HloTestBase::HloTestBase(se::Platform* test_platform,
85                          se::Platform* reference_platform,
86                          bool verifier_layout_sensitive,
87                          bool allow_mixed_precision_in_hlo_verifier,
88                          HloPredicate instruction_can_change_layout_func)
89     : test_runner_(test_platform),
90       reference_runner_(reference_platform),
91       verifier_layout_sensitive_(verifier_layout_sensitive),
92       allow_mixed_precision_in_hlo_verifier_(
93           allow_mixed_precision_in_hlo_verifier) {
94   hlo_verifier_ = std::make_unique<HloVerifier>(
95       /*layout_sensitive=*/verifier_layout_sensitive,
96       /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
97       instruction_can_change_layout_func);
98 }
99 
GetReferencePlatform()100 /*static*/ se::Platform* HloTestBase::GetReferencePlatform() {
101   auto result = PlatformUtil::GetPlatform(kInterpreter);
102   TF_CHECK_OK(result.status()) << "could not get interpreter platform";
103   return result.ValueOrDie();
104 }
105 
GetTestPlatform()106 /*static*/ se::Platform* HloTestBase::GetTestPlatform() {
107   auto result = PlatformUtil::GetDefaultPlatform();
108   TF_CHECK_OK(result.status()) << "could not get test platform";
109   return result.ValueOrDie();
110 }
111 
CreateNewUnverifiedModule(const std::string & name)112 std::unique_ptr<HloModule> HloTestBase::CreateNewUnverifiedModule(
113     const std::string& name) {
114   return std::make_unique<HloModule>(name, GetModuleConfigForTest());
115 }
116 
CreateNewVerifiedModule(const std::string & name,int64_t replica_count)117 std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
118     const std::string& name, int64_t replica_count) {
119   return std::make_unique<VerifiedHloModule>(
120       name, GetModuleConfigForTest(replica_count), verifier_layout_sensitive_,
121       allow_mixed_precision_in_hlo_verifier_,
122       backend().compiler()->ShapeSizeBytesFunction());
123 }
124 
125 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,int64_t replica_count,int64_t num_partitions)126 HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
127                                           int64_t replica_count,
128                                           int64_t num_partitions) {
129   TF_ASSIGN_OR_RETURN(
130       auto module,
131       ParseAndReturnVerifiedModule(
132           hlo_text, GetModuleConfigForTest(replica_count, num_partitions)));
133   UpdateEntryComputationLayout(module.get());
134   return module;
135 }
136 
137 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,const HloModuleConfig & config)138 HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
139                                           const HloModuleConfig& config) {
140   auto module = std::make_unique<VerifiedHloModule>(
141       TestName(), config, verifier_layout_sensitive_,
142       allow_mixed_precision_in_hlo_verifier_,
143       backend().compiler()->ShapeSizeBytesFunction());
144   TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
145   UpdateEntryComputationLayout(module.get());
146   return std::move(module);
147 }
148 
AddEntryComputationAndUpdateEntryComputationLayout(HloModule * module,std::unique_ptr<HloComputation> computation)149 HloComputation* HloTestBase::AddEntryComputationAndUpdateEntryComputationLayout(
150     HloModule* module, std::unique_ptr<HloComputation> computation) {
151   auto comp = module->AddEntryComputation(std::move(computation));
152   UpdateEntryComputationLayout(module);
153   return comp;
154 }
155 
UpdateEntryComputationLayout(HloModule * module)156 void HloTestBase::UpdateEntryComputationLayout(HloModule* module) {
157   xla::UpdateEntryComputationLayout(
158       module, test_runner_.device_shape_representation_fn());
159 }
160 
161 /* static */
RunHloPass(HloPassInterface * hlo_pass,HloModule * module)162 StatusOr<bool> HloTestBase::RunHloPass(HloPassInterface* hlo_pass,
163                                        HloModule* module) {
164   const std::string module_str_before_run =
165       module->ToProto().ShortDebugString();
166   const auto status_or = hlo_pass->Run(module);
167   if (status_or.status().ok()) {
168     const std::string module_str_after_run =
169         module->ToProto().ShortDebugString();
170     const bool passChangedHlo = status_or.ValueOrDie();
171     if (passChangedHlo) {
172       // Check that the proto actually changed.
173       EXPECT_NE(module_str_after_run, module_str_before_run);
174     } else {
175       // Check that the proto remains same.
176       EXPECT_EQ(module_str_after_run, module_str_before_run);
177     }
178   }
179   return status_or;
180 }
181 
182 /* static */
DefaultPrecisionConfig(int operands)183 PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) {
184   PrecisionConfig precision_config;
185   precision_config.mutable_operand_precision()->Resize(
186       operands, PrecisionConfig::DEFAULT);
187   return precision_config;
188 }
189 
SetAotFastMathDebugOptions(DebugOptions * options)190 void HloTestBase::SetAotFastMathDebugOptions(DebugOptions* options) {
191   options->set_xla_cpu_enable_fast_math(true);
192   options->set_xla_gpu_enable_fast_min_max(true);
193   options->set_xla_cpu_enable_fast_min_max(true);
194   options->set_xla_cpu_fast_math_honor_nans(false);
195   options->set_xla_cpu_fast_math_honor_infs(false);
196   options->set_xla_cpu_fast_math_honor_functions(false);
197   options->set_xla_cpu_fast_math_honor_division(false);
198 }
199 
GetDebugOptionsForTest()200 DebugOptions HloTestBase::GetDebugOptionsForTest() {
201   auto debug_options = GetDebugOptionsFromFlags();
202   // TODO(b/38354253): Change tests to use Parameters instead of Constants.
203   debug_options.add_xla_disable_hlo_passes("constant_folding");
204   debug_options.set_xla_gpu_max_kernel_unroll_factor(1);
205   debug_options.set_xla_hlo_evaluator_use_fast_path(true);
206   return debug_options;
207 }
208 
RunAndFilecheckHloRewrite(absl::string_view hlo,HloPassInterface && hlo_pass,std::optional<absl::string_view> expected,std::function<void (HloModule *)> after_pass_checks)209 void HloTestBase::RunAndFilecheckHloRewrite(
210     absl::string_view hlo, HloPassInterface&& hlo_pass,
211     std::optional<absl::string_view> expected,
212     std::function<void(HloModule*)> after_pass_checks) {
213   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
214                           ParseAndReturnVerifiedModule(hlo));
215   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&hlo_pass, module.get()));
216   EXPECT_EQ(changed, expected.has_value());
217   if (changed) {
218     TF_ASSERT_OK_AND_ASSIGN(
219         bool filecheck_matches,
220         RunFileCheck(
221             module->ToString(HloPrintOptions{}.set_print_operand_shape(false)),
222             *expected));
223     EXPECT_TRUE(filecheck_matches);
224     if (after_pass_checks) {
225       after_pass_checks(module.get());
226     }
227   }
228 }
229 
Execute(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)230 StatusOr<Literal> HloTestBase::Execute(std::unique_ptr<HloModule> module,
231                                        absl::Span<Literal* const> arguments) {
232   return test_runner_.Execute(std::move(module), arguments);
233 }
234 
ExecuteNoHloPasses(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)235 Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
236                                         absl::Span<Literal* const> arguments) {
237   return test_runner_
238       .Execute(std::move(module), arguments,
239                /*run_hlo_passes=*/false)
240       .ValueOrDie();
241 }
242 
ExecuteAndTransfer(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments)243 Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
244                                         absl::Span<Literal* const> arguments) {
245   return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
246 }
247 
ExecuteReplicated(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,int64_t num_replicas,bool use_threads,bool run_hlo_passes)248 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
249     std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
250     int64_t num_replicas, bool use_threads, bool run_hlo_passes) {
251   HloRunner::ReplicatedExecuteOptions options;
252   options.num_replicas = num_replicas;
253   options.run_hlo_passes = run_hlo_passes;
254   options.use_threads = use_threads;
255   for (auto argument : arguments) {
256     options.arguments.push_back(argument);
257   }
258   return test_runner_.ExecuteReplicated(std::move(module), options);
259 }
260 
ExecuteReplicated(std::unique_ptr<HloModule> module,absl::Span<Literal * const> arguments,int64_t num_replicas,DeviceAssignment * device_assignment,bool run_hlo_passes,bool use_threads)261 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
262     std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
263     int64_t num_replicas, DeviceAssignment* device_assignment,
264     bool run_hlo_passes, bool use_threads) {
265   HloRunner::ReplicatedExecuteOptions options;
266   options.num_replicas = num_replicas;
267   options.run_hlo_passes = run_hlo_passes;
268   options.use_threads = use_threads;
269   for (auto argument : arguments) {
270     options.arguments.push_back(argument);
271   }
272   return test_runner_.ExecuteReplicated(std::move(module), options,
273                                         device_assignment);
274 }
275 
ExecuteReplicated(std::function<Executable * (int64_t)> executable_provider,std::function<int64_t (int64_t)> argument_count_provider,std::function<const Literal * (int64_t,int64_t)> argument_provider,int64_t num_replicas,bool run_hlo_passes,DeviceAssignment * device_assignment)276 StatusOr<std::vector<Literal>> HloTestBase::ExecuteReplicated(
277     std::function<Executable*(int64_t)> executable_provider,
278     std::function<int64_t(int64_t)> argument_count_provider,
279     std::function<const Literal*(int64_t, int64_t)> argument_provider,
280     int64_t num_replicas, bool run_hlo_passes,
281     DeviceAssignment* device_assignment) {
282   HloRunner::ReplicatedExecuteOptions options;
283   options.num_replicas = num_replicas;
284   options.run_hlo_passes = run_hlo_passes;
285   options.use_threads = true;
286   return test_runner_.ExecuteReplicated(
287       executable_provider, argument_count_provider, argument_provider, options,
288       device_assignment);
289 }
290 
MakeReferenceModule(const HloModule & test_module,const std::function<void (HloModule *)> & reference_preprocessor)291 StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
292     const HloModule& test_module,
293     const std::function<void(HloModule*)>& reference_preprocessor) {
294   std::unique_ptr<HloModule> reference_module = test_module.Clone();
295   const auto& program_shape = GetProgramShapeWithLayout(test_module);
296 
297   if (reference_preprocessor != nullptr) {
298     reference_preprocessor(reference_module.get());
299     if (!ProgramShapesEqual(program_shape,
300                             GetProgramShapeWithLayout(*reference_module))) {
301       return InvalidArgument(
302           "reference preprocessor must not modify the program shape");
303     }
304   }
305   TF_RETURN_IF_ERROR(hlo_verifier_->Run(reference_module.get()).status());
306   return std::move(reference_module);
307 }
308 
RunAndCompareInternal(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,bool run_hlo_passes,const std::function<void (HloModule *)> & reference_preprocessor)309 StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
310     std::unique_ptr<HloModule> module,
311     const absl::Span<Literal* const> arguments,
312     const optional<ErrorSpec>& error, bool run_hlo_passes,
313     const std::function<void(HloModule*)>& reference_preprocessor) {
314   TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status());
315   TF_ASSIGN_OR_RETURN(auto reference_module,
316                       MakeReferenceModule(*module, reference_preprocessor));
317 
318   // Execute on two backends.
319   TF_ASSIGN_OR_RETURN(
320       auto test,
321       test_runner_.Execute(std::move(module), arguments, run_hlo_passes));
322   TF_ASSIGN_OR_RETURN(auto reference,
323                       reference_runner_.Execute(std::move(reference_module),
324                                                 arguments, run_hlo_passes));
325   if (reference.IsAll(0)) {
326     LOG(WARNING) << "Reference value is only zeros.";
327   }
328 
329   return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test,
330                                       error);
331 }
332 
RunAndCompare(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)333 ::testing::AssertionResult HloTestBase::RunAndCompare(
334     std::unique_ptr<HloModule> module,
335     const absl::Span<Literal* const> arguments,
336     const optional<ErrorSpec>& error,
337     const std::function<void(HloModule*)>& reference_preprocessor) {
338   auto result =
339       RunAndCompareInternal(std::move(module), arguments, error,
340                             /*run_hlo_passes=*/true, reference_preprocessor);
341   if (!result.ok()) {
342     return ::testing::AssertionFailure() << result.status();
343   }
344   return result.ValueOrDie();
345 }
346 
RunAndCompareNoHloPasses(std::unique_ptr<HloModule> module,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)347 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
348     std::unique_ptr<HloModule> module,
349     const absl::Span<Literal* const> arguments,
350     const optional<ErrorSpec>& error,
351     const std::function<void(HloModule*)>& reference_preprocessor) {
352   auto result =
353       RunAndCompareInternal(std::move(module), arguments, error,
354                             /*run_hlo_passes=*/false, reference_preprocessor);
355   if (!result.ok()) {
356     return ::testing::AssertionFailure() << result.status();
357   }
358   return result.ValueOrDie();
359 }
360 
RunAndCompare(std::unique_ptr<HloModule> module,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)361 ::testing::AssertionResult HloTestBase::RunAndCompare(
362     std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
363     const std::function<void(HloModule*)>& reference_preprocessor) {
364   auto fake_arguments = MakeFakeArguments(module.get()).value();
365 
366   std::vector<Literal*> fake_argument_ptrs;
367   absl::c_transform(
368       fake_arguments, std::back_inserter(fake_argument_ptrs),
369       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
370 
371   return RunAndCompare(std::move(module), fake_argument_ptrs, error,
372                        reference_preprocessor);
373 }
374 
RunAndCompareNoHloPasses(std::unique_ptr<HloModule> module,const optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)375 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
376     std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
377     const std::function<void(HloModule*)>& reference_preprocessor) {
378   const auto fake_arguments = MakeFakeArguments(module.get()).value();
379   std::vector<Literal*> fake_argument_ptrs;
380   absl::c_transform(
381       fake_arguments, std::back_inserter(fake_argument_ptrs),
382       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
383 
384   return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error,
385                                   reference_preprocessor);
386 }
387 
Run(std::unique_ptr<HloModule> module,bool run_hlo_passes)388 ::testing::AssertionResult HloTestBase::Run(std::unique_ptr<HloModule> module,
389                                             bool run_hlo_passes) {
390   const auto fake_arguments = MakeFakeArguments(module.get()).value();
391   const auto change = hlo_verifier_->Run(module.get());
392   if (!change.ok()) {
393     return ::testing::AssertionFailure() << change.status();
394   }
395 
396   const auto output =
397       test_runner_.Execute(std::move(module), fake_arguments, run_hlo_passes);
398   return output.ok()
399              ? ::testing::AssertionSuccess()
400              : ::testing::AssertionFailure() << output.status().error_message();
401 }
402 
RunAndCompare(string_view hlo_string,const std::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)403 ::testing::AssertionResult HloTestBase::RunAndCompare(
404     string_view hlo_string, const std::optional<ErrorSpec>& error,
405     const std::function<void(HloModule*)>& reference_preprocessor) {
406   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
407   if (!module_or_status.ok()) {
408     return ::testing::AssertionFailure()
409            << "Error while parsing HLO text format: "
410            << module_or_status.status().ToString();
411   }
412   return RunAndCompare(std::move(module_or_status).value(), error,
413                        reference_preprocessor);
414 }
415 
416 StatusOr<::testing::AssertionResult>
RunAndCompareTwoModulesInternal(std::unique_ptr<HloModule> module_0,std::unique_ptr<HloModule> module_1,const absl::Span<Literal * const> arguments,const std::optional<ErrorSpec> & error,bool run_hlo_passes)417 HloTestBase::RunAndCompareTwoModulesInternal(
418     std::unique_ptr<HloModule> module_0, std::unique_ptr<HloModule> module_1,
419     const absl::Span<Literal* const> arguments,
420     const std::optional<ErrorSpec>& error, bool run_hlo_passes) {
421   TF_RETURN_IF_ERROR(hlo_verifier_->Run(module_0.get()).status());
422   TF_RETURN_IF_ERROR(hlo_verifier_->Run(module_1.get()).status());
423 
424   // Execute the two modules.
425   TF_ASSIGN_OR_RETURN(
426       auto test_0,
427       test_runner_.Execute(std::move(module_0), arguments, run_hlo_passes));
428   TF_ASSIGN_OR_RETURN(
429       auto test_1,
430       test_runner_.Execute(std::move(module_1), arguments, run_hlo_passes));
431 
432   return LiteralTestUtil::NearOrEqual(/*expected=*/test_0, /*actual=*/test_1,
433                                       error);
434 }
435 
RunAndCompareTwoModules(std::unique_ptr<HloModule> module_0,std::unique_ptr<HloModule> module_1,const absl::Span<Literal * const> arguments,const optional<ErrorSpec> & error)436 ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules(
437     std::unique_ptr<HloModule> module_0, std::unique_ptr<HloModule> module_1,
438     const absl::Span<Literal* const> arguments,
439     const optional<ErrorSpec>& error) {
440   auto result = RunAndCompareTwoModulesInternal(
441       std::move(module_0), std::move(module_1), arguments, error,
442       /*run_hlo_passes=*/true);
443   if (!result.ok()) {
444     return ::testing::AssertionFailure() << result.status();
445   }
446   return result.ValueOrDie();
447 }
448 
RunAndCompareTwoModules(std::unique_ptr<HloModule> module_0,std::unique_ptr<HloModule> module_1,const optional<ErrorSpec> & error)449 ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules(
450     std::unique_ptr<HloModule> module_0, std::unique_ptr<HloModule> module_1,
451     const optional<ErrorSpec>& error) {
452   const auto params_0 = module_0->entry_computation()->parameter_instructions();
453   const auto params_1 = module_1->entry_computation()->parameter_instructions();
454   for (int i = 0; i < params_0.size(); ++i) {
455     const HloModuleConfig& module_config_0 = module_0->config();
456     const Shape& param_shape_0 =
457         (module_config_0.has_entry_computation_layout() &&
458          module_config_0.entry_computation_layout()
459              .parameter_layout(i)
460              .shape()
461              .is_static())
462             ? module_config_0.entry_computation_layout()
463                   .parameter_layout(i)
464                   .shape()
465             : params_0[i]->shape();
466 
467     const HloModuleConfig& module_config_1 = module_1->config();
468     const Shape& param_shape_1 =
469         (module_config_1.has_entry_computation_layout() &&
470          module_config_1.entry_computation_layout()
471              .parameter_layout(i)
472              .shape()
473              .is_static())
474             ? module_config_1.entry_computation_layout()
475                   .parameter_layout(i)
476                   .shape()
477             : params_1[i]->shape();
478 
479     if (!ShapeUtil::Equal(param_shape_0, param_shape_1)) {
480       return ::testing::AssertionFailure()
481              << "Error : mismatching parameter shapes: "
482              << param_shape_0.ToString() << " Vs. " << param_shape_1.ToString();
483     }
484   }
485 
486   auto fake_arguments = MakeFakeArguments(module_0.get()).value();
487 
488   std::vector<Literal*> fake_argument_ptrs;
489   absl::c_transform(
490       fake_arguments, std::back_inserter(fake_argument_ptrs),
491       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
492 
493   return RunAndCompareTwoModules(std::move(module_0), std::move(module_1),
494                                  fake_argument_ptrs, error);
495 }
496 
RunAndCompareTwoModules(string_view hlo_string_module_0,string_view hlo_string_module_1,const std::optional<ErrorSpec> & error)497 ::testing::AssertionResult HloTestBase::RunAndCompareTwoModules(
498     string_view hlo_string_module_0, string_view hlo_string_module_1,
499     const std::optional<ErrorSpec>& error) {
500   auto module_0_or_status = ParseAndReturnVerifiedModule(hlo_string_module_0);
501   if (!module_0_or_status.ok()) {
502     return ::testing::AssertionFailure()
503            << "Error while parsing HLO text format: "
504            << module_0_or_status.status().ToString();
505   }
506 
507   auto module_1_or_status = ParseAndReturnVerifiedModule(hlo_string_module_1);
508   if (!module_1_or_status.ok()) {
509     return ::testing::AssertionFailure()
510            << "Error while parsing HLO text format: "
511            << module_1_or_status.status().ToString();
512   }
513   return RunAndCompareTwoModules(std::move(module_0_or_status).value(),
514                                  std::move(module_1_or_status).value(), error);
515 }
516 
Run(string_view hlo_string,bool run_hlo_passes,ExecutionProfile * profile,const tensorflow::protobuf::Message * backend_config)517 ::testing::AssertionResult HloTestBase::Run(
518     string_view hlo_string, bool run_hlo_passes, ExecutionProfile* profile,
519     const tensorflow::protobuf::Message* backend_config) {
520   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
521   if (!module_or_status.ok()) {
522     return ::testing::AssertionFailure()
523            << "Error while parsing HLO text format: "
524            << module_or_status.status().ToString();
525   }
526 
527   std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
528   const auto fake_arguments = MakeFakeArguments(module.get()).value();
529   std::vector<Literal*> fake_argument_ptrs;
530   absl::c_transform(
531       fake_arguments, std::back_inserter(fake_argument_ptrs),
532       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
533 
534   if (profile != nullptr) {
535     // We have to enable HLO profiling since otherwise currently the
536     // ExecutionProfile is not correct.
537     //
538     // TODO(b/119432044): Fix collection of the ExecutionProfile
539     // so that this is not necessary.
540     HloModuleConfig config = module->config();
541     DebugOptions debug_options = config.debug_options();
542     debug_options.set_xla_hlo_profile(true);
543     config.set_debug_options(debug_options);
544     module->set_config(config);
545   }
546 
547   if (backend_config) {
548     // Set backend configuration if it is given.
549     HloInstruction* instruction =
550         module->entry_computation()->root_instruction();
551     Status s = instruction->set_backend_config(*backend_config);
552     return s.ok() ? ::testing::AssertionSuccess()
553                   : ::testing::AssertionFailure() << s.error_message();
554   }
555 
556   auto output = test_runner_.Execute(std::move(module), fake_argument_ptrs,
557                                      /*run_hlo_passes=*/run_hlo_passes,
558                                      /*profile=*/profile);
559 
560   return output.ok()
561              ? ::testing::AssertionSuccess()
562              : ::testing::AssertionFailure() << output.status().error_message();
563 }
564 
RunReplicated(string_view hlo_string,bool run_hlo_passes,int64_t num_replicas,const tensorflow::protobuf::Message * backend_config)565 ::testing::AssertionResult HloTestBase::RunReplicated(
566     string_view hlo_string, bool run_hlo_passes, int64_t num_replicas,
567     const tensorflow::protobuf::Message* backend_config) {
568   auto module_or_status =
569       ParseAndReturnVerifiedModule(hlo_string, num_replicas);
570   if (!module_or_status.ok()) {
571     return ::testing::AssertionFailure()
572            << "Error while parsing HLO text format: "
573            << module_or_status.status().ToString();
574   }
575 
576   std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
577   const auto fake_arguments = MakeFakeArguments(module.get()).value();
578   std::vector<Literal*> fake_argument_ptrs;
579   absl::c_transform(
580       fake_arguments, std::back_inserter(fake_argument_ptrs),
581       [](const Literal& literal) { return const_cast<Literal*>(&literal); });
582 
583   if (backend_config) {
584     // Set backend configuration if it is given.
585     HloInstruction* instruction =
586         module->entry_computation()->root_instruction();
587     Status s = instruction->set_backend_config(*backend_config);
588     return s.ok() ? ::testing::AssertionSuccess()
589                   : ::testing::AssertionFailure() << s.error_message();
590   }
591 
592   HloRunner::ReplicatedExecuteOptions options;
593   options.num_replicas = num_replicas;
594   options.run_hlo_passes = run_hlo_passes;
595   options.use_threads = true;
596   for (auto argument : fake_argument_ptrs) {
597     options.arguments.push_back(argument);
598   }
599   auto output = test_runner_.ExecuteReplicated(std::move(module), options);
600 
601   return output.ok()
602              ? ::testing::AssertionSuccess()
603              : ::testing::AssertionFailure() << output.status().error_message();
604 }
605 
RunMultipleTimes(string_view hlo_string,bool run_hlo_passes,std::vector<ExecutionProfile> * profiles,const tensorflow::protobuf::Message * backend_config,bool assert_determinism)606 ::testing::AssertionResult HloTestBase::RunMultipleTimes(
607     string_view hlo_string, bool run_hlo_passes,
608     std::vector<ExecutionProfile>* profiles,
609     const tensorflow::protobuf::Message* backend_config,
610     bool assert_determinism) {
611   int n = profiles->size();
612   std::vector<std::vector<Literal*>> fake_argument_ptrs(n);
613   std::vector<std::vector<Literal>> fake_arguments(n);
614   std::vector<std::unique_ptr<Executable>> executables(n);
615 
616   for (int i = 0; i < n; ++i) {
617     auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
618     if (!module_or_status.ok()) {
619       return ::testing::AssertionFailure()
620              << "Error while parsing HLO text format: "
621              << module_or_status.status().ToString();
622     }
623     std::unique_ptr<HloModule> module =
624         std::move(module_or_status.ValueOrDie());
625 
626     fake_arguments[i] = MakeFakeArguments(module.get()).value();
627 
628     if (profiles != nullptr) {
629       // We have to enable HLO profiling since otherwise currently the
630       // ExecutionProfile is not correct.
631       //
632       // TODO(b/119432044): Fix collection of the ExecutionProfile
633       // so that this is not necessary.
634       HloModuleConfig config = module->config();
635       DebugOptions debug_options = config.debug_options();
636       debug_options.set_xla_hlo_profile(true);
637       config.set_debug_options(debug_options);
638       module->set_config(config);
639     }
640 
641     if (backend_config) {
642       // Set backend configuration if it is given.
643       HloInstruction* instruction =
644           module->entry_computation()->root_instruction();
645       Status s = instruction->set_backend_config(*backend_config);
646       return s.ok() ? ::testing::AssertionSuccess()
647                     : ::testing::AssertionFailure() << s.error_message();
648     }
649 
650     auto executable =
651         test_runner_.CreateExecutable(std::move(module), run_hlo_passes);
652     if (!executable.ok()) {
653       return ::testing::AssertionFailure()
654              << executable.status().error_message();
655     }
656     executables[i] = std::move(executable.ValueOrDie());
657   }
658 
659   std::optional<Literal> canonical_output;
660   for (int i = 0; i < n; ++i) {
661     StatusOr<Literal> output = test_runner_.ExecuteWithExecutable(
662         executables[i].get(), fake_arguments[i],
663         /*profile=*/&((*profiles)[i]));
664     if (!output.ok()) {
665       return ::testing::AssertionFailure() << output.status().error_message();
666     }
667 
668     if (assert_determinism) {
669       if (!canonical_output.has_value()) {
670         canonical_output = std::move(output).value();
671       } else {
672         if (*canonical_output != output.ValueOrDie()) {
673           return ::testing::AssertionFailure()
674                  << "Successive runs have returned different results: "
675                  << *canonical_output << " vs. " << output.ValueOrDie();
676         }
677       }
678     }
679   }
680 
681   return ::testing::AssertionSuccess();
682 }
683 
RunAndCompareFromFile(const std::string & filename,const std::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)684 ::testing::AssertionResult HloTestBase::RunAndCompareFromFile(
685     const std::string& filename, const std::optional<ErrorSpec>& error,
686     const std::function<void(HloModule*)>& reference_preprocessor) {
687   auto module_or_status =
688       HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
689   if (!module_or_status.ok()) {
690     return ::testing::AssertionFailure()
691            << "failed reading hlo module from file";
692   }
693   return RunAndCompare(std::move(module_or_status).value(), error,
694                        reference_preprocessor);
695 }
696 
RunAndCompareNoHloPasses(string_view hlo_string,const std::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)697 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
698     string_view hlo_string, const std::optional<ErrorSpec>& error,
699     const std::function<void(HloModule*)>& reference_preprocessor) {
700   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
701   if (!module_or_status.ok()) {
702     return ::testing::AssertionFailure()
703            << "Error while parsing HLO text format: "
704            << module_or_status.status().ToString();
705   }
706   return RunAndCompareNoHloPasses(std::move(module_or_status).value(), error,
707                                   reference_preprocessor);
708 }
709 
RunAndCompareNoHloPassesFromFile(const std::string & filename,const std::optional<ErrorSpec> & error,const std::function<void (HloModule *)> & reference_preprocessor)710 ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile(
711     const std::string& filename, const std::optional<ErrorSpec>& error,
712     const std::function<void(HloModule*)>& reference_preprocessor) {
713   auto module_or_status =
714       HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
715   if (!module_or_status.ok()) {
716     return ::testing::AssertionFailure()
717            << "failed reading hlo module from file";
718   }
719   return RunAndCompareNoHloPasses(std::move(module_or_status).value(), error,
720                                   reference_preprocessor);
721 }
722 
FindComputation(HloModule * module,absl::string_view name)723 HloComputation* HloTestBase::FindComputation(HloModule* module,
724                                              absl::string_view name) {
725   auto computations = module->computations();
726   auto it = absl::c_find_if(
727       computations, [&](HloComputation* c) { return c->name() == name; });
728   if (it == computations.end()) {
729     return nullptr;
730   }
731   return *it;
732 }
733 
FindInstruction(HloModule * module,absl::string_view name)734 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
735                                              absl::string_view name) {
736   for (const HloComputation* c : module->computations()) {
737     auto instructions = c->instructions();
738     auto it = absl::c_find_if(
739         instructions, [&](HloInstruction* i) { return i->name() == name; });
740     if (it != instructions.end()) {
741       return *it;
742     }
743   }
744   return nullptr;
745 }
746 
FindInstruction(HloModule * module,HloOpcode opcode)747 HloInstruction* HloTestBase::FindInstruction(HloModule* module,
748                                              HloOpcode opcode) {
749   for (const HloComputation* c : module->computations()) {
750     auto instructions = c->instructions();
751     auto it = absl::c_find_if(
752         instructions, [&](HloInstruction* i) { return i->opcode() == opcode; });
753     if (it != instructions.end()) {
754       return *it;
755     }
756   }
757   return nullptr;
758 }
759 
backend()760 Backend& HloTestBase::backend() { return test_runner_.backend(); }
761 
762 /* static */
TestName()763 std::string HloTestBase::TestName() {
764   return ::testing::UnitTest::GetInstance()->current_test_info()->name();
765 }
766 
767 }  // namespace xla
768