1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ 18 19 #include <map> 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <vector> 24 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/service/backend.h" 27 #include "tensorflow/compiler/xla/service/compiler.h" 28 #include "tensorflow/compiler/xla/service/computation_placer.h" 29 #include "tensorflow/compiler/xla/service/executable.h" 30 #include "tensorflow/compiler/xla/service/hlo_computation.h" 31 #include "tensorflow/compiler/xla/service/hlo_module.h" 32 #include "tensorflow/compiler/xla/service/hlo_runner_interface.h" 33 #include "tensorflow/compiler/xla/status_macros.h" 34 #include "tensorflow/compiler/xla/statusor.h" 35 #include "tensorflow/compiler/xla/types.h" 36 #include "tensorflow/compiler/xla/util.h" 37 #include "tensorflow/compiler/xla/xla_data.pb.h" 38 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 39 40 namespace xla { 41 42 // A base class for running an HloModule. This executes the given HloModule on a 43 // certain backend directly without using the client interface. HloModule can be 44 // explicitly built, or loaded from a serialization file (e.g., hlo proto 45 // file), or parsed from a hlo textual IR string. 46 class HloRunner : public HloRunnerInterface { 47 public: 48 // intra_op_parallelism_threads: For the CPU backend only. It is the thread 49 // pool size for parallel execution of an individual operator. The default 50 // value of -1 will result in initializing the thread pool with the number of 51 // threads equal to the number of 52 // cores in the system. 53 explicit HloRunner(se::Platform* platform, 54 int intra_op_parallelism_threads = -1); 55 56 ~HloRunner() override; 57 58 // Transfers data between the host and device. 59 StatusOr<ScopedShapedBuffer> TransferLiteralToDevice(const Literal& literal); 60 StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice( 61 absl::Span<const Literal* const> literals); 62 StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice( 63 absl::Span<const Literal> literals); 64 StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer); 65 66 // Executes the given module with given literals as input and returns the 67 // result as a Literal. 68 // 69 // If run_hlo_passes is false, the module will be executed without Hlo 70 // optimization. 71 72 using HloRunnerInterface::Execute; 73 74 StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, 75 absl::Span<const Literal* const> arguments, 76 bool run_hlo_passes, 77 ExecutionProfile* profile) override; 78 79 using HloRunnerInterface::ExecuteWithExecutable; 80 81 StatusOr<Literal> ExecuteWithExecutable( 82 Executable* executable, absl::Span<const Literal* const> arguments, 83 ExecutionProfile* profile) override; 84 85 // As Execute(), but accepts and returns device buffers instead of host 86 // buffers. 87 StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers( 88 std::unique_ptr<HloModule> module, 89 absl::Span<ScopedShapedBuffer const> arguments, 90 bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); 91 92 StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers( 93 Executable* executable, absl::Span<ScopedShapedBuffer const> arguments, 94 ExecutionProfile* profile = nullptr); 95 96 // Creates an executable object given an HLO module. If run_hlo_passes is 97 // true, the HLO passes will be run as part of compilation. 98 StatusOr<std::unique_ptr<Executable>> CreateExecutable( 99 std::unique_ptr<HloModule> module, bool run_hlo_passes) override; 100 101 // Executes a given HLO module into a set of replicas, and returns a map 102 // with the replica number as key, and the corresponding returned literal as 103 // value. 104 StatusOr<std::vector<Literal>> ExecuteReplicated( 105 std::unique_ptr<HloModule> module, 106 const ReplicatedExecuteOptions& options) override; 107 108 // Same as above, but with specified device assignment. 109 StatusOr<std::vector<Literal>> ExecuteReplicated( 110 std::unique_ptr<HloModule> module, 111 const ReplicatedExecuteOptions& options, 112 DeviceAssignment* device_assignment) override; 113 114 // Same as above, but with a reusable Executable. This may update the profile 115 // information in *executable. 116 // 117 // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes, 118 // since we've already compiled the Executable. 119 StatusOr<std::vector<Literal>> ExecuteReplicated( 120 Executable* executable, const ReplicatedExecuteOptions& options, 121 DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr); 122 123 // Same as above, but with different reusable Executables. This may update the 124 // profile information in *executables. 125 // 126 // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes, 127 // since we've already compiled the Executable. 128 StatusOr<std::vector<Literal>> ExecuteReplicated( 129 std::function<Executable*(int64_t)> executable_provider, 130 std::function<int64_t(int64_t)> argument_count_provider, 131 std::function<const Literal*(int64_t, int64_t)> argument_provider, 132 const ReplicatedExecuteOptions& options, 133 DeviceAssignment* device_assignment = nullptr); 134 135 // If backend is not created in the constructor, creates and returns the 136 // default backend. If creation fails, crashes the program. 137 // 138 // This creates the backend lazily so it's possible to instantiate an 139 // HloRunner in a program without any backends linked in. 140 Backend& backend(); 141 const Backend& backend() const; 142 143 absl::string_view Name() const override; 144 device_shape_representation_fn()145 DeviceShapeRepresentationFn device_shape_representation_fn() { 146 return device_shape_representation_fn_; 147 } 148 149 private: 150 // Creates a ServiceExecutableRunOptions object to configure a run on device, 151 // using the provided stream object. If device_assignment is not nullptr, it 152 // will be used to configure the replication parameters. Replicated executions 153 // should pass the device_assignment parameter. 154 ServiceExecutableRunOptions GetServiceRunOptionsForDevice( 155 int64_t device, se::Stream* stream, DeviceAssignment* device_assignment, 156 RunId run_id); 157 158 // Common implementation code for ExecuteReplicated() above. 159 StatusOr<std::vector<Literal>> ExecuteReplicatedImpl( 160 std::function<StatusOr<std::vector<ScopedShapedBuffer>>( 161 const std::vector<ServiceExecutableRunOptions>&, 162 const std::vector<absl::Span<const ShapedBuffer* const>>&)> 163 execution_helper, 164 std::function<int64_t(int64_t)> argument_count_provider, 165 std::function<const Literal*(int64_t, int64_t)> argument_provider, 166 const ReplicatedExecuteOptions& options, 167 DeviceAssignment* device_assignment); 168 169 std::unique_ptr<Backend> backend_; 170 171 DeviceShapeRepresentationFn device_shape_representation_fn_; 172 }; 173 174 } // namespace xla 175 176 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ 177