1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
16
17 #include <string>
18
19 #include "tensorflow/compiler/xla/debug_options_flags.h"
20 #include "tensorflow/compiler/xla/service/computation_layout.h"
21 #include "tensorflow/compiler/xla/service/computation_placer.h"
22 #include "tensorflow/compiler/xla/service/dump.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
26 #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
27 #include "tensorflow/stream_executor/tpu/proto_helper.h"
28
29 namespace tensorflow {
30 namespace tpu {
31 using ::stream_executor::port::Status;
32 using ::stream_executor::port::StatusOr;
33 using ::xla::ComputationLayout;
34 using ::xla::DebugOptions;
35 using ::xla::DeviceAssignment;
36 using ::xla::HloModuleConfig;
37 using ::xla::HloSharding;
38 using ::xla::InvalidArgument;
39 using ::xla::ProgramShape;
40 using ::xla::Shape;
41 using ::xla::ShapeTree;
42 using ::xla::ShapeUtil;
43
ValidateResultShape(const Shape & client_shape,const Shape & result_shape)44 Status ValidateResultShape(const Shape& client_shape,
45 const Shape& result_shape) {
46 TF_RETURN_IF_ERROR(
47 xla::ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
48 if (!xla::ShapeUtil::Compatible(client_shape, result_shape)) {
49 return InvalidArgument(
50 "Shape used to set computation result layout %s is not compatible "
51 "with result shape %s",
52 xla::ShapeUtil::HumanStringWithLayout(client_shape),
53 xla::ShapeUtil::HumanString(result_shape));
54 }
55 return OkStatus();
56 }
57
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const Shape> argument_shapes,absl::optional<const Shape> result_layout,absl::optional<const DeviceAssignment> device_assignment,int replica_count,int num_partitions,const DebugOptions * debug_options,const int * seed,const int * launch_id,const bool * alias_passthrough_params,const xla::FusionConfigCollection * fusion_config_collection,const std::vector<std::vector<bool>> * fusion_config)58 StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
59 const ProgramShape& program_shape, absl::Span<const Shape> argument_shapes,
60 absl::optional<const Shape> result_layout,
61 absl::optional<const DeviceAssignment> device_assignment, int replica_count,
62 int num_partitions, const DebugOptions* debug_options, const int* seed,
63 const int* launch_id, const bool* alias_passthrough_params,
64 const xla::FusionConfigCollection* fusion_config_collection,
65 const std::vector<std::vector<bool>>* fusion_config) {
66 auto config = absl::make_unique<HloModuleConfig>(program_shape);
67 ComputationLayout* computation_layout =
68 config->mutable_entry_computation_layout();
69 if (program_shape.parameters_size() != argument_shapes.size()) {
70 return InvalidArgument("computation takes %d parameters, but %u given",
71 program_shape.parameters_size(),
72 argument_shapes.size());
73 }
74 for (int i = 0; i < argument_shapes.size(); ++i) {
75 // Verify that shape of arguments matches the shape of the arguments in the
76 // ProgramShape.
77 if (!ShapeUtil::Compatible(argument_shapes[i],
78 program_shape.parameters(i))) {
79 return InvalidArgument(
80 "Argument does not match shape of computation parameter %d: want "
81 "%s, got %s",
82 i, ShapeUtil::HumanString(program_shape.parameters(i)),
83 ShapeUtil::HumanString(argument_shapes[i]));
84 }
85 TF_RETURN_IF_ERROR(
86 computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
87 argument_shapes[i]));
88 }
89
90 if (result_layout.has_value()) {
91 TF_RETURN_IF_ERROR(
92 ValidateResultShape(result_layout.value(), program_shape.result()));
93 TF_RETURN_IF_ERROR(
94 computation_layout->mutable_result_layout()->CopyLayoutFromShape(
95 result_layout.value()));
96 } else {
97 // If the result layout is not set, then choose the default.
98 computation_layout->mutable_result_layout()->SetToDefaultLayout();
99 }
100
101 config->set_replica_count(replica_count);
102 config->set_num_partitions(num_partitions);
103 if (seed != nullptr) {
104 config->set_seed(*seed);
105 }
106 if (launch_id != nullptr) {
107 config->set_launch_id(*launch_id);
108 }
109 if (debug_options != nullptr) {
110 config->set_debug_options(*debug_options);
111 } else {
112 config->set_debug_options(xla::GetDebugOptionsFromFlags());
113 }
114
115 // TODO(henrytan): set intra_op_parallelism_threads.
116 // Reference:
117 // tensorflow/compiler/xla/service/service.cc?l=324.
118
119 if (device_assignment.has_value()) {
120 config->set_static_device_assignment(device_assignment.value());
121 }
122
123 if (alias_passthrough_params != nullptr) {
124 config->set_alias_passthrough_params(*alias_passthrough_params);
125 }
126
127 if (fusion_config_collection != nullptr && fusion_config != nullptr &&
128 *fusion_config_collection != xla::FusionConfigCollection::kOff) {
129 config->set_fusion_config_collection(*fusion_config_collection);
130 *config->mutable_fusion_config() = *fusion_config;
131 }
132
133 return std::move(config);
134 }
135
CreateModuleConfig(const xla::ProgramShape & program_shape,absl::Span<const Shape> argument_shapes,absl::optional<const Shape> result_layout,absl::optional<const DeviceAssignment> device_assignment,int replica_count,int num_partitions,const DebugOptions * debug_options)136 StatusOr<std::unique_ptr<xla::HloModuleConfig>> CreateModuleConfig(
137 const xla::ProgramShape& program_shape,
138 absl::Span<const Shape> argument_shapes,
139 absl::optional<const Shape> result_layout,
140 absl::optional<const DeviceAssignment> device_assignment, int replica_count,
141 int num_partitions, const DebugOptions* debug_options) {
142 return CreateModuleConfig(program_shape, argument_shapes, result_layout,
143 device_assignment, replica_count, num_partitions,
144 debug_options, /*seed=*/nullptr,
145 /*launch_id=*/nullptr,
146 /*alias_passthrough_params=*/nullptr,
147 /*fusion_config_collection=*/nullptr,
148 /*fusion_config=*/nullptr);
149 }
150
GetSubtree(const ShapeTree<HloSharding> & tuple_shape_tree,int element_index)151 ShapeTree<HloSharding> GetSubtree(
152 const ShapeTree<HloSharding>& tuple_shape_tree, int element_index) {
153 ShapeTree<HloSharding> element_shape_tree(
154 xla::ShapeUtil::GetTupleElementShape(tuple_shape_tree.shape(),
155 element_index),
156 HloSharding::Replicate());
157
158 xla::ShapeIndex src_index;
159 src_index.push_back(element_index);
160 element_shape_tree.CopySubtreeFrom(tuple_shape_tree, src_index, {});
161 return element_shape_tree;
162 }
163
GetPerDeviceShape(const Shape & shape,const HloSharding & sharding,int64_t device)164 Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding,
165 int64_t device) {
166 if (shape.IsTuple()) {
167 ShapeTree<HloSharding> tuple_shape_tree = sharding.GetAsShapeTree(shape);
168 std::vector<Shape> arg_shapes;
169 for (int64_t i = 0; i < xla::ShapeUtil::TupleElementCount(shape); ++i) {
170 Shape element_shape = xla::ShapeUtil::GetTupleElementShape(shape, i);
171 HloSharding element_sharding = tuple_shape_tree.element({i});
172 if (element_shape.IsTuple()) {
173 element_sharding = HloSharding::Tuple(GetSubtree(tuple_shape_tree, i));
174 }
175 if (element_sharding.UsesDevice(device)) {
176 arg_shapes.push_back(
177 GetPerDeviceShape(element_shape, element_sharding, device));
178 }
179 }
180 return xla::ShapeUtil::MakeTupleShape(arg_shapes);
181 }
182
183 if (sharding.IsTileMaximal()) {
184 return shape;
185 }
186
187 std::vector<int64_t> dimensions;
188 std::vector<int64_t> offset = sharding.TileOffsetForDevice(shape, device);
189 std::vector<int64_t> limit = sharding.TileLimitForDevice(shape, device);
190 dimensions.resize(limit.size());
191 for (int64_t i = 0; i < limit.size(); ++i) {
192 dimensions[i] = limit[i] - offset[i];
193 }
194 if (shape.has_layout()) {
195 return xla::ShapeUtil::MakeShapeWithLayout(shape.element_type(), dimensions,
196 shape.layout().minor_to_major());
197 }
198 return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions);
199 }
200
AddVariableUpdatesToCores(const TPUCompileMetadataProto & metadata,const XlaCompiler::CompilationResult & compilation_result,const std::vector<ShardingAndIndex> & arg_core_mapping,std::vector<bool> * may_modify_variables,std::vector<std::vector<xla::Shape>> * per_core_output_shapes,std::vector<std::vector<std::pair<int,bool>>> * per_core_variable_indices)201 Status AddVariableUpdatesToCores(
202 const TPUCompileMetadataProto& metadata,
203 const XlaCompiler::CompilationResult& compilation_result,
204 const std::vector<ShardingAndIndex>& arg_core_mapping,
205 std::vector<bool>* may_modify_variables,
206 std::vector<std::vector<xla::Shape>>* per_core_output_shapes,
207 std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices) {
208 // Add all variables to the corresponding core.
209 may_modify_variables->resize(metadata.num_cores_per_replica(), false);
210 int resource_update_pos = 0;
211 for (int i = 0; i < metadata.args_size(); ++i) {
212 const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata.args(i);
213 if (proto_arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
214 const auto& sharding = proto_arg.sharding();
215 bool updated = false;
216 if (resource_update_pos < compilation_result.resource_updates.size()) {
217 const XlaCompiler::ResourceUpdate& update =
218 compilation_result.resource_updates[resource_update_pos];
219 if (update.input_index == i) {
220 updated = true;
221 int pos = compilation_result.outputs.size() + resource_update_pos;
222 xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
223 compilation_result.xla_output_shape, pos);
224 auto add_to_core = [&](int64_t core,
225 const xla::Shape& per_core_shape) {
226 (*per_core_output_shapes)[core].push_back(per_core_shape);
227 (*may_modify_variables)[core] =
228 (*may_modify_variables)[core] || update.modified;
229 };
230 if (sharding.type() == xla::OpSharding::MAXIMAL) {
231 add_to_core(sharding.tile_assignment_devices(0), shape);
232 } else if (sharding.type() == xla::OpSharding::OTHER) {
233 auto sharding_or =
234 xla::HloSharding::FromProto(proto_arg.sharding());
235 TF_RET_CHECK(sharding_or.ok());
236 for (int64_t core :
237 proto_arg.sharding().tile_assignment_devices()) {
238 xla::Shape per_core_shape =
239 GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
240 add_to_core(core, per_core_shape);
241 }
242 } else {
243 TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
244 for (int64_t core = 0; core < metadata.num_cores_per_replica();
245 ++core) {
246 add_to_core(core, shape);
247 }
248 }
249 ++resource_update_pos;
250 }
251 }
252 if (sharding.type() == xla::OpSharding::MAXIMAL) {
253 (*per_core_variable_indices)[sharding.tile_assignment_devices(0)]
254 .push_back(
255 std::pair<int, bool>(arg_core_mapping[i].indices[0], updated));
256 } else if (sharding.type() == xla::OpSharding::OTHER) {
257 for (int core : sharding.tile_assignment_devices()) {
258 (*per_core_variable_indices)[core].push_back(
259 std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
260 }
261 } else {
262 TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
263 for (int64_t core = 0; core < metadata.num_cores_per_replica();
264 ++core) {
265 (*per_core_variable_indices)[core].push_back(
266 std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
267 }
268 }
269 }
270 }
271 return OkStatus();
272 }
273
ComputeOutputShapesForEachCore(const tpu::TPUCompileMetadataProto & metadata,const XlaCompiler::CompilationResult & compilation_result,std::vector<std::vector<xla::Shape>> * per_core_output_shapes)274 Status ComputeOutputShapesForEachCore(
275 const tpu::TPUCompileMetadataProto& metadata,
276 const XlaCompiler::CompilationResult& compilation_result,
277 std::vector<std::vector<xla::Shape>>* per_core_output_shapes) {
278 for (int i = 0; i < metadata.retvals_size(); ++i) {
279 const tpu::TPUCompileMetadataProto::Retval& retval = metadata.retvals(i);
280 TF_RET_CHECK(!compilation_result.outputs[i].is_constant)
281 << "TPU compilation output " << i
282 << " has a compile-time constant value. "
283 "This should never happen.";
284
285 xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
286 compilation_result.xla_output_shape, i);
287 auto add_shape_to_core = [&](int core, xla::Shape per_core_shape) {
288 (*per_core_output_shapes)[core].push_back(std::move(per_core_shape));
289 };
290 if (retval.sharding().type() == xla::OpSharding::MAXIMAL) {
291 add_shape_to_core(retval.sharding().tile_assignment_devices(0),
292 std::move(shape));
293 } else if (retval.sharding().type() == xla::OpSharding::OTHER) {
294 auto sharding_or = xla::HloSharding::FromProto(retval.sharding());
295 TF_RET_CHECK(sharding_or.ok());
296 for (int64_t core : retval.sharding().tile_assignment_devices()) {
297 xla::Shape per_core_shape =
298 GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
299 add_shape_to_core(core, std::move(per_core_shape));
300 }
301 } else {
302 TF_RET_CHECK(retval.sharding().type() == xla::OpSharding::REPLICATED)
303 << "Not all of the constant tensors were consumed.";
304 for (int core = 0; core < per_core_output_shapes->size(); ++core) {
305 add_shape_to_core(core, shape);
306 }
307 }
308 }
309 return OkStatus();
310 }
311
CreateHloModules(const TPUCompileMetadataProto & metadata,const tensorflow::XlaCompiler::CompilationResult & compilation_result,const absl::optional<xla::DeviceAssignment> & device_assignment,std::vector<std::unique_ptr<xla::HloModule>> * hlo_modules)312 Status CreateHloModules(
313 const TPUCompileMetadataProto& metadata,
314 const tensorflow::XlaCompiler::CompilationResult& compilation_result,
315 const absl::optional<xla::DeviceAssignment>& device_assignment,
316 std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules) {
317 TF_RET_CHECK(
318 compilation_result.computation->proto().has_host_program_shape());
319
320 auto debug_options = xla::DebugOptions();
321 debug_options.set_xla_step_marker_location(metadata.step_marker_location());
322 TF_ASSIGN_OR_RETURN(
323 std::unique_ptr<xla::HloModuleConfig> module_config,
324 CreateModuleConfig(
325 xla::ProgramShape(
326 compilation_result.computation->proto().host_program_shape()),
327 compilation_result.xla_input_shapes,
328 compilation_result.xla_output_shape, device_assignment,
329 metadata.num_replicas(), metadata.num_cores_per_replica(),
330 &debug_options));
331
332 TF_ASSIGN_OR_RETURN(
333 std::unique_ptr<xla::HloModule> hlo_module,
334 xla::HloModule::CreateFromProto(compilation_result.computation->proto(),
335 *module_config));
336 DumpHloModuleIfEnabled(*hlo_module, "before_optimizations");
337 hlo_modules->push_back(std::move(hlo_module));
338
339 return OkStatus();
340 }
341
CreateTpuCompilationRequest(const absl::variant<MlirToHloArgs,FunctionToHloArgs> & computation,const TPUCompileMetadataProto & metadata,const std::vector<TensorShape> & arg_shapes)342 StatusOr<TpuCompilationRequestProto> CreateTpuCompilationRequest(
343 const absl::variant<MlirToHloArgs, FunctionToHloArgs>& computation,
344 const TPUCompileMetadataProto& metadata,
345 const std::vector<TensorShape>& arg_shapes) {
346 VLOG(1) << "CreateTpuCompilationRequest.";
347 TpuCompilationRequestProto compilation_request;
348 bool use_mlir = computation.index() == 0;
349 compilation_request.set_use_mlir(use_mlir);
350 if (use_mlir) {
351 VLOG(1) << "Serializing MlirModule";
352 const MlirToHloArgs& mlir_computation = absl::get<0>(computation);
353 *compilation_request.mutable_mlir_module() =
354 string(mlir_computation.mlir_module);
355 } else {
356 VLOG(1) << "Serializing FunctionDefinitionLibrary";
357 const FunctionToHloArgs& function_computation = absl::get<1>(computation);
358 *compilation_request.mutable_fdef_lib() =
359 function_computation.flib_def->ToProto();
360 compilation_request.set_graph_def_version(
361 function_computation.graph_def_version);
362 *compilation_request.mutable_function() = *function_computation.function;
363 // TODO(b/160937500): serializing and copying large guaranteed_constants can
364 // be a perf hit. There is a future work to refactor the compilation layer
365 // to avoid passing guaranteed_constants over C_API.
366 if (function_computation.guaranteed_constants.index() == 0) {
367 absl::Span<const TensorProto* const> guaranteed_constants =
368 absl::get<0>(function_computation.guaranteed_constants);
369 for (const TensorProto* constant : guaranteed_constants) {
370 *compilation_request.add_guaranteed_constants() = *constant;
371 }
372 } else {
373 CHECK_EQ(function_computation.guaranteed_constants.index(), 1);
374 const OpInputList& guaranteed_constants =
375 *absl::get<1>(function_computation.guaranteed_constants);
376 for (const Tensor& constant : guaranteed_constants) {
377 constant.AsProtoTensorContent(
378 compilation_request.add_guaranteed_constants());
379 }
380 }
381 }
382
383 for (const TensorShape& shape : arg_shapes) {
384 shape.AsProto(compilation_request.add_arg_shapes());
385 }
386
387 *(compilation_request.mutable_metadata()) = metadata;
388
389 VLOG(1) << "TpuCompilationRequest:\n" << compilation_request.DebugString();
390 return compilation_request;
391 }
392
CompileOpMetadataFromContext(OpKernelConstruction * ctx,TPUCompileMetadataProto * metadata,NameAttrList * function_name,std::string * mlir_module)393 Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
394 TPUCompileMetadataProto* metadata,
395 NameAttrList* function_name,
396 std::string* mlir_module) {
397 CHECK_NE(metadata, nullptr);
398
399 int num_computations;
400 TF_RETURN_IF_ERROR(ctx->GetAttr("num_computations", &num_computations));
401
402 std::string metadata_string;
403 TF_RETURN_IF_ERROR(ctx->GetAttr("metadata", &metadata_string));
404 if (!metadata->ParsePartialFromString(metadata_string)) {
405 return errors::InvalidArgument("Unable to parse TPUCompileMetadataProto");
406 }
407
408 if (function_name != nullptr) {
409 TF_RETURN_IF_ERROR(ctx->GetAttr("function", function_name));
410 }
411
412 if (mlir_module != nullptr) {
413 TF_RETURN_IF_ERROR(ctx->GetAttr("mlir_module", mlir_module));
414 }
415
416 if (num_computations != metadata->num_cores_per_replica()) {
417 return errors::InvalidArgument(
418 "num_computations must be equal to "
419 "num_cores_per_replica in the 'metadata' "
420 "attribute (",
421 num_computations, " vs ", metadata->num_cores_per_replica(), ")");
422 }
423
424 if (metadata->has_device_assignment()) {
425 StatusOr<std::unique_ptr<DeviceAssignment>> device_assignment_or_error =
426 DeviceAssignment::Deserialize(metadata->device_assignment());
427 TF_RETURN_IF_ERROR(device_assignment_or_error.status());
428 const DeviceAssignment& device_assignment =
429 *device_assignment_or_error.ValueOrDie();
430 const int num_replicas = metadata->num_replicas();
431 if (device_assignment.replica_count() != num_replicas) {
432 return errors::InvalidArgument(
433 "Device assignment replica_count != num_replicas; ",
434 device_assignment.replica_count(), " vs ", num_replicas);
435 }
436 if (device_assignment.computation_count() !=
437 metadata->num_cores_per_replica()) {
438 return errors::InvalidArgument(
439 "Device assignment computation_count != num_cores_per_replica; ",
440 device_assignment.computation_count(), " vs ",
441 metadata->num_cores_per_replica());
442 }
443 }
444 return OkStatus();
445 }
446
ComputeArgumentShapes(const tpu::TPUCompileMetadataProto & metadata,const std::vector<TensorShape> & dynamic_shapes,std::vector<TensorShape> * arg_shapes)447 Status ComputeArgumentShapes(const tpu::TPUCompileMetadataProto& metadata,
448 const std::vector<TensorShape>& dynamic_shapes,
449 std::vector<TensorShape>* arg_shapes) {
450 arg_shapes->resize(metadata.args_size());
451 int dynamic_shape_pos = 0;
452 for (int i = 0; i < metadata.args_size(); ++i) {
453 const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
454 // The XLA compiler determines the shape of each constant by inspecting the
455 // value of its corresponding host-memory tensor. As a result, we don't need
456 // to give the compiler graph-inferred shapes for constant arguments.
457 if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
458 continue;
459 }
460 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
461 PartialTensorShape static_shape(arg.shape());
462
463 TensorShape& shape = (*arg_shapes)[i];
464 if (static_shape.IsFullyDefined()) {
465 TF_RET_CHECK(static_shape.AsTensorShape(&shape));
466 } else {
467 TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
468 << "Too few dynamic shapes";
469 shape = dynamic_shapes[dynamic_shape_pos++];
470 if (!static_shape.IsCompatibleWith(shape)) {
471 return errors::InvalidArgument(
472 "Mismatch between static and dynamic shape for argument. Static "
473 "shape: ",
474 static_shape.DebugString(),
475 "; dynamic shape: ", shape.DebugString());
476 }
477 }
478 }
479 // Checks we consumed all of the dynamic shapes.
480 TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
481 << "Too many dynamic shapes";
482 return OkStatus();
483 }
484 } // namespace tpu
485 } // namespace tensorflow
486