1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "absl/types/span.h" 24 #include "tensorflow/compiler/xla/client/client.h" 25 #include "tensorflow/compiler/xla/client/executable_build_options.h" 26 #include "tensorflow/compiler/xla/client/xla_computation.h" 27 #include "tensorflow/compiler/xla/executable_run_options.h" 28 #include "tensorflow/compiler/xla/service/compiler.h" 29 #include "tensorflow/compiler/xla/service/executable.h" 30 #include "tensorflow/compiler/xla/service/hlo.pb.h" 31 #include "tensorflow/compiler/xla/service/local_service.h" 32 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" 33 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 34 #include "tensorflow/compiler/xla/shape_tree.h" 35 #include "tensorflow/compiler/xla/statusor.h" 36 #include "tensorflow/compiler/xla/xla_data.pb.h" 37 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 38 #include "tensorflow/stream_executor/device_memory_allocator.h" 39 40 namespace xla { 41 42 class LocalExecutable { 43 public: 44 // Low-level constructor; LocalClient::Compile() is the usual way to create 45 // executables. 46 LocalExecutable(std::unique_ptr<Executable> executable, Backend* backend, 47 ExecutableBuildOptions build_options); 48 49 // Run the compiled computation with the given arguments and options and 50 // return the result. 51 StatusOr<ScopedShapedBuffer> Run( 52 const absl::Span<const ShapedBuffer* const> arguments, 53 ExecutableRunOptions run_options); 54 55 // Similar to Run(), but allows for donating argument buffers to the 56 // executable. 57 StatusOr<ExecutionOutput> Run(std::vector<ExecutionInput> arguments, 58 ExecutableRunOptions run_options); 59 60 // Similar to Run(), but need not block the host waiting for the computation 61 // to complete before returning. 62 StatusOr<ScopedShapedBuffer> RunAsync( 63 const absl::Span<const ShapedBuffer* const> arguments, 64 ExecutableRunOptions run_options); 65 66 // Similar to RunAsync(), but allows for donating argument buffers to the 67 // executable. 68 StatusOr<ExecutionOutput> RunAsync(std::vector<ExecutionInput> arguments, 69 ExecutableRunOptions run_options); 70 71 // Return the options used to build the executable. build_options()72 const ExecutableBuildOptions& build_options() const { return build_options_; } 73 74 // Return the built executable. executable()75 Executable* executable() const { return executable_.get(); } 76 77 private: 78 StatusOr<ExecutionOutput> RunAsync( 79 absl::Span<Shape const* const> argument_host_shapes, 80 std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options); 81 82 // Validates that the given arguments and options satisfy various constraints 83 // of the computation. 84 // 85 // The given ExecutableRunOptions override any values from TF_XLA_FLAGS 86 // environment variable. 87 Status ValidateExecutionOptions(const ExecutableRunOptions& run_options, 88 const Backend& backend); 89 90 // Returns a literal containing the contents of the given ShapedBuffer. 91 StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer); 92 93 StatusOr<std::pair<ServiceExecutableRunOptions, StreamPool::Ptr>> RunHelper( 94 const absl::Span<const Shape* const> argument_shapes, 95 ExecutableRunOptions run_options); 96 97 // The ordinal of the device which this executable was compiled for. The 98 // executable can run on all equivalent devices (as determined by 99 // Backend::devices_equivalent). build_device_ordinal()100 int build_device_ordinal() const { return build_options_.device_ordinal(); } 101 102 template <typename T> AsyncCallAndBlockHostUntilDone(absl::Span<Shape const * const> argument_shapes,const ExecutableRunOptions & run_options,std::function<StatusOr<T> (const ExecutableRunOptions &)> async_callback)103 StatusOr<T> AsyncCallAndBlockHostUntilDone( 104 absl::Span<Shape const* const> argument_shapes, 105 const ExecutableRunOptions& run_options, 106 std::function<StatusOr<T>(const ExecutableRunOptions&)> async_callback) { 107 TF_ASSIGN_OR_RETURN(auto options_and_stream, 108 RunHelper(argument_shapes, run_options)); 109 ExecutableRunOptions options = options_and_stream.first.run_options(); 110 options.set_device_ordinal(-1); 111 StatusOr<T> result = async_callback(options); 112 Status block_status = options.stream()->BlockHostUntilDone(); 113 TF_RETURN_IF_ERROR(result.status()); 114 TF_RETURN_IF_ERROR(block_status); 115 return result; 116 } 117 118 // Compiled computation. 119 std::unique_ptr<Executable> executable_; 120 121 // Execution backend. 122 Backend* backend_ = nullptr; 123 124 // Options used to build the executable. 125 const ExecutableBuildOptions build_options_; 126 }; 127 128 // An XLA Client specialization for use when the client and service run in 129 // the same process. 130 class LocalClient : public Client { 131 public: LocalClient(LocalService * service)132 explicit LocalClient(LocalService* service) 133 : Client(service), local_service_(service) {} 134 135 LocalClient(const LocalClient&) = delete; 136 void operator=(const LocalClient&) = delete; 137 138 // Build and return LocalExecutable objects (one per partition, as specified 139 // by the build options). The executable is compiled using the given 140 // XlaComputation, argument layouts and options. 141 // 142 // The given ExecutableBuildOptions overrides any values from XLA_FLAGS 143 // environment variable. 144 StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> Compile( 145 const XlaComputation& computation, 146 const absl::Span<const Shape* const> argument_layouts, 147 const ExecutableBuildOptions& options); 148 149 // Same as Compile() above, but return AotCompilationResult objects (instead 150 // of LocalExecutable objects), which can be persisted to later load 151 // LocalExecutable(s) using the Load() method below. 152 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 153 CompileAheadOfTime(const XlaComputation& computation, 154 const absl::Span<const Shape* const> argument_layouts, 155 const ExecutableBuildOptions& options); 156 157 // Return a LocalExecutable object loaded from a serialized 158 // AotCompilationResult. 159 StatusOr<std::unique_ptr<LocalExecutable>> Load( 160 const std::string& serialized_aot_result, 161 const ExecutableBuildOptions& options); 162 163 // Copy the literal data to the device with the given ordinal and return as a 164 // ScopedShapedBuffer. If non-null the given memory allocator is used for 165 // device memory allocation. If null, the default memory allocator for the 166 // device is used. 167 StatusOr<ScopedShapedBuffer> LiteralToShapedBuffer( 168 const LiteralSlice& literal, int device_ordinal, 169 se::DeviceMemoryAllocator* allocator = nullptr); 170 171 // Transfer the BorrowingLiteral to the device with the given ordinal. 172 StatusOr<TransferToServerResponse> TransferToLocalServer( 173 const ::xla::BorrowingLiteral& literal, int device_ordinal); 174 175 // Copy the data from the device contained in the given ShapedBuffer and 176 // return as a Literal. 177 StatusOr<Literal> ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); 178 179 // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid 180 // as long as the handle is valid. 181 StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer( 182 const GlobalDataHandle& data, int replica_number); 183 184 // Transfer the given literal to the infeed queue of the given device. 185 // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does 186 // not inherit from Client and there is no possibility of confusion with 187 // Client::TransferToInfeed. 188 Status TransferToInfeedLocal(const LiteralSlice& literal, int device_ordinal); 189 190 // Transfer and return a value from the outfeed of the given device. The 191 // shape of the object to transfer is determined by `literal`'s shape. 192 // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does 193 // not inherit from Client and there is no possibility of confusion with 194 // Client::TransferFromOutfeed. 195 Status TransferFromOutfeedLocal(int device_ordinal, 196 MutableBorrowingLiteral literal); 197 198 // Returns the device ordinal that corresponds to the given replica number. 199 // 200 // This returns an error if there is not a one-to-one correspondence of 201 // replicas to device ordinals, but is useful as a short term mechanism for 202 // the "easy" case where a single replica is a single device. 203 StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number); 204 205 // Returns the platform that the underlying service targets. 206 se::Platform* platform() const; 207 208 // Returns the number of devices on the system of the service platform 209 // type. Not all devices may be supported by the service (see 210 // device_ordinal_supported method). 211 int device_count() const; 212 213 // Returns the default device ordinal that the service will run computations 214 // on if no device ordinal is specified in execute options. 215 int default_device_ordinal() const; 216 217 // Returns whether the device with the given ordinal can be used by the 218 // service to execute computations. Not all devices of a particular platform 219 // may be usable by the service (eg, a GPU with insufficient CUDA compute 220 // capability). 221 bool device_ordinal_supported(int device_ordinal) const; 222 223 // Returns the backend used to execute computations. 224 const Backend& backend() const; 225 Backend* mutable_backend(); 226 227 private: 228 LocalService* local_service_; 229 }; 230 231 } // namespace xla 232 233 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_ 234