1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
17
18 #include <numeric>
19 #include <vector>
20
21 #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/types/variant.h"
25 #include "tensorflow/compiler/jit/defs.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/jit/shape_inference.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
29 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
30 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
31 #include "tensorflow/compiler/tf2xla/layout_util.h"
32 #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
33 #include "tensorflow/compiler/tf2xla/shape_util.h"
34 #include "tensorflow/compiler/tf2xla/sharding_util.h"
35 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
36 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
37 #include "tensorflow/compiler/tf2xla/type_util.h"
38 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
39 #include "tensorflow/compiler/tf2xla/xla_context.h"
40 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
41 #include "tensorflow/compiler/xla/client/client_library.h"
42 #include "tensorflow/compiler/xla/client/xla_builder.h"
43 #include "tensorflow/compiler/xla/client/xla_computation.h"
44 #include "tensorflow/compiler/xla/protobuf_util.h"
45 #include "tensorflow/compiler/xla/shape_util.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/core/common_runtime/device.h"
48 #include "tensorflow/core/common_runtime/executor.h"
49 #include "tensorflow/core/common_runtime/function.h"
50 #include "tensorflow/core/common_runtime/graph_constructor.h"
51 #include "tensorflow/core/common_runtime/graph_optimizer.h"
52 #include "tensorflow/core/framework/attr_value_util.h"
53 #include "tensorflow/core/framework/function.h"
54 #include "tensorflow/core/framework/node_def_util.h"
55 #include "tensorflow/core/framework/types.h"
56 #include "tensorflow/core/graph/node_builder.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/gtl/cleanup.h"
59 #include "tensorflow/core/lib/hash/hash.h"
60 #include "tensorflow/core/platform/logging.h"
61 #include "tensorflow/core/protobuf/error_codes.pb.h"
62 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
63 #include "tensorflow/core/util/dump_graph.h"
64
65 namespace tensorflow {
66 namespace {
67
68 // Checks that arguments `args` match types `types`.
CheckSignature(const DataTypeVector & types,absl::Span<const XlaCompiler::Argument> args)69 Status CheckSignature(const DataTypeVector& types,
70 absl::Span<const XlaCompiler::Argument> args) {
71 if (args.size() != types.size()) {
72 return errors::Internal("Compilation arguments have ", args.size(),
73 " elements while function has ", types.size());
74 }
75 for (int i = 0, end = types.size(); i < end; ++i) {
76 // Don't perform type checks on resource variables and tensor
77 // lists (DT_VARIANT) as we have to trick the type system in order to
78 // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor.
79 if (types[i] != args[i].type && types[i] != DT_RESOURCE &&
80 types[i] != DT_VARIANT) {
81 return errors::Internal(
82 "Argument ", i, " has declared type ", DataTypeString(args[i].type),
83 " but function parameter has type ", DataTypeString(types[i]));
84 }
85 }
86 return OkStatus();
87 }
88
89 // Uses the _Arg and _Retval nodes in the graph to determine an OpSharding for
90 // each argument and return value.
91 StatusOr<
92 std::pair<std::map<int, xla::OpSharding>, std::map<int, xla::OpSharding>>>
ComputeArgAndRetvalShardings(const Graph & graph)93 ComputeArgAndRetvalShardings(const Graph& graph) {
94 auto get_sharding_for_node =
95 [](const Node* n) -> StatusOr<std::optional<xla::OpSharding>> {
96 TF_ASSIGN_OR_RETURN(
97 auto sharding,
98 ParseShardingFromDevice(*n, std::numeric_limits<int32>::max(),
99 /*add_metadata=*/false));
100 return sharding;
101 };
102 std::map<int, xla::OpSharding> arg_shardings;
103 std::map<int, xla::OpSharding> retval_shardings;
104 for (const Node* n : graph.nodes()) {
105 if (n->IsArg()) {
106 TF_ASSIGN_OR_RETURN(auto sharding, get_sharding_for_node(n));
107 if (!sharding.has_value()) continue;
108 int index;
109 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
110 TF_RET_CHECK(index >= 0) << "Negative _Arg index";
111 arg_shardings[index] = std::move(*sharding);
112 } else if (n->IsRetval()) {
113 TF_ASSIGN_OR_RETURN(auto sharding, get_sharding_for_node(n));
114 if (!sharding.has_value()) continue;
115 int index;
116 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
117 TF_RET_CHECK(index >= 0) << "Negative _Retval index";
118 retval_shardings[index] = std::move(*sharding);
119 }
120 }
121 return std::make_pair(std::move(arg_shardings), std::move(retval_shardings));
122 }
123
ExecuteGraph(XlaContext * xla_context,std::unique_ptr<Graph> graph,XlaCompilationDevice * device,FunctionLibraryRuntime * flib,int64_t step_id)124 Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
125 XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
126 int64_t step_id) {
127 // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
128 // resource manager takes ownership via Create, and unrefs via Cleanup. We
129 // explicitly add a reference to ensure the refcount at entry is maintained at
130 // all exit points; Create and Cleanup are always called in this function.
131 //
132 // The Executor requires us to use ScopedStepContainer. We wrap it in a
133 // unique_ptr so we can capture the cleanup status in the end.
134 xla_context->Ref();
135 Status status;
136 auto step_container = std::make_unique<ScopedStepContainer>(
137 step_id, [&status, device](const string& name) {
138 status = device->resource_manager()->Cleanup(name);
139 });
140 TF_RETURN_IF_ERROR(step_container->Create(device->resource_manager(),
141 XlaContext::kXlaContextResourceName,
142 xla_context));
143
144 GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
145 TF_RETURN_IF_ERROR(graph_compiler.Compile());
146 // Explicitly clean up the step container, to capture the cleanup status.
147 step_container.reset();
148 return status;
149 }
150
151 // Builds the XLA computation.
152 // - `args` is the list of input arguments
153 // - `retvals` is the list of retvals produced by _Retval operators, in index
154 // order.
155 // - `arg_shardings` and `retval_shardings` are mapping from arg/return indices
156 // to sharding.
157 // - If `return_updated_values_for_all_resources` is true, all resources will be
158 // included in `resource_updates`, regardless of whether their value changed.
159 // - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
160 // - Sets `*resource_updates` to a description of resources whose values are
161 // written by the computation; the variable writes are the last
162 // - `resource_updates.size()` return values from the computation. Each entry in
163 // `resource_updates` is a ResourceUpdate, whose `index` is the index of a
164 // resource variable argument to the computation to be updated, and `type` is
165 // the type of the final output.
BuildComputation(const std::vector<XlaCompiler::Argument> & args,const std::vector<XlaExpression> & retvals,const std::map<int,xla::OpSharding> & arg_shardings,const std::map<int,xla::OpSharding> & retval_shardings,const std::vector<std::unique_ptr<XlaResource>> & resources,std::unique_ptr<xla::XlaOp> token_output,const XlaShapeLayoutHelpers::ShapeDeterminationFns & shape_determination_fns,bool is_entry_computation,bool return_updated_values_for_all_resources,bool always_return_tuple,bool use_tuple_arg,bool alias_resource_update,xla::XlaBuilder * builder,xla::XlaComputation * computation,int * num_computation_outputs,int * num_nonconst_outputs,std::vector<XlaCompiler::OutputDescription> * outputs,std::vector<XlaCompiler::ResourceUpdate> * resource_updates,xla::Shape * output_shape,absl::Span<int const> input_mapping)166 Status BuildComputation(
167 const std::vector<XlaCompiler::Argument>& args,
168 const std::vector<XlaExpression>& retvals,
169 const std::map<int, xla::OpSharding>& arg_shardings,
170 const std::map<int, xla::OpSharding>& retval_shardings,
171 const std::vector<std::unique_ptr<XlaResource>>& resources,
172 std::unique_ptr<xla::XlaOp> token_output,
173 const XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns,
174 bool is_entry_computation, bool return_updated_values_for_all_resources,
175 bool always_return_tuple, bool use_tuple_arg, bool alias_resource_update,
176 xla::XlaBuilder* builder, xla::XlaComputation* computation,
177 int* num_computation_outputs, int* num_nonconst_outputs,
178 std::vector<XlaCompiler::OutputDescription>* outputs,
179 std::vector<XlaCompiler::ResourceUpdate>* resource_updates,
180 xla::Shape* output_shape, absl::Span<int const> input_mapping) {
181 // Attach a common operator name as metadata. This has no semantic effect — it
182 // merely makes the HLO graph more readable when visualized via TensorBoard,
183 // since TensorBoard forms groups out of operators with similar names.
184 xla::OpMetadata retval_metadata;
185 retval_metadata.set_op_name("XLA_Retvals");
186 builder->SetOpMetadata(retval_metadata);
187 VLOG(1) << "Building new computation";
188 auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
189
190 // Builds a no-op XLA computation. We need to set the sharding of outputs, but
191 // cannot change the sharding of the existing output op. To do this, we build
192 // a new identity op to which shardings can be applied.
193 auto identity_op = [builder](xla::XlaOp op,
194 const std::optional<xla::OpSharding>& sharding) {
195 xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
196 return xla::Copy(op);
197 };
198
199 std::vector<xla::XlaOp> elems;
200 elems.reserve(retvals.size());
201
202 // Keeps track of sharding of each retval. If a retval is not in this list,
203 // replicate sharding is used. The first element is the output index, second
204 // element is the sharding.
205 std::unordered_map<int, xla::OpSharding> retval_index_and_sharding;
206 for (int i = 0, end = retvals.size(); i < end; ++i) {
207 XlaCompiler::OutputDescription& output = (*outputs)[i];
208 const XlaExpression& retval = retvals[i];
209 output.type = retval.dtype();
210 switch (retval.kind()) {
211 case XlaExpression::Kind::kConstant:
212 output.is_constant = true;
213 output.constant_value = *retval.constant_value();
214 output.shape = output.constant_value.shape();
215 break;
216
217 case XlaExpression::Kind::kTensorList: {
218 output.is_tensor_list = true;
219 xla::XlaOp value = retval.handle();
220 elems.push_back(value);
221 break;
222 }
223
224 case XlaExpression::Kind::kXlaOp: {
225 output.is_constant = false;
226 TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
227 xla::XlaOp value = retval.handle();
228 auto it = retval_shardings.find(i);
229 std::optional<xla::OpSharding> sharding =
230 it == retval_shardings.end() ? std::optional<xla::OpSharding>()
231 : it->second;
232 if (it != retval_shardings.end()) {
233 retval_index_and_sharding[elems.size()] = it->second;
234 }
235 if (shape_determination_fns.shape_representation_fn) {
236 TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(value));
237 TF_ASSIGN_OR_RETURN(value,
238 ReshapeWithCorrectRepresentationAndSharding(
239 builder, value, original_shape,
240 shape_determination_fns, sharding,
241 /*fast_mem=*/false));
242 }
243 if (it != retval_shardings.end()) {
244 // Apply the sharding to the output, if there is a core assignment.
245 value = identity_op(value, sharding);
246 }
247
248 elems.push_back(value);
249 break;
250 }
251
252 case XlaExpression::Kind::kResource:
253 // Resources will be pushed into elems later when processing resource
254 // arguments below.
255 output.is_constant = false;
256 output.input_index = retval.resource()->arg_num();
257 output.shape = retval.resource()->shape();
258 break;
259
260 case XlaExpression::Kind::kInvalid:
261 return errors::InvalidArgument(
262 "Invalid expression returned by computation. "
263 "This probably means a return value was not set.");
264 }
265 }
266 *num_nonconst_outputs = elems.size();
267
268 // Add return values for resources whose values have changed.
269 std::vector<const XlaResource*> arg_resources;
270 arg_resources.reserve(resources.size());
271 for (const auto& resource : resources) {
272 if (resource->arg_num() >= 0) {
273 arg_resources.push_back(resource.get());
274 }
275 }
276 std::sort(arg_resources.begin(), arg_resources.end(),
277 [](const XlaResource* a, const XlaResource* b) {
278 return a->arg_num() < b->arg_num();
279 });
280
281 absl::flat_hash_map<int, int> argument_to_xla_arg;
282 for (int xla_arg = 0; xla_arg < input_mapping.size(); xla_arg++) {
283 argument_to_xla_arg[input_mapping[xla_arg]] = xla_arg;
284 }
285
286 std::vector<xla::XlaBuilder::InputOutputAlias> aliases;
287 for (const XlaResource* resource : arg_resources) {
288 DCHECK_LT(resource->arg_num(), args.size());
289 const XlaCompiler::Argument& arg = args[resource->arg_num()];
290 auto it = arg_shardings.find(resource->arg_num());
291 bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
292 // TensorArray gradients were modified if their values changed or there are
293 // any newly created gradients.
294 for (const auto& grad : resource->tensor_array_gradients()) {
295 modified =
296 modified ||
297 !grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
298 arg.tensor_array_gradients.count(grad.first) == 0;
299 }
300
301 if (return_updated_values_for_all_resources || modified ||
302 arg.requires_broadcast) {
303 resource_updates->emplace_back();
304 XlaCompiler::ResourceUpdate& update = resource_updates->back();
305 update.input_index = resource->arg_num();
306 update.type = resource->type();
307 update.shape = resource->shape();
308 update.modified = modified;
309 int param_num = use_tuple_arg ? 0 : update.input_index;
310 if (is_entry_computation &&
311 arg.resource_kind != XlaResource::kTensorArray &&
312 alias_resource_update && argument_to_xla_arg.count(param_num)) {
313 // Assuming tuple arg and results are used.
314 xla::ShapeIndex param_index =
315 use_tuple_arg ? xla::ShapeIndex({update.input_index})
316 : xla::ShapeIndex{};
317 int xla_param_num = argument_to_xla_arg[param_num];
318 int64_t output_index_num = elems.size();
319 xla::ShapeIndex output_index = xla::ShapeIndex({output_index_num});
320 VLOG(3) << "Storing alias: " << output_index.ToString() << ": ("
321 << xla_param_num << ", " << param_index.ToString() << ")";
322 aliases.push_back({output_index, xla_param_num, param_index});
323 }
324 for (const auto& grad : resource->tensor_array_gradients()) {
325 update.tensor_array_gradients_accessed.insert(grad.first);
326 }
327
328 xla::XlaOp handle;
329 TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
330 auto sharding = it == arg_shardings.end()
331 ? std::optional<xla::OpSharding>()
332 : it->second;
333 // Set layout of the retval to device representation layout.
334 if (shape_determination_fns.layout_preference_fn &&
335 shape_determination_fns.shape_representation_fn) {
336 TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(handle));
337 TF_ASSIGN_OR_RETURN(
338 handle, ReshapeWithCorrectRepresentationAndSharding(
339 builder, handle, original_shape,
340 shape_determination_fns, sharding, arg.fast_mem));
341 }
342
343 // Request that the value be returned on a specific core.
344 if (it != arg_shardings.end()) {
345 retval_index_and_sharding[elems.size()] = it->second;
346 }
347 // Ensures the correct sharding is applied to the output.
348 handle = identity_op(handle, sharding);
349 elems.push_back(handle);
350 }
351 }
352
353 // If we have token output, append it as the last one.
354 if (token_output) {
355 elems.push_back(*token_output);
356 }
357
358 *num_computation_outputs = elems.size();
359
360 // Builds the XLA computation. We *always* form a tuple here to ensure that
361 // the output value is the last thing added into the XLA computation, even
362 // if there is only one output value.
363 xla::XlaOp tuple;
364 if (retval_index_and_sharding.empty() || !is_entry_computation) {
365 tuple = xla::Tuple(builder, elems);
366 } else {
367 std::vector<xla::Shape> elem_shapes;
368 for (const auto& elem : elems) {
369 TF_ASSIGN_OR_RETURN(xla::Shape elem_shape,
370 elem.builder()->GetShape(elem));
371 elem_shapes.push_back(elem_shape);
372 }
373 xla::Shape shape = xla::ShapeUtil::MakeTupleShape(elem_shapes);
374 // Copy specified sharding from retval_index_and_sharding.
375 std::vector<xla::HloSharding> sharding_elems;
376 for (int i = 0, end = elems.size(); i < end; i++) {
377 const auto& iter = retval_index_and_sharding.find(i);
378 TF_RET_CHECK(iter != retval_index_and_sharding.end());
379 const xla::OpSharding& sub_op_sharding = iter->second;
380 TF_ASSIGN_OR_RETURN(xla::HloSharding sub_sharding,
381 xla::HloSharding::FromProto(sub_op_sharding));
382 if (elem_shapes[i].IsTuple()) {
383 const std::vector<xla::HloSharding> sub_sharding_elems =
384 sub_sharding.tuple_elements();
385 const int64_t sub_sharding_elems_size = sub_sharding_elems.size();
386 TF_RET_CHECK(sub_sharding_elems_size ==
387 xla::ShapeUtil::GetLeafCount(elem_shapes[i]));
388 for (const auto& sub_sharding_elem : sub_sharding_elems) {
389 sharding_elems.push_back(sub_sharding_elem);
390 }
391 } else {
392 sharding_elems.push_back(sub_sharding);
393 }
394 }
395 xla::HloSharding modified_sharding =
396 xla::HloSharding::Tuple(shape, sharding_elems);
397 xla::OpSharding op_sharding = modified_sharding.ToProto();
398 // Assign proper sharding to the tuple instruction.
399 xla::XlaScopedShardingAssignment assign_sharding(builder, op_sharding);
400 tuple = xla::Tuple(builder, elems);
401 }
402 bool returns_tuple = always_return_tuple || elems.size() != 1;
403 VLOG(3) << "Computation returns a tuple=" << returns_tuple;
404 if (!returns_tuple) {
405 xla::GetTupleElement(tuple, 0);
406
407 for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
408 if (alias.output_index == xla::ShapeIndex({0})) {
409 VLOG(3) << "For aliased parameter " << alias.param_number << ": "
410 << alias.param_index.ToString()
411 << " normalizing output_index from {0} to {}, as a scalar is "
412 "returned from the cluster";
413 alias.output_index = xla::ShapeIndex({});
414 }
415 }
416 }
417
418 for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
419 builder->SetUpAlias(alias.output_index, alias.param_number,
420 alias.param_index);
421 }
422 TF_ASSIGN_OR_RETURN(*computation, builder->Build());
423
424 TF_ASSIGN_OR_RETURN(auto program_shape, computation->GetProgramShape());
425 *output_shape = program_shape.result();
426 return OkStatus();
427 }
428
429 } // namespace
430
431
HumanString() const432 string XlaCompiler::Argument::HumanString() const {
433 string common;
434 if (!name.empty()) {
435 common = absl::StrCat(" name=", name);
436 }
437 absl::StrAppend(&common, " type=", DataTypeString(type),
438 " shape=", ShapeHumanString());
439 absl::StrAppend(
440 &common, " is_same_data_across_replicas=", is_same_data_across_replicas);
441 switch (kind) {
442 case kInvalid:
443 return "invalid";
444 case kConstant:
445 return absl::StrCat("kind=constant", common,
446 " value=", constant_value.DebugString());
447 case kConstantResource:
448 return absl::StrCat("kind=constant-resource", common,
449 " value=", constant_value.DebugString());
450 case kResource: {
451 string output = absl::StrCat(
452 "kind=resource", common,
453 " resource_kind=", XlaResource::KindToString(resource_kind),
454 " initialized=", initialized, " is_fast_mem=", fast_mem);
455 if (max_array_size >= 0) {
456 absl::StrAppend(&output, " max_array_size=", max_array_size);
457 }
458 if (!tensor_array_gradients.empty()) {
459 absl::StrAppend(&output, " tensor_array_gradients=",
460 absl::StrJoin(tensor_array_gradients, ","));
461 }
462 return output;
463 }
464 case kParameter:
465 return absl::StrCat("kind=parameter", common);
466 case kTensorList:
467 return absl::StrCat("kind=tensorlist", common);
468 case kToken:
469 return absl::StrCat("token", common);
470 }
471 }
472
DimensionSizes() const473 std::vector<int64_t> XlaCompiler::Argument::DimensionSizes() const {
474 if (absl::holds_alternative<TensorShape>(shape)) {
475 return xla::InlinedVectorToVector(std::get<TensorShape>(shape).dim_sizes());
476 } else {
477 return xla::SpanToVector(std::get<xla::Shape>(shape).dimensions());
478 }
479 }
480
481 absl::InlinedVector<int64_t, 4>
DimensionSizesAsInlinedVector() const482 XlaCompiler::Argument::DimensionSizesAsInlinedVector() const {
483 if (absl::holds_alternative<TensorShape>(shape)) {
484 return std::get<TensorShape>(shape).dim_sizes();
485 } else {
486 auto v = std::get<xla::Shape>(shape).dimensions();
487 return absl::InlinedVector<int64_t, 4>(v.begin(), v.end());
488 }
489 }
490
ShapeHumanString() const491 string XlaCompiler::Argument::ShapeHumanString() const {
492 if (absl::holds_alternative<TensorShape>(shape)) {
493 return std::get<TensorShape>(shape).DebugString();
494 } else {
495 return std::get<xla::Shape>(shape).DebugString();
496 }
497 }
498
XlaCompiler(XlaCompiler::Options options)499 XlaCompiler::XlaCompiler(XlaCompiler::Options options)
500 : options_(options),
501 initialization_status_(OkStatus()),
502 next_step_id_(1),
503 device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
504 device_mgr_(absl::WrapUnique(device_)) {
505 CHECK(!options_.device_type.type_string().empty());
506 if (options_.populate_resource_manager) {
507 initialization_status_ =
508 (*options_.populate_resource_manager)(device_->resource_manager());
509 }
510
511 local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
512 FunctionDefLibrary{}));
513 local_pflr_.reset(new ProcessFunctionLibraryRuntime(
514 &device_mgr_, Env::Default(), /*config=*/nullptr,
515 options.graph_def_version, local_flib_def_.get(), OptimizerOptions()));
516 pflr_.reset(new ProcessFunctionLibraryRuntime(
517 &device_mgr_, Env::Default(), /*config=*/nullptr,
518 options.graph_def_version, options.flib_def, OptimizerOptions()));
519
520 local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
521 flib_runtime_ = pflr_->GetFLR(device_->name());
522
523 // The default layout preference is no preference and the default shape
524 // representation function is the identity.
525 XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns =
526 options_.shape_determination_fns;
527 if (!shape_determination_fns.shape_representation_fn) {
528 shape_determination_fns.shape_representation_fn =
529 IdentityShapeRepresentationFn();
530 }
531 if (!shape_determination_fns.layout_preference_fn) {
532 shape_determination_fns.layout_preference_fn = UseNoPreferenceLayoutFn();
533 }
534 }
535
536 XlaCompiler::~XlaCompiler() = default;
537
NextStepId()538 int64_t XlaCompiler::NextStepId() { return next_step_id_++; }
539
operator ()(const std::pair<string,std::vector<Argument>> & signature) const540 uint64 XlaCompiler::SignatureHash::operator()(
541 const std::pair<string, std::vector<Argument>>& signature) const {
542 return std::hash<string>()(signature.first);
543 }
544
GetFunctionBody(const NameAttrList & function,FunctionLibraryRuntime * flib_runtime,const FunctionBody ** fbody)545 static Status GetFunctionBody(const NameAttrList& function,
546 FunctionLibraryRuntime* flib_runtime,
547 const FunctionBody** fbody) {
548 FunctionLibraryRuntime::Handle handle;
549 TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
550 function.name(), AttrSlice(&function.attr()), &handle));
551
552 *fbody = flib_runtime->GetFunctionBody(handle);
553 TF_RET_CHECK(*fbody);
554 return OkStatus();
555 }
556
FindFunctionBody(const NameAttrList & function,const FunctionBody ** fbody,const ConfigProto ** config_proto)557 Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
558 const FunctionBody** fbody,
559 const ConfigProto** config_proto) {
560 // The function may be in either the local_flib_runtime_ or flib_runtime_.
561 // Look up the function in local first and if it is not found then look up the
562 // function in flib_runtime_.
563 auto status = GetFunctionBody(function, local_flib_runtime_, fbody);
564 if (!status.ok()) {
565 if (!errors::IsNotFound(status)) {
566 return status;
567 }
568 TF_RETURN_WITH_CONTEXT_IF_ERROR(
569 GetFunctionBody(function, flib_runtime_, fbody),
570 "Local lookup failed with: ", status.error_message());
571 if (config_proto) {
572 *config_proto = flib_runtime_->config_proto();
573 }
574 VLOG(4) << "Function " << function.name() << " in flib_runtime_";
575 } else {
576 if (config_proto) {
577 *config_proto = local_flib_runtime_->config_proto();
578 }
579 VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
580 }
581 return OkStatus();
582 }
583
GetGraph(const FunctionBody * fbody)584 std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
585 std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
586 CopyGraph(*fbody->graph, graph.get());
587
588 bool is_inside_mustcompile = false;
589 TryGetNodeAttr(AttrSlice(&fbody->fdef.attr()), kXlaMustCompileAttr,
590 &is_inside_mustcompile);
591
592 // Performs a first function inlining pass before shape inference, since
593 // otherwise shape inference can't see inside functions and a comprehensive
594 // shape_map, including function ops, is needed to constant-propagate Shape
595 // Ops below.
596 auto flags = GetBuildXlaOpsPassFlags();
597 OptimizerOptions opts;
598 opts.set_opt_level(OptimizerOptions::L0);
599 opts.set_do_common_subexpression_elimination(false);
600 opts.set_do_function_inlining(true);
601 opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
602 GraphOptimizer optimizer(opts);
603 // Do not constant fold nodes that output DT_VARIANT type tensors.
604 // XLA does not support Const nodes of Variant type since it needs
605 // to know the original ops to be able to compile them to the relevant
606 // XLA form.
607 // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
608 // the form:
609 // Const
610 // |
611 // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
612 // |
613 // (Discard popped list)
614 //
615 // Would have been reduced to "Const -> Op" without this filter.
616 // However since we are only allowed to specify the filter at the "Node"
617 // level there is no good way to allow the above behavior. So we
618 // disallow any sort of constant folding on Variant nodes for now.
619 //
620 // Also do not consider constant folding Shape ops. When there is a dynamic
621 // dimension in a tensor, TF2XLA currently represent them as the static
622 // upperbound shape, which can be constant folded and then lose the info
623 // that this Shape is dynamic.
624 auto cf_consider_fn = [](const Node* n) {
625 for (const auto& output_arg : n->op_def().output_arg()) {
626 if (output_arg.type() == DT_VARIANT) {
627 return false;
628 }
629 }
630 const auto& ts = n->type_string();
631 // XLA has special logic to handle dynamic shapes, don't constant fold
632 // them.
633 if (ts == "Shape" || ts == "ShapeN" || ts == "Size") {
634 return false;
635 }
636 return true;
637 };
638 GraphOptimizer::Options graph_optimizer_options;
639 graph_optimizer_options.cf_consider_fn = cf_consider_fn;
640 graph_optimizer_options.inline_multi_device_functions = true;
641 graph_optimizer_options.inline_impl_selection_group_functions = true;
642 graph_optimizer_options.inline_with_single_device_body_placer = true;
643 graph_optimizer_options.ignore_noinline = is_inside_mustcompile;
644
645 {
646 GraphShapeInfo shape_info;
647 InferShapes(graph.get(), /*arg_shapes=*/{},
648 flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
649 .IgnoreError();
650 auto node_name_index = graph->BuildNodeNameIndex();
651 std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
652 for (const auto& node_shape_info : shape_info) {
653 const string& node_name = node_shape_info.first;
654 const std::vector<InferredShape>& output_shapes = node_shape_info.second;
655 const auto& node_iter = node_name_index.find(node_name);
656 if (node_iter != node_name_index.end()) {
657 auto& partial_shapes = shape_map[node_name];
658 for (const auto& inferred_shape : output_shapes) {
659 partial_shapes.push_back(inferred_shape.shape);
660 }
661 }
662 }
663 graph_optimizer_options.shape_map = &shape_map;
664 optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
665 /*device=*/nullptr, &graph, graph_optimizer_options);
666 }
667
668 // Run shape inference on the graph and optimize the graph again.
669 GraphShapeInfo shape_info;
670 InferShapes(graph.get(), /*arg_shapes=*/{},
671 flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
672 .IgnoreError();
673 auto node_name_index = graph->BuildNodeNameIndex();
674 std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
675 for (const auto& node_shape_info : shape_info) {
676 const string& node_name = node_shape_info.first;
677 const std::vector<InferredShape>& output_shapes = node_shape_info.second;
678 const auto& node_iter = node_name_index.find(node_name);
679 if (node_iter != node_name_index.end()) {
680 auto& partial_shapes = shape_map[node_name];
681 for (const auto& inferred_shape : output_shapes) {
682 partial_shapes.push_back(inferred_shape.shape);
683 }
684 }
685 }
686 graph_optimizer_options.shape_map = &shape_map;
687 optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
688 /*device=*/nullptr, &graph, graph_optimizer_options);
689
690 return graph;
691 }
692
693 // Collects all control rets from `orig_control_ret_nodes` that are still valid,
694 // keeping the same order.
GetValidControlRets(absl::Span<Node * const> orig_control_ret_nodes,const Graph & graph)695 std::vector<std::string> GetValidControlRets(
696 absl::Span<Node* const> orig_control_ret_nodes, const Graph& graph) {
697 // Build map from control ret node name to index.
698 // We use Node name instead of Node* here to index into the map as we populate
699 // the map with nodes in FunctionDef control_ret_nodes and later query it
700 // using the nodes in `graph`. The Node pointers would be different but the
701 // Node name is expected to remain the same between the two.
702 absl::flat_hash_map<const string, int> control_ret_nodes_map;
703 for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
704 const Node* n = orig_control_ret_nodes[i];
705 control_ret_nodes_map[n->name()] = i;
706 }
707 // Check which control rets are still valid.
708 std::vector<bool> is_valid_control_ret(orig_control_ret_nodes.size(), false);
709 int num_valid_control_rets = 0;
710 for (const Node* n : graph.nodes()) {
711 auto iter = control_ret_nodes_map.find(n->name());
712 if (iter != control_ret_nodes_map.end()) {
713 ++num_valid_control_rets;
714 is_valid_control_ret[iter->second] = true;
715 }
716 }
717 // Return valid control rets in same order as they appear in
718 // `orig_control_ret_nodes`.
719 std::vector<std::string> valid_control_rets;
720 valid_control_rets.reserve(num_valid_control_rets);
721 for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
722 if (is_valid_control_ret[i]) {
723 valid_control_rets.push_back(orig_control_ret_nodes[i]->name());
724 }
725 }
726 return valid_control_rets;
727 }
728
CompileFunction(const XlaCompiler::CompileOptions & options,const NameAttrList & fn_name_attrs,absl::Span<const XlaCompiler::Argument> args,XlaCompiler::CompilationResult * result)729 Status XlaCompiler::CompileFunction(
730 const XlaCompiler::CompileOptions& options,
731 const NameAttrList& fn_name_attrs,
732 absl::Span<const XlaCompiler::Argument> args,
733 XlaCompiler::CompilationResult* result) {
734 const string function_id =
735 Canonicalize(fn_name_attrs.name(), AttrSlice(&fn_name_attrs.attr()));
736 VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
737
738 const std::vector<XlaCompiler::Argument> arg_vector(args.begin(), args.end());
739 auto it = cache_.find({function_id, arg_vector});
740 if (it != cache_.end()) {
741 *result = it->second;
742 return OkStatus();
743 }
744
745 const FunctionBody* fbody;
746 const ConfigProto* config = nullptr;
747 TF_RETURN_IF_ERROR(FindFunctionBody(fn_name_attrs, &fbody, &config));
748
749 std::optional<ConfigProto> config_proto;
750 if (config) {
751 config_proto = *config;
752 }
753
754 TF_RETURN_WITH_CONTEXT_IF_ERROR(
755 CheckSignature(fbody->arg_types, args),
756 "Signature check failure while compiling: ", fn_name_attrs.name());
757
758 // Set shapes for _Arg nodes. They are useful for constant folding (e.g. an
759 // Xla op requires a compile-time constant input, and that input is shape of
760 // an _Arg node.
761 for (int i = 0, end = args.size(); i < end; i++) {
762 // Skip resource variables and tensor lists.
763 DataType dtype;
764 TF_RETURN_IF_ERROR(GetNodeAttr(fbody->arg_nodes[i]->def(), "T", &dtype));
765 if (dtype == DT_RESOURCE || dtype == DT_VARIANT) {
766 continue;
767 }
768
769 if (absl::holds_alternative<xla::Shape>(args[i].shape)) {
770 xla::Shape xla_shape = std::get<xla::Shape>(args[i].shape);
771 TensorShape tensor_shape;
772 // If xla_shape is dynamic, prevent constant folding by not setting
773 // output_shapes.
774 if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok() &&
775 xla_shape.is_static()) {
776 fbody->arg_nodes[i]->ClearAttr("_output_shapes");
777 fbody->arg_nodes[i]->AddAttr("_output_shapes",
778 std::vector<TensorShape>{tensor_shape});
779 }
780 } else {
781 TensorShape tensor_shape = std::get<TensorShape>(args[i].shape);
782 fbody->arg_nodes[i]->ClearAttr("_output_shapes");
783 fbody->arg_nodes[i]->AddAttr("_output_shapes",
784 std::vector<TensorShape>{tensor_shape});
785 }
786 }
787
788 std::unique_ptr<Graph> graph = GetGraph(fbody);
789
790 // _Arg and _Retval nodes don't exist in the stored subgraph for the function;
791 // they are added by the function body looked up. Therefore, they don't have
792 // core assignments here.
793 // Attempt to assign a core to each _Retval and _Arg. Chooses the
794 // lowest-numbered core that consumes the argument. We choose the
795 // lowest-numbered core so the assignment is deterministic.
796 for (Node* n : graph->nodes()) {
797 if (n->IsArg()) {
798 TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
799 }
800 }
801 // Do _Retval as a second loop, in case the retval's input is an _Arg (which
802 // may have gotten a device assignment from the first loop).
803 for (Node* n : graph->nodes()) {
804 if (n->IsRetval()) {
805 TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
806 }
807 }
808
809 if (VLOG_IS_ON(2)) {
810 VLOG(2) << "XlaCompiler::CompileFunction: "
811 << DumpGraphToFile(
812 absl::StrCat("xla_compile_function_", function_id), *graph);
813 }
814
815 VLOG(1) << "====================================================";
816
817 auto state = ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
818 if (options.is_entry_computation) {
819 state = GetMlirBridgeRolloutState(config_proto);
820 }
821
822 if (state == ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
823 GraphDebugInfo debug_info;
824 VLOG(1) << "Using the MLIR bridge to compile the function.";
825 std::vector<std::string> valid_control_rets =
826 GetValidControlRets(fbody->control_ret_nodes, *graph);
827 auto mlir_result = CompileGraphToXlaHlo(
828 std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
829 valid_control_rets, options_.device_type.type_string(),
830 options.use_tuple_arg, /*analyse_graph=*/false, *options_.flib_def,
831 debug_info, options_.shape_determination_fns, result);
832 if (mlir_result.ok()) {
833 VLOG(1) << "MLIR bridge was successfull";
834 } else {
835 VLOG(1) << "MLIR failed, no fallback";
836 return mlir_result;
837 }
838 } else {
839 VLOG(1) << "MLIR bridge off. Using the old bridge to compile the function";
840 TF_RETURN_IF_ERROR(
841 CompileGraph(options, function_id, std::move(graph), args, result));
842 }
843 VLOG(1) << "====================================================";
844
845 cache_[{function_id, arg_vector}] = *result;
846 return OkStatus();
847 }
848
849 // Computes the XLA shape for argument 'arg'.
XLAShapeForArgument(const XlaCompiler::Argument & arg,bool is_entry_computation,const std::optional<xla::HloSharding> & arg_sharding,xla::Shape * xla_shape) const850 Status XlaCompiler::XLAShapeForArgument(
851 const XlaCompiler::Argument& arg, bool is_entry_computation,
852 const std::optional<xla::HloSharding>& arg_sharding,
853 xla::Shape* xla_shape) const {
854 switch (arg.kind) {
855 case XlaCompiler::Argument::kConstant:
856 LOG(FATAL) << "Unreachable case";
857 case XlaCompiler::Argument::kParameter: {
858 if (is_entry_computation) {
859 TensorShape shape;
860 if (absl::holds_alternative<TensorShape>(arg.shape)) {
861 shape = std::get<TensorShape>(arg.shape);
862 } else {
863 TF_RETURN_IF_ERROR(
864 XLAShapeToTensorShape(std::get<xla::Shape>(arg.shape), &shape));
865 }
866 auto layout_preference =
867 options_.shape_determination_fns.layout_preference_fn(
868 shape, arg.type, arg.kind);
869 TF_ASSIGN_OR_RETURN(
870 *xla_shape,
871 options_.shape_determination_fns.shape_representation_fn(
872 shape, arg.type,
873 /*use_fast_memory=*/false, layout_preference));
874 TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
875 arg_sharding, /*use_fast_memory=*/false,
876 options_.shape_determination_fns, xla_shape));
877 } else {
878 if (absl::holds_alternative<xla::Shape>(arg.shape)) {
879 *xla_shape = std::get<xla::Shape>(arg.shape);
880 } else {
881 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
882 arg.type, std::get<TensorShape>(arg.shape), xla_shape));
883 }
884 }
885 return OkStatus();
886 }
887 case XlaCompiler::Argument::kTensorList: {
888 TF_RET_CHECK(absl::holds_alternative<xla::Shape>(arg.shape));
889 *xla_shape = std::get<xla::Shape>(arg.shape);
890 return OkStatus();
891 }
892 case XlaCompiler::Argument::kConstantResource:
893 case XlaCompiler::Argument::kResource: {
894 TF_RET_CHECK(arg.initialized);
895
896 switch (arg.resource_kind) {
897 case XlaResource::kVariable: {
898 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
899 auto layout_preference =
900 options_.shape_determination_fns.layout_preference_fn(
901 std::get<TensorShape>(arg.shape), arg.type, arg.kind);
902 TF_ASSIGN_OR_RETURN(
903 *xla_shape,
904 options_.shape_determination_fns.shape_representation_fn(
905 std::get<TensorShape>(arg.shape), arg.type,
906 /*use_fast_memory=*/arg.fast_mem, layout_preference));
907 TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
908 arg_sharding, arg.fast_mem, options_.shape_determination_fns,
909 xla_shape));
910 return OkStatus();
911 }
912 case XlaResource::kTensorArray: {
913 if (arg.max_array_size < 0) {
914 return errors::InvalidArgument(
915 "Negative max_array_size in XLAShapeForArgument");
916 }
917 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
918 TensorShape shape;
919 shape.AddDim(arg.max_array_size);
920 shape.AppendShape(std::get<TensorShape>(arg.shape));
921 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
922
923 if (!arg.tensor_array_gradients.empty()) {
924 std::vector<xla::Shape> tuple_shape(
925 arg.tensor_array_gradients.size() + 1, *xla_shape);
926 *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
927 }
928 return OkStatus();
929 }
930 case XlaResource::kStack: {
931 if (arg.max_array_size < 0) {
932 return errors::InvalidArgument(
933 "Negative max_array_size in XLAShapeForArgument");
934 }
935 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
936 TensorShape shape;
937 shape.AddDim(arg.max_array_size);
938 shape.AppendShape(std::get<TensorShape>(arg.shape));
939 xla::Shape buffer_shape;
940 TF_RETURN_IF_ERROR(
941 TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
942 *xla_shape = xla::ShapeUtil::MakeTupleShape(
943 {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
944 return OkStatus();
945 }
946
947 case XlaResource::kInvalid:
948 return errors::Internal(
949 "Invalid resource type in XLAShapeForArgument()");
950 }
951 }
952 case XlaCompiler::Argument::kToken: {
953 *xla_shape = xla::ShapeUtil::MakeTokenShape();
954 return OkStatus();
955 }
956 case XlaCompiler::Argument::kInvalid:
957 return errors::Internal("Invalid argument type in XLAShapeForArgument()");
958 }
959 }
960
961 /* static */
PopulateArgumentFromResource(const XlaResource & resource,Argument * arg)962 void XlaCompiler::PopulateArgumentFromResource(const XlaResource& resource,
963 Argument* arg) {
964 arg->initialized = resource.initialized();
965 arg->kind = XlaCompiler::Argument::kResource;
966 arg->resource_kind = resource.kind();
967
968 arg->type = resource.type();
969 arg->shape = resource.shape();
970 arg->max_array_size = resource.max_array_size();
971 for (const auto& gradient : resource.tensor_array_gradients()) {
972 arg->tensor_array_gradients.insert(gradient.first);
973 }
974 arg->name = resource.name();
975 }
976
977 // Builds XLA computations for each of the arguments to the computation.
978 // `args` are the arguments to the computation.
BuildArguments(const Graph & graph,const std::vector<XlaCompiler::Argument> & args,bool use_tuple_arg,xla::XlaBuilder * builder,XlaContext * context,const std::map<int,xla::OpSharding> & arg_shardings,std::vector<XlaExpression> * arg_expressions,std::vector<int> * input_to_args,std::vector<xla::Shape> * input_shapes,bool is_entry_computation)979 Status XlaCompiler::BuildArguments(
980 const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
981 bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
982 const std::map<int, xla::OpSharding>& arg_shardings,
983 std::vector<XlaExpression>* arg_expressions,
984 std::vector<int>* input_to_args, std::vector<xla::Shape>* input_shapes,
985 bool is_entry_computation) {
986 arg_expressions->resize(args.size());
987
988 // Argument numbers of arguments and resources that are to be passed to the
989 // XLA computation as runtime parameters. `input_to_args[a] = b` means that
990 // the a'th XLA input corresponds to the b'th original arg indexes.
991 input_to_args->clear();
992 input_to_args->reserve(args.size());
993
994 // Fills in constant arguments, and computes non-constant argument order.
995 for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
996 ++i) {
997 const XlaCompiler::Argument& arg = args[i];
998 XlaExpression& arg_expression = (*arg_expressions)[i];
999 switch (arg.kind) {
1000 case XlaCompiler::Argument::kConstantResource:
1001 case XlaCompiler::Argument::kResource: {
1002 TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
1003 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
1004 // TODO(phawkins): this code assumes that resource arguments do not
1005 // alias.
1006 XlaResource* resource =
1007 context->AddResource(std::make_unique<XlaResource>(
1008 arg.resource_kind, i, arg.name, arg.type,
1009 std::get<TensorShape>(arg.shape), xla::XlaOp(),
1010 /*max_array_size=*/arg.max_array_size,
1011 /*tensor_array_gradients=*/arg.tensor_array_gradients,
1012 /*tensor_array_multiple_writes_aggregate=*/true,
1013 arg.definition_stack_trace));
1014 arg_expression =
1015 arg.kind == XlaCompiler::Argument::kResource
1016 ? XlaExpression::Resource(resource)
1017 : XlaExpression::ConstantResource(arg.constant_value, resource);
1018 if (arg.initialized) {
1019 input_to_args->push_back(i);
1020 }
1021 break;
1022 }
1023 case XlaCompiler::Argument::kParameter:
1024 case XlaCompiler::Argument::kTensorList:
1025 case XlaCompiler::Argument::kToken: {
1026 input_to_args->push_back(i);
1027 break;
1028 }
1029 case XlaCompiler::Argument::kConstant:
1030 arg_expression = XlaExpression::Constant(arg.constant_value);
1031 break;
1032 case XlaCompiler::Argument::kInvalid:
1033 return errors::Internal(
1034 "Unreachable case in BuildArguments() while filling constant args");
1035 }
1036 }
1037
1038 if (input_to_args->empty() && !use_tuple_arg) {
1039 return OkStatus();
1040 }
1041
1042 // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds
1043 // to the d'th XLA input. Note that the value -1 corresponds to constants, or
1044 // other args that don't correspond to an input.
1045 std::vector<int> arg_to_inputs(args.size(), -1);
1046 for (int i = 0, end = input_to_args->size(); i < end; i++) {
1047 arg_to_inputs[input_to_args->at(i)] = i;
1048 }
1049
1050 std::vector<xla::Shape> arg_shapes(input_to_args->size());
1051 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1052 // Computes the shapes of non-constant arguments.
1053 auto arg_sharding = arg_shardings.find((*input_to_args)[i]);
1054 std::optional<xla::HloSharding> sharding;
1055 if (arg_sharding != arg_shardings.end()) {
1056 TF_ASSIGN_OR_RETURN(auto hlo_sharding,
1057 xla::HloSharding::FromProto(arg_sharding->second));
1058 sharding = hlo_sharding;
1059 }
1060 TF_RETURN_IF_ERROR(XLAShapeForArgument(args[(*input_to_args)[i]],
1061 is_entry_computation, sharding,
1062 &arg_shapes[i]));
1063 }
1064
1065 if (use_tuple_arg) {
1066 input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes));
1067 } else {
1068 *input_shapes = arg_shapes;
1069 }
1070
1071 // Attach a common operator name as metadata. This has no semantic effect — it
1072 // merely makes the HLO graph more readable when visualized via TensorBoard,
1073 // since TensorBoard forms groups out of operators with similar names.
1074 xla::OpMetadata arg_metadata;
1075 arg_metadata.set_op_name("XLA_Args");
1076 builder->SetOpMetadata(arg_metadata);
1077
1078 // Build parameter handles for non-constant arguments.
1079 std::vector<xla::XlaOp> arg_handles(input_to_args->size());
1080 if (use_tuple_arg) {
1081 xla::XlaOp tuple;
1082 if (is_entry_computation) {
1083 xla::OpSharding tuple_sharding;
1084 tuple_sharding.set_type(xla::OpSharding::TUPLE);
1085 for (int64_t parameter : *input_to_args) {
1086 auto it = arg_shardings.find(parameter);
1087 *tuple_sharding.add_tuple_shardings() =
1088 it == arg_shardings.end() ? xla::sharding_builder::AssignDevice(0)
1089 : it->second;
1090 }
1091 std::vector<bool> is_same_across_replicas;
1092 for (int i = 0, end = input_to_args->size(); i < end; ++i) {
1093 // Add an entry to is_same_across_replicas for every leaf buffer.
1094 is_same_across_replicas.insert(
1095 is_same_across_replicas.end(),
1096 xla::ShapeUtil::GetLeafCount(arg_shapes[i]),
1097 args[input_to_args->at(i)].is_same_data_across_replicas);
1098 }
1099 xla::XlaScopedShardingAssignment assign_tuple_sharding(
1100 builder, input_to_args->empty() ? std::optional<xla::OpSharding>()
1101 : tuple_sharding);
1102 tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple",
1103 is_same_across_replicas);
1104 } else {
1105 tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
1106 }
1107
1108 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1109 auto it = arg_shardings.find(i);
1110 xla::XlaScopedShardingAssignment assign_sharding(
1111 builder, it == arg_shardings.end() ? std::optional<xla::OpSharding>()
1112 : it->second);
1113 auto& arg = args[input_to_args->at(i)];
1114
1115 xla::OpMetadata arg_metadata;
1116 arg_metadata.set_op_name(arg.node_name);
1117 builder->SetOneShotOpMetadata(arg_metadata);
1118 arg_handles[i] = xla::GetTupleElement(tuple, i);
1119 }
1120 } else {
1121 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1122 auto it = arg_shardings.find(i);
1123 xla::XlaScopedShardingAssignment assign_sharding(
1124 builder, it == arg_shardings.end() ? std::optional<xla::OpSharding>()
1125 : it->second);
1126 if (is_entry_computation) {
1127 // Add an entry to is_same_across_replicas for every leaf buffer.
1128 std::vector<bool> is_same_across_replicas(
1129 xla::ShapeUtil::GetLeafCount((*input_shapes)[i]),
1130 args[input_to_args->at(i)].is_same_data_across_replicas);
1131 arg_handles[i] =
1132 xla::Parameter(builder, i, (*input_shapes)[i],
1133 absl::StrCat("arg", i), is_same_across_replicas);
1134 } else {
1135 arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i],
1136 absl::StrCat("arg", i));
1137 }
1138 }
1139 }
1140
1141 builder->ClearOpMetadata();
1142
1143 // Fill in the handles in non-constant arguments, and reshape parameters
1144 // back to their correct shapes.
1145 VLOG(2) << "XLA computation inputs:";
1146 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1147 const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
1148 VLOG(2) << " XLA arg " << i
1149 << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
1150 << " name: " << arg.name << " TF arg " << input_to_args->at(i)
1151 << " node name: " << arg.node_name
1152 << (arg_shardings.find(i) == arg_shardings.end()
1153 ? ""
1154 : absl::StrCat(" sharding: ",
1155 arg_shardings.at(i).DebugString()));
1156 XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
1157 switch (arg.kind) {
1158 case XlaCompiler::Argument::kConstantResource:
1159 case XlaCompiler::Argument::kResource: {
1160 TF_RET_CHECK(arg.initialized);
1161 XlaResource* resource = arg_expression.resource();
1162 TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
1163 arg_handles[i], builder));
1164 VLOG(2) << " resource: num_gradients: "
1165 << arg.tensor_array_gradients.size();
1166 break;
1167 }
1168 case XlaCompiler::Argument::kParameter:
1169 // Reshape parameters back to their correct shapes.
1170 // TODO(b/76097077): propagate device assignments onto arguments and
1171 // return values of functions, and then reshape unconditionally.
1172 if (is_entry_computation) {
1173 arg_expression = XlaExpression::XlaOp(
1174 xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type);
1175 } else {
1176 arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
1177 if (arg.value_bound) {
1178 TF_RET_CHECK(arg.value_dynamism);
1179 // Propagate upper bound and value dynamism to arg_expression.
1180 arg_expression.set_value_bound(arg.value_bound.value());
1181 arg_expression.set_value_dynamism(arg.value_dynamism.value());
1182 }
1183 }
1184 break;
1185 case XlaCompiler::Argument::kTensorList: {
1186 arg_expression = XlaExpression::TensorList(arg_handles[i]);
1187 break;
1188 }
1189 case XlaCompiler::Argument::kToken: {
1190 arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
1191 break;
1192 }
1193 case XlaCompiler::Argument::kConstant:
1194 case XlaCompiler::Argument::kInvalid:
1195 return errors::Internal(
1196 "Unreachable case in BuildArguments() while filling handles");
1197 }
1198 }
1199
1200 return OkStatus();
1201 }
1202
1203 namespace {
1204
1205 // Check that the ops of all non-functional nodes have been registered.
ValidateFunctionDef(const FunctionDef * fdef,const FunctionLibraryDefinition & flib_def)1206 Status ValidateFunctionDef(const FunctionDef* fdef,
1207 const FunctionLibraryDefinition& flib_def) {
1208 for (const NodeDef& node : fdef->node_def()) {
1209 const string& op = node.op();
1210 if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
1211 continue;
1212 }
1213 const OpDef* op_def;
1214 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def));
1215 }
1216 return OkStatus();
1217 }
1218
1219 // If node is PartitionedCall or StatefulPartitionedCall, returns the
1220 // name from the "f" attr, else returns node.def().op().
1221 // Returned pointer points to the internal string either in node's attributes
1222 // or in its NodeDef. This pointer is valid as long as the node has not been
1223 // modified.
GetPotentialFunctionName(const Node & node,const string ** name)1224 Status GetPotentialFunctionName(const Node& node, const string** name) {
1225 if (node.IsPartitionedCall()) {
1226 const AttrValue* attr_value;
1227 TF_RETURN_IF_ERROR(
1228 node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
1229 if (!attr_value->has_func()) {
1230 return errors::InvalidArgument(
1231 "The attribute value for attribute 'f' in node ", node.DebugString(),
1232 " does not have 'func' field set");
1233 }
1234 *name = &attr_value->func().name();
1235 return OkStatus();
1236 }
1237 *name = &node.type_string();
1238 return OkStatus();
1239 }
1240
1241 // Check that the graph doesn't have any invalid nodes (e.g. incompatible with
1242 // given device_type, invalid data type, missing attributes...)
ValidateGraph(const Graph * graph,const FunctionLibraryDefinition & flib_def,const DeviceType & device_type,const string & name)1243 Status ValidateGraph(const Graph* graph,
1244 const FunctionLibraryDefinition& flib_def,
1245 const DeviceType& device_type, const string& name) {
1246 // Make sure the XLA compilation kernels are registered. This operation is
1247 // idempotent so it is fine if someone called it already.
1248 XlaOpRegistry::RegisterCompilationKernels();
1249
1250 auto maybe_error = [&](const Node* node, const Status& s) -> Status {
1251 if (!s.ok()) {
1252 std::string errmsg = absl::StrCat(
1253 "Detected unsupported operations when trying to compile graph ", name,
1254 " on ", device_type.type_string(), ": ", node->def().op(), " (",
1255 s.error_message(), ")", FormatNodeForError(*node));
1256 if (absl::StrContains(device_type.type_string(), "TPU")) {
1257 absl::StrAppend(&errmsg,
1258 "\nOne approach is to outside compile the unsupported "
1259 "ops to run on CPUs by enabling soft placement "
1260 "`tf.config.set_soft_device_placement(True)`."
1261 " This has a potential performance penalty.\n");
1262 }
1263 if (std::shared_ptr<AbstractStackTrace> stack_trace =
1264 node->GetStackTrace()) {
1265 absl::StrAppend(
1266 &errmsg, "\nThe op is created at: \n",
1267 stack_trace->ToString({/*show_line_contents =*/true,
1268 /*filter_common_prefix =*/true,
1269 /*drop_internal_frames =*/true}));
1270 }
1271
1272 return errors::InvalidArgument(errmsg);
1273 }
1274 return OkStatus();
1275 };
1276
1277 for (const Node* node : graph->nodes()) {
1278 if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
1279 continue;
1280 }
1281 const string* function_name;
1282 TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name));
1283 const FunctionDef* fdef = flib_def.Find(*function_name);
1284 Status s;
1285 if (fdef) {
1286 s = ValidateFunctionDef(fdef, flib_def);
1287 TF_RETURN_IF_ERROR(maybe_error(node, s));
1288 continue;
1289 }
1290 const OpDef* op_def;
1291 s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def);
1292 TF_RETURN_IF_ERROR(maybe_error(node, s));
1293 TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
1294 s = FindKernelDef(device_type, node->def(), nullptr, nullptr);
1295 TF_RETURN_IF_ERROR(maybe_error(node, s));
1296 }
1297 return OkStatus();
1298 }
1299
ConvertConstantsToExpressions(xla::XlaBuilder * builder,absl::Span<XlaExpression> expressions)1300 void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
1301 absl::Span<XlaExpression> expressions) {
1302 for (XlaExpression& expression : expressions) {
1303 if (expression.kind() == XlaExpression::Kind::kConstant) {
1304 expression =
1305 XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype());
1306 }
1307 }
1308 }
1309
1310 } // namespace
1311
CompileGraph(const XlaCompiler::CompileOptions & options,string const & name,std::unique_ptr<Graph> graph,absl::Span<const XlaCompiler::Argument> args,CompilationResult * result)1312 Status XlaCompiler::CompileGraph(
1313 const XlaCompiler::CompileOptions& options, string const& name,
1314 std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
1315 CompilationResult* result) {
1316 VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
1317
1318 TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
1319 graph.get(), options_.flib_def, local_flib_def_.get()));
1320 TF_RETURN_IF_ERROR(RearrangeFunctionArguments(
1321 [this](const NameAttrList& function, const FunctionBody** fbody) {
1322 return FindFunctionBody(function, fbody);
1323 },
1324 graph.get(), local_flib_def_.get(),
1325 pflr_->GetFunctionLibraryDefinition()));
1326
1327 if (VLOG_IS_ON(2)) {
1328 VLOG(2) << "XlaCompiler::CompileGraph: "
1329 << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph,
1330 flib_runtime_->GetFunctionLibraryDefinition());
1331 }
1332
1333 // Report the error here if initialization failed.
1334 TF_RETURN_IF_ERROR(initialization_status_);
1335
1336 // Detect invalid nodes.
1337 // FunctionalizeControlFlow may remove some nodes from the graph.
1338 TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
1339 options_.device_type, name));
1340 xla::XlaBuilder builder(name);
1341 XlaContext* context = new XlaContext(this, &builder, graph.get());
1342 core::ScopedUnref context_unref(context);
1343
1344 std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
1345 int token_input_index = -1;
1346 std::unique_ptr<xla::XlaOp> token_output;
1347 if (options.add_token_input_output) {
1348 // Add extra token input.
1349 token_input_index = real_args.size();
1350
1351 XlaCompiler::Argument token_arg;
1352 token_arg.kind = XlaCompiler::Argument::kToken;
1353 real_args.push_back(token_arg);
1354 }
1355
1356 std::map<int, xla::OpSharding> arg_shardings;
1357 std::map<int, xla::OpSharding> retval_shardings;
1358 TF_ASSIGN_OR_RETURN(std::tie(arg_shardings, retval_shardings),
1359 ComputeArgAndRetvalShardings(*graph));
1360
1361 std::vector<XlaExpression> arg_expressions;
1362 TF_RETURN_IF_ERROR(BuildArguments(
1363 *graph, real_args, options.use_tuple_arg, &builder, context,
1364 arg_shardings, &arg_expressions, &result->input_mapping,
1365 &result->xla_input_shapes, options.is_entry_computation));
1366 context->set_args(std::move(arg_expressions));
1367
1368 PushNodeTokenMapping();
1369 // Use std::set instead of std::unordered_set to ensure determinism.
1370 std::set<std::string> output_node_token_inputs;
1371 if (token_input_index != -1) {
1372 // Original token comes from input.
1373 auto arg_expression = context->args()[token_input_index];
1374 TF_RETURN_IF_ERROR(
1375 SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
1376
1377 // Calculate token inputs for output token.
1378 output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
1379
1380 // If there's no side-effecting op in the graph, use token input as token
1381 // output.
1382 if (output_node_token_inputs.empty()) {
1383 output_node_token_inputs.insert(kXlaTokenArgNodeName);
1384 }
1385 } else if (options.is_entry_computation) {
1386 // Original token is manually created.
1387 if (HasSideEffectingNodes(*graph)) {
1388 TF_RETURN_IF_ERROR(
1389 SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
1390 }
1391 }
1392
1393 Status execute_status = ExecuteGraph(context, std::move(graph), device_,
1394 flib_runtime_, NextStepId());
1395 if (!execute_status.ok()) {
1396 VLOG(1) << "Failed executing graph " << name;
1397 return execute_status;
1398 }
1399 if (token_input_index != -1) {
1400 // Add extra token output.
1401 std::vector<xla::XlaOp> token_inputs;
1402 for (const auto& node_name : output_node_token_inputs) {
1403 auto token_or = GetNodeToken(node_name);
1404 TF_RETURN_IF_ERROR(token_or.status());
1405 token_inputs.push_back(token_or.ValueOrDie());
1406 }
1407 token_output.reset(new xla::XlaOp(xla::AfterAll(&builder, token_inputs)));
1408 }
1409 TF_RETURN_IF_ERROR(PopNodeTokenMapping());
1410
1411 int num_nonconst_outputs;
1412 int num_computation_outputs;
1413 result->computation = std::make_shared<xla::XlaComputation>();
1414 result->outputs.resize(context->retvals().size());
1415 std::vector<XlaExpression> retvals = context->retvals();
1416 ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
1417 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns{
1418 UseNoPreferenceLayoutFn(), IdentityShapeRepresentationFn()};
1419 TF_RETURN_IF_ERROR(BuildComputation(
1420 real_args, retvals, arg_shardings, retval_shardings, context->resources(),
1421 std::move(token_output),
1422 options.is_entry_computation ? options_.shape_determination_fns
1423 : shape_determination_fns,
1424 options.is_entry_computation,
1425 options.return_updated_values_for_all_resources,
1426 options.always_return_tuple, options.use_tuple_arg,
1427 options.alias_resource_update, &builder, result->computation.get(),
1428 &num_computation_outputs, &num_nonconst_outputs, &result->outputs,
1429 &result->resource_updates, &result->xla_output_shape,
1430 result->input_mapping));
1431
1432 VLOG(2) << "Outputs: total: " << context->retvals().size()
1433 << " nonconstant: " << num_nonconst_outputs;
1434 VLOG(2) << "XLA output shape: "
1435 << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape);
1436 result->collective_info = context->GetCollectiveInfo();
1437 return OkStatus();
1438 }
1439
GetChannelHandle(const string & key,xla::ChannelHandle * channel)1440 Status XlaCompiler::GetChannelHandle(const string& key,
1441 xla::ChannelHandle* channel) {
1442 auto result = channels_.emplace(key, xla::ChannelHandle());
1443 if (result.second) {
1444 TF_ASSIGN_OR_RETURN(result.first->second, client()->CreateChannelHandle());
1445 }
1446 *channel = result.first->second;
1447 VLOG(1) << "Channel: " << key << " " << channel->DebugString();
1448 return OkStatus();
1449 }
1450
GetHostToDeviceChannelHandle(const string & key,xla::ChannelHandle * channel)1451 Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key,
1452 xla::ChannelHandle* channel) {
1453 auto result = channels_.emplace(key, xla::ChannelHandle());
1454 if (result.second) {
1455 TF_ASSIGN_OR_RETURN(result.first->second,
1456 client()->CreateHostToDeviceChannelHandle());
1457 }
1458 *channel = result.first->second;
1459 VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString();
1460 return OkStatus();
1461 }
1462
GetDeviceToHostChannelHandle(const string & key,xla::ChannelHandle * channel)1463 Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
1464 xla::ChannelHandle* channel) {
1465 auto result = channels_.emplace(key, xla::ChannelHandle());
1466 if (result.second) {
1467 TF_ASSIGN_OR_RETURN(result.first->second,
1468 client()->CreateDeviceToHostChannelHandle());
1469 }
1470 *channel = result.first->second;
1471 VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString();
1472 return OkStatus();
1473 }
1474
1475 namespace {
1476
SetTransfer(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes,tf2xla::HostTransferMetadata * transfer)1477 void SetTransfer(const string& key, absl::Span<const DataType> types,
1478 absl::Span<const TensorShape> shapes,
1479 tf2xla::HostTransferMetadata* transfer) {
1480 transfer->set_key(key);
1481 CHECK(types.size() == shapes.size());
1482 for (int i = 0, end = types.size(); i < end; ++i) {
1483 tf2xla::TensorMetadata* metadata = transfer->add_metadata();
1484 metadata->set_type(types[i]);
1485 shapes[i].AsProto(metadata->mutable_shape());
1486 }
1487 }
1488
1489 } // namespace
1490
SetDeviceToHostMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1491 Status XlaCompiler::SetDeviceToHostMetadata(
1492 const string& key, absl::Span<const DataType> types,
1493 absl::Span<const TensorShape> shapes) {
1494 if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
1495 tf2xla::HostTransferMetadata& existing_transfer = host_compute_sends_[key];
1496 tf2xla::HostTransferMetadata new_transfer;
1497 SetTransfer(key, types, shapes, &new_transfer);
1498 if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) {
1499 return OkStatus();
1500 } else {
1501 return errors::InvalidArgument(
1502 "Duplicate calls to SetDeviceToHostMetadata with key ", key);
1503 }
1504 }
1505 tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
1506 SetTransfer(key, types, shapes, &transfer);
1507 return OkStatus();
1508 }
1509
GetDeviceToHostShapes(const string & key,std::vector<TensorShape> * shapes) const1510 Status XlaCompiler::GetDeviceToHostShapes(
1511 const string& key, std::vector<TensorShape>* shapes) const {
1512 const auto iter = host_compute_sends_.find(key);
1513 if (iter == host_compute_sends_.end()) {
1514 return errors::InvalidArgument(
1515 "No host compute send shapes registered for key ", key);
1516 }
1517 shapes->clear();
1518 for (int i = 0; i < iter->second.metadata_size(); ++i) {
1519 TensorShape shape(iter->second.metadata(i).shape());
1520 shapes->push_back(shape);
1521 }
1522 return OkStatus();
1523 }
1524
SetHostToDeviceMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1525 Status XlaCompiler::SetHostToDeviceMetadata(
1526 const string& key, absl::Span<const DataType> types,
1527 absl::Span<const TensorShape> shapes) {
1528 if (host_compute_recvs_.find(key) != host_compute_recvs_.end()) {
1529 tf2xla::HostTransferMetadata& existing_transfer = host_compute_recvs_[key];
1530 tf2xla::HostTransferMetadata new_transfer;
1531 SetTransfer(key, types, shapes, &new_transfer);
1532 if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) {
1533 return OkStatus();
1534 } else {
1535 return errors::InvalidArgument(
1536 "Duplicate calls to SetHostToDeviceMetadata with key ", key);
1537 }
1538 }
1539 tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
1540 SetTransfer(key, types, shapes, &transfer);
1541 return OkStatus();
1542 }
1543
GetHostComputeControlDependency(const string & host_compute_name,xla::XlaOp * handle)1544 Status XlaCompiler::GetHostComputeControlDependency(
1545 const string& host_compute_name, xla::XlaOp* handle) {
1546 const auto iter = host_compute_control_output_.find(host_compute_name);
1547 if (iter == host_compute_control_output_.end()) {
1548 return errors::InvalidArgument(
1549 "No registered control handle for host compute Op '", host_compute_name,
1550 "'");
1551 } else {
1552 *handle = iter->second;
1553 }
1554 return OkStatus();
1555 }
1556
SetHostComputeControlDependency(const string & host_compute_name,const xla::XlaOp & handle)1557 Status XlaCompiler::SetHostComputeControlDependency(
1558 const string& host_compute_name, const xla::XlaOp& handle) {
1559 if (host_compute_control_output_.find(host_compute_name) !=
1560 host_compute_control_output_.end()) {
1561 return errors::InvalidArgument(
1562 "Duplicate control handles registered for host compute Op ",
1563 host_compute_name);
1564 }
1565 host_compute_control_output_[host_compute_name] = handle;
1566 return OkStatus();
1567 }
1568
PushNodeTokenMapping()1569 void XlaCompiler::PushNodeTokenMapping() {
1570 node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
1571 }
1572
PopNodeTokenMapping()1573 Status XlaCompiler::PopNodeTokenMapping() {
1574 if (node_token_mapping_stack_.empty()) {
1575 return errors::FailedPrecondition(
1576 "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
1577 "empty.");
1578 }
1579 node_token_mapping_stack_.pop();
1580 return OkStatus();
1581 }
1582
SetNodeToken(const string & node_name,const xla::XlaOp & op)1583 Status XlaCompiler::SetNodeToken(const string& node_name,
1584 const xla::XlaOp& op) {
1585 if (node_token_mapping_stack_.empty()) {
1586 return errors::FailedPrecondition(
1587 "Calling SetNodeToken() when node_token_mapping_stack_ is "
1588 "empty.");
1589 }
1590 auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
1591 if (!insert_result.second) {
1592 return errors::FailedPrecondition("Token mapping already exists for node ",
1593 node_name);
1594 }
1595 return OkStatus();
1596 }
1597
GetNodeToken(const string & node_name)1598 StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
1599 if (node_token_mapping_stack_.empty()) {
1600 return errors::FailedPrecondition(
1601 "Calling GetNodeToken() when node_token_mapping_stack_ is "
1602 "empty.");
1603 }
1604 auto iter = node_token_mapping_stack_.top().find(node_name);
1605 if (iter == node_token_mapping_stack_.top().end()) {
1606 return errors::FailedPrecondition("Cannot find token mapping for node ",
1607 node_name);
1608 }
1609 return iter->second;
1610 }
1611
1612 } // namespace tensorflow
1613