xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/executable_build_options.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
18 
19 #include <optional>
20 #include <vector>
21 
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/compiler/xla/pjrt/compile_options.pb.h"
24 #include "tensorflow/compiler/xla/service/computation_placer.h"
25 #include "tensorflow/compiler/xla/shape.h"
26 #include "tensorflow/compiler/xla/xla.pb.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/threadpool.h"
29 
30 namespace stream_executor {
31 
32 // Forward-declared to avoid StreamExecutor dependency.
33 class DeviceMemoryAllocator;
34 
35 }  // namespace stream_executor
36 
37 namespace xla {
38 
39 // Class containing options for building an LocalExecutable with
40 // LocalClient::Compile.
41 class ExecutableBuildOptions {
42  public:
43   // If set, this is the device to build the computation for. Valid
44   // device_ordinal values are: 0 to # of devices - 1. These values are
45   // identical to the device ordinal values used by StreamExecutor. The built
46   // executable will be executable on any device equivalent to the specified
47   // device as determined by Backend::devices_equivalent(). A value of -1
48   // indicates this option has not been set.
49   ExecutableBuildOptions& set_device_ordinal(int device_ordinal);
50   int device_ordinal() const;
51 
52   // If set, this specifies the layout of the result of the computation. If not
53   // set, the service will chose the layout of the result. A Shape is used to
54   // store the layout to accommodate tuple result shapes. A value of nullptr
55   // indicates the option has not been set.
56   ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout);
57   const Shape* result_layout() const;
58 
59   // Expose access to the XLA debug options which will be passed to the
60   // compilation process.
has_debug_options()61   bool has_debug_options() const { return debug_options_.has_value(); }
debug_options()62   const DebugOptions& debug_options() const { return *debug_options_; }
63   DebugOptions* mutable_debug_options();
64 
65   // If set, this specifies an allocator that can be used to allocate temporary
66   // space on the device during compilation.  For example, the compiler might
67   // want to run various algorithms on the device and pick the fastest one -- it
68   // might allocate buffers for use by these algorithms using this allocator.
69   //
70   // This does not need to be the same as the se::DeviceMemoryAllocator passed
71   // when running the executable.
72   ExecutableBuildOptions& set_device_allocator(
73       se::DeviceMemoryAllocator* allocator);
74   se::DeviceMemoryAllocator* device_allocator() const;
75 
76   // Returns a string representation of the build options, suitable for
77   // debugging.
78   std::string ToString() const;
79 
80   // The number of replicas of this computation that are to be executed.
81   // Defaults to 1.
num_replicas()82   int num_replicas() const { return num_replicas_; }
83   ExecutableBuildOptions& set_num_replicas(int num_replicas);
84 
85   // The number of partitions in this computation. Defaults to 1.
num_partitions()86   int num_partitions() const { return num_partitions_; }
87   ExecutableBuildOptions& set_num_partitions(int num_partitions);
88 
89   // Indicates whether to use SPMD (true) or MPMD (false) partitioning when
90   // num_partitions > 1 and XLA is requested to partition the input program.
use_spmd_partitioning()91   bool use_spmd_partitioning() const { return use_spmd_partitioning_; }
92   ExecutableBuildOptions& set_use_spmd_partitioning(bool use_spmd_partitioning);
93 
94   // Whether to automatically generate XLA shardings for SPMD partitioner.
use_auto_spmd_partitioning()95   bool use_auto_spmd_partitioning() const {
96     return use_auto_spmd_partitioning_;
97   }
98   ExecutableBuildOptions& set_use_auto_spmd_partitioning(
99       bool use_auto_spmd_partitioning);
100 
auto_spmd_partitioning_mesh_shape()101   std::vector<int64_t> auto_spmd_partitioning_mesh_shape() const {
102     return auto_spmd_partitioning_mesh_shape_;
103   }
104   ExecutableBuildOptions& set_auto_spmd_partitioning_mesh_shape(
105       std::vector<int64_t> mesh_shape);
106 
auto_spmd_partitioning_mesh_ids()107   std::vector<int64_t> auto_spmd_partitioning_mesh_ids() const {
108     return auto_spmd_partitioning_mesh_ids_;
109   }
110   ExecutableBuildOptions& set_auto_spmd_partitioning_mesh_ids(
111       std::vector<int64_t> mesh_ids);
112 
deduplicate_hlo()113   bool deduplicate_hlo() const { return deduplicate_hlo_; }
114   ExecutableBuildOptions& set_deduplicate_hlo(bool deduplicate_hlo);
115 
116   // If set, this specifies a static device assignment for the computation.
117   // Otherwise, the computation will be compiled generically and can be run with
118   // any device assignment compatible with the computation's replica and
119   // partition counts.
has_device_assignment()120   bool has_device_assignment() const { return device_assignment_.has_value(); }
121   ExecutableBuildOptions& set_device_assignment(
122       const DeviceAssignment& device_assignment);
device_assignment()123   const DeviceAssignment& device_assignment() const {
124     CHECK(device_assignment_.has_value());
125     return device_assignment_.value();
126   }
127 
128   // Whether input and output buffers are aliased if the associated parameter is
129   // passed-through XLA modules without being changed.
alias_passthrough_params()130   bool alias_passthrough_params() const { return alias_passthrough_params_; }
set_alias_passthrough_params(bool alias_passthrough_params)131   void set_alias_passthrough_params(bool alias_passthrough_params) {
132     alias_passthrough_params_ = alias_passthrough_params;
133   }
134 
run_backend_only()135   bool run_backend_only() const { return run_backend_only_; }
136   // By default, XLA builds an executable by invoking standard compilation, i.e,
137   // running Compiler::Compile, or both Compiler::RunHloPasses and
138   // Compiler::RunBackend. When run_backend_only is set to true, XLA builds an
139   // executable by invoking only RunBackend and skip invoking RunHloPasses,
140   // which can be used to compile post-optimizations HLO modules.
set_run_backend_only(bool run_backend_only)141   ExecutableBuildOptions& set_run_backend_only(bool run_backend_only) {
142     run_backend_only_ = run_backend_only;
143     return *this;
144   }
145 
allow_spmd_sharding_propagation_to_output()146   bool allow_spmd_sharding_propagation_to_output() const {
147     return allow_spmd_sharding_propagation_to_output_;
148   }
149   // Allows sharding propagation to propagate to the outputs. This changes the
150   // output shape of the computation (which is undesirable), but it can be used
151   // to allow to run partial compilation to determine what would be the output
152   // sharding of a computation if XLA would be allowed to propagate the sharding
153   // which can be used by higher level framework as a way to query intermediate
154   // sharding of operations when multiple computation would be chained and
155   // merged together.
set_allow_spmd_sharding_propagation_to_output(bool allow_spmd_sharding_propagation_to_output)156   ExecutableBuildOptions& set_allow_spmd_sharding_propagation_to_output(
157       bool allow_spmd_sharding_propagation_to_output) {
158     allow_spmd_sharding_propagation_to_output_ =
159         allow_spmd_sharding_propagation_to_output;
160     return *this;
161   }
162 
163   // Thread pool for parallel compilation.
compile_thread_pool()164   tensorflow::thread::ThreadPool* compile_thread_pool() const {
165     return compile_thread_pool_;
166   }
set_compile_thread_pool(tensorflow::thread::ThreadPool * compile_thread_pool)167   ExecutableBuildOptions& set_compile_thread_pool(
168       tensorflow::thread::ThreadPool* compile_thread_pool) {
169     compile_thread_pool_ = compile_thread_pool;
170     return *this;
171   }
172 
173   StatusOr<ExecutableBuildOptionsProto> ToProto() const;
174 
175  private:
176   int device_ordinal_ = -1;
177   Shape result_layout_;
178   bool result_layout_set_ = false;
179   std::optional<DebugOptions> debug_options_;
180   se::DeviceMemoryAllocator* device_allocator_ = nullptr;
181   int num_replicas_ = 1;
182   int num_partitions_ = 1;
183   bool use_spmd_partitioning_ = false;
184   bool use_auto_spmd_partitioning_ = false;
185   std::vector<int64_t> auto_spmd_partitioning_mesh_shape_;
186   std::vector<int64_t> auto_spmd_partitioning_mesh_ids_;
187   bool deduplicate_hlo_ = false;
188   bool broadcast_replicated_params_ = false;
189   std::optional<DeviceAssignment> device_assignment_;
190   bool alias_passthrough_params_ = false;
191   bool run_backend_only_ = false;
192   bool allow_spmd_sharding_propagation_to_output_ = false;
193   tensorflow::thread::ThreadPool* compile_thread_pool_ = nullptr;
194 };
195 
196 StatusOr<ExecutableBuildOptions> ExecutableBuildOptionsFromProto(
197     const ExecutableBuildOptionsProto& input);
198 
199 // Creates an ExecutionOptions based on a given ExecutableBuildOptions and
200 // ProgramShape.
201 ExecutionOptions CreateExecutionOptions(
202     const ExecutableBuildOptions& build_options,
203     const ProgramShape* program_shape);
204 
205 }  // namespace xla
206 
207 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
208