1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_JIT_TESTS_XLA_COMPILATION_CACHE_TEST_HELPER_H_ 16 #define TENSORFLOW_COMPILER_JIT_TESTS_XLA_COMPILATION_CACHE_TEST_HELPER_H_ 17 18 #include <memory> 19 #include <utility> 20 21 #include "absl/strings/string_view.h" 22 #include "tensorflow/compiler/jit/xla_activity_listener.h" 23 #include "tensorflow/core/graph/graph_def_builder.h" 24 #include "tensorflow/core/platform/errors.h" 25 #include "tensorflow/core/platform/test.h" 26 27 namespace tensorflow { 28 29 // A listener to inspect the use of XLA's persistent compilation cache entries. 30 class JitCompilationListener : public XlaActivityListener { 31 public: Listen(const XlaAutoClusteringActivity & auto_clustering_activity)32 Status Listen( 33 const XlaAutoClusteringActivity& auto_clustering_activity) override { 34 return OkStatus(); 35 } 36 Listen(const XlaJitCompilationActivity & jit_compilation_activity)37 Status Listen( 38 const XlaJitCompilationActivity& jit_compilation_activity) override { 39 used_persistent_cache_.push_back( 40 jit_compilation_activity.used_persistent_cache()); 41 return OkStatus(); 42 } 43 Listen(const XlaOptimizationRemark & optimization_remark)44 Status Listen(const XlaOptimizationRemark& optimization_remark) override { 45 return OkStatus(); 46 } 47 ~JitCompilationListener()48 ~JitCompilationListener() override {} 49 VerifyListenerHistory(bool expect_persistent_cache_use)50 Status VerifyListenerHistory(bool expect_persistent_cache_use) { 51 for (bool used_persistent_cache : used_persistent_cache_) { 52 if (used_persistent_cache != expect_persistent_cache_use) { 53 return errors::FailedPrecondition("Unexpected listener history."); 54 } 55 } 56 return OkStatus(); 57 } 58 ClearListenerHistory()59 void ClearListenerHistory() { used_persistent_cache_.clear(); } 60 61 private: 62 std::vector<bool> used_persistent_cache_; 63 }; 64 65 // Fixture for testing XLA compilation cache serialization. 66 class XlaCompilationCacheSerializeTest : public ::testing::Test { 67 protected: XlaCompilationCacheSerializeTest()68 XlaCompilationCacheSerializeTest() { 69 auto listener = absl::make_unique<JitCompilationListener>(); 70 listener_ = listener.get(); 71 RegisterXlaActivityListener(std::move(listener)); 72 } 73 listener()74 JitCompilationListener* listener() const { return listener_; } 75 76 // Returns a test graph that will split into two XLA clusters (due to a node 77 // with _XlaCompile = false). 78 GraphDef GetTestGraph(const PartialTensorShape& input_shape); 79 80 // Runs the graph using specified batch size both with and without XLA JIT 81 // compilation. Returns an error if the results between the two do not match. 82 Status ExecuteWithBatch(const GraphDef& graph, int batch); 83 84 // Adds the suffix "_altered" to the HLO module names of all of the persistent 85 // XLA compilation cache entries found at the specified directory. If none are 86 // found, returns NOT_FOUND error. 87 Status AlterPersistentCacheEntryHloModuleNames( 88 absl::string_view persistent_cache_dir_path, 89 absl::string_view file_prefix = "xla_compile_cache"); 90 91 private: 92 JitCompilationListener* listener_; 93 }; 94 95 } // namespace tensorflow 96 97 #endif // TENSORFLOW_COMPILER_JIT_TESTS_XLA_COMPILATION_CACHE_TEST_HELPER_H_ 98