1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/compiler/xla/service/backend.h"
19
20 #include <algorithm>
21 #include <memory>
22 #include <string>
23 #include <utility>
24
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/compiler/xla/service/compiler.h"
27 #include "tensorflow/compiler/xla/service/platform_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/platform/cpu_info.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
38
39 namespace xla {
40
set_platform(se::Platform * platform)41 BackendOptions& BackendOptions::set_platform(se::Platform* platform) {
42 platform_ = platform;
43 return *this;
44 }
45
platform() const46 se::Platform* BackendOptions::platform() const { return platform_; }
47
set_intra_op_parallelism_threads(int num_threads)48 BackendOptions& BackendOptions::set_intra_op_parallelism_threads(
49 int num_threads) {
50 intra_op_parallelism_threads_ = num_threads;
51 return *this;
52 }
53
intra_op_parallelism_threads() const54 int BackendOptions::intra_op_parallelism_threads() const {
55 return intra_op_parallelism_threads_;
56 }
57
set_allowed_devices(const std::optional<std::set<int>> & allowed_devices)58 BackendOptions& BackendOptions::set_allowed_devices(
59 const std::optional<std::set<int>>& allowed_devices) {
60 allowed_devices_ = allowed_devices;
61 return *this;
62 }
63
allowed_devices() const64 const std::optional<std::set<int>>& BackendOptions::allowed_devices() const {
65 return allowed_devices_;
66 }
67
68 // Define this in .cc file to avoid having to include eigen or forward declare
69 // these types in the header.
70 struct Backend::IntraOpThreadPool {
IntraOpThreadPoolxla::Backend::IntraOpThreadPool71 explicit IntraOpThreadPool(const int num_threads)
72 : pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(),
73 "XLAEigen", num_threads)),
74 device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(),
75 pool->NumThreads())) {}
76
77 std::unique_ptr<tensorflow::thread::ThreadPool> pool;
78 std::unique_ptr<Eigen::ThreadPoolDevice> device;
79 };
80
CreateBackend(const BackendOptions & options)81 /* static */ StatusOr<std::unique_ptr<Backend>> Backend::CreateBackend(
82 const BackendOptions& options) {
83 se::Platform* platform = options.platform();
84 TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
85 TF_ASSIGN_OR_RETURN(
86 auto stream_executors,
87 PlatformUtil::GetStreamExecutors(platform, options.allowed_devices()));
88 TF_ASSIGN_OR_RETURN(auto transfer_manager,
89 TransferManager::GetForPlatform(platform));
90 TF_ASSIGN_OR_RETURN(auto computation_placer,
91 ComputationPlacer::GetForPlatform(platform));
92 std::unique_ptr<Backend> backend(
93 new Backend(platform, compiler, stream_executors, transfer_manager,
94 computation_placer, options.intra_op_parallelism_threads()));
95 return std::move(backend);
96 }
97
98 /* static */ StatusOr<std::unique_ptr<Backend>>
CreateDefaultBackend()99 Backend::CreateDefaultBackend() {
100 TF_ASSIGN_OR_RETURN(se::Platform * platform,
101 PlatformUtil::GetDefaultPlatform());
102 BackendOptions backend_options;
103 backend_options.set_platform(platform);
104 return CreateBackend(backend_options);
105 }
106
BorrowStream(int device_ordinal)107 StatusOr<StreamPool::Ptr> Backend::BorrowStream(int device_ordinal) {
108 TF_ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal));
109 return BorrowStream(executor);
110 }
111
BorrowStream(se::StreamExecutor * executor)112 StatusOr<StreamPool::Ptr> Backend::BorrowStream(se::StreamExecutor* executor) {
113 absl::MutexLock l(&mu_);
114 if (!stream_pools_.contains(executor)) {
115 stream_pools_.emplace(executor, std::make_unique<StreamPool>());
116 }
117 return stream_pools_.at(executor)->BorrowStream(executor);
118 }
119
Backend(se::Platform * platform,Compiler * compiler,absl::Span<se::StreamExecutor * const> stream_executors,TransferManager * transfer_manager,ComputationPlacer * computation_placer,int intra_op_parallelism_threads)120 Backend::Backend(se::Platform* platform, Compiler* compiler,
121 absl::Span<se::StreamExecutor* const> stream_executors,
122 TransferManager* transfer_manager,
123 ComputationPlacer* computation_placer,
124 int intra_op_parallelism_threads)
125 : platform_(platform),
126 compiler_(compiler),
127 transfer_manager_(transfer_manager),
128 computation_placer_(computation_placer),
129 stream_executors_(stream_executors.begin(), stream_executors.end()) {
130 // Create a memory allocator for the valid stream executors.
131 memory_allocator_ = std::make_shared<se::StreamExecutorMemoryAllocator>(
132 platform, stream_executors_);
133 CHECK(!stream_executors_.empty())
134 << "Service found no devices for backend " << platform_->Name() << '.';
135
136 if (platform->id() == se::host::kHostPlatformId) {
137 const int num_threads = intra_op_parallelism_threads > 0
138 ? intra_op_parallelism_threads
139 : tensorflow::port::MaxParallelism();
140 intra_op_thread_pool_.reset(new IntraOpThreadPool(num_threads));
141 }
142 }
143
~Backend()144 Backend::~Backend() {}
145
default_device_ordinal() const146 int Backend::default_device_ordinal() const {
147 return default_stream_executor()->device_ordinal();
148 }
149
eigen_intra_op_thread_pool_device() const150 const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device()
151 const {
152 if (intra_op_thread_pool_ == nullptr) {
153 return nullptr;
154 }
155 return intra_op_thread_pool_->device.get();
156 }
157
eigen_intra_op_thread_pool() const158 tensorflow::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const {
159 if (intra_op_thread_pool_ == nullptr) {
160 return nullptr;
161 }
162 return intra_op_thread_pool_->pool.get();
163 }
164
stream_executor(int device_ordinal) const165 StatusOr<se::StreamExecutor*> Backend::stream_executor(
166 int device_ordinal) const {
167 if (device_ordinal < 0 ||
168 device_ordinal > stream_executors_.back()->device_ordinal()) {
169 return InvalidArgument(
170 "Invalid device ordinal value (%d). Valid range is [0, %d].",
171 device_ordinal, stream_executors_.back()->device_ordinal());
172 }
173 for (auto* executor : stream_executors_) {
174 if (executor->device_ordinal() == device_ordinal) {
175 return executor;
176 }
177 }
178 return InvalidArgument("device %s not supported by XLA service",
179 device_name(device_ordinal));
180 }
181
devices_equivalent(int device_ordinal_a,int device_ordinal_b)182 StatusOr<bool> Backend::devices_equivalent(int device_ordinal_a,
183 int device_ordinal_b) {
184 // Use the name from device description to determine equivalence. This is a
185 // bit crude but works for GPUs which is the important case where we compile
186 // an executable for one GPU and want to know if it will run (well) on
187 // another.
188 TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor_a,
189 stream_executor(device_ordinal_a));
190 TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor_b,
191 stream_executor(device_ordinal_b));
192 return (executor_a->GetDeviceDescription().name() ==
193 executor_b->GetDeviceDescription().name());
194 }
195
ResetDevices()196 Status Backend::ResetDevices() {
197 return transfer_manager_->ResetDevices(stream_executors_);
198 }
199
200 } // namespace xla
201