1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16
17 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
18
19 #include <memory>
20 #include <vector>
21
22 #include "absl/strings/string_view.h"
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/compiler/xla/client/local_client.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/map_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
28 #include "tensorflow/compiler/xla/service/hlo_parser.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/test_helpers.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/logging.h"
37
38 namespace xla {
39
40 /* static */ TestAllocator* LocalClientTestBase::allocator_;
41
Allocate(int device_ordinal,uint64_t size,bool retry_on_failure,int64_t memory_space)42 StatusOr<se::OwningDeviceMemory> TestAllocator::Allocate(int device_ordinal,
43 uint64_t size,
44 bool retry_on_failure,
45 int64_t memory_space) {
46 VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")";
47 {
48 absl::MutexLock lock(&count_mutex_);
49 allocation_count_++;
50 device_allocation_count_[device_ordinal]++;
51 }
52 return se::StreamExecutorMemoryAllocator::Allocate(
53 device_ordinal, size, retry_on_failure, memory_space);
54 }
55
Deallocate(int device_ordinal,se::DeviceMemoryBase mem)56 Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
57 VLOG(2) << "Deallocate(" << device_ordinal << ")";
58 {
59 absl::MutexLock lock(&count_mutex_);
60 deallocation_count_++;
61 device_deallocation_count_[device_ordinal]++;
62 }
63 return se::StreamExecutorMemoryAllocator::Deallocate(device_ordinal, mem);
64 }
65
allocation_count() const66 int64_t TestAllocator::allocation_count() const {
67 absl::MutexLock lock(&count_mutex_);
68 return allocation_count_;
69 }
70
allocation_count(int device_ordinal) const71 int64_t TestAllocator::allocation_count(int device_ordinal) const {
72 absl::MutexLock lock(&count_mutex_);
73 auto it = device_allocation_count_.find(device_ordinal);
74 if (it == device_allocation_count_.end()) {
75 return 0;
76 } else {
77 return it->second;
78 }
79 }
80
deallocation_count() const81 int64_t TestAllocator::deallocation_count() const {
82 absl::MutexLock lock(&count_mutex_);
83 return deallocation_count_;
84 }
85
deallocation_count(int device_ordinal) const86 int64_t TestAllocator::deallocation_count(int device_ordinal) const {
87 absl::MutexLock lock(&count_mutex_);
88 auto it = device_deallocation_count_.find(device_ordinal);
89 if (it == device_deallocation_count_.end()) {
90 return 0;
91 } else {
92 return it->second;
93 }
94 }
95
GetOrCreateAllocator(se::Platform * platform)96 /* static */ TestAllocator* LocalClientTestBase::GetOrCreateAllocator(
97 se::Platform* platform) {
98 static absl::Mutex mu(absl::kConstInit);
99 absl::MutexLock lock(&mu);
100
101 if (allocator_ == nullptr) {
102 allocator_ = new TestAllocator(
103 platform == nullptr ? PlatformUtil::GetDefaultPlatform().ValueOrDie()
104 : platform);
105 }
106 return allocator_;
107 }
108
109 // Define this in .cc file to avoid having to include eigen or forward declare
110 // these types in the header.
111 struct LocalClientTestBase::EigenThreadPoolWrapper {
EigenThreadPoolWrapperxla::LocalClientTestBase::EigenThreadPoolWrapper112 explicit EigenThreadPoolWrapper()
113 : pool(new tensorflow::thread::ThreadPool(
114 tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)),
115 device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(),
116 pool->NumThreads())) {}
117
118 std::unique_ptr<tensorflow::thread::ThreadPool> pool;
119 std::unique_ptr<Eigen::ThreadPoolDevice> device;
120 };
121
LocalClientTestBase(se::Platform * platform)122 LocalClientTestBase::LocalClientTestBase(se::Platform* platform)
123 : local_client_(
124 ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()),
125 thread_pool_wrapper_(new EigenThreadPoolWrapper()) {
126 // Take the first executor, since it's the default one.
127 stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform())
128 .ValueOrDie()
129 .front();
130 transfer_manager_ =
131 TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie();
132 }
133
~LocalClientTestBase()134 LocalClientTestBase::~LocalClientTestBase() {}
135
LiteralToShapedBuffer(const Literal & literal)136 ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer(
137 const Literal& literal) {
138 return local_client_
139 ->LiteralToShapedBuffer(literal, local_client_->default_device_ordinal())
140 .value();
141 }
142
ShapedBufferToLiteral(const ShapedBuffer & shaped_buffer)143 Literal LocalClientTestBase::ShapedBufferToLiteral(
144 const ShapedBuffer& shaped_buffer) {
145 return local_client_->ShapedBufferToLiteral(shaped_buffer).value();
146 }
147
DefaultExecutableBuildOptions() const148 ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions()
149 const {
150 return ExecutableBuildOptions();
151 }
152
DefaultExecutableRunOptions() const153 ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const {
154 ExecutableRunOptions run_options;
155 run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get());
156 run_options.set_allocator(GetOrCreateAllocator(local_client_->platform()));
157 return run_options;
158 }
159
ExecuteLocallyOrDie(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments)160 ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
161 const XlaComputation& computation,
162 absl::Span<const ShapedBuffer* const> arguments) {
163 return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
164 DefaultExecutableRunOptions())
165 .value();
166 }
167
ExecuteLocallyOrDie(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments,const ExecutableBuildOptions & build_options,const ExecutableRunOptions & run_options)168 ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
169 const XlaComputation& computation,
170 absl::Span<const ShapedBuffer* const> arguments,
171 const ExecutableBuildOptions& build_options,
172 const ExecutableRunOptions& run_options) {
173 return ExecuteLocally(computation, arguments, build_options, run_options)
174 .value();
175 }
176
ExecuteLocally(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments)177 StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
178 const XlaComputation& computation,
179 absl::Span<const ShapedBuffer* const> arguments) {
180 return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
181 DefaultExecutableRunOptions());
182 }
183
ExecuteLocally(const XlaComputation & computation,absl::Span<const ShapedBuffer * const> arguments,const ExecutableBuildOptions & build_options,const ExecutableRunOptions & run_options)184 StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
185 const XlaComputation& computation,
186 absl::Span<const ShapedBuffer* const> arguments,
187 const ExecutableBuildOptions& build_options,
188 const ExecutableRunOptions& run_options) {
189 std::vector<const Shape*> argument_layouts(arguments.size());
190 for (int i = 0; i < arguments.size(); ++i) {
191 argument_layouts[i] = &arguments[i]->on_device_shape();
192 }
193 TF_ASSIGN_OR_RETURN(
194 auto executables,
195 local_client_->Compile(computation, argument_layouts, build_options));
196 TF_RET_CHECK(executables.size() == 1);
197 TF_ASSIGN_OR_RETURN(auto ret, executables[0]->Run(arguments, run_options));
198
199 auto device_ordinal =
200 build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal();
201 auto* stream = run_options.stream();
202 if (!stream) {
203 stream = local_client_->mutable_backend()
204 ->BorrowStream(device_ordinal)
205 .ValueOrDie()
206 .get();
207 }
208 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
209 return std::move(ret);
210 }
211
212 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text)213 LocalClientTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) {
214 return ParseAndReturnVerifiedModule(hlo_text, HloModuleConfig());
215 }
216
217 StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,const HloModuleConfig & config)218 LocalClientTestBase::ParseAndReturnVerifiedModule(
219 absl::string_view hlo_text, const HloModuleConfig& config) {
220 auto module = std::make_unique<VerifiedHloModule>(
221 TestName(), config, /*verifier_layout_sensitive=*/false,
222 /*allow_mixed_precision_in_hlo_verifier=*/true,
223 local_client_->backend().compiler()->ShapeSizeBytesFunction());
224 TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
225 return std::move(module);
226 }
227
228 } // namespace xla
229