xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/hlo_test_base.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/backend.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_runner.h"
30 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
31 #include "tensorflow/compiler/xla/service/platform_util.h"
32 #include "tensorflow/compiler/xla/shape_layout.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
35 #include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
36 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
40 #include "tensorflow/core/platform/test.h"
41 
42 namespace xla {
43 
44 // A base class for tests which build and/or run HLO code. The class includes
45 // support for running an HLO module on two platforms and compare the results.
46 // This is a lower level of abstraction than using the client interface and
47 // enables, for one, explicitly building a graph of HLO instructions to run.
48 //
49 // This can also be used to write text/file-based test cases. Note that the test
50 // target is responsible for linking the needed backends. A convenient way to do
51 // this is to make it an xla_test: it will generate test targets linking with
52 // the respective backends, which will be used as the test backend; the
53 // interpreter backend is already linked with hlo_test_base so it will be the
54 // default reference backend. For example, if you want to compare both cpu vs.
55 // interpreter, and gpu vs. interpreter, you can:
56 //
57 //  xla_test (
58 //    name = "sample_text_test",
59 //    srcs = ["sample_text_test.cc"],
60 //    backends = [
61 //      "cpu",
62 //      "gpu",
63 //    ],
64 //    deps = [
65 //      "//third_party/tensorflow/compiler/xla/tests:hlo_test_base",
66 //      ...
67 //    ],
68 //  )
69 //
70 // For a more detailed example, see "../tests/sample_text_test.cc".
71 class HloTestBase : public ManifestCheckingTest {
72  public:
73   // Creates a new HLO module for a test. The module created will have
74   // TestName() for its name; it will also automatically populate its debug
75   // options from command-line flags. If you want a fresh HloModule object and
76   // then add HloComputations to it, it's recommended to use this method in your
77   // tests.
78   //
79   // This returns a vanilla HloModule that doesn't run the HLO verifier on
80   // destruction.
81   ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.")
82   std::unique_ptr<HloModule> CreateNewUnverifiedModule(
83       const std::string& name = TestName());
84 
85   // Like CreateNewUnverifiedModule, except the HloModule returned here runs the
86   // HLO verifier on destruction.
87   std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
88       const std::string& name = TestName(), int64_t replica_count = 1);
89 
90   // Parses the given string and returns module as a VerifiedHloModule.
91   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
92       absl::string_view hlo_text, int64_t replica_count = 1,
93       int64_t num_partitions = 1);
94   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
95       absl::string_view hlo_text, const HloModuleConfig& config);
96 
97   // Runs the hlo_pass with the provided module and returns the result. This
98   // function also verifies that the module remains unchanged when hlo_pass
99   // returns false as the StatusOr value.
100   //
101   // These three overloads all do the same thing.  The && overload lets you do
102   // `RunHloPass(MyPass(), module)` all in one line.  The reason for the
103   // overload that takes a pointer is that, at one point in the past, non-const
104   // lvalue references were banned in Google code.
105   static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass,
106                                    HloModule* module);
RunHloPass(HloPassInterface & hlo_pass,HloModule * module)107   static StatusOr<bool> RunHloPass(HloPassInterface& hlo_pass,
108                                    HloModule* module) {
109     return RunHloPass(&hlo_pass, module);
110   }
RunHloPass(HloPassInterface && hlo_pass,HloModule * module)111   static StatusOr<bool> RunHloPass(HloPassInterface&& hlo_pass,
112                                    HloModule* module) {
113     return RunHloPass(&hlo_pass, module);
114   }
115 
116   static PrecisionConfig DefaultPrecisionConfig(int operands);
117 
118   // Sets most fath math options to be enabled to model the fast math flags
119   // generally used for CPU:AOT compilation.
120   static void SetAotFastMathDebugOptions(DebugOptions* options);
121 
122  protected:
123   // This uses the interpreter backend as the reference backend and
124   // automatically finds another supported backend as the test backend. If the
125   // interpreter is the only supported backend, it will be both the test backend
126   // and the reference backend.
127   explicit HloTestBase(bool verifier_layout_sensitive = false,
128                        bool allow_mixed_precision_in_hlo_verifier = true,
129                        HloPredicate instruction_can_change_layout_func = {});
130 
131   // If your test doesn't use interpreter as the reference backend, you can use
132   // this constructor. Note that your test target is responsible for linking in
133   // both needed backends.
134   HloTestBase(se::Platform* test_platform, se::Platform* reference_platform,
135               bool verifier_layout_sensitive = false,
136               bool allow_mixed_precision_in_hlo_verifier = true,
137               HloPredicate instruction_can_change_layout_func = {});
138 
~HloTestBase()139   ~HloTestBase() override {}
140 
141   // Runs pass `hlo_pass` on input HLO module `hlo`, and FileChecks the result
142   // against `expected`.
143   //
144   // If the rewrite has changed the module, also runs `additional_checks` on the
145   // result.
146   void RunAndFilecheckHloRewrite(
147       absl::string_view hlo, HloPassInterface&& hlo_pass,
148       std::optional<absl::string_view> expected,
149       std::function<void(HloModule*)> after_pass_checks = nullptr);
150 
151   // Populates debug options from command-line flags and adjusts the options for
152   // testing. It is recommended to use this when you need to pass in
153   // DebugOptions, e.g. when creating a module from a string or a file.
154   //
155   // This function is virtual so tests can specify an alternative set of debug
156   // options (e.g. disabling additional passes).
157   virtual DebugOptions GetDebugOptionsForTest();
158 
159   // Gets an HloModuleConfig with options appropriate for tests.
160   HloModuleConfig GetModuleConfigForTest(int64_t replica_count = 1,
161                                          int64_t num_partitions = 1) {
162     HloModuleConfig config;
163     config.set_debug_options(GetDebugOptionsForTest());
164     config.set_replica_count(replica_count);
165     config.set_num_partitions(num_partitions);
166     return config;
167   }
168 
169   // Executes the given module and return the result as a Literal.
170   StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
171                             absl::Span<Literal* const> arguments);
172 
173   // Same as above, except the module will be executed without running any HLO
174   // passes on it.
175   Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
176                              absl::Span<Literal* const> arguments);
177 
178   Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module,
179                              absl::Span<Literal* const> arguments);
180 
181   // Executes the given module on multiple replicas.
182   //
183   // use_threads indicates whether this replicated computation will be executed
184   // with a thread-per-replica, vs using an implicitly async call such as
185   // Executable::ExecuteOnStreams.
186   StatusOr<std::vector<Literal>> ExecuteReplicated(
187       std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
188       int64_t num_replicas, bool use_threads, bool run_hlo_passes = false);
189 
190   // Same as above, but uses specified device assignment.
191   StatusOr<std::vector<Literal>> ExecuteReplicated(
192       std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
193       int64_t num_replicas, DeviceAssignment* device_assignment,
194       bool run_hlo_passes, bool use_threads);
195 
196   // Same as above, but allows passing different programs for replicas.
197   StatusOr<std::vector<Literal>> ExecuteReplicated(
198       std::function<Executable*(int64_t)> executable_provider,
199       std::function<int64_t(int64_t)> argument_count_provider,
200       std::function<const Literal*(int64_t, int64_t)> argument_provider,
201       int64_t num_replicas, bool run_hlo_passes,
202       DeviceAssignment* device_assignment = nullptr);
203 
204   // Executes the given hlo module on two backends and compares results.
205   //
206   // 'arguments': the input of the hlo module.
207   //
208   // 'error': if has value, expects the results to be near (within the error
209   // bound). Otherwise, expects the results to be equal.
210   //
211   // 'reference_preprocessor': the module should be ready to run on the test
212   // backend, but it might need to be tailored so that it is able to run on the
213   // reference backend. Note that the program shape of the module must not be
214   // modified.
215   [[nodiscard]] ::testing::AssertionResult RunAndCompare(
216       std::unique_ptr<HloModule> module,
217       const absl::Span<Literal* const> arguments,
218       const std::optional<ErrorSpec>& error,
219       const std::function<void(HloModule*)>& reference_preprocessor = nullptr);
220 
221   // Same as above, except that the module will be executed without Hlo
222   // optimization.
223   [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses(
224       std::unique_ptr<HloModule> module,
225       const absl::Span<Literal* const> arguments,
226       const std::optional<ErrorSpec>& error,
227       const std::function<void(HloModule*)>& reference_preprocessor = nullptr);
228 
229   // Executes an hlo module with fake inputs and compares the results.
230   [[nodiscard]] ::testing::AssertionResult RunAndCompare(
231       std::unique_ptr<HloModule> module, const std::optional<ErrorSpec>& error,
232       const std::function<void(HloModule*)>& reference_preprocessor = nullptr);
233 
234   // Same as above, except that the module will be executed without Hlo
235   // optimization.
236   [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses(
237       std::unique_ptr<HloModule> module, const std::optional<ErrorSpec>& error,
238       const std::function<void(HloModule*)>& reference_preprocessor = nullptr);
239 
240   // Executes an hlo module with fake inputs and checks that the execution is
241   // successful.
242   [[nodiscard]] ::testing::AssertionResult Run(
243       std::unique_ptr<HloModule> module, bool run_hlo_passes);
244 
245   // Convenient wrappers for executing and comparing an hlo module with fake
246   // input. Module can be passed in directly, or parsed from an hlo_string,
247   // or loaded from a file.
248   [[nodiscard]] ::testing::AssertionResult RunAndCompare(
249       const absl::string_view hlo_string, const std::optional<ErrorSpec>& error,
250       const std::function<void(HloModule*)>& reference_preprocessor = nullptr);
251   [[nodiscard]] ::testing::AssertionResult Run(
252       const absl::string_view hlo_string, bool run_hlo_passes = true,
253       ExecutionProfile* profile = nullptr,
254       const tensorflow::protobuf::Message* backend_config = nullptr);
255 
256   // Same as below, except requires passing fake arguments.
257   ::testing::AssertionResult RunAndCompareTwoModules(
258       std::unique_ptr<HloModule> module_0, std::unique_ptr<HloModule> module_1,
259       const absl::Span<Literal* const> arguments,
260       const std::optional<ErrorSpec>& error);
261 
262   // Same as below, except requires passing the modules.
263   ::testing::AssertionResult RunAndCompareTwoModules(
264       std::unique_ptr<HloModule> module_0, std::unique_ptr<HloModule> module_1,
265       const std::optional<ErrorSpec>& error);
266 
267   // Convenient wrapper for executing and comparing results of two unoptimized
268   // hlo modules with fake input.
269   ::testing::AssertionResult RunAndCompareTwoModules(
270       absl::string_view hlo_string_module_0,
271       absl::string_view hlo_string_module_1,
272       const std::optional<ErrorSpec>& error);
273 
274   // Executes an hlo module with fake inputs on multiple replicas.
275   [[nodiscard]] ::testing::AssertionResult RunReplicated(
276       const absl::string_view hlo_string, bool run_hlo_passes = true,
277       int64_t num_replicas = 1,
278       const tensorflow::protobuf::Message* backend_config = nullptr);
279 
280   // If assert_determinism is true, the assertion will fail unless all runs
281   // produce exactly the same output.
282   [[nodiscard]] ::testing::AssertionResult RunMultipleTimes(
283       const absl::string_view hlo_string, bool run_hlo_passes,
284       std::vector<ExecutionProfile>* profiles,
285       const tensorflow::protobuf::Message* backend_config = nullptr,
286       bool assert_determinism = false);
287   [[nodiscard]] ::testing::AssertionResult RunAndCompareFromFile(
288       const std::string& filename, const std::optional<ErrorSpec>& error,
289       const std::function<void(HloModule*)>& reference_preprocessor = nullptr);
290   [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses(
291       const absl::string_view hlo_string, const std::optional<ErrorSpec>& error,
292       const std::function<void(HloModule*)>& reference_preprocessor = nullptr);
293   [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPassesFromFile(
294       const std::string& filename, const std::optional<ErrorSpec>& error,
295       const std::function<void(HloModule*)>& reference_preprocessor = nullptr);
296 
297   // Convenience method to force the layout of a given parameter in a module.
298   // The layout of parameter number 'param_no' in the 'module' is set to
299   // 'layout'.
ForceParameterLayout(HloModule * module,int64_t param_no,const Layout & layout)300   void ForceParameterLayout(HloModule* module, int64_t param_no,
301                             const Layout& layout) {
302     ASSERT_LT(param_no,
303               module->mutable_entry_computation_layout()->parameter_count());
304     module->mutable_entry_computation_layout()
305         ->mutable_parameter_layout(param_no)
306         ->ResetLayout(layout);
307   }
308 
309   // Convenience method to force the layout of the computation result in a
310   // module. The result layout of 'module' is set to 'layout'.
ForceResultLayout(HloModule * module,const Layout & layout)311   void ForceResultLayout(HloModule* module, const Layout& layout) {
312     module->mutable_entry_computation_layout()
313         ->mutable_result_layout()
314         ->ResetLayout(layout);
315   }
316 
ForceResultLayout(HloModule * module,const Layout & layout,ShapeIndexView shape_index)317   void ForceResultLayout(HloModule* module, const Layout& layout,
318                          ShapeIndexView shape_index) {
319     module->mutable_entry_computation_layout()
320         ->mutable_result_layout()
321         ->ResetLayout(layout, shape_index);
322   }
323 
324   // Convenience method to clear the layout of the computation result in
325   // 'module'.
ForceClearResultLayout(HloModule * module)326   void ForceClearResultLayout(HloModule* module) {
327     module->mutable_entry_computation_layout()
328         ->mutable_result_layout()
329         ->Clear();
330   }
331 
332   // Gets the computation/instruction from the given module with the given name.
333   //
334   // This is useful for tests which create HLOs from a string and then want to
335   // inspect a particular computation or instruction.
336   HloComputation* FindComputation(HloModule* module, absl::string_view name);
337   HloInstruction* FindInstruction(HloModule* module, absl::string_view name);
338   // Gets the instruction from the given module with the given opcode.
339   HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode);
340 
341   // Return an HLO verifier constructed for the test backend.
verifier()342   HloVerifier& verifier() const { return *hlo_verifier_; }
343 
344   static std::string TestName();
345 
346   // Returns the backend owned by the test runner.
347   Backend& backend();
348 
349   HloRunner test_runner_;
350   HloRunner reference_runner_;
351 
352   bool verifier_layout_sensitive_;
353   bool allow_mixed_precision_in_hlo_verifier_;
354   std::unique_ptr<HloVerifier> hlo_verifier_;
355 
356   ErrorSpec error_spec_{0.0001};
357 
358   HloComputation* AddEntryComputationAndUpdateEntryComputationLayout(
359       HloModule*, std::unique_ptr<HloComputation> computation);
360   void UpdateEntryComputationLayout(HloModule* module);
361 
362  protected:
363   // Helper functions to get test and reference platforms.
364   static se::Platform* GetReferencePlatform();
365   static se::Platform* GetTestPlatform();
366 
367  private:
368   // Given the test module, makes a reference module that is ready to run on the
369   // reference platform. This assumes that the given module is ready to run on
370   // the test platform.
371   StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule(
372       const HloModule& test_module,
373       const std::function<void(HloModule*)>& reference_preprocessor);
374 
375   // Runs the module on two platforms with or without running hlo passes and
376   // compares the results. Returns whether the results are near or equal. If any
377   // error happens before the results are computed, returns the error status.
378   StatusOr<::testing::AssertionResult> RunAndCompareInternal(
379       std::unique_ptr<HloModule> module,
380       const absl::Span<Literal* const> arguments,
381       const std::optional<ErrorSpec>& error, bool run_hlo_passes,
382       const std::function<void(HloModule*)>& reference_preprocessor);
383 
384   // Runs the two module on with or without running hlo passes and
385   // compares the results. Returns whether the results are near or equal. If any
386   // error happens before the results are computed, returns the error status.
387   StatusOr<::testing::AssertionResult> RunAndCompareTwoModulesInternal(
388       std::unique_ptr<HloModule> module_0, std::unique_ptr<HloModule> module_1,
389       const absl::Span<Literal* const> arguments,
390       const std::optional<ErrorSpec>& error, bool run_hlo_passes);
391 };
392 
393 }  // namespace xla
394 
395 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
396