xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/utils.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/compiler/xla/client/executable_build_options.h"
20 #include "tensorflow/compiler/xla/client/xla_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo.pb.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
24 #include "tensorflow/compiler/xla/shape.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 
28 namespace xla {
29 
30 namespace {
GetShardedShape(const Shape & shape,const OpSharding & sharding)31 StatusOr<Shape> GetShardedShape(const Shape& shape,
32                                 const OpSharding& sharding) {
33   if (sharding.type() == OpSharding::TUPLE) {
34     if (!shape.IsTuple()) {
35       return InvalidArgument(
36           "Got tuple OpSharding (%s) for non-tuple shape (%s)",
37           sharding.DebugString(), shape.ToString());
38     }
39     if (sharding.tuple_shardings_size() != shape.tuple_shapes_size()) {
40       return InvalidArgument(
41           "Got mismatched OpSharding tuple size (%d) and shape tuple size (%d)."
42           " (OpSharding: %s, shape: %s)",
43           sharding.tuple_shardings_size(), shape.tuple_shapes_size(),
44           sharding.DebugString(), shape.ToString());
45     }
46     std::vector<Shape> sharded_subshapes;
47     const int tuple_shapes_size = shape.tuple_shapes_size();
48     sharded_subshapes.reserve(tuple_shapes_size);
49     for (int i = 0; i < tuple_shapes_size; ++i) {
50       TF_ASSIGN_OR_RETURN(
51           Shape sharded_subshape,
52           GetShardedShape(shape.tuple_shapes(i), sharding.tuple_shardings(i)));
53       sharded_subshapes.emplace_back(std::move(sharded_subshape));
54     }
55     return ShapeUtil::MakeTupleShape(sharded_subshapes);
56   }
57   TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding,
58                       HloSharding::FromProto(sharding));
59   return hlo_sharding.TileShape(shape);
60 }
61 
GetShardedShape(const HloInstructionProto & instr)62 StatusOr<Shape> GetShardedShape(const HloInstructionProto& instr) {
63   const Shape unsharded_shape(instr.shape());
64   Shape sharded_shape;
65   if (instr.has_sharding()) {
66     TF_ASSIGN_OR_RETURN(sharded_shape,
67                         GetShardedShape(unsharded_shape, instr.sharding()));
68   } else {
69     sharded_shape = unsharded_shape;
70   }
71   LayoutUtil::ClearLayout(&sharded_shape);
72   return sharded_shape;
73 }
74 
75 // Returns sharded (argument shapes, result shape) without layouts.
GetShardedProgramShapes(const XlaComputation & computation,const ProgramShape & program_shape)76 StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
77     const XlaComputation& computation, const ProgramShape& program_shape) {
78   std::vector<Shape> arg_shapes;
79   arg_shapes.resize(program_shape.parameters_size());
80   Shape result_shape;
81   for (const HloComputationProto& comp : computation.proto().computations()) {
82     if (comp.id() != computation.proto().entry_computation_id()) {
83       continue;
84     }
85     for (const HloInstructionProto& instr : comp.instructions()) {
86       if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
87         if (instr.parameter_number() >= program_shape.parameters_size()) {
88           return InvalidArgument(
89               "Got invalid parameter number %d, expected %d parameters",
90               instr.parameter_number(), program_shape.parameters_size());
91         }
92         TF_ASSIGN_OR_RETURN(arg_shapes[instr.parameter_number()],
93                             GetShardedShape(instr));
94       }
95       if (instr.id() == comp.root_id()) {
96         if (result_shape.element_type() != PRIMITIVE_TYPE_INVALID) {
97           return InvalidArgument("Found multiple root instructions");
98         }
99         TF_ASSIGN_OR_RETURN(result_shape, GetShardedShape(instr));
100       }
101     }
102   }
103   for (int i = 0; i < arg_shapes.size(); ++i) {
104     if (arg_shapes[i].element_type() == PRIMITIVE_TYPE_INVALID) {
105       return InvalidArgument("Couldn't find parameter %d", i);
106     }
107   }
108   if (result_shape.element_type() == PRIMITIVE_TYPE_INVALID) {
109     return InvalidArgument("Couldn't find root instruction");
110   }
111   return std::make_pair(arg_shapes, result_shape);
112 }
113 }  // namespace
114 
ParseDeviceAssignmentCompileOptions(bool compile_portable_executable,ExecutableBuildOptions * build_options,std::function<StatusOr<DeviceAssignment> (int,int)> GetDefaultDeviceAssignmentFunction,int * num_replicas,int * num_partitions,std::shared_ptr<DeviceAssignment> * device_assignment)115 Status ParseDeviceAssignmentCompileOptions(
116     bool compile_portable_executable, ExecutableBuildOptions* build_options,
117     std::function<StatusOr<DeviceAssignment>(int, int)>
118         GetDefaultDeviceAssignmentFunction,
119     int* num_replicas, int* num_partitions,
120     std::shared_ptr<DeviceAssignment>* device_assignment) {
121   if (compile_portable_executable) {
122     if (build_options->has_device_assignment()) {
123       return InvalidArgument(
124           "CompileOptions requests portable executable but "
125           "ExecutableBuildOptions includes a device assignment");
126     }
127     *num_replicas = 1;
128     *num_partitions = 1;
129   } else {
130     if (!build_options->has_device_assignment()) {
131       VLOG(2) << "Compile using default device_assignment.";
132       TF_ASSIGN_OR_RETURN(
133           DeviceAssignment device_assignment,
134           GetDefaultDeviceAssignmentFunction(build_options->num_replicas(),
135                                              build_options->num_partitions()));
136       build_options->set_device_assignment(device_assignment);
137     }
138     VLOG(2) << "Compile device_assignment:\n"
139             << build_options->device_assignment().ToString();
140     *num_replicas = build_options->device_assignment().replica_count();
141     *num_partitions = build_options->device_assignment().computation_count();
142     *device_assignment =
143         std::make_shared<DeviceAssignment>(build_options->device_assignment());
144   }
145   return OkStatus();
146 }
147 
DetermineArgumentLayoutsFromCompileOptions(const XlaComputation & computation,std::function<StatusOr<Shape> (Shape)> choose_compact_layout_for_shape_function,std::optional<std::vector<Shape>> & argument_layouts,ExecutableBuildOptions * build_options,std::vector<const Shape * > * argument_layout_pointers)148 Status DetermineArgumentLayoutsFromCompileOptions(
149     const XlaComputation& computation,
150     std::function<StatusOr<Shape>(Shape)>
151         choose_compact_layout_for_shape_function,
152     std::optional<std::vector<Shape>>& argument_layouts,
153     ExecutableBuildOptions* build_options,
154     std::vector<const Shape*>* argument_layout_pointers) {
155   TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
156                       computation.GetProgramShape());
157   if (!argument_layouts) {
158     argument_layouts.emplace(program_shape.parameters());
159     for (Shape& shape : *argument_layouts) {
160       LayoutUtil::ClearLayout(&shape);
161     }
162   } else if (argument_layouts->size() != program_shape.parameters_size()) {
163     return InvalidArgument(
164         "CompileOptions specify %d argument layouts, but computation has %d "
165         "arguments",
166         argument_layouts->size(), program_shape.parameters_size());
167   }
168   argument_layout_pointers->reserve(argument_layouts->size());
169 
170   // Assign a default layout based on `sharded_shape` to any array subshapes in
171   // `dst_shape` that are missing layouts.
172   auto assign_layouts = [&choose_compact_layout_for_shape_function](
173                             const Shape& sharded_shape, Shape* dst_shape) {
174     return ShapeUtil::ForEachMutableSubshapeWithStatus(
175         dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
176           if (subshape->IsArray() && !subshape->has_layout()) {
177             CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
178             const Shape& sharded_subshape =
179                 ShapeUtil::GetSubshape(sharded_shape, idx);
180             LayoutUtil::SetToDefaultLayout(subshape);
181             TF_ASSIGN_OR_RETURN(
182                 Shape layout,
183                 choose_compact_layout_for_shape_function(sharded_subshape));
184             *subshape->mutable_layout() = layout.layout();
185           }
186           return OkStatus();
187         });
188   };
189   TF_ASSIGN_OR_RETURN(auto sharded_shapes,
190                       GetShardedProgramShapes(computation, program_shape));
191 
192   CHECK_EQ(sharded_shapes.first.size(), argument_layouts->size());
193   for (int i = 0; i < argument_layouts->size(); ++i) {
194     Shape* layout = &(*argument_layouts)[i];
195     argument_layout_pointers->push_back(layout);
196     TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout));
197   }
198 
199   Shape result_layout;
200   if (build_options->result_layout()) {
201     result_layout = *build_options->result_layout();
202   } else {
203     result_layout = program_shape.result();
204     LayoutUtil::ClearLayout(&result_layout);
205   }
206   TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
207   build_options->set_result_layout(result_layout);
208   return OkStatus();
209 }
210 
ComputeParametersThatMustBeDonated(const HloModule & module,bool tuple_inputs)211 StatusOr<std::vector<int>> ComputeParametersThatMustBeDonated(
212     const HloModule& module, bool tuple_inputs) {
213   HloComputation* computation = module.entry_computation();
214   int number_of_parameters = [&]() -> int {
215     if (tuple_inputs) {
216       CHECK_EQ(computation->num_parameters(), 1);
217       const Shape& input_tuple_shape =
218           computation->parameter_instruction(0)->shape();
219       CHECK(input_tuple_shape.IsTuple());
220       return input_tuple_shape.tuple_shapes_size();
221     } else {
222       return computation->num_parameters();
223     }
224   }();
225   // If any buffer in a parameter is aliased we will donate the entire input
226   // parameter.
227   std::vector<int> parameters_to_donate;
228   parameters_to_donate.reserve(computation->num_parameters());
229   const HloInputOutputAliasConfig& config = module.input_output_alias_config();
230   TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
231       [&](const ShapeIndex& output_index,
232           const HloInputOutputAliasConfig::Alias& alias) {
233         if (tuple_inputs) {
234           if (alias.parameter_number != 0) {
235             return InvalidArgument(
236                 "Unexpected parameter number %d in alias config with tupled "
237                 "inputs",
238                 alias.parameter_number);
239           }
240           const ShapeIndex& index = alias.parameter_index;
241           if (!index.empty()) {
242             int this_parameter = index.data()[0];
243             if (this_parameter >= number_of_parameters) {
244               return InvalidArgument(
245                   "Unexpected parameter index %s in alias config with tupled "
246                   "inputs and %d parameters",
247                   index.ToString(), number_of_parameters);
248             }
249             parameters_to_donate.push_back(this_parameter);
250           }
251         } else {
252           int this_parameter = alias.parameter_number;
253           if (this_parameter >= number_of_parameters) {
254             return InvalidArgument(
255                 "Unexpected parameter number %d in alias config without tupled "
256                 "inputs and %d parameters",
257                 this_parameter, number_of_parameters);
258           }
259           parameters_to_donate.push_back(this_parameter);
260         }
261         return OkStatus();
262       }));
263   absl::c_sort(parameters_to_donate);
264   return parameters_to_donate;
265 }
266 
DefaultThreadPoolSize()267 int DefaultThreadPoolSize() {
268   // Google's CI system exposes an environment variable NPROC that describes
269   // a CPU reservation for tests.
270   // TODO(phawkins): expose a better thought-out set of knobs to control
271   // parallelism.
272   const char* nproc_str = std::getenv("NPROC");
273   int nproc = 0;
274   if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
275     return std::max(0, nproc);
276   }
277   return tensorflow::port::MaxParallelism();
278 }
279 
HasMajorToMinorLayout(PrimitiveType type,absl::Span<int64_t const> dims,absl::Span<int64_t const> byte_strides)280 bool HasMajorToMinorLayout(PrimitiveType type, absl::Span<int64_t const> dims,
281                            absl::Span<int64_t const> byte_strides) {
282   CHECK_EQ(dims.size(), byte_strides.size());
283   // If the array is size 0, the strides are irrelevant.
284   if (absl::c_find(dims, 0) != dims.end()) {
285     return true;
286   }
287   int64_t stride = primitive_util::ByteWidth(type);
288   for (int i = static_cast<int>(dims.size()) - 1; i >= 0; --i) {
289     // If a dimension is of size 1, its stride is irrelevant.
290     if (dims[i] != 1) {
291       if (byte_strides[i] != stride) {
292         return false;
293       }
294       stride *= dims[i];
295     }
296   }
297   return true;
298 }
299 
300 }  // namespace xla
301