1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <vector> 24 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/debug_options_flags.h" 27 #include "tensorflow/compiler/xla/executable_run_options.h" 28 #include "tensorflow/compiler/xla/service/allocation_tracker.h" 29 #include "tensorflow/compiler/xla/service/backend.h" 30 #include "tensorflow/compiler/xla/service/channel_tracker.h" 31 #include "tensorflow/compiler/xla/service/compilation_cache.h" 32 #include "tensorflow/compiler/xla/service/executable.h" 33 #include "tensorflow/compiler/xla/service/execution_tracker.h" 34 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" 35 #include "tensorflow/compiler/xla/service/hlo_module.h" 36 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 37 #include "tensorflow/compiler/xla/service_interface.h" 38 #include "tensorflow/compiler/xla/statusor.h" 39 #include "tensorflow/compiler/xla/types.h" 40 #include "tensorflow/compiler/xla/xla.pb.h" 41 #include "tensorflow/compiler/xla/xla_data.pb.h" 42 #include "tensorflow/core/platform/logging.h" 43 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 44 #include "tensorflow/stream_executor/device_memory_allocator.h" 45 46 namespace xla { 47 48 // Options to configure the service when it is created. 49 class ServiceOptions { 50 public: 51 // Set the platform backing the service, or nullptr for the default platform. 52 ServiceOptions& set_platform(se::Platform* platform); 53 se::Platform* platform() const; 54 55 // Set the default number of replicas to use when compiling replicated 56 // programs. 57 ServiceOptions& set_number_of_replicas(int number_of_replicas); 58 int number_of_replicas() const; 59 60 // Sets the thread pool size for parallel execution of an individual operator. 61 ServiceOptions& set_intra_op_parallelism_threads(int num_threads); 62 int intra_op_parallelism_threads() const; 63 64 // Sets the allowed_devices set for selectively constructing stream executors 65 // on the platform. 66 ServiceOptions& set_allowed_devices( 67 const std::optional<std::set<int>>& allowed_devices); 68 const std::optional<std::set<int>>& allowed_devices() const; 69 70 private: 71 se::Platform* platform_ = nullptr; 72 int number_of_replicas_ = 1; 73 int intra_op_parallelism_threads_ = -1; 74 std::optional<std::set<int>> allowed_devices_; 75 }; 76 77 // The XLA service object, which is the same across all platforms. It maintains 78 // the service state of computations and allocations, and delegates 79 // target-specific requests to the target-specific infrastructure 80 // (target-specific compiler, StreamExecutor). 81 class Service : public ServiceInterface { 82 public: 83 // Factory method for creating a new Service. 84 static StatusOr<std::unique_ptr<Service>> NewService( 85 se::Platform* platform = nullptr); 86 static StatusOr<std::unique_ptr<Service>> NewService( 87 const ServiceOptions& options); 88 89 // Unregisters a previously-allocated global handle. 90 // 91 // If the handle given is not currently allocated, a NOT_FOUND status is 92 // returned. 93 Status Unregister(const UnregisterRequest* arg, 94 UnregisterResponse* result) override; 95 96 // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each 97 // element in the tuple. 98 Status DeconstructTuple(const DeconstructTupleRequest* arg, 99 DeconstructTupleResponse* result) override; 100 101 // Compiles a computation into an executable. The request contains the whole 102 // computation graph. Returns the handle to the executable. 103 Status Compile(const CompileRequest* arg, CompileResponse* result) override; 104 105 // Executes an executable with the provided global data passes as immutable 106 // arguments. The request contains the handle to the executable. Returns 107 // global data output and execution timing. 108 Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; 109 110 // Executes one or more computations in parallel with the provided global data 111 // passed as immutable arguments. Returns global data output for each 112 // computation. 113 Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, 114 ExecuteParallelResponse* result) override; 115 116 // Requests one or more device handles from the target. 117 // 118 // When N device handles are requested and the number of replicas is R, at 119 // least N * R devices must be available. The devices are assigned based on 120 // the device ordinals such that the first R available devices are assigned to 121 // the first set of replicas, and the next R devices to the second set of 122 // replicas, etc. Each returned device handle represents the device with the 123 // replica id 0. 124 Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, 125 GetDeviceHandlesResponse* result) override; 126 127 // Waits until the specified execution is complete and returns the result. 128 // Calling this API multiple times with the same execution handle returns the 129 // method with an error since the execution handle is destroyed after the 130 // first call. 131 Status WaitForExecution(const WaitForExecutionRequest* arg, 132 WaitForExecutionResponse* result) override; 133 134 // Requests that global data be transferred to the client in literal form. 135 Status TransferToClient(const TransferToClientRequest* arg, 136 TransferToClientResponse* result) override; 137 138 // Transfers data from a literal provided by the client, into device memory. 139 Status TransferToServer(const TransferToServerRequest* arg, 140 TransferToServerResponse* result) override; 141 142 // Transfers data from a literal provided by the client, into the Infeed 143 // buffer of the device. 144 Status TransferToInfeed(const TransferToInfeedRequest* arg, 145 TransferToInfeedResponse* result) override; 146 147 // Transfers data from the Outfeed othe device to the literal provided by the 148 // client. 149 Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, 150 TransferFromOutfeedResponse* result) override; 151 152 // Resets devices, clearing all existing state on all the devices associated 153 // with this service (including memory allocated on the devices). 154 // 155 // ResetDevice may only be called where no previous Execution state on the 156 // device is used by the next Execution. 157 // 158 // ResetDevice should be called before an Execution that expect the device to 159 // be in the reset state. For example, if the prior Execution modifies device 160 // state (e.g., architectural state) that the next Execution depends on. 161 Status ResetDevice(const ResetDeviceRequest* arg, 162 ResetDeviceResponse* result) override; 163 164 Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, 165 ComputeConstantResponse* result) override; 166 167 // Returns the shape (with layout) of an array associated with a given data 168 // handle. 169 Status GetShape(const GetShapeRequest* arg, 170 GetShapeResponse* result) override; 171 172 // Retrieves the statistics of a computation. 173 Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg, 174 ComputationStatsResponse* result) override; 175 176 // Creates a unique channel handle that can be used for Send/Recv 177 // instructions. 178 Status CreateChannelHandle(const CreateChannelHandleRequest* arg, 179 CreateChannelHandleResponse* result) override; 180 181 // Returns the backend used to execute computations. backend()182 const Backend& backend() const { return *execute_backend_; } mutable_backend()183 Backend* mutable_backend() { return execute_backend_.get(); } 184 185 // Create a Hlo module config for the given program shape and arguments. 186 // aot_options is optional; if not given a default is used. 187 StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig( 188 const ProgramShape& program_shape, 189 absl::Span<const Shape* const> argument_shapes, 190 const ExecutionOptions* execution_options, 191 const AotCompilationOptions* aot_options = nullptr); 192 193 private: 194 // A private overload for Service itself, used by other methods within this 195 // class. 196 StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig( 197 const ProgramShape& program_shape, 198 absl::Span<const ShapedBuffer* const> arguments, 199 const ExecutionOptions& execution_options, 200 const AotCompilationOptions* aot_options = nullptr); 201 202 // Prepare the executors for executing parallel. 203 StatusOr<std::vector<se::StreamExecutor*>> GetExecutors( 204 const ExecutionOptions& execution_options, int64_t requests_size, 205 int64_t request_index) const; 206 207 // Prepare the arguments for executing parallel. 208 StatusOr<std::vector<std::vector<const ShapedBuffer*>>> GetArguments( 209 const ExecutionOptions& execution_options, 210 absl::Span<const GlobalDataHandle* const> arguments) const; 211 212 protected: 213 friend class LocalExecutable; 214 215 // The constructor is private. Use the NewService factory to create new 216 // service objects. 217 Service(const ServiceOptions& options, 218 std::unique_ptr<Backend> execute_backend); 219 220 // Resolves the given argument handles in the allocation tracker and returns 221 // the corresponding allocations for every replica. The function also verifies 222 // that each allocation matches the execution platform and device ordinal of 223 // the corresponding replica. 224 StatusOr<std::vector<std::vector<const ShapedBuffer*>>> 225 ResolveAndValidateArguments( 226 absl::Span<const GlobalDataHandle* const> arguments, 227 absl::Span<se::StreamExecutor* const> stream_executors) const; 228 229 // Builds an Executable for the given parameters. 230 // 231 // If device_allocator is not null, the compiler may use it to allocate temp 232 // buffers, which the compiler is responsible for freeing. The allocator 233 // given here need not match the allocator used when running the executable. 234 StatusOr<std::unique_ptr<Executable>> BuildExecutable( 235 const HloModuleProto& module_proto, 236 std::unique_ptr<HloModuleConfig> module_config, Backend* backend, 237 se::StreamExecutor* executor, const Compiler::CompileOptions& options, 238 bool run_backend_only = false); 239 240 // Same as BuildExecutable() above, but builds a list of Executables for the 241 // given computations that may interact with each other. 242 StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables( 243 const std::vector<const HloModuleProto*>& module_protos, 244 std::vector<std::unique_ptr<HloModuleConfig>> module_configs, 245 Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors, 246 const Compiler::CompileOptions& options, bool run_backend_only = false); 247 248 // Same as BuildExecutable() above, but builds a list of 249 // AotCompilationResult(s), which can be persisted to later load Executable 250 // objects. 251 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> BuildAotResults( 252 const std::vector<const HloModuleProto*>& module_protos, 253 std::vector<std::unique_ptr<HloModuleConfig>> module_configs, 254 Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors, 255 const Compiler::CompileOptions& options, bool run_backend_only = false); 256 257 // Runs the given executable with the given arguments and register the result 258 // in the allocation tracker. The handle of the result from the tracker is 259 // returned. If the parameter "profile" is not null, it points to an 260 // ExecutionProfile object which will be filled in with profile data. 261 StatusOr<GlobalDataHandle> ExecuteAndRegisterResult( 262 Executable* executable, 263 absl::Span<const std::vector<const ShapedBuffer*>> arguments, 264 Backend* backend, const DeviceHandle& device_handle, 265 const std::string& result_tag, ExecutionProfile* profile); 266 267 // Runs the given executables with the given arguments and register the result 268 // from each executable in the allocation tracker. The handles of the result 269 // from the tracker are returned. 270 StatusOr<std::vector<GlobalDataHandle>> ExecuteParallelAndRegisterResult( 271 absl::Span<Executable* const> executables, 272 absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments, 273 Backend* backend, absl::Span<const DeviceHandle> device_handles, 274 absl::Span<const std::string> result_tags, ExecutionProfile* profile); 275 276 // Convenience function which checks whether the given client_shape 277 // (presumably passed by the client to set the result layout) is valid for the 278 // given computation result shape. 279 Status ValidateResultShape(const Shape& client_shape, 280 const Shape& result_shape) const; 281 282 // Returns the stream executors assigned to the replicas represented by the 283 // given device handle. Each device_handle is a virtual replicated device that 284 // represents a set of physical devices for the replicas. 285 StatusOr<std::vector<se::StreamExecutor*>> Replicas( 286 const Backend& backend, const DeviceHandle& device_handle) const; 287 288 // Returns the device handle that represents the replicated device for a 289 // single computation that is not model-parallelized. 290 DeviceHandle SingleComputationDeviceHandle() const; 291 292 ServiceOptions options_; 293 294 // Cache containing previously built Executables. 295 CompilationCache compilation_cache_; 296 297 // Tracks channels created via the API. 298 ChannelTracker channel_tracker_; 299 300 // Tracks allocations made via the API and computation execution. 301 AllocationTracker allocation_tracker_; 302 303 // Tracks asynchronously launched executions via the API. 304 ExecutionTracker execution_tracker_; 305 306 // Backend to compile and execute computations on. 307 std::unique_ptr<Backend> execute_backend_; 308 309 Service(const Service&) = delete; 310 Service& operator=(const Service&) = delete; 311 }; 312 313 } // namespace xla 314 315 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ 316