1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 17 #define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <optional> 22 #include <string> 23 #include <vector> 24 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/service/backend.h" 27 #include "tensorflow/compiler/xla/service/computation_layout.h" 28 #include "tensorflow/compiler/xla/service/hlo_module.h" 29 #include "tensorflow/compiler/xla/service/hlo_runner.h" 30 #include "tensorflow/compiler/xla/service/hlo_verifier.h" 31 #include "tensorflow/compiler/xla/service/platform_util.h" 32 #include "tensorflow/compiler/xla/shape_layout.h" 33 #include "tensorflow/compiler/xla/statusor.h" 34 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 35 #include "tensorflow/compiler/xla/tests/manifest_checking_test.h" 36 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" 37 #include "tensorflow/compiler/xla/types.h" 38 #include "tensorflow/compiler/xla/xla_data.pb.h" 39 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 40 #include "tensorflow/core/platform/test.h" 41 42 namespace xla { 43 44 // A base class for tests which build and/or run HLO code. The class includes 45 // support for running an HLO module on two platforms and compare the results. 46 // This is a lower level of abstraction than using the client interface and 47 // enables, for one, explicitly building a graph of HLO instructions to run. 48 // 49 // This can also be used to write text/file-based test cases. Note that the test 50 // target is responsible for linking the needed backends. A convenient way to do 51 // this is to make it an xla_test: it will generate test targets linking with 52 // the respective backends, which will be used as the test backend; the 53 // interpreter backend is already linked with hlo_test_base so it will be the 54 // default reference backend. For example, if you want to compare both cpu vs. 55 // interpreter, and gpu vs. interpreter, you can: 56 // 57 // xla_test ( 58 // name = "sample_text_test", 59 // srcs = ["sample_text_test.cc"], 60 // backends = [ 61 // "cpu", 62 // "gpu", 63 // ], 64 // deps = [ 65 // "//third_party/tensorflow/compiler/xla/tests:hlo_test_base", 66 // ... 67 // ], 68 // ) 69 // 70 // For a more detailed example, see "../tests/sample_text_test.cc". 71 class HloTestBase : public ManifestCheckingTest { 72 public: 73 // Creates a new HLO module for a test. The module created will have 74 // TestName() for its name; it will also automatically populate its debug 75 // options from command-line flags. If you want a fresh HloModule object and 76 // then add HloComputations to it, it's recommended to use this method in your 77 // tests. 78 // 79 // This returns a vanilla HloModule that doesn't run the HLO verifier on 80 // destruction. 81 ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.") 82 std::unique_ptr<HloModule> CreateNewUnverifiedModule( 83 const std::string& name = TestName()); 84 85 // Like CreateNewUnverifiedModule, except the HloModule returned here runs the 86 // HLO verifier on destruction. 87 std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule( 88 const std::string& name = TestName(), int64_t replica_count = 1); 89 90 // Parses the given string and returns module as a VerifiedHloModule. 91 StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule( 92 absl::string_view hlo_text, int64_t replica_count = 1, 93 int64_t num_partitions = 1); 94 StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule( 95 absl::string_view hlo_text, const HloModuleConfig& config); 96 97 // Runs the hlo_pass with the provided module and returns the result. This 98 // function also verifies that the module remains unchanged when hlo_pass 99 // returns false as the StatusOr value. 100 // 101 // These three overloads all do the same thing. The && overload lets you do 102 // `RunHloPass(MyPass(), module)` all in one line. The reason for the 103 // overload that takes a pointer is that, at one point in the past, non-const 104 // lvalue references were banned in Google code. 105 static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass, 106 HloModule* module); RunHloPass(HloPassInterface & hlo_pass,HloModule * module)107 static StatusOr<bool> RunHloPass(HloPassInterface& hlo_pass, 108 HloModule* module) { 109 return RunHloPass(&hlo_pass, module); 110 } RunHloPass(HloPassInterface && hlo_pass,HloModule * module)111 static StatusOr<bool> RunHloPass(HloPassInterface&& hlo_pass, 112 HloModule* module) { 113 return RunHloPass(&hlo_pass, module); 114 } 115 116 static PrecisionConfig DefaultPrecisionConfig(int operands); 117 118 // Sets most fath math options to be enabled to model the fast math flags 119 // generally used for CPU:AOT compilation. 120 static void SetAotFastMathDebugOptions(DebugOptions* options); 121 122 protected: 123 // This uses the interpreter backend as the reference backend and 124 // automatically finds another supported backend as the test backend. If the 125 // interpreter is the only supported backend, it will be both the test backend 126 // and the reference backend. 127 explicit HloTestBase(bool verifier_layout_sensitive = false, 128 bool allow_mixed_precision_in_hlo_verifier = true, 129 HloPredicate instruction_can_change_layout_func = {}); 130 131 // If your test doesn't use interpreter as the reference backend, you can use 132 // this constructor. Note that your test target is responsible for linking in 133 // both needed backends. 134 HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, 135 bool verifier_layout_sensitive = false, 136 bool allow_mixed_precision_in_hlo_verifier = true, 137 HloPredicate instruction_can_change_layout_func = {}); 138 ~HloTestBase()139 ~HloTestBase() override {} 140 141 // Runs pass `hlo_pass` on input HLO module `hlo`, and FileChecks the result 142 // against `expected`. 143 // 144 // If the rewrite has changed the module, also runs `additional_checks` on the 145 // result. 146 void RunAndFilecheckHloRewrite( 147 absl::string_view hlo, HloPassInterface&& hlo_pass, 148 std::optional<absl::string_view> expected, 149 std::function<void(HloModule*)> after_pass_checks = nullptr); 150 151 // Populates debug options from command-line flags and adjusts the options for 152 // testing. It is recommended to use this when you need to pass in 153 // DebugOptions, e.g. when creating a module from a string or a file. 154 // 155 // This function is virtual so tests can specify an alternative set of debug 156 // options (e.g. disabling additional passes). 157 virtual DebugOptions GetDebugOptionsForTest(); 158 159 // Gets an HloModuleConfig with options appropriate for tests. 160 HloModuleConfig GetModuleConfigForTest(int64_t replica_count = 1, 161 int64_t num_partitions = 1) { 162 HloModuleConfig config; 163 config.set_debug_options(GetDebugOptionsForTest()); 164 config.set_replica_count(replica_count); 165 config.set_num_partitions(num_partitions); 166 return config; 167 } 168 169 // Executes the given module and return the result as a Literal. 170 StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, 171 absl::Span<Literal* const> arguments); 172 173 // Same as above, except the module will be executed without running any HLO 174 // passes on it. 175 Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module, 176 absl::Span<Literal* const> arguments); 177 178 Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module, 179 absl::Span<Literal* const> arguments); 180 181 // Executes the given module on multiple replicas. 182 // 183 // use_threads indicates whether this replicated computation will be executed 184 // with a thread-per-replica, vs using an implicitly async call such as 185 // Executable::ExecuteOnStreams. 186 StatusOr<std::vector<Literal>> ExecuteReplicated( 187 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments, 188 int64_t num_replicas, bool use_threads, bool run_hlo_passes = false); 189 190 // Same as above, but uses specified device assignment. 191 StatusOr<std::vector<Literal>> ExecuteReplicated( 192 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments, 193 int64_t num_replicas, DeviceAssignment* device_assignment, 194 bool run_hlo_passes, bool use_threads); 195 196 // Same as above, but allows passing different programs for replicas. 197 StatusOr<std::vector<Literal>> ExecuteReplicated( 198 std::function<Executable*(int64_t)> executable_provider, 199 std::function<int64_t(int64_t)> argument_count_provider, 200 std::function<const Literal*(int64_t, int64_t)> argument_provider, 201 int64_t num_replicas, bool run_hlo_passes, 202 DeviceAssignment* device_assignment = nullptr); 203 204 // Executes the given hlo module on two backends and compares results. 205 // 206 // 'arguments': the input of the hlo module. 207 // 208 // 'error': if has value, expects the results to be near (within the error 209 // bound). Otherwise, expects the results to be equal. 210 // 211 // 'reference_preprocessor': the module should be ready to run on the test 212 // backend, but it might need to be tailored so that it is able to run on the 213 // reference backend. Note that the program shape of the module must not be 214 // modified. 215 [[nodiscard]] ::testing::AssertionResult RunAndCompare( 216 std::unique_ptr<HloModule> module, 217 const absl::Span<Literal* const> arguments, 218 const std::optional<ErrorSpec>& error, 219 const std::function<void(HloModule*)>& reference_preprocessor = nullptr); 220 221 // Same as above, except that the module will be executed without Hlo 222 // optimization. 223 [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( 224 std::unique_ptr<HloModule> module, 225 const absl::Span<Literal* const> arguments, 226 const std::optional<ErrorSpec>& error, 227 const std::function<void(HloModule*)>& reference_preprocessor = nullptr); 228 229 // Executes an hlo module with fake inputs and compares the results. 230 [[nodiscard]] ::testing::AssertionResult RunAndCompare( 231 std::unique_ptr<HloModule> module, const std::optional<ErrorSpec>& error, 232 const std::function<void(HloModule*)>& reference_preprocessor = nullptr); 233 234 // Same as above, except that the module will be executed without Hlo 235 // optimization. 236 [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( 237 std::unique_ptr<HloModule> module, const std::optional<ErrorSpec>& error, 238 const std::function<void(HloModule*)>& reference_preprocessor = nullptr); 239 240 // Executes an hlo module with fake inputs and checks that the execution is 241 // successful. 242 [[nodiscard]] ::testing::AssertionResult Run( 243 std::unique_ptr<HloModule> module, bool run_hlo_passes); 244 245 // Convenient wrappers for executing and comparing an hlo module with fake 246 // input. Module can be passed in directly, or parsed from an hlo_string, 247 // or loaded from a file. 248 [[nodiscard]] ::testing::AssertionResult RunAndCompare( 249 const absl::string_view hlo_string, const std::optional<ErrorSpec>& error, 250 const std::function<void(HloModule*)>& reference_preprocessor = nullptr); 251 [[nodiscard]] ::testing::AssertionResult Run( 252 const absl::string_view hlo_string, bool run_hlo_passes = true, 253 ExecutionProfile* profile = nullptr, 254 const tensorflow::protobuf::Message* backend_config = nullptr); 255 256 // Same as below, except requires passing fake arguments. 257 ::testing::AssertionResult RunAndCompareTwoModules( 258 std::unique_ptr<HloModule> module_0, std::unique_ptr<HloModule> module_1, 259 const absl::Span<Literal* const> arguments, 260 const std::optional<ErrorSpec>& error); 261 262 // Same as below, except requires passing the modules. 263 ::testing::AssertionResult RunAndCompareTwoModules( 264 std::unique_ptr<HloModule> module_0, std::unique_ptr<HloModule> module_1, 265 const std::optional<ErrorSpec>& error); 266 267 // Convenient wrapper for executing and comparing results of two unoptimized 268 // hlo modules with fake input. 269 ::testing::AssertionResult RunAndCompareTwoModules( 270 absl::string_view hlo_string_module_0, 271 absl::string_view hlo_string_module_1, 272 const std::optional<ErrorSpec>& error); 273 274 // Executes an hlo module with fake inputs on multiple replicas. 275 [[nodiscard]] ::testing::AssertionResult RunReplicated( 276 const absl::string_view hlo_string, bool run_hlo_passes = true, 277 int64_t num_replicas = 1, 278 const tensorflow::protobuf::Message* backend_config = nullptr); 279 280 // If assert_determinism is true, the assertion will fail unless all runs 281 // produce exactly the same output. 282 [[nodiscard]] ::testing::AssertionResult RunMultipleTimes( 283 const absl::string_view hlo_string, bool run_hlo_passes, 284 std::vector<ExecutionProfile>* profiles, 285 const tensorflow::protobuf::Message* backend_config = nullptr, 286 bool assert_determinism = false); 287 [[nodiscard]] ::testing::AssertionResult RunAndCompareFromFile( 288 const std::string& filename, const std::optional<ErrorSpec>& error, 289 const std::function<void(HloModule*)>& reference_preprocessor = nullptr); 290 [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( 291 const absl::string_view hlo_string, const std::optional<ErrorSpec>& error, 292 const std::function<void(HloModule*)>& reference_preprocessor = nullptr); 293 [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( 294 const std::string& filename, const std::optional<ErrorSpec>& error, 295 const std::function<void(HloModule*)>& reference_preprocessor = nullptr); 296 297 // Convenience method to force the layout of a given parameter in a module. 298 // The layout of parameter number 'param_no' in the 'module' is set to 299 // 'layout'. ForceParameterLayout(HloModule * module,int64_t param_no,const Layout & layout)300 void ForceParameterLayout(HloModule* module, int64_t param_no, 301 const Layout& layout) { 302 ASSERT_LT(param_no, 303 module->mutable_entry_computation_layout()->parameter_count()); 304 module->mutable_entry_computation_layout() 305 ->mutable_parameter_layout(param_no) 306 ->ResetLayout(layout); 307 } 308 309 // Convenience method to force the layout of the computation result in a 310 // module. The result layout of 'module' is set to 'layout'. ForceResultLayout(HloModule * module,const Layout & layout)311 void ForceResultLayout(HloModule* module, const Layout& layout) { 312 module->mutable_entry_computation_layout() 313 ->mutable_result_layout() 314 ->ResetLayout(layout); 315 } 316 ForceResultLayout(HloModule * module,const Layout & layout,ShapeIndexView shape_index)317 void ForceResultLayout(HloModule* module, const Layout& layout, 318 ShapeIndexView shape_index) { 319 module->mutable_entry_computation_layout() 320 ->mutable_result_layout() 321 ->ResetLayout(layout, shape_index); 322 } 323 324 // Convenience method to clear the layout of the computation result in 325 // 'module'. ForceClearResultLayout(HloModule * module)326 void ForceClearResultLayout(HloModule* module) { 327 module->mutable_entry_computation_layout() 328 ->mutable_result_layout() 329 ->Clear(); 330 } 331 332 // Gets the computation/instruction from the given module with the given name. 333 // 334 // This is useful for tests which create HLOs from a string and then want to 335 // inspect a particular computation or instruction. 336 HloComputation* FindComputation(HloModule* module, absl::string_view name); 337 HloInstruction* FindInstruction(HloModule* module, absl::string_view name); 338 // Gets the instruction from the given module with the given opcode. 339 HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode); 340 341 // Return an HLO verifier constructed for the test backend. verifier()342 HloVerifier& verifier() const { return *hlo_verifier_; } 343 344 static std::string TestName(); 345 346 // Returns the backend owned by the test runner. 347 Backend& backend(); 348 349 HloRunner test_runner_; 350 HloRunner reference_runner_; 351 352 bool verifier_layout_sensitive_; 353 bool allow_mixed_precision_in_hlo_verifier_; 354 std::unique_ptr<HloVerifier> hlo_verifier_; 355 356 ErrorSpec error_spec_{0.0001}; 357 358 HloComputation* AddEntryComputationAndUpdateEntryComputationLayout( 359 HloModule*, std::unique_ptr<HloComputation> computation); 360 void UpdateEntryComputationLayout(HloModule* module); 361 362 protected: 363 // Helper functions to get test and reference platforms. 364 static se::Platform* GetReferencePlatform(); 365 static se::Platform* GetTestPlatform(); 366 367 private: 368 // Given the test module, makes a reference module that is ready to run on the 369 // reference platform. This assumes that the given module is ready to run on 370 // the test platform. 371 StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule( 372 const HloModule& test_module, 373 const std::function<void(HloModule*)>& reference_preprocessor); 374 375 // Runs the module on two platforms with or without running hlo passes and 376 // compares the results. Returns whether the results are near or equal. If any 377 // error happens before the results are computed, returns the error status. 378 StatusOr<::testing::AssertionResult> RunAndCompareInternal( 379 std::unique_ptr<HloModule> module, 380 const absl::Span<Literal* const> arguments, 381 const std::optional<ErrorSpec>& error, bool run_hlo_passes, 382 const std::function<void(HloModule*)>& reference_preprocessor); 383 384 // Runs the two module on with or without running hlo passes and 385 // compares the results. Returns whether the results are near or equal. If any 386 // error happens before the results are computed, returns the error status. 387 StatusOr<::testing::AssertionResult> RunAndCompareTwoModulesInternal( 388 std::unique_ptr<HloModule> module_0, std::unique_ptr<HloModule> module_1, 389 const absl::Span<Literal* const> arguments, 390 const std::optional<ErrorSpec>& error, bool run_hlo_passes); 391 }; 392 393 } // namespace xla 394 395 #endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 396