xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/llvm_compiler_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_compiler.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/backend.h"
23 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
24 #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/platform_util.h"
27 #include "tensorflow/compiler/xla/test_helpers.h"
28 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
29 #include "tensorflow/core/platform/test.h"
30 #include "tensorflow/stream_executor/stream_executor.h"
31 
32 namespace xla {
33 namespace gpu {
34 
35 // Creating dummy data structure needed to initialize a GpuDummyCompiler
36 PLATFORM_DEFINE_ID(kDummyTestId);
37 constexpr char kDummyTriple[] = "dummy-triple";
38 constexpr char kDummyLayout[] = "e";
39 
40 // This class is a dummy implementation of GpuCompiler and is targeted for unit
41 // test only
42 class GpuDummyCompiler : public GpuCompiler {
43  public:
GpuDummyCompiler()44   GpuDummyCompiler() : GpuCompiler(kDummyTestId, kDummyTriple, kDummyLayout) {}
45 
OptimizeHloConvolutionCanonicalization(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)46   Status OptimizeHloConvolutionCanonicalization(
47       HloModule* hlo_module, se::StreamExecutor* stream_exec,
48       se::DeviceMemoryAllocator* device_allocator) {
49     return OkStatus();
50   }
51 
OptimizeHloPostLayoutAssignment(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)52   Status OptimizeHloPostLayoutAssignment(
53       HloModule* hlo_module, se::StreamExecutor* stream_exec,
54       se::DeviceMemoryAllocator* device_allocator) {
55     return OkStatus();
56   }
57 
GetGpuVersion(se::StreamExecutor *)58   GpuVersion GetGpuVersion(se::StreamExecutor*) override {
59     return se::CudaComputeCapability{0, 0};
60   }
61 
CompileTargetBinary(const HloModuleConfig & module_config,llvm::Module * llvm_module,GpuVersion gpu_version,se::StreamExecutor * stream_exec,bool relocatable,const HloModule * debug_module)62   StatusOr<std::pair<std::string, std::vector<uint8_t>>> CompileTargetBinary(
63       const HloModuleConfig& module_config, llvm::Module* llvm_module,
64       GpuVersion gpu_version, se::StreamExecutor* stream_exec, bool relocatable,
65       const HloModule* debug_module) {
66     std::vector<uint8_t> compiled_results;
67     return std::pair<std::string, std::vector<uint8_t>>(
68         "", std::move(compiled_results));
69   }
70 };
71 }  // namespace gpu
72 
73 namespace {
74 
75 class LLVMCompilerTest : public ::testing::Test {
76  public:
SetUp()77   void SetUp() override {
78     Platform* platform = FindPlatform();
79     ASSERT_NE(platform, nullptr);
80 
81     BackendOptions backend_options;
82     backend_options.set_platform(platform);
83     StatusOr<std::unique_ptr<Backend>> backend_or_status =
84         Backend::CreateBackend(backend_options);
85     ASSERT_IS_OK(backend_or_status.status());
86     backend_ = std::move(backend_or_status).value();
87   }
88 
~LLVMCompilerTest()89   ~LLVMCompilerTest() override {}
90 
91  protected:
92   using Platform = se::Platform;
93 
LLVMCompilerTest(std::string platform_name)94   explicit LLVMCompilerTest(std::string platform_name)
95       : platform_name_(std::move(platform_name)) {}
96 
TestCompilerHooks(LLVMCompiler * compiler)97   void TestCompilerHooks(LLVMCompiler* compiler) {
98     int pre_opt_hook_call_count = 0;
99     int post_opt_hook_call_count = 0;
100 
101     auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module&) {
102       ++pre_opt_hook_call_count;
103       return OkStatus();
104     };
105     auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module&) {
106       ++post_opt_hook_call_count;
107       return OkStatus();
108     };
109 
110     // Create HLO module, and run the compiler.
111     auto builder = HloComputation::Builder(TestName());
112     builder.AddInstruction(
113         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
114 
115     auto hlo_module = CreateNewVerifiedModule();
116     hlo_module->AddEntryComputation(builder.Build());
117 
118     compiler->SetPreOptimizationHook(pre_opt_hook);
119     compiler->SetPostOptimizationHook(post_opt_hook);
120 
121     ASSERT_TRUE(compiler
122                     ->RunBackend(std::move(hlo_module),
123                                  backend_->default_stream_executor(),
124                                  /*device_allocator=*/nullptr)
125                     .ok());
126 
127     // Test that hooks were called.
128     EXPECT_EQ(1, pre_opt_hook_call_count);
129     EXPECT_EQ(1, post_opt_hook_call_count);
130   }
131 
TestMultiModuleCompilation(LLVMCompiler * compiler)132   void TestMultiModuleCompilation(LLVMCompiler* compiler) {
133     HloComputation::Builder builder(TestName());
134     builder.AddInstruction(
135         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
136 
137     std::unique_ptr<HloModule> hlo_module = CreateNewVerifiedModule();
138     hlo_module->AddEntryComputation(builder.Build());
139 
140     auto module_group = std::make_unique<HloModuleGroup>("test_module_group");
141     module_group->push_back(hlo_module->Clone());
142     module_group->push_back(std::move(hlo_module));
143 
144     std::vector<std::vector<se::StreamExecutor*>> executors;
145     executors.push_back({backend_->default_stream_executor()});
146     executors.push_back({backend_->default_stream_executor()});
147 
148     EXPECT_IS_OK(compiler->Compile(std::move(module_group),
149                                    std::move(executors),
150                                    /*device_allocator=*/nullptr));
151   }
152 
153  private:
FindPlatform()154   Platform* FindPlatform() {
155     auto status_or_platform = PlatformUtil::GetPlatform(platform_name_);
156     return status_or_platform.ok() ? status_or_platform.ValueOrDie() : nullptr;
157   }
158 
159   std::string platform_name_;
160   std::unique_ptr<Backend> backend_;
161 
TestName()162   static std::string TestName() {
163     return ::testing::UnitTest::GetInstance()->current_test_info()->name();
164   }
165 
CreateNewVerifiedModule()166   std::unique_ptr<HloModule> CreateNewVerifiedModule() {
167     HloModuleConfig config;
168     config.set_debug_options(GetDebugOptionsFromFlags());
169     return std::make_unique<VerifiedHloModule>(
170         TestName(), config, /*verifier_layout_sensitive=*/false,
171         /*allow_mixed_precision_in_hlo_verifier=*/true,
172         backend_->compiler()->ShapeSizeBytesFunction());
173   }
174 };
175 
176 class CpuCompilerTest : public LLVMCompilerTest {
177  public:
CpuCompilerTest()178   CpuCompilerTest() : LLVMCompilerTest("Host") {}
179 };
180 
181 class GpuCompilerTest : public LLVMCompilerTest {
182  public:
GpuCompilerTest()183   GpuCompilerTest() : LLVMCompilerTest("GPU") {}
184 };
185 
TEST_F(CpuCompilerTest,HooksTest)186 TEST_F(CpuCompilerTest, HooksTest) {
187   cpu::CpuCompiler compiler;
188   TestCompilerHooks(&compiler);
189 }
190 
TEST_F(GpuCompilerTest,HooksTest)191 TEST_F(GpuCompilerTest, HooksTest) {
192   gpu::GpuDummyCompiler compiler;
193   TestCompilerHooks(&compiler);
194 }
195 
TEST_F(CpuCompilerTest,CpuMultiModuleCompilation)196 TEST_F(CpuCompilerTest, CpuMultiModuleCompilation) {
197   cpu::CpuCompiler compiler;
198   TestMultiModuleCompilation(&compiler);
199 }
200 
TEST_F(GpuCompilerTest,GpuMultModuleCompilation)201 TEST_F(GpuCompilerTest, GpuMultModuleCompilation) {
202   gpu::GpuDummyCompiler compiler;
203   TestMultiModuleCompilation(&compiler);
204 }
205 }  // namespace
206 }  // namespace xla
207