1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_ 18 19 #include <map> 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <vector> 24 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/service/computation_placer.h" 27 #include "tensorflow/compiler/xla/service/executable.h" 28 #include "tensorflow/compiler/xla/service/hlo_computation.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/status_macros.h" 31 #include "tensorflow/compiler/xla/statusor.h" 32 #include "tensorflow/compiler/xla/types.h" 33 #include "tensorflow/compiler/xla/util.h" 34 #include "tensorflow/compiler/xla/xla_data.pb.h" 35 36 namespace xla { 37 38 // A base class for running an HloModule. This executes the given HloModule on a 39 // certain backend directly without using the client interface. HloModule can be 40 // explicitly built, or loaded from a serialization file (e.g., hlo proto 41 // file), or parsed from a hlo textual IR string. 42 class HloRunnerInterface { 43 public: 44 // The options used to configure an ExecuteReplicated() call. 45 struct ReplicatedExecuteOptions { 46 // The number of devices the HLO module should be replicated onto. 47 int64_t num_replicas = 1; 48 49 // The arguments to be fed to each replica. Since this is used for a 50 // replicated execution, all the arguments are the same for all replicas. 51 std::vector<const Literal*> arguments; 52 53 // If the HLO module being run has an infeed instruction, this will be the 54 // data which will be fed to it, for as many as infeed_steps steps. 55 std::vector<const Literal*> infeed_values; 56 57 // The number of times the infeed literal should be fed to the HLO module. 58 // For a clean exit, this should match the iterations-per-loop parameter 59 // used when generating the HLO module proto (that is usually the main 60 // while boundary counter). A value higher then iterations-per-loop would 61 // lead to infeed threads feeding to a gone computation, while a lower 62 // value would trigger a stuck ExecuteReplicated() call (the computation 63 // will be trying to infeed data which will never come). 64 int64_t infeed_steps = -1; 65 66 // The shape of the outfeed operation. If empty, the HLO module does not 67 // generate any outfeed. 68 Shape outfeed_shape; 69 70 // A pointer to a vector where the outfeed values will be stored. If 71 // nullptr, the values will be read and discarded. 72 std::vector<Literal>* outfeed_values = nullptr; 73 74 // Whether the HLO passes should be run on the input module. Usually 75 // saved modules are coming from after the HLO pass pipeline, so triggering 76 // another run will likely cause errors. 77 bool run_hlo_passes = false; 78 79 // If true, executes on multiple threads using se::Stream::ExecuteOnStream. 80 // Otherwise, executes using xla::Executable::ExecuteOnStreams. 81 bool use_threads = false; 82 }; 83 84 HloRunnerInterface() = default; 85 86 virtual ~HloRunnerInterface() = default; 87 88 // Converts an HloModule from the given hlo textual IR string (in 89 // HloModule::ToString format). 90 static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString( 91 const absl::string_view hlo_string, const DebugOptions& debug_options); 92 93 // Reads the proto file in xla.HloProto format, creates and returns the 94 // HloModule. 95 static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile( 96 const std::string& filename, const DebugOptions& debug_options); 97 static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile( 98 const std::string& filename, const DebugOptions& debug_options); 99 100 // Reads the proto file in xla.HloModule format, creates and returns the 101 // HloModule. 102 static StatusOr<std::unique_ptr<HloModule>> 103 ReadModuleFromModuleBinaryProtofile(const std::string& filename, 104 const DebugOptions& debug_options); 105 106 // Reads the hlo text dump file in HloModule::ToString format, creates and 107 // returns the HloModule. 108 static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile( 109 const std::string& filename, const DebugOptions& debug_options); 110 111 // Creates an executable object given an HLO module. If run_hlo_passes is 112 // true, the HLO passes will be run as part of compilation. 113 virtual StatusOr<std::unique_ptr<Executable>> CreateExecutable( 114 std::unique_ptr<HloModule> module, bool run_hlo_passes) = 0; 115 116 // Executes the given module with given literals as input and returns the 117 // result as a Literal. 118 // 119 // If run_hlo_passes is false, the module will be executed without Hlo 120 // optimization 121 StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, 122 absl::Span<const Literal* const> arguments, 123 bool run_hlo_passes = true) { 124 return Execute(std::move(module), arguments, run_hlo_passes, nullptr); 125 } 126 127 StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, 128 absl::Span<const Literal> arguments, 129 bool run_hlo_passes = true, 130 ExecutionProfile* profile = nullptr); 131 132 virtual StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, 133 absl::Span<const Literal* const> arguments, 134 bool run_hlo_passes, 135 ExecutionProfile* profile) = 0; 136 137 // Same as above, but with Executable as input. 138 StatusOr<Literal> ExecuteWithExecutable(Executable* executable, 139 absl::Span<const Literal> arguments, 140 ExecutionProfile* profile = nullptr); 141 ExecuteWithExecutable(Executable * executable,absl::Span<const Literal * const> arguments)142 StatusOr<Literal> ExecuteWithExecutable( 143 Executable* executable, absl::Span<const Literal* const> arguments) { 144 return ExecuteWithExecutable(executable, arguments, nullptr); 145 } 146 147 virtual StatusOr<Literal> ExecuteWithExecutable( 148 Executable* executable, absl::Span<const Literal* const> arguments, 149 ExecutionProfile* profile) = 0; 150 151 // Executes a given HLO module into a set of replicas, and returns a map 152 // with the replica number as key, and the corresponding returned literal as 153 // value. 154 // TODO(b/172931928): change to non-virtual function. 155 virtual StatusOr<std::vector<Literal>> ExecuteReplicated( 156 std::unique_ptr<HloModule> module, 157 const ReplicatedExecuteOptions& options) = 0; 158 159 // Same as above, but with specified device assignment. 160 virtual StatusOr<std::vector<Literal>> ExecuteReplicated( 161 std::unique_ptr<HloModule> module, 162 const ReplicatedExecuteOptions& options, 163 DeviceAssignment* device_assignment) = 0; 164 165 virtual StatusOr<std::vector<Literal>> ExecuteReplicated( 166 std::function<Executable*(int64_t)> executable_provider, 167 std::function<int64_t(int64_t)> argument_count_provider, 168 std::function<const Literal*(int64_t, int64_t)> argument_provider, 169 const ReplicatedExecuteOptions& options, 170 DeviceAssignment* device_assignment) = 0; 171 172 // Returns the name of this runner. 173 virtual absl::string_view Name() const = 0; 174 175 typedef std::function<Shape(const Shape&)> DeviceShapeRepresentationFn; 176 }; 177 178 } // namespace xla 179 180 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_ 181