1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/utils.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/compiler/xla/client/executable_build_options.h"
20 #include "tensorflow/compiler/xla/client/xla_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo.pb.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
24 #include "tensorflow/compiler/xla/shape.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27
28 namespace xla {
29
30 namespace {
GetShardedShape(const Shape & shape,const OpSharding & sharding)31 StatusOr<Shape> GetShardedShape(const Shape& shape,
32 const OpSharding& sharding) {
33 if (sharding.type() == OpSharding::TUPLE) {
34 if (!shape.IsTuple()) {
35 return InvalidArgument(
36 "Got tuple OpSharding (%s) for non-tuple shape (%s)",
37 sharding.DebugString(), shape.ToString());
38 }
39 if (sharding.tuple_shardings_size() != shape.tuple_shapes_size()) {
40 return InvalidArgument(
41 "Got mismatched OpSharding tuple size (%d) and shape tuple size (%d)."
42 " (OpSharding: %s, shape: %s)",
43 sharding.tuple_shardings_size(), shape.tuple_shapes_size(),
44 sharding.DebugString(), shape.ToString());
45 }
46 std::vector<Shape> sharded_subshapes;
47 const int tuple_shapes_size = shape.tuple_shapes_size();
48 sharded_subshapes.reserve(tuple_shapes_size);
49 for (int i = 0; i < tuple_shapes_size; ++i) {
50 TF_ASSIGN_OR_RETURN(
51 Shape sharded_subshape,
52 GetShardedShape(shape.tuple_shapes(i), sharding.tuple_shardings(i)));
53 sharded_subshapes.emplace_back(std::move(sharded_subshape));
54 }
55 return ShapeUtil::MakeTupleShape(sharded_subshapes);
56 }
57 TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding,
58 HloSharding::FromProto(sharding));
59 return hlo_sharding.TileShape(shape);
60 }
61
GetShardedShape(const HloInstructionProto & instr)62 StatusOr<Shape> GetShardedShape(const HloInstructionProto& instr) {
63 const Shape unsharded_shape(instr.shape());
64 Shape sharded_shape;
65 if (instr.has_sharding()) {
66 TF_ASSIGN_OR_RETURN(sharded_shape,
67 GetShardedShape(unsharded_shape, instr.sharding()));
68 } else {
69 sharded_shape = unsharded_shape;
70 }
71 LayoutUtil::ClearLayout(&sharded_shape);
72 return sharded_shape;
73 }
74
75 // Returns sharded (argument shapes, result shape) without layouts.
GetShardedProgramShapes(const XlaComputation & computation,const ProgramShape & program_shape)76 StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
77 const XlaComputation& computation, const ProgramShape& program_shape) {
78 std::vector<Shape> arg_shapes;
79 arg_shapes.resize(program_shape.parameters_size());
80 Shape result_shape;
81 for (const HloComputationProto& comp : computation.proto().computations()) {
82 if (comp.id() != computation.proto().entry_computation_id()) {
83 continue;
84 }
85 for (const HloInstructionProto& instr : comp.instructions()) {
86 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
87 if (instr.parameter_number() >= program_shape.parameters_size()) {
88 return InvalidArgument(
89 "Got invalid parameter number %d, expected %d parameters",
90 instr.parameter_number(), program_shape.parameters_size());
91 }
92 TF_ASSIGN_OR_RETURN(arg_shapes[instr.parameter_number()],
93 GetShardedShape(instr));
94 }
95 if (instr.id() == comp.root_id()) {
96 if (result_shape.element_type() != PRIMITIVE_TYPE_INVALID) {
97 return InvalidArgument("Found multiple root instructions");
98 }
99 TF_ASSIGN_OR_RETURN(result_shape, GetShardedShape(instr));
100 }
101 }
102 }
103 for (int i = 0; i < arg_shapes.size(); ++i) {
104 if (arg_shapes[i].element_type() == PRIMITIVE_TYPE_INVALID) {
105 return InvalidArgument("Couldn't find parameter %d", i);
106 }
107 }
108 if (result_shape.element_type() == PRIMITIVE_TYPE_INVALID) {
109 return InvalidArgument("Couldn't find root instruction");
110 }
111 return std::make_pair(arg_shapes, result_shape);
112 }
113 } // namespace
114
ParseDeviceAssignmentCompileOptions(bool compile_portable_executable,ExecutableBuildOptions * build_options,std::function<StatusOr<DeviceAssignment> (int,int)> GetDefaultDeviceAssignmentFunction,int * num_replicas,int * num_partitions,std::shared_ptr<DeviceAssignment> * device_assignment)115 Status ParseDeviceAssignmentCompileOptions(
116 bool compile_portable_executable, ExecutableBuildOptions* build_options,
117 std::function<StatusOr<DeviceAssignment>(int, int)>
118 GetDefaultDeviceAssignmentFunction,
119 int* num_replicas, int* num_partitions,
120 std::shared_ptr<DeviceAssignment>* device_assignment) {
121 if (compile_portable_executable) {
122 if (build_options->has_device_assignment()) {
123 return InvalidArgument(
124 "CompileOptions requests portable executable but "
125 "ExecutableBuildOptions includes a device assignment");
126 }
127 *num_replicas = 1;
128 *num_partitions = 1;
129 } else {
130 if (!build_options->has_device_assignment()) {
131 VLOG(2) << "Compile using default device_assignment.";
132 TF_ASSIGN_OR_RETURN(
133 DeviceAssignment device_assignment,
134 GetDefaultDeviceAssignmentFunction(build_options->num_replicas(),
135 build_options->num_partitions()));
136 build_options->set_device_assignment(device_assignment);
137 }
138 VLOG(2) << "Compile device_assignment:\n"
139 << build_options->device_assignment().ToString();
140 *num_replicas = build_options->device_assignment().replica_count();
141 *num_partitions = build_options->device_assignment().computation_count();
142 *device_assignment =
143 std::make_shared<DeviceAssignment>(build_options->device_assignment());
144 }
145 return OkStatus();
146 }
147
DetermineArgumentLayoutsFromCompileOptions(const XlaComputation & computation,std::function<StatusOr<Shape> (Shape)> choose_compact_layout_for_shape_function,std::optional<std::vector<Shape>> & argument_layouts,ExecutableBuildOptions * build_options,std::vector<const Shape * > * argument_layout_pointers)148 Status DetermineArgumentLayoutsFromCompileOptions(
149 const XlaComputation& computation,
150 std::function<StatusOr<Shape>(Shape)>
151 choose_compact_layout_for_shape_function,
152 std::optional<std::vector<Shape>>& argument_layouts,
153 ExecutableBuildOptions* build_options,
154 std::vector<const Shape*>* argument_layout_pointers) {
155 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
156 computation.GetProgramShape());
157 if (!argument_layouts) {
158 argument_layouts.emplace(program_shape.parameters());
159 for (Shape& shape : *argument_layouts) {
160 LayoutUtil::ClearLayout(&shape);
161 }
162 } else if (argument_layouts->size() != program_shape.parameters_size()) {
163 return InvalidArgument(
164 "CompileOptions specify %d argument layouts, but computation has %d "
165 "arguments",
166 argument_layouts->size(), program_shape.parameters_size());
167 }
168 argument_layout_pointers->reserve(argument_layouts->size());
169
170 // Assign a default layout based on `sharded_shape` to any array subshapes in
171 // `dst_shape` that are missing layouts.
172 auto assign_layouts = [&choose_compact_layout_for_shape_function](
173 const Shape& sharded_shape, Shape* dst_shape) {
174 return ShapeUtil::ForEachMutableSubshapeWithStatus(
175 dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
176 if (subshape->IsArray() && !subshape->has_layout()) {
177 CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
178 const Shape& sharded_subshape =
179 ShapeUtil::GetSubshape(sharded_shape, idx);
180 LayoutUtil::SetToDefaultLayout(subshape);
181 TF_ASSIGN_OR_RETURN(
182 Shape layout,
183 choose_compact_layout_for_shape_function(sharded_subshape));
184 *subshape->mutable_layout() = layout.layout();
185 }
186 return OkStatus();
187 });
188 };
189 TF_ASSIGN_OR_RETURN(auto sharded_shapes,
190 GetShardedProgramShapes(computation, program_shape));
191
192 CHECK_EQ(sharded_shapes.first.size(), argument_layouts->size());
193 for (int i = 0; i < argument_layouts->size(); ++i) {
194 Shape* layout = &(*argument_layouts)[i];
195 argument_layout_pointers->push_back(layout);
196 TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout));
197 }
198
199 Shape result_layout;
200 if (build_options->result_layout()) {
201 result_layout = *build_options->result_layout();
202 } else {
203 result_layout = program_shape.result();
204 LayoutUtil::ClearLayout(&result_layout);
205 }
206 TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
207 build_options->set_result_layout(result_layout);
208 return OkStatus();
209 }
210
ComputeParametersThatMustBeDonated(const HloModule & module,bool tuple_inputs)211 StatusOr<std::vector<int>> ComputeParametersThatMustBeDonated(
212 const HloModule& module, bool tuple_inputs) {
213 HloComputation* computation = module.entry_computation();
214 int number_of_parameters = [&]() -> int {
215 if (tuple_inputs) {
216 CHECK_EQ(computation->num_parameters(), 1);
217 const Shape& input_tuple_shape =
218 computation->parameter_instruction(0)->shape();
219 CHECK(input_tuple_shape.IsTuple());
220 return input_tuple_shape.tuple_shapes_size();
221 } else {
222 return computation->num_parameters();
223 }
224 }();
225 // If any buffer in a parameter is aliased we will donate the entire input
226 // parameter.
227 std::vector<int> parameters_to_donate;
228 parameters_to_donate.reserve(computation->num_parameters());
229 const HloInputOutputAliasConfig& config = module.input_output_alias_config();
230 TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
231 [&](const ShapeIndex& output_index,
232 const HloInputOutputAliasConfig::Alias& alias) {
233 if (tuple_inputs) {
234 if (alias.parameter_number != 0) {
235 return InvalidArgument(
236 "Unexpected parameter number %d in alias config with tupled "
237 "inputs",
238 alias.parameter_number);
239 }
240 const ShapeIndex& index = alias.parameter_index;
241 if (!index.empty()) {
242 int this_parameter = index.data()[0];
243 if (this_parameter >= number_of_parameters) {
244 return InvalidArgument(
245 "Unexpected parameter index %s in alias config with tupled "
246 "inputs and %d parameters",
247 index.ToString(), number_of_parameters);
248 }
249 parameters_to_donate.push_back(this_parameter);
250 }
251 } else {
252 int this_parameter = alias.parameter_number;
253 if (this_parameter >= number_of_parameters) {
254 return InvalidArgument(
255 "Unexpected parameter number %d in alias config without tupled "
256 "inputs and %d parameters",
257 this_parameter, number_of_parameters);
258 }
259 parameters_to_donate.push_back(this_parameter);
260 }
261 return OkStatus();
262 }));
263 absl::c_sort(parameters_to_donate);
264 return parameters_to_donate;
265 }
266
DefaultThreadPoolSize()267 int DefaultThreadPoolSize() {
268 // Google's CI system exposes an environment variable NPROC that describes
269 // a CPU reservation for tests.
270 // TODO(phawkins): expose a better thought-out set of knobs to control
271 // parallelism.
272 const char* nproc_str = std::getenv("NPROC");
273 int nproc = 0;
274 if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
275 return std::max(0, nproc);
276 }
277 return tensorflow::port::MaxParallelism();
278 }
279
HasMajorToMinorLayout(PrimitiveType type,absl::Span<int64_t const> dims,absl::Span<int64_t const> byte_strides)280 bool HasMajorToMinorLayout(PrimitiveType type, absl::Span<int64_t const> dims,
281 absl::Span<int64_t const> byte_strides) {
282 CHECK_EQ(dims.size(), byte_strides.size());
283 // If the array is size 0, the strides are irrelevant.
284 if (absl::c_find(dims, 0) != dims.end()) {
285 return true;
286 }
287 int64_t stride = primitive_util::ByteWidth(type);
288 for (int i = static_cast<int>(dims.size()) - 1; i >= 0; --i) {
289 // If a dimension is of size 1, its stride is irrelevant.
290 if (dims[i] != 1) {
291 if (byte_strides[i] != stride) {
292 return false;
293 }
294 stride *= dims[i];
295 }
296 }
297 return true;
298 }
299
300 } // namespace xla
301