xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/local_client_test_base.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16 
17 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/strings/string_view.h"
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/compiler/xla/client/local_client.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/map_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
28 #include "tensorflow/compiler/xla/service/hlo_parser.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/test_helpers.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/logging.h"
37 
38 namespace xla {
39 
40 /* static */ TestAllocator* LocalClientTestBase::allocator_;
41 
Allocate(int device_ordinal,uint64_t size,bool retry_on_failure,int64_t memory_space)42 StatusOr<se::OwningDeviceMemory> TestAllocator::Allocate(int device_ordinal,
43                                                          uint64_t size,
44                                                          bool retry_on_failure,
45                                                          int64_t memory_space) {
46   VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")";
47   {
48     absl::MutexLock lock(&count_mutex_);
49     allocation_count_++;
50     device_allocation_count_[device_ordinal]++;
51   }
52   return se::StreamExecutorMemoryAllocator::Allocate(
53       device_ordinal, size, retry_on_failure, memory_space);
54 }
55 
Deallocate(int device_ordinal,se::DeviceMemoryBase mem)56 Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
57   VLOG(2) << "Deallocate(" << device_ordinal << ")";
58   {
59     absl::MutexLock lock(&count_mutex_);
60     deallocation_count_++;
61     device_deallocation_count_[device_ordinal]++;
62   }
63   return se::StreamExecutorMemoryAllocator::Deallocate(device_ordinal, mem);
64 }
65 
allocation_count() const66 int64_t TestAllocator::allocation_count() const {
67   absl::MutexLock lock(&count_mutex_);
68   return allocation_count_;
69 }
70 
allocation_count(int device_ordinal) const71 int64_t TestAllocator::allocation_count(int device_ordinal) const {
72   absl::MutexLock lock(&count_mutex_);
73   auto it = device_allocation_count_.find(device_ordinal);
74   if (it == device_allocation_count_.end()) {
75     return 0;
76   } else {
77     return it->second;
78   }
79 }
80 
deallocation_count() const81 int64_t TestAllocator::deallocation_count() const {
82   absl::MutexLock lock(&count_mutex_);
83   return deallocation_count_;
84 }
85 
deallocation_count(int device_ordinal) const86 int64_t TestAllocator::deallocation_count(int device_ordinal) const {
87   absl::MutexLock lock(&count_mutex_);
88   auto it = device_deallocation_count_.find(device_ordinal);
89   if (it == device_deallocation_count_.end()) {
90     return 0;
91   } else {
92     return it->second;
93   }
94 }
95 
GetOrCreateAllocator(se::Platform * platform)96 /* static */ TestAllocator* LocalClientTestBase::GetOrCreateAllocator(
97     se::Platform* platform) {
98   static absl::Mutex mu(absl::kConstInit);
99   absl::MutexLock lock(&mu);
100 
101   if (allocator_ == nullptr) {
102     allocator_ = new TestAllocator(
103         platform == nullptr ? PlatformUtil::GetDefaultPlatform().ValueOrDie()
104                             : platform);
105   }
106   return allocator_;
107 }
108 
109 // Define this in .cc file to avoid having to include eigen or forward declare
110 // these types in the header.
111 struct LocalClientTestBase::EigenThreadPoolWrapper {
EigenThreadPoolWrapperxla::LocalClientTestBase::EigenThreadPoolWrapper112   explicit EigenThreadPoolWrapper()
113       : pool(new tensorflow::thread::ThreadPool(
114             tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)),
115         device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(),
116                                            pool->NumThreads())) {}
117 
118   std::unique_ptr<tensorflow::thread::ThreadPool> pool;
119   std::unique_ptr<Eigen::ThreadPoolDevice> device;
120 };
121 
LocalClientTestBase(se::Platform * platform)122 LocalClientTestBase::LocalClientTestBase(se::Platform* platform)
123     : local_client_(
124           ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()),
125       thread_pool_wrapper_(new EigenThreadPoolWrapper()) {
126   // Take the first executor, since it's the default one.
127   stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform())
128                          .ValueOrDie()
129                          .front();
130   transfer_manager_ =
131       TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie();
132 }
133 
~LocalClientTestBase()134 LocalClientTestBase::~LocalClientTestBase() {}
135 
LiteralToShapedBuffer(const Literal & literal)136 ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer(
137     const Literal& literal) {
138   return local_client_
139       ->LiteralToShapedBuffer(literal, local_client_->default_device_ordinal())
140       .value();
141 }
142 
ShapedBufferToLiteral(const ShapedBuffer & shaped_buffer)143 Literal LocalClientTestBase::ShapedBufferToLiteral(
144     const ShapedBuffer& shaped_buffer) {
145   return local_client_->ShapedBufferToLiteral(shaped_buffer).value();
146 }
147 
DefaultExecutableBuildOptions() const148 ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions()
149     const {
150   return ExecutableBuildOptions();
151 }
152 
DefaultExecutableRunOptions() const153 ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const {
154   ExecutableRunOptions run_options;
155   run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get());
156   run_options.set_allocator(GetOrCreateAllocator(local_client_->platform()));
157   return run_options;
158 }
159 
ExecuteLocallyOrDie(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments)160 ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
161     const XlaComputation& computation,
162     absl::Span<const ShapedBuffer* const> arguments) {
163   return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
164                         DefaultExecutableRunOptions())
165       .value();
166 }
167 
ExecuteLocallyOrDie(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments,const ExecutableBuildOptions & build_options,const ExecutableRunOptions & run_options)168 ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
169     const XlaComputation& computation,
170     absl::Span<const ShapedBuffer* const> arguments,
171     const ExecutableBuildOptions& build_options,
172     const ExecutableRunOptions& run_options) {
173   return ExecuteLocally(computation, arguments, build_options, run_options)
174       .value();
175 }
176 
ExecuteLocally(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments)177 StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
178     const XlaComputation& computation,
179     absl::Span<const ShapedBuffer* const> arguments) {
180   return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
181                         DefaultExecutableRunOptions());
182 }
183 
ExecuteLocally(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments,const ExecutableBuildOptions & build_options,const ExecutableRunOptions & run_options)184 StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
185     const XlaComputation& computation,
186     absl::Span<const ShapedBuffer* const> arguments,
187     const ExecutableBuildOptions& build_options,
188     const ExecutableRunOptions& run_options) {
189   std::vector<const Shape*> argument_layouts(arguments.size());
190   for (int i = 0; i < arguments.size(); ++i) {
191     argument_layouts[i] = &arguments[i]->on_device_shape();
192   }
193   TF_ASSIGN_OR_RETURN(
194       auto executables,
195       local_client_->Compile(computation, argument_layouts, build_options));
196   TF_RET_CHECK(executables.size() == 1);
197   TF_ASSIGN_OR_RETURN(auto ret, executables[0]->Run(arguments, run_options));
198 
199   auto device_ordinal =
200       build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal();
201   auto* stream = run_options.stream();
202   if (!stream) {
203     stream = local_client_->mutable_backend()
204                  ->BorrowStream(device_ordinal)
205                  .ValueOrDie()
206                  .get();
207   }
208   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
209   return std::move(ret);
210 }
211 
212 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text)213 LocalClientTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) {
214   return ParseAndReturnVerifiedModule(hlo_text, HloModuleConfig());
215 }
216 
217 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,const HloModuleConfig & config)218 LocalClientTestBase::ParseAndReturnVerifiedModule(
219     absl::string_view hlo_text, const HloModuleConfig& config) {
220   auto module = std::make_unique<VerifiedHloModule>(
221       TestName(), config, /*verifier_layout_sensitive=*/false,
222       /*allow_mixed_precision_in_hlo_verifier=*/true,
223       local_client_->backend().compiler()->ShapeSizeBytesFunction());
224   TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
225   return std::move(module);
226 }
227 
228 }  // namespace xla
229