1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BACKEND_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BACKEND_H_ 18 19 #include <map> 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <vector> 24 25 #include "absl/container/flat_hash_map.h" 26 #include "absl/strings/str_cat.h" 27 #include "absl/types/span.h" 28 #include "tensorflow/compiler/xla/service/compiler.h" 29 #include "tensorflow/compiler/xla/service/computation_placer.h" 30 #include "tensorflow/compiler/xla/service/stream_pool.h" 31 #include "tensorflow/compiler/xla/service/transfer_manager.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/compiler/xla/types.h" 34 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 35 #include "tensorflow/stream_executor/device_memory_allocator.h" 36 37 namespace Eigen { 38 struct ThreadPoolDevice; 39 } 40 41 namespace xla { 42 43 // Options to configure the backend when it is created. 44 class BackendOptions { 45 public: 46 // Set the platform backing the backend, or nullptr for the default platform. 47 BackendOptions& set_platform(se::Platform* platform); 48 se::Platform* platform() const; 49 50 // Sets the thread pool size for parallel execution of an individual operator. 51 // The default value of -1 will result in initializing the thread pool with 52 // the number of threads equal to the number of cores in the system. 53 BackendOptions& set_intra_op_parallelism_threads(int num_threads); 54 int intra_op_parallelism_threads() const; 55 56 // Sets the allowed_devices for selectively constructing stream executors 57 // on the platform. 58 BackendOptions& set_allowed_devices( 59 const std::optional<std::set<int>>& allowed_devices); 60 const std::optional<std::set<int>>& allowed_devices() const; 61 62 private: 63 se::Platform* platform_ = nullptr; 64 int intra_op_parallelism_threads_ = -1; 65 std::optional<std::set<int>> allowed_devices_; 66 }; 67 68 // Class which encapsulates an XLA backend. It includes everything necessary 69 // to compile and execute computations on a particular platform. 70 // 71 // It also offers a pooling API for creation/use of initialized streams: 72 // 73 // StreamPool::Ptr stream = backend->BorrowStream().value(); 74 class Backend { 75 public: 76 // Creates a new backend. 77 static StatusOr<std::unique_ptr<Backend>> CreateBackend( 78 const BackendOptions& options); 79 80 // Creates a backend for the default platform. The default platform is defined 81 // in PlatformUtil. 82 static StatusOr<std::unique_ptr<Backend>> CreateDefaultBackend(); 83 84 ~Backend(); 85 86 // Accessors for the various objects. platform()87 se::Platform* platform() const { return platform_; } compiler()88 Compiler* compiler() const { return compiler_; } memory_allocator()89 se::DeviceMemoryAllocator* memory_allocator() const { 90 return memory_allocator_.get(); 91 } shared_memory_allocator()92 std::shared_ptr<se::DeviceMemoryAllocator> shared_memory_allocator() const { 93 return memory_allocator_; 94 } transfer_manager()95 TransferManager* transfer_manager() const { return transfer_manager_; } computation_placer()96 ComputationPlacer* computation_placer() const { return computation_placer_; } 97 98 // Returns the number of devices of the platform type which are visible. Not 99 // all of these devices may be usable by XLA. device_count()100 int device_count() const { return stream_executors_.size(); } 101 102 // Returns the device ordinal number of the default device. 103 int default_device_ordinal() const; 104 105 // Returns stream executors of all supported devices for this backend. The 106 // executors are ordered by the device ordinal. stream_executors()107 const std::vector<se::StreamExecutor*>& stream_executors() const { 108 return stream_executors_; 109 } 110 111 // Returns the stream executor for the given device ordinal. 112 StatusOr<se::StreamExecutor*> stream_executor(int device_ordinal) const; 113 114 // Returns the stream executor for the default device ordinal. This stream 115 // executor can only be used when the number of computations is 1 (replication 116 // can be > 1). default_stream_executor()117 se::StreamExecutor* default_stream_executor() const { 118 CHECK(!stream_executors_.empty()); 119 return stream_executors_[0]; 120 } 121 122 // Borrows a stream for use by the caller, either by grabbing it from an 123 // internal pool, or by constructing/initializating it, and returns the result 124 // to the caller. 125 StatusOr<StreamPool::Ptr> BorrowStream(int device_ordinal); 126 StatusOr<StreamPool::Ptr> BorrowStream(se::StreamExecutor* executor); 127 128 // Returns a function to borrow a stream, as `BorrowStream` above does. 129 // Purely for convenience, the caller could rather make this anonymous 130 // function itself. StreamBorrower()131 std::function<StatusOr<StreamPool::Ptr>(int)> StreamBorrower() { 132 return [this](int device_ordinal) { return BorrowStream(device_ordinal); }; 133 } 134 135 // Returns whether the given device ordinal of the backend is supported. device_ordinal_supported(int device_ordinal)136 bool device_ordinal_supported(int device_ordinal) const { 137 return (device_ordinal >= 0 && device_ordinal < device_count() && 138 stream_executors_[device_ordinal] != nullptr); 139 } 140 141 // Return a string identifier for the given device, eg: "GPU:3". device_name(int device_ordinal)142 std::string device_name(int device_ordinal) const { 143 return absl::StrCat(platform_->Name(), ":", device_ordinal); 144 } 145 146 // Returns true if the devices with the given ordinals are equivalent from 147 // XLA's perspective. That is, an executable compiled for one device would 148 // be equivalent to an executable compiled for the other. 149 StatusOr<bool> devices_equivalent(int device_ordinal_a, int device_ordinal_b); 150 151 // For the host platform, returns the configured eigen threadpool device to be 152 // used for scheduling work. For other platforms, returns NULL. 153 const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const; 154 tensorflow::thread::ThreadPool* eigen_intra_op_thread_pool() const; 155 156 // Resets the devices associated with this backend. 157 Status ResetDevices(); 158 159 private: 160 Backend(se::Platform* platform, Compiler* compiler, 161 absl::Span<se::StreamExecutor* const> stream_executors, 162 TransferManager* transfer_manager, 163 ComputationPlacer* computation_placer, 164 int intra_op_parallelism_threads); 165 Backend(const Backend&) = delete; 166 Backend& operator=(const Backend&) = delete; 167 168 se::Platform* platform_; 169 Compiler* compiler_; 170 TransferManager* transfer_manager_; 171 ComputationPlacer* computation_placer_; 172 173 // Vector of stream executors. stream_executors_[0] is the default executor. 174 std::vector<se::StreamExecutor*> stream_executors_; 175 176 absl::Mutex mu_; 177 178 // Mapping from stream executor to stream pools, used by `BorrowStream` above. 179 absl::flat_hash_map<se::StreamExecutor*, std::unique_ptr<StreamPool>> 180 stream_pools_ ABSL_GUARDED_BY(mu_); 181 182 // The default memory allocator to use. 183 // This must be a shared_ptr, as this is passed all the way down to the 184 // cluster compilation. This allows asynchronous compilation to hold a 185 // referecence until the compilation is finished. 186 std::shared_ptr<se::StreamExecutorMemoryAllocator> memory_allocator_; 187 188 // For the CPU backend, an Eigen threadpool device for use by Eigen code. 189 struct IntraOpThreadPool; 190 std::unique_ptr<IntraOpThreadPool> intra_op_thread_pool_; 191 }; 192 193 } // namespace xla 194 195 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_BACKEND_H_ 196