1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/executable_build_options.h"
17
18 #include "absl/strings/str_format.h"
19 #include "tensorflow/compiler/xla/debug_options_flags.h"
20 #include "tensorflow/compiler/xla/execution_options_util.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22
23 namespace xla {
24
set_device_allocator(se::DeviceMemoryAllocator * allocator)25 ExecutableBuildOptions& ExecutableBuildOptions::set_device_allocator(
26 se::DeviceMemoryAllocator* allocator) {
27 device_allocator_ = allocator;
28 return *this;
29 }
30
device_allocator() const31 se::DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const {
32 return device_allocator_;
33 }
34
set_device_ordinal(int device_ordinal)35 ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal(
36 int device_ordinal) {
37 CHECK_GE(device_ordinal, 0);
38 device_ordinal_ = device_ordinal;
39 return *this;
40 }
41
device_ordinal() const42 int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; }
43
mutable_debug_options()44 DebugOptions* ExecutableBuildOptions::mutable_debug_options() {
45 if (!has_debug_options()) {
46 debug_options_ = GetDebugOptionsFromFlags();
47 }
48 return &debug_options_.value();
49 }
50
set_result_layout(const Shape & shape_with_layout)51 ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout(
52 const Shape& shape_with_layout) {
53 result_layout_set_ = true;
54 result_layout_ = shape_with_layout;
55 return *this;
56 }
57
result_layout() const58 const Shape* ExecutableBuildOptions::result_layout() const {
59 return result_layout_set_ ? &result_layout_ : nullptr;
60 }
61
set_num_replicas(int num_replicas)62 ExecutableBuildOptions& ExecutableBuildOptions::set_num_replicas(
63 int num_replicas) {
64 num_replicas_ = num_replicas;
65 return *this;
66 }
67
set_num_partitions(int num_partitions)68 ExecutableBuildOptions& ExecutableBuildOptions::set_num_partitions(
69 int num_partitions) {
70 num_partitions_ = num_partitions;
71 return *this;
72 }
73
set_use_spmd_partitioning(bool use_spmd_partitioning)74 ExecutableBuildOptions& ExecutableBuildOptions::set_use_spmd_partitioning(
75 bool use_spmd_partitioning) {
76 use_spmd_partitioning_ = use_spmd_partitioning;
77 return *this;
78 }
79
set_use_auto_spmd_partitioning(bool use_auto_spmd_partitioning)80 ExecutableBuildOptions& ExecutableBuildOptions::set_use_auto_spmd_partitioning(
81 bool use_auto_spmd_partitioning) {
82 use_auto_spmd_partitioning_ = use_auto_spmd_partitioning;
83 return *this;
84 }
85
86 ExecutableBuildOptions&
set_auto_spmd_partitioning_mesh_shape(std::vector<int64_t> mesh_shape)87 ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_shape(
88 std::vector<int64_t> mesh_shape) {
89 auto_spmd_partitioning_mesh_shape_ = mesh_shape;
90 return *this;
91 }
92
93 ExecutableBuildOptions&
set_auto_spmd_partitioning_mesh_ids(std::vector<int64_t> mesh_ids)94 ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_ids(
95 std::vector<int64_t> mesh_ids) {
96 auto_spmd_partitioning_mesh_ids_ = mesh_ids;
97 return *this;
98 }
99
set_deduplicate_hlo(bool deduplicate_hlo)100 ExecutableBuildOptions& ExecutableBuildOptions::set_deduplicate_hlo(
101 bool deduplicate_hlo) {
102 deduplicate_hlo_ = deduplicate_hlo;
103 return *this;
104 }
105
ToProto() const106 StatusOr<ExecutableBuildOptionsProto> ExecutableBuildOptions::ToProto() const {
107 ExecutableBuildOptionsProto output;
108 output.set_device_ordinal(device_ordinal());
109 if (result_layout()) {
110 *output.mutable_result_layout() = result_layout()->ToProto();
111 }
112 if (has_debug_options()) {
113 *output.mutable_debug_options() = debug_options();
114 }
115 output.set_num_replicas(num_replicas());
116 output.set_num_partitions(num_partitions());
117 output.set_use_spmd_partitioning(use_spmd_partitioning());
118 output.set_use_auto_spmd_partitioning(use_auto_spmd_partitioning());
119 output.set_deduplicate_hlo(deduplicate_hlo());
120 if (has_device_assignment()) {
121 TF_RETURN_IF_ERROR(
122 device_assignment().Serialize(output.mutable_device_assignment()));
123 }
124 output.set_alias_passthrough_params(alias_passthrough_params());
125 output.set_run_backend_only(run_backend_only());
126 output.set_allow_spmd_sharding_propagation_to_output(
127 allow_spmd_sharding_propagation_to_output());
128
129 return output;
130 }
131
ExecutableBuildOptionsFromProto(const ExecutableBuildOptionsProto & input)132 StatusOr<ExecutableBuildOptions> ExecutableBuildOptionsFromProto(
133 const ExecutableBuildOptionsProto& input) {
134 xla::ExecutableBuildOptions output;
135 if (input.device_ordinal() != -1) {
136 output.set_device_ordinal(input.device_ordinal());
137 }
138 if (input.has_result_layout()) {
139 output.set_result_layout(xla::Shape(input.result_layout()));
140 }
141 if (input.has_debug_options()) {
142 *output.mutable_debug_options() = input.debug_options();
143 }
144 output.set_num_replicas(input.num_replicas());
145 output.set_num_partitions(input.num_partitions());
146 output.set_use_spmd_partitioning(input.use_spmd_partitioning());
147 output.set_use_auto_spmd_partitioning(input.use_auto_spmd_partitioning());
148 output.set_deduplicate_hlo(input.deduplicate_hlo());
149 if (input.has_device_assignment()) {
150 TF_ASSIGN_OR_RETURN(
151 std::unique_ptr<xla::DeviceAssignment> assignment,
152 xla::DeviceAssignment::Deserialize(input.device_assignment()));
153 output.set_device_assignment(*assignment);
154 }
155 output.set_alias_passthrough_params(input.alias_passthrough_params());
156 output.set_run_backend_only(input.run_backend_only());
157 output.set_allow_spmd_sharding_propagation_to_output(
158 input.allow_spmd_sharding_propagation_to_output());
159 return output;
160 }
161
set_device_assignment(const DeviceAssignment & device_assignment)162 ExecutableBuildOptions& ExecutableBuildOptions::set_device_assignment(
163 const DeviceAssignment& device_assignment) {
164 device_assignment_ = device_assignment;
165 return *this;
166 }
167
ToString() const168 std::string ExecutableBuildOptions::ToString() const {
169 std::string result_layout = "nullopt";
170 if (result_layout_set_) {
171 result_layout = ShapeUtil::HumanStringWithLayout(result_layout_);
172 }
173 return absl::StrFormat(
174 "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, "
175 "num_replicas=%d}",
176 device_ordinal_, result_layout, num_replicas_);
177 }
178
CreateExecutionOptions(const ExecutableBuildOptions & build_options,const ProgramShape * program_shape)179 ExecutionOptions CreateExecutionOptions(
180 const ExecutableBuildOptions& build_options,
181 const ProgramShape* program_shape) {
182 ExecutionOptions execution_options = CreateDefaultExecutionOptions();
183 if (build_options.has_debug_options()) {
184 *execution_options.mutable_debug_options() = build_options.debug_options();
185 }
186 if (build_options.result_layout() != nullptr) {
187 *execution_options.mutable_shape_with_output_layout() =
188 build_options.result_layout()->ToProto();
189 } else {
190 Shape result_shape(program_shape->result());
191 LayoutUtil::SetToDefaultLayout(&result_shape);
192 *execution_options.mutable_shape_with_output_layout() =
193 result_shape.ToProto();
194 }
195 execution_options.set_num_replicas(build_options.num_replicas());
196 execution_options.set_num_partitions(build_options.num_partitions());
197 execution_options.set_use_spmd_partitioning(
198 build_options.use_spmd_partitioning());
199 execution_options.set_use_auto_spmd_partitioning(
200 build_options.use_auto_spmd_partitioning());
201 for (auto t : build_options.auto_spmd_partitioning_mesh_shape()) {
202 execution_options.mutable_auto_spmd_partitioning_mesh_shape()->Add(t);
203 }
204 for (auto t : build_options.auto_spmd_partitioning_mesh_ids()) {
205 execution_options.mutable_auto_spmd_partitioning_mesh_ids()->Add(t);
206 }
207 execution_options.set_deduplicate_hlo(build_options.deduplicate_hlo());
208 execution_options.set_allow_spmd_sharding_propagation_to_output(
209 build_options.allow_spmd_sharding_propagation_to_output());
210 if (build_options.has_device_assignment()) {
211 TF_CHECK_OK(build_options.device_assignment().Serialize(
212 execution_options.mutable_device_assignment()));
213 }
214 execution_options.set_alias_passthrough_params(
215 build_options.alias_passthrough_params());
216 return execution_options;
217 }
218
219 } // namespace xla
220