xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/backend.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/compiler/xla/service/backend.h"
19 
20 #include <algorithm>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/compiler/xla/service/compiler.h"
27 #include "tensorflow/compiler/xla/service/platform_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/platform/cpu_info.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
38 
39 namespace xla {
40 
set_platform(se::Platform * platform)41 BackendOptions& BackendOptions::set_platform(se::Platform* platform) {
42   platform_ = platform;
43   return *this;
44 }
45 
platform() const46 se::Platform* BackendOptions::platform() const { return platform_; }
47 
set_intra_op_parallelism_threads(int num_threads)48 BackendOptions& BackendOptions::set_intra_op_parallelism_threads(
49     int num_threads) {
50   intra_op_parallelism_threads_ = num_threads;
51   return *this;
52 }
53 
intra_op_parallelism_threads() const54 int BackendOptions::intra_op_parallelism_threads() const {
55   return intra_op_parallelism_threads_;
56 }
57 
set_allowed_devices(const std::optional<std::set<int>> & allowed_devices)58 BackendOptions& BackendOptions::set_allowed_devices(
59     const std::optional<std::set<int>>& allowed_devices) {
60   allowed_devices_ = allowed_devices;
61   return *this;
62 }
63 
allowed_devices() const64 const std::optional<std::set<int>>& BackendOptions::allowed_devices() const {
65   return allowed_devices_;
66 }
67 
68 // Define this in .cc file to avoid having to include eigen or forward declare
69 // these types in the header.
70 struct Backend::IntraOpThreadPool {
IntraOpThreadPoolxla::Backend::IntraOpThreadPool71   explicit IntraOpThreadPool(const int num_threads)
72       : pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(),
73                                                 "XLAEigen", num_threads)),
74         device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(),
75                                            pool->NumThreads())) {}
76 
77   std::unique_ptr<tensorflow::thread::ThreadPool> pool;
78   std::unique_ptr<Eigen::ThreadPoolDevice> device;
79 };
80 
CreateBackend(const BackendOptions & options)81 /* static */ StatusOr<std::unique_ptr<Backend>> Backend::CreateBackend(
82     const BackendOptions& options) {
83   se::Platform* platform = options.platform();
84   TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
85   TF_ASSIGN_OR_RETURN(
86       auto stream_executors,
87       PlatformUtil::GetStreamExecutors(platform, options.allowed_devices()));
88   TF_ASSIGN_OR_RETURN(auto transfer_manager,
89                       TransferManager::GetForPlatform(platform));
90   TF_ASSIGN_OR_RETURN(auto computation_placer,
91                       ComputationPlacer::GetForPlatform(platform));
92   std::unique_ptr<Backend> backend(
93       new Backend(platform, compiler, stream_executors, transfer_manager,
94                   computation_placer, options.intra_op_parallelism_threads()));
95   return std::move(backend);
96 }
97 
98 /* static */ StatusOr<std::unique_ptr<Backend>>
CreateDefaultBackend()99 Backend::CreateDefaultBackend() {
100   TF_ASSIGN_OR_RETURN(se::Platform * platform,
101                       PlatformUtil::GetDefaultPlatform());
102   BackendOptions backend_options;
103   backend_options.set_platform(platform);
104   return CreateBackend(backend_options);
105 }
106 
BorrowStream(int device_ordinal)107 StatusOr<StreamPool::Ptr> Backend::BorrowStream(int device_ordinal) {
108   TF_ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal));
109   return BorrowStream(executor);
110 }
111 
BorrowStream(se::StreamExecutor * executor)112 StatusOr<StreamPool::Ptr> Backend::BorrowStream(se::StreamExecutor* executor) {
113   absl::MutexLock l(&mu_);
114   if (!stream_pools_.contains(executor)) {
115     stream_pools_.emplace(executor, std::make_unique<StreamPool>());
116   }
117   return stream_pools_.at(executor)->BorrowStream(executor);
118 }
119 
Backend(se::Platform * platform,Compiler * compiler,absl::Span<se::StreamExecutor * const> stream_executors,TransferManager * transfer_manager,ComputationPlacer * computation_placer,int intra_op_parallelism_threads)120 Backend::Backend(se::Platform* platform, Compiler* compiler,
121                  absl::Span<se::StreamExecutor* const> stream_executors,
122                  TransferManager* transfer_manager,
123                  ComputationPlacer* computation_placer,
124                  int intra_op_parallelism_threads)
125     : platform_(platform),
126       compiler_(compiler),
127       transfer_manager_(transfer_manager),
128       computation_placer_(computation_placer),
129       stream_executors_(stream_executors.begin(), stream_executors.end()) {
130   // Create a memory allocator for the valid stream executors.
131   memory_allocator_ = std::make_shared<se::StreamExecutorMemoryAllocator>(
132       platform, stream_executors_);
133   CHECK(!stream_executors_.empty())
134       << "Service found no devices for backend " << platform_->Name() << '.';
135 
136   if (platform->id() == se::host::kHostPlatformId) {
137     const int num_threads = intra_op_parallelism_threads > 0
138                                 ? intra_op_parallelism_threads
139                                 : tensorflow::port::MaxParallelism();
140     intra_op_thread_pool_.reset(new IntraOpThreadPool(num_threads));
141   }
142 }
143 
~Backend()144 Backend::~Backend() {}
145 
default_device_ordinal() const146 int Backend::default_device_ordinal() const {
147   return default_stream_executor()->device_ordinal();
148 }
149 
eigen_intra_op_thread_pool_device() const150 const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device()
151     const {
152   if (intra_op_thread_pool_ == nullptr) {
153     return nullptr;
154   }
155   return intra_op_thread_pool_->device.get();
156 }
157 
eigen_intra_op_thread_pool() const158 tensorflow::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const {
159   if (intra_op_thread_pool_ == nullptr) {
160     return nullptr;
161   }
162   return intra_op_thread_pool_->pool.get();
163 }
164 
stream_executor(int device_ordinal) const165 StatusOr<se::StreamExecutor*> Backend::stream_executor(
166     int device_ordinal) const {
167   if (device_ordinal < 0 ||
168       device_ordinal > stream_executors_.back()->device_ordinal()) {
169     return InvalidArgument(
170         "Invalid device ordinal value (%d). Valid range is [0, %d].",
171         device_ordinal, stream_executors_.back()->device_ordinal());
172   }
173   for (auto* executor : stream_executors_) {
174     if (executor->device_ordinal() == device_ordinal) {
175       return executor;
176     }
177   }
178   return InvalidArgument("device %s not supported by XLA service",
179                          device_name(device_ordinal));
180 }
181 
devices_equivalent(int device_ordinal_a,int device_ordinal_b)182 StatusOr<bool> Backend::devices_equivalent(int device_ordinal_a,
183                                            int device_ordinal_b) {
184   // Use the name from device description to determine equivalence. This is a
185   // bit crude but works for GPUs which is the important case where we compile
186   // an executable for one GPU and want to know if it will run (well) on
187   // another.
188   TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor_a,
189                       stream_executor(device_ordinal_a));
190   TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor_b,
191                       stream_executor(device_ordinal_b));
192   return (executor_a->GetDeviceDescription().name() ==
193           executor_b->GetDeviceDescription().name());
194 }
195 
ResetDevices()196 Status Backend::ResetDevices() {
197   return transfer_manager_->ResetDevices(stream_executors_);
198 }
199 
200 }  // namespace xla
201