xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
16 
17 #include <string>
18 
19 #include "tensorflow/compiler/xla/debug_options_flags.h"
20 #include "tensorflow/compiler/xla/service/computation_layout.h"
21 #include "tensorflow/compiler/xla/service/computation_placer.h"
22 #include "tensorflow/compiler/xla/service/dump.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
26 #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
27 #include "tensorflow/stream_executor/tpu/proto_helper.h"
28 
29 namespace tensorflow {
30 namespace tpu {
31 using ::stream_executor::port::Status;
32 using ::stream_executor::port::StatusOr;
33 using ::xla::ComputationLayout;
34 using ::xla::DebugOptions;
35 using ::xla::DeviceAssignment;
36 using ::xla::HloModuleConfig;
37 using ::xla::HloSharding;
38 using ::xla::InvalidArgument;
39 using ::xla::ProgramShape;
40 using ::xla::Shape;
41 using ::xla::ShapeTree;
42 using ::xla::ShapeUtil;
43 
ValidateResultShape(const Shape & client_shape,const Shape & result_shape)44 Status ValidateResultShape(const Shape& client_shape,
45                            const Shape& result_shape) {
46   TF_RETURN_IF_ERROR(
47       xla::ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
48   if (!xla::ShapeUtil::Compatible(client_shape, result_shape)) {
49     return InvalidArgument(
50         "Shape used to set computation result layout %s is not compatible "
51         "with result shape %s",
52         xla::ShapeUtil::HumanStringWithLayout(client_shape),
53         xla::ShapeUtil::HumanString(result_shape));
54   }
55   return OkStatus();
56 }
57 
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const Shape> argument_shapes,absl::optional<const Shape> result_layout,absl::optional<const DeviceAssignment> device_assignment,int replica_count,int num_partitions,const DebugOptions * debug_options,const int * seed,const int * launch_id,const bool * alias_passthrough_params,const xla::FusionConfigCollection * fusion_config_collection,const std::vector<std::vector<bool>> * fusion_config)58 StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
59     const ProgramShape& program_shape, absl::Span<const Shape> argument_shapes,
60     absl::optional<const Shape> result_layout,
61     absl::optional<const DeviceAssignment> device_assignment, int replica_count,
62     int num_partitions, const DebugOptions* debug_options, const int* seed,
63     const int* launch_id, const bool* alias_passthrough_params,
64     const xla::FusionConfigCollection* fusion_config_collection,
65     const std::vector<std::vector<bool>>* fusion_config) {
66   auto config = absl::make_unique<HloModuleConfig>(program_shape);
67   ComputationLayout* computation_layout =
68       config->mutable_entry_computation_layout();
69   if (program_shape.parameters_size() != argument_shapes.size()) {
70     return InvalidArgument("computation takes %d parameters, but %u given",
71                            program_shape.parameters_size(),
72                            argument_shapes.size());
73   }
74   for (int i = 0; i < argument_shapes.size(); ++i) {
75     // Verify that shape of arguments matches the shape of the arguments in the
76     // ProgramShape.
77     if (!ShapeUtil::Compatible(argument_shapes[i],
78                                program_shape.parameters(i))) {
79       return InvalidArgument(
80           "Argument does not match shape of computation parameter %d: want "
81           "%s, got %s",
82           i, ShapeUtil::HumanString(program_shape.parameters(i)),
83           ShapeUtil::HumanString(argument_shapes[i]));
84     }
85     TF_RETURN_IF_ERROR(
86         computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
87             argument_shapes[i]));
88   }
89 
90   if (result_layout.has_value()) {
91     TF_RETURN_IF_ERROR(
92         ValidateResultShape(result_layout.value(), program_shape.result()));
93     TF_RETURN_IF_ERROR(
94         computation_layout->mutable_result_layout()->CopyLayoutFromShape(
95             result_layout.value()));
96   } else {
97     // If the result layout is not set, then choose the default.
98     computation_layout->mutable_result_layout()->SetToDefaultLayout();
99   }
100 
101   config->set_replica_count(replica_count);
102   config->set_num_partitions(num_partitions);
103   if (seed != nullptr) {
104     config->set_seed(*seed);
105   }
106   if (launch_id != nullptr) {
107     config->set_launch_id(*launch_id);
108   }
109   if (debug_options != nullptr) {
110     config->set_debug_options(*debug_options);
111   } else {
112     config->set_debug_options(xla::GetDebugOptionsFromFlags());
113   }
114 
115   // TODO(henrytan): set intra_op_parallelism_threads.
116   // Reference:
117   // tensorflow/compiler/xla/service/service.cc?l=324.
118 
119   if (device_assignment.has_value()) {
120     config->set_static_device_assignment(device_assignment.value());
121   }
122 
123   if (alias_passthrough_params != nullptr) {
124     config->set_alias_passthrough_params(*alias_passthrough_params);
125   }
126 
127   if (fusion_config_collection != nullptr && fusion_config != nullptr &&
128       *fusion_config_collection != xla::FusionConfigCollection::kOff) {
129     config->set_fusion_config_collection(*fusion_config_collection);
130     *config->mutable_fusion_config() = *fusion_config;
131   }
132 
133   return std::move(config);
134 }
135 
CreateModuleConfig(const xla::ProgramShape & program_shape,absl::Span<const Shape> argument_shapes,absl::optional<const Shape> result_layout,absl::optional<const DeviceAssignment> device_assignment,int replica_count,int num_partitions,const DebugOptions * debug_options)136 StatusOr<std::unique_ptr<xla::HloModuleConfig>> CreateModuleConfig(
137     const xla::ProgramShape& program_shape,
138     absl::Span<const Shape> argument_shapes,
139     absl::optional<const Shape> result_layout,
140     absl::optional<const DeviceAssignment> device_assignment, int replica_count,
141     int num_partitions, const DebugOptions* debug_options) {
142   return CreateModuleConfig(program_shape, argument_shapes, result_layout,
143                             device_assignment, replica_count, num_partitions,
144                             debug_options, /*seed=*/nullptr,
145                             /*launch_id=*/nullptr,
146                             /*alias_passthrough_params=*/nullptr,
147                             /*fusion_config_collection=*/nullptr,
148                             /*fusion_config=*/nullptr);
149 }
150 
GetSubtree(const ShapeTree<HloSharding> & tuple_shape_tree,int element_index)151 ShapeTree<HloSharding> GetSubtree(
152     const ShapeTree<HloSharding>& tuple_shape_tree, int element_index) {
153   ShapeTree<HloSharding> element_shape_tree(
154       xla::ShapeUtil::GetTupleElementShape(tuple_shape_tree.shape(),
155                                            element_index),
156       HloSharding::Replicate());
157 
158   xla::ShapeIndex src_index;
159   src_index.push_back(element_index);
160   element_shape_tree.CopySubtreeFrom(tuple_shape_tree, src_index, {});
161   return element_shape_tree;
162 }
163 
GetPerDeviceShape(const Shape & shape,const HloSharding & sharding,int64_t device)164 Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding,
165                         int64_t device) {
166   if (shape.IsTuple()) {
167     ShapeTree<HloSharding> tuple_shape_tree = sharding.GetAsShapeTree(shape);
168     std::vector<Shape> arg_shapes;
169     for (int64_t i = 0; i < xla::ShapeUtil::TupleElementCount(shape); ++i) {
170       Shape element_shape = xla::ShapeUtil::GetTupleElementShape(shape, i);
171       HloSharding element_sharding = tuple_shape_tree.element({i});
172       if (element_shape.IsTuple()) {
173         element_sharding = HloSharding::Tuple(GetSubtree(tuple_shape_tree, i));
174       }
175       if (element_sharding.UsesDevice(device)) {
176         arg_shapes.push_back(
177             GetPerDeviceShape(element_shape, element_sharding, device));
178       }
179     }
180     return xla::ShapeUtil::MakeTupleShape(arg_shapes);
181   }
182 
183   if (sharding.IsTileMaximal()) {
184     return shape;
185   }
186 
187   std::vector<int64_t> dimensions;
188   std::vector<int64_t> offset = sharding.TileOffsetForDevice(shape, device);
189   std::vector<int64_t> limit = sharding.TileLimitForDevice(shape, device);
190   dimensions.resize(limit.size());
191   for (int64_t i = 0; i < limit.size(); ++i) {
192     dimensions[i] = limit[i] - offset[i];
193   }
194   if (shape.has_layout()) {
195     return xla::ShapeUtil::MakeShapeWithLayout(shape.element_type(), dimensions,
196                                                shape.layout().minor_to_major());
197   }
198   return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions);
199 }
200 
AddVariableUpdatesToCores(const TPUCompileMetadataProto & metadata,const XlaCompiler::CompilationResult & compilation_result,const std::vector<ShardingAndIndex> & arg_core_mapping,std::vector<bool> * may_modify_variables,std::vector<std::vector<xla::Shape>> * per_core_output_shapes,std::vector<std::vector<std::pair<int,bool>>> * per_core_variable_indices)201 Status AddVariableUpdatesToCores(
202     const TPUCompileMetadataProto& metadata,
203     const XlaCompiler::CompilationResult& compilation_result,
204     const std::vector<ShardingAndIndex>& arg_core_mapping,
205     std::vector<bool>* may_modify_variables,
206     std::vector<std::vector<xla::Shape>>* per_core_output_shapes,
207     std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices) {
208   // Add all variables to the corresponding core.
209   may_modify_variables->resize(metadata.num_cores_per_replica(), false);
210   int resource_update_pos = 0;
211   for (int i = 0; i < metadata.args_size(); ++i) {
212     const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata.args(i);
213     if (proto_arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
214       const auto& sharding = proto_arg.sharding();
215       bool updated = false;
216       if (resource_update_pos < compilation_result.resource_updates.size()) {
217         const XlaCompiler::ResourceUpdate& update =
218             compilation_result.resource_updates[resource_update_pos];
219         if (update.input_index == i) {
220           updated = true;
221           int pos = compilation_result.outputs.size() + resource_update_pos;
222           xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
223               compilation_result.xla_output_shape, pos);
224           auto add_to_core = [&](int64_t core,
225                                  const xla::Shape& per_core_shape) {
226             (*per_core_output_shapes)[core].push_back(per_core_shape);
227             (*may_modify_variables)[core] =
228                 (*may_modify_variables)[core] || update.modified;
229           };
230           if (sharding.type() == xla::OpSharding::MAXIMAL) {
231             add_to_core(sharding.tile_assignment_devices(0), shape);
232           } else if (sharding.type() == xla::OpSharding::OTHER) {
233             auto sharding_or =
234                 xla::HloSharding::FromProto(proto_arg.sharding());
235             TF_RET_CHECK(sharding_or.ok());
236             for (int64_t core :
237                  proto_arg.sharding().tile_assignment_devices()) {
238               xla::Shape per_core_shape =
239                   GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
240               add_to_core(core, per_core_shape);
241             }
242           } else {
243             TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
244             for (int64_t core = 0; core < metadata.num_cores_per_replica();
245                  ++core) {
246               add_to_core(core, shape);
247             }
248           }
249           ++resource_update_pos;
250         }
251       }
252       if (sharding.type() == xla::OpSharding::MAXIMAL) {
253         (*per_core_variable_indices)[sharding.tile_assignment_devices(0)]
254             .push_back(
255                 std::pair<int, bool>(arg_core_mapping[i].indices[0], updated));
256       } else if (sharding.type() == xla::OpSharding::OTHER) {
257         for (int core : sharding.tile_assignment_devices()) {
258           (*per_core_variable_indices)[core].push_back(
259               std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
260         }
261       } else {
262         TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
263         for (int64_t core = 0; core < metadata.num_cores_per_replica();
264              ++core) {
265           (*per_core_variable_indices)[core].push_back(
266               std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
267         }
268       }
269     }
270   }
271   return OkStatus();
272 }
273 
ComputeOutputShapesForEachCore(const tpu::TPUCompileMetadataProto & metadata,const XlaCompiler::CompilationResult & compilation_result,std::vector<std::vector<xla::Shape>> * per_core_output_shapes)274 Status ComputeOutputShapesForEachCore(
275     const tpu::TPUCompileMetadataProto& metadata,
276     const XlaCompiler::CompilationResult& compilation_result,
277     std::vector<std::vector<xla::Shape>>* per_core_output_shapes) {
278   for (int i = 0; i < metadata.retvals_size(); ++i) {
279     const tpu::TPUCompileMetadataProto::Retval& retval = metadata.retvals(i);
280     TF_RET_CHECK(!compilation_result.outputs[i].is_constant)
281         << "TPU compilation output " << i
282         << " has a compile-time constant value. "
283            "This should never happen.";
284 
285     xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
286         compilation_result.xla_output_shape, i);
287     auto add_shape_to_core = [&](int core, xla::Shape per_core_shape) {
288       (*per_core_output_shapes)[core].push_back(std::move(per_core_shape));
289     };
290     if (retval.sharding().type() == xla::OpSharding::MAXIMAL) {
291       add_shape_to_core(retval.sharding().tile_assignment_devices(0),
292                         std::move(shape));
293     } else if (retval.sharding().type() == xla::OpSharding::OTHER) {
294       auto sharding_or = xla::HloSharding::FromProto(retval.sharding());
295       TF_RET_CHECK(sharding_or.ok());
296       for (int64_t core : retval.sharding().tile_assignment_devices()) {
297         xla::Shape per_core_shape =
298             GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
299         add_shape_to_core(core, std::move(per_core_shape));
300       }
301     } else {
302       TF_RET_CHECK(retval.sharding().type() == xla::OpSharding::REPLICATED)
303           << "Not all of the constant tensors were consumed.";
304       for (int core = 0; core < per_core_output_shapes->size(); ++core) {
305         add_shape_to_core(core, shape);
306       }
307     }
308   }
309   return OkStatus();
310 }
311 
CreateHloModules(const TPUCompileMetadataProto & metadata,const tensorflow::XlaCompiler::CompilationResult & compilation_result,const absl::optional<xla::DeviceAssignment> & device_assignment,std::vector<std::unique_ptr<xla::HloModule>> * hlo_modules)312 Status CreateHloModules(
313     const TPUCompileMetadataProto& metadata,
314     const tensorflow::XlaCompiler::CompilationResult& compilation_result,
315     const absl::optional<xla::DeviceAssignment>& device_assignment,
316     std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules) {
317   TF_RET_CHECK(
318       compilation_result.computation->proto().has_host_program_shape());
319 
320   auto debug_options = xla::DebugOptions();
321   debug_options.set_xla_step_marker_location(metadata.step_marker_location());
322   TF_ASSIGN_OR_RETURN(
323       std::unique_ptr<xla::HloModuleConfig> module_config,
324       CreateModuleConfig(
325           xla::ProgramShape(
326               compilation_result.computation->proto().host_program_shape()),
327           compilation_result.xla_input_shapes,
328           compilation_result.xla_output_shape, device_assignment,
329           metadata.num_replicas(), metadata.num_cores_per_replica(),
330           &debug_options));
331 
332   TF_ASSIGN_OR_RETURN(
333       std::unique_ptr<xla::HloModule> hlo_module,
334       xla::HloModule::CreateFromProto(compilation_result.computation->proto(),
335                                       *module_config));
336   DumpHloModuleIfEnabled(*hlo_module, "before_optimizations");
337   hlo_modules->push_back(std::move(hlo_module));
338 
339   return OkStatus();
340 }
341 
CreateTpuCompilationRequest(const absl::variant<MlirToHloArgs,FunctionToHloArgs> & computation,const TPUCompileMetadataProto & metadata,const std::vector<TensorShape> & arg_shapes)342 StatusOr<TpuCompilationRequestProto> CreateTpuCompilationRequest(
343     const absl::variant<MlirToHloArgs, FunctionToHloArgs>& computation,
344     const TPUCompileMetadataProto& metadata,
345     const std::vector<TensorShape>& arg_shapes) {
346   VLOG(1) << "CreateTpuCompilationRequest.";
347   TpuCompilationRequestProto compilation_request;
348   bool use_mlir = computation.index() == 0;
349   compilation_request.set_use_mlir(use_mlir);
350   if (use_mlir) {
351     VLOG(1) << "Serializing MlirModule";
352     const MlirToHloArgs& mlir_computation = absl::get<0>(computation);
353     *compilation_request.mutable_mlir_module() =
354         string(mlir_computation.mlir_module);
355   } else {
356     VLOG(1) << "Serializing FunctionDefinitionLibrary";
357     const FunctionToHloArgs& function_computation = absl::get<1>(computation);
358     *compilation_request.mutable_fdef_lib() =
359         function_computation.flib_def->ToProto();
360     compilation_request.set_graph_def_version(
361         function_computation.graph_def_version);
362     *compilation_request.mutable_function() = *function_computation.function;
363     // TODO(b/160937500): serializing and copying large guaranteed_constants can
364     // be a perf hit. There is a future work to refactor the compilation layer
365     // to avoid passing guaranteed_constants over C_API.
366     if (function_computation.guaranteed_constants.index() == 0) {
367       absl::Span<const TensorProto* const> guaranteed_constants =
368           absl::get<0>(function_computation.guaranteed_constants);
369       for (const TensorProto* constant : guaranteed_constants) {
370         *compilation_request.add_guaranteed_constants() = *constant;
371       }
372     } else {
373       CHECK_EQ(function_computation.guaranteed_constants.index(), 1);
374       const OpInputList& guaranteed_constants =
375           *absl::get<1>(function_computation.guaranteed_constants);
376       for (const Tensor& constant : guaranteed_constants) {
377         constant.AsProtoTensorContent(
378             compilation_request.add_guaranteed_constants());
379       }
380     }
381   }
382 
383   for (const TensorShape& shape : arg_shapes) {
384     shape.AsProto(compilation_request.add_arg_shapes());
385   }
386 
387   *(compilation_request.mutable_metadata()) = metadata;
388 
389   VLOG(1) << "TpuCompilationRequest:\n" << compilation_request.DebugString();
390   return compilation_request;
391 }
392 
CompileOpMetadataFromContext(OpKernelConstruction * ctx,TPUCompileMetadataProto * metadata,NameAttrList * function_name,std::string * mlir_module)393 Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
394                                     TPUCompileMetadataProto* metadata,
395                                     NameAttrList* function_name,
396                                     std::string* mlir_module) {
397   CHECK_NE(metadata, nullptr);
398 
399   int num_computations;
400   TF_RETURN_IF_ERROR(ctx->GetAttr("num_computations", &num_computations));
401 
402   std::string metadata_string;
403   TF_RETURN_IF_ERROR(ctx->GetAttr("metadata", &metadata_string));
404   if (!metadata->ParsePartialFromString(metadata_string)) {
405     return errors::InvalidArgument("Unable to parse TPUCompileMetadataProto");
406   }
407 
408   if (function_name != nullptr) {
409     TF_RETURN_IF_ERROR(ctx->GetAttr("function", function_name));
410   }
411 
412   if (mlir_module != nullptr) {
413     TF_RETURN_IF_ERROR(ctx->GetAttr("mlir_module", mlir_module));
414   }
415 
416   if (num_computations != metadata->num_cores_per_replica()) {
417     return errors::InvalidArgument(
418         "num_computations must be equal to "
419         "num_cores_per_replica in the 'metadata' "
420         "attribute (",
421         num_computations, " vs ", metadata->num_cores_per_replica(), ")");
422   }
423 
424   if (metadata->has_device_assignment()) {
425     StatusOr<std::unique_ptr<DeviceAssignment>> device_assignment_or_error =
426         DeviceAssignment::Deserialize(metadata->device_assignment());
427     TF_RETURN_IF_ERROR(device_assignment_or_error.status());
428     const DeviceAssignment& device_assignment =
429         *device_assignment_or_error.ValueOrDie();
430     const int num_replicas = metadata->num_replicas();
431     if (device_assignment.replica_count() != num_replicas) {
432       return errors::InvalidArgument(
433           "Device assignment replica_count != num_replicas; ",
434           device_assignment.replica_count(), " vs ", num_replicas);
435     }
436     if (device_assignment.computation_count() !=
437         metadata->num_cores_per_replica()) {
438       return errors::InvalidArgument(
439           "Device assignment computation_count != num_cores_per_replica; ",
440           device_assignment.computation_count(), " vs ",
441           metadata->num_cores_per_replica());
442     }
443   }
444   return OkStatus();
445 }
446 
ComputeArgumentShapes(const tpu::TPUCompileMetadataProto & metadata,const std::vector<TensorShape> & dynamic_shapes,std::vector<TensorShape> * arg_shapes)447 Status ComputeArgumentShapes(const tpu::TPUCompileMetadataProto& metadata,
448                              const std::vector<TensorShape>& dynamic_shapes,
449                              std::vector<TensorShape>* arg_shapes) {
450   arg_shapes->resize(metadata.args_size());
451   int dynamic_shape_pos = 0;
452   for (int i = 0; i < metadata.args_size(); ++i) {
453     const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
454     // The XLA compiler determines the shape of each constant by inspecting the
455     // value of its corresponding host-memory tensor. As a result, we don't need
456     // to give the compiler graph-inferred shapes for constant arguments.
457     if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
458       continue;
459     }
460     TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
461     PartialTensorShape static_shape(arg.shape());
462 
463     TensorShape& shape = (*arg_shapes)[i];
464     if (static_shape.IsFullyDefined()) {
465       TF_RET_CHECK(static_shape.AsTensorShape(&shape));
466     } else {
467       TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
468           << "Too few dynamic shapes";
469       shape = dynamic_shapes[dynamic_shape_pos++];
470       if (!static_shape.IsCompatibleWith(shape)) {
471         return errors::InvalidArgument(
472             "Mismatch between static and dynamic shape for argument. Static "
473             "shape: ",
474             static_shape.DebugString(),
475             "; dynamic shape: ", shape.DebugString());
476       }
477     }
478   }
479   // Checks we consumed all of the dynamic shapes.
480   TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
481       << "Too many dynamic shapes";
482   return OkStatus();
483 }
484 }  // namespace tpu
485 }  // namespace tensorflow
486