xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_runner.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
18 
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/backend.h"
27 #include "tensorflow/compiler/xla/service/compiler.h"
28 #include "tensorflow/compiler/xla/service/computation_placer.h"
29 #include "tensorflow/compiler/xla/service/executable.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/service/hlo_runner_interface.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
39 
40 namespace xla {
41 
42 // A base class for running an HloModule. This executes the given HloModule on a
43 // certain backend directly without using the client interface. HloModule can be
44 // explicitly built, or loaded from a serialization file (e.g., hlo proto
45 // file), or parsed from a hlo textual IR string.
46 class HloRunner : public HloRunnerInterface {
47  public:
48   // intra_op_parallelism_threads: For the CPU backend only. It is the thread
49   // pool size for parallel execution of an individual operator. The default
50   // value of -1 will result in initializing the thread pool with the number of
51   // threads equal to the number of
52   // cores in the system.
53   explicit HloRunner(se::Platform* platform,
54                      int intra_op_parallelism_threads = -1);
55 
56   ~HloRunner() override;
57 
58   // Transfers data between the host and device.
59   StatusOr<ScopedShapedBuffer> TransferLiteralToDevice(const Literal& literal);
60   StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
61       absl::Span<const Literal* const> literals);
62   StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
63       absl::Span<const Literal> literals);
64   StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer);
65 
66   // Executes the given module with given literals as input and returns the
67   // result as a Literal.
68   //
69   // If run_hlo_passes is false, the module will be executed without Hlo
70   // optimization.
71 
72   using HloRunnerInterface::Execute;
73 
74   StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
75                             absl::Span<const Literal* const> arguments,
76                             bool run_hlo_passes,
77                             ExecutionProfile* profile) override;
78 
79   using HloRunnerInterface::ExecuteWithExecutable;
80 
81   StatusOr<Literal> ExecuteWithExecutable(
82       Executable* executable, absl::Span<const Literal* const> arguments,
83       ExecutionProfile* profile) override;
84 
85   // As Execute(), but accepts and returns device buffers instead of host
86   // buffers.
87   StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers(
88       std::unique_ptr<HloModule> module,
89       absl::Span<ScopedShapedBuffer const> arguments,
90       bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
91 
92   StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers(
93       Executable* executable, absl::Span<ScopedShapedBuffer const> arguments,
94       ExecutionProfile* profile = nullptr);
95 
96   // Creates an executable object given an HLO module. If run_hlo_passes is
97   // true, the HLO passes will be run as part of compilation.
98   StatusOr<std::unique_ptr<Executable>> CreateExecutable(
99       std::unique_ptr<HloModule> module, bool run_hlo_passes) override;
100 
101   // Executes a given HLO module into a set of replicas, and returns a map
102   // with the replica number as key, and the corresponding returned literal as
103   // value.
104   StatusOr<std::vector<Literal>> ExecuteReplicated(
105       std::unique_ptr<HloModule> module,
106       const ReplicatedExecuteOptions& options) override;
107 
108   // Same as above, but with specified device assignment.
109   StatusOr<std::vector<Literal>> ExecuteReplicated(
110       std::unique_ptr<HloModule> module,
111       const ReplicatedExecuteOptions& options,
112       DeviceAssignment* device_assignment) override;
113 
114   // Same as above, but with a reusable Executable.  This may update the profile
115   // information in *executable.
116   //
117   // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes,
118   // since we've already compiled the Executable.
119   StatusOr<std::vector<Literal>> ExecuteReplicated(
120       Executable* executable, const ReplicatedExecuteOptions& options,
121       DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr);
122 
123   // Same as above, but with different reusable Executables. This may update the
124   // profile information in *executables.
125   //
126   // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes,
127   // since we've already compiled the Executable.
128   StatusOr<std::vector<Literal>> ExecuteReplicated(
129       std::function<Executable*(int64_t)> executable_provider,
130       std::function<int64_t(int64_t)> argument_count_provider,
131       std::function<const Literal*(int64_t, int64_t)> argument_provider,
132       const ReplicatedExecuteOptions& options,
133       DeviceAssignment* device_assignment = nullptr);
134 
135   // If backend is not created in the constructor, creates and returns the
136   // default backend. If creation fails, crashes the program.
137   //
138   // This creates the backend lazily so it's possible to instantiate an
139   // HloRunner in a program without any backends linked in.
140   Backend& backend();
141   const Backend& backend() const;
142 
143   absl::string_view Name() const override;
144 
device_shape_representation_fn()145   DeviceShapeRepresentationFn device_shape_representation_fn() {
146     return device_shape_representation_fn_;
147   }
148 
149  private:
150   // Creates a ServiceExecutableRunOptions object to configure a run on device,
151   // using the provided stream object. If device_assignment is not nullptr, it
152   // will be used to configure the replication parameters. Replicated executions
153   // should pass the device_assignment parameter.
154   ServiceExecutableRunOptions GetServiceRunOptionsForDevice(
155       int64_t device, se::Stream* stream, DeviceAssignment* device_assignment,
156       RunId run_id);
157 
158   // Common implementation code for ExecuteReplicated() above.
159   StatusOr<std::vector<Literal>> ExecuteReplicatedImpl(
160       std::function<StatusOr<std::vector<ScopedShapedBuffer>>(
161           const std::vector<ServiceExecutableRunOptions>&,
162           const std::vector<absl::Span<const ShapedBuffer* const>>&)>
163           execution_helper,
164       std::function<int64_t(int64_t)> argument_count_provider,
165       std::function<const Literal*(int64_t, int64_t)> argument_provider,
166       const ReplicatedExecuteOptions& options,
167       DeviceAssignment* device_assignment);
168 
169   std::unique_ptr<Backend> backend_;
170 
171   DeviceShapeRepresentationFn device_shape_representation_fn_;
172 };
173 
174 }  // namespace xla
175 
176 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
177