xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/xla_compiler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
17 
18 #include <numeric>
19 #include <vector>
20 
21 #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/types/variant.h"
25 #include "tensorflow/compiler/jit/defs.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/jit/shape_inference.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
29 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
30 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
31 #include "tensorflow/compiler/tf2xla/layout_util.h"
32 #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
33 #include "tensorflow/compiler/tf2xla/shape_util.h"
34 #include "tensorflow/compiler/tf2xla/sharding_util.h"
35 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
36 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
37 #include "tensorflow/compiler/tf2xla/type_util.h"
38 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
39 #include "tensorflow/compiler/tf2xla/xla_context.h"
40 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
41 #include "tensorflow/compiler/xla/client/client_library.h"
42 #include "tensorflow/compiler/xla/client/xla_builder.h"
43 #include "tensorflow/compiler/xla/client/xla_computation.h"
44 #include "tensorflow/compiler/xla/protobuf_util.h"
45 #include "tensorflow/compiler/xla/shape_util.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/core/common_runtime/device.h"
48 #include "tensorflow/core/common_runtime/executor.h"
49 #include "tensorflow/core/common_runtime/function.h"
50 #include "tensorflow/core/common_runtime/graph_constructor.h"
51 #include "tensorflow/core/common_runtime/graph_optimizer.h"
52 #include "tensorflow/core/framework/attr_value_util.h"
53 #include "tensorflow/core/framework/function.h"
54 #include "tensorflow/core/framework/node_def_util.h"
55 #include "tensorflow/core/framework/types.h"
56 #include "tensorflow/core/graph/node_builder.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/gtl/cleanup.h"
59 #include "tensorflow/core/lib/hash/hash.h"
60 #include "tensorflow/core/platform/logging.h"
61 #include "tensorflow/core/protobuf/error_codes.pb.h"
62 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
63 #include "tensorflow/core/util/dump_graph.h"
64 
65 namespace tensorflow {
66 namespace {
67 
68 // Checks that arguments `args` match types `types`.
CheckSignature(const DataTypeVector & types,absl::Span<const XlaCompiler::Argument> args)69 Status CheckSignature(const DataTypeVector& types,
70                       absl::Span<const XlaCompiler::Argument> args) {
71   if (args.size() != types.size()) {
72     return errors::Internal("Compilation arguments have ", args.size(),
73                             " elements while function has ", types.size());
74   }
75   for (int i = 0, end = types.size(); i < end; ++i) {
76     // Don't perform type checks on resource variables and tensor
77     // lists (DT_VARIANT) as we have to trick the type system in order to
78     // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor.
79     if (types[i] != args[i].type && types[i] != DT_RESOURCE &&
80         types[i] != DT_VARIANT) {
81       return errors::Internal(
82           "Argument ", i, " has declared type ", DataTypeString(args[i].type),
83           " but function parameter has type ", DataTypeString(types[i]));
84     }
85   }
86   return OkStatus();
87 }
88 
89 // Uses the _Arg and _Retval nodes in the graph to determine an OpSharding for
90 // each argument and return value.
91 StatusOr<
92     std::pair<std::map<int, xla::OpSharding>, std::map<int, xla::OpSharding>>>
ComputeArgAndRetvalShardings(const Graph & graph)93 ComputeArgAndRetvalShardings(const Graph& graph) {
94   auto get_sharding_for_node =
95       [](const Node* n) -> StatusOr<std::optional<xla::OpSharding>> {
96     TF_ASSIGN_OR_RETURN(
97         auto sharding,
98         ParseShardingFromDevice(*n, std::numeric_limits<int32>::max(),
99                                 /*add_metadata=*/false));
100     return sharding;
101   };
102   std::map<int, xla::OpSharding> arg_shardings;
103   std::map<int, xla::OpSharding> retval_shardings;
104   for (const Node* n : graph.nodes()) {
105     if (n->IsArg()) {
106       TF_ASSIGN_OR_RETURN(auto sharding, get_sharding_for_node(n));
107       if (!sharding.has_value()) continue;
108       int index;
109       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
110       TF_RET_CHECK(index >= 0) << "Negative _Arg index";
111       arg_shardings[index] = std::move(*sharding);
112     } else if (n->IsRetval()) {
113       TF_ASSIGN_OR_RETURN(auto sharding, get_sharding_for_node(n));
114       if (!sharding.has_value()) continue;
115       int index;
116       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
117       TF_RET_CHECK(index >= 0) << "Negative _Retval index";
118       retval_shardings[index] = std::move(*sharding);
119     }
120   }
121   return std::make_pair(std::move(arg_shardings), std::move(retval_shardings));
122 }
123 
ExecuteGraph(XlaContext * xla_context,std::unique_ptr<Graph> graph,XlaCompilationDevice * device,FunctionLibraryRuntime * flib,int64_t step_id)124 Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
125                     XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
126                     int64_t step_id) {
127   // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
128   // resource manager takes ownership via Create, and unrefs via Cleanup.  We
129   // explicitly add a reference to ensure the refcount at entry is maintained at
130   // all exit points; Create and Cleanup are always called in this function.
131   //
132   // The Executor requires us to use ScopedStepContainer. We wrap it in a
133   // unique_ptr so we can capture the cleanup status in the end.
134   xla_context->Ref();
135   Status status;
136   auto step_container = std::make_unique<ScopedStepContainer>(
137       step_id, [&status, device](const string& name) {
138         status = device->resource_manager()->Cleanup(name);
139       });
140   TF_RETURN_IF_ERROR(step_container->Create(device->resource_manager(),
141                                             XlaContext::kXlaContextResourceName,
142                                             xla_context));
143 
144   GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
145   TF_RETURN_IF_ERROR(graph_compiler.Compile());
146   // Explicitly clean up the step container, to capture the cleanup status.
147   step_container.reset();
148   return status;
149 }
150 
151 // Builds the XLA computation.
152 // - `args` is the list of input arguments
153 // - `retvals` is the list of retvals produced by _Retval operators, in index
154 //   order.
155 // - `arg_shardings` and `retval_shardings` are mapping from arg/return indices
156 //   to sharding.
157 // - If `return_updated_values_for_all_resources` is true, all resources will be
158 //   included in `resource_updates`, regardless of whether their value changed.
159 // - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
160 // - Sets `*resource_updates` to a description of resources whose values are
161 //   written by the computation; the variable writes are the last
162 // - `resource_updates.size()` return values from the computation. Each entry in
163 //   `resource_updates` is a ResourceUpdate, whose `index` is the index of a
164 //   resource variable argument to the computation to be updated, and `type` is
165 //   the type of the final output.
BuildComputation(const std::vector<XlaCompiler::Argument> & args,const std::vector<XlaExpression> & retvals,const std::map<int,xla::OpSharding> & arg_shardings,const std::map<int,xla::OpSharding> & retval_shardings,const std::vector<std::unique_ptr<XlaResource>> & resources,std::unique_ptr<xla::XlaOp> token_output,const XlaShapeLayoutHelpers::ShapeDeterminationFns & shape_determination_fns,bool is_entry_computation,bool return_updated_values_for_all_resources,bool always_return_tuple,bool use_tuple_arg,bool alias_resource_update,xla::XlaBuilder * builder,xla::XlaComputation * computation,int * num_computation_outputs,int * num_nonconst_outputs,std::vector<XlaCompiler::OutputDescription> * outputs,std::vector<XlaCompiler::ResourceUpdate> * resource_updates,xla::Shape * output_shape,absl::Span<int const> input_mapping)166 Status BuildComputation(
167     const std::vector<XlaCompiler::Argument>& args,
168     const std::vector<XlaExpression>& retvals,
169     const std::map<int, xla::OpSharding>& arg_shardings,
170     const std::map<int, xla::OpSharding>& retval_shardings,
171     const std::vector<std::unique_ptr<XlaResource>>& resources,
172     std::unique_ptr<xla::XlaOp> token_output,
173     const XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns,
174     bool is_entry_computation, bool return_updated_values_for_all_resources,
175     bool always_return_tuple, bool use_tuple_arg, bool alias_resource_update,
176     xla::XlaBuilder* builder, xla::XlaComputation* computation,
177     int* num_computation_outputs, int* num_nonconst_outputs,
178     std::vector<XlaCompiler::OutputDescription>* outputs,
179     std::vector<XlaCompiler::ResourceUpdate>* resource_updates,
180     xla::Shape* output_shape, absl::Span<int const> input_mapping) {
181   // Attach a common operator name as metadata. This has no semantic effect — it
182   // merely makes the HLO graph more readable when visualized via TensorBoard,
183   // since TensorBoard forms groups out of operators with similar names.
184   xla::OpMetadata retval_metadata;
185   retval_metadata.set_op_name("XLA_Retvals");
186   builder->SetOpMetadata(retval_metadata);
187   VLOG(1) << "Building new computation";
188   auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
189 
190   // Builds a no-op XLA computation. We need to set the sharding of outputs, but
191   // cannot change the sharding of the existing output op. To do this, we build
192   // a new identity op to which shardings can be applied.
193   auto identity_op = [builder](xla::XlaOp op,
194                                const std::optional<xla::OpSharding>& sharding) {
195     xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
196     return xla::Copy(op);
197   };
198 
199   std::vector<xla::XlaOp> elems;
200   elems.reserve(retvals.size());
201 
202   // Keeps track of sharding of each retval. If a retval is not in this list,
203   // replicate sharding is used. The first element is the output index, second
204   // element is the sharding.
205   std::unordered_map<int, xla::OpSharding> retval_index_and_sharding;
206   for (int i = 0, end = retvals.size(); i < end; ++i) {
207     XlaCompiler::OutputDescription& output = (*outputs)[i];
208     const XlaExpression& retval = retvals[i];
209     output.type = retval.dtype();
210     switch (retval.kind()) {
211       case XlaExpression::Kind::kConstant:
212         output.is_constant = true;
213         output.constant_value = *retval.constant_value();
214         output.shape = output.constant_value.shape();
215         break;
216 
217       case XlaExpression::Kind::kTensorList: {
218         output.is_tensor_list = true;
219         xla::XlaOp value = retval.handle();
220         elems.push_back(value);
221         break;
222       }
223 
224       case XlaExpression::Kind::kXlaOp: {
225         output.is_constant = false;
226         TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
227         xla::XlaOp value = retval.handle();
228         auto it = retval_shardings.find(i);
229         std::optional<xla::OpSharding> sharding =
230             it == retval_shardings.end() ? std::optional<xla::OpSharding>()
231                                          : it->second;
232         if (it != retval_shardings.end()) {
233           retval_index_and_sharding[elems.size()] = it->second;
234         }
235         if (shape_determination_fns.shape_representation_fn) {
236           TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(value));
237           TF_ASSIGN_OR_RETURN(value,
238                               ReshapeWithCorrectRepresentationAndSharding(
239                                   builder, value, original_shape,
240                                   shape_determination_fns, sharding,
241                                   /*fast_mem=*/false));
242         }
243         if (it != retval_shardings.end()) {
244           // Apply the sharding to the output, if there is a core assignment.
245           value = identity_op(value, sharding);
246         }
247 
248         elems.push_back(value);
249         break;
250       }
251 
252       case XlaExpression::Kind::kResource:
253         // Resources will be pushed into elems later when processing resource
254         // arguments below.
255         output.is_constant = false;
256         output.input_index = retval.resource()->arg_num();
257         output.shape = retval.resource()->shape();
258         break;
259 
260       case XlaExpression::Kind::kInvalid:
261         return errors::InvalidArgument(
262             "Invalid expression returned by computation. "
263             "This probably means a return value was not set.");
264     }
265   }
266   *num_nonconst_outputs = elems.size();
267 
268   // Add return values for resources whose values have changed.
269   std::vector<const XlaResource*> arg_resources;
270   arg_resources.reserve(resources.size());
271   for (const auto& resource : resources) {
272     if (resource->arg_num() >= 0) {
273       arg_resources.push_back(resource.get());
274     }
275   }
276   std::sort(arg_resources.begin(), arg_resources.end(),
277             [](const XlaResource* a, const XlaResource* b) {
278               return a->arg_num() < b->arg_num();
279             });
280 
281   absl::flat_hash_map<int, int> argument_to_xla_arg;
282   for (int xla_arg = 0; xla_arg < input_mapping.size(); xla_arg++) {
283     argument_to_xla_arg[input_mapping[xla_arg]] = xla_arg;
284   }
285 
286   std::vector<xla::XlaBuilder::InputOutputAlias> aliases;
287   for (const XlaResource* resource : arg_resources) {
288     DCHECK_LT(resource->arg_num(), args.size());
289     const XlaCompiler::Argument& arg = args[resource->arg_num()];
290     auto it = arg_shardings.find(resource->arg_num());
291     bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
292     // TensorArray gradients were modified if their values changed or there are
293     // any newly created gradients.
294     for (const auto& grad : resource->tensor_array_gradients()) {
295       modified =
296           modified ||
297           !grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
298           arg.tensor_array_gradients.count(grad.first) == 0;
299     }
300 
301     if (return_updated_values_for_all_resources || modified ||
302         arg.requires_broadcast) {
303       resource_updates->emplace_back();
304       XlaCompiler::ResourceUpdate& update = resource_updates->back();
305       update.input_index = resource->arg_num();
306       update.type = resource->type();
307       update.shape = resource->shape();
308       update.modified = modified;
309       int param_num = use_tuple_arg ? 0 : update.input_index;
310       if (is_entry_computation &&
311           arg.resource_kind != XlaResource::kTensorArray &&
312           alias_resource_update && argument_to_xla_arg.count(param_num)) {
313         // Assuming tuple arg and results are used.
314         xla::ShapeIndex param_index =
315             use_tuple_arg ? xla::ShapeIndex({update.input_index})
316                           : xla::ShapeIndex{};
317         int xla_param_num = argument_to_xla_arg[param_num];
318         int64_t output_index_num = elems.size();
319         xla::ShapeIndex output_index = xla::ShapeIndex({output_index_num});
320         VLOG(3) << "Storing alias: " << output_index.ToString() << ": ("
321                 << xla_param_num << ", " << param_index.ToString() << ")";
322         aliases.push_back({output_index, xla_param_num, param_index});
323       }
324       for (const auto& grad : resource->tensor_array_gradients()) {
325         update.tensor_array_gradients_accessed.insert(grad.first);
326       }
327 
328       xla::XlaOp handle;
329       TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
330       auto sharding = it == arg_shardings.end()
331                           ? std::optional<xla::OpSharding>()
332                           : it->second;
333       // Set layout of the retval to device representation layout.
334       if (shape_determination_fns.layout_preference_fn &&
335           shape_determination_fns.shape_representation_fn) {
336         TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(handle));
337         TF_ASSIGN_OR_RETURN(
338             handle, ReshapeWithCorrectRepresentationAndSharding(
339                         builder, handle, original_shape,
340                         shape_determination_fns, sharding, arg.fast_mem));
341       }
342 
343       // Request that the value be returned on a specific core.
344       if (it != arg_shardings.end()) {
345         retval_index_and_sharding[elems.size()] = it->second;
346       }
347       // Ensures the correct sharding is applied to the output.
348       handle = identity_op(handle, sharding);
349       elems.push_back(handle);
350     }
351   }
352 
353   // If we have token output, append it as the last one.
354   if (token_output) {
355     elems.push_back(*token_output);
356   }
357 
358   *num_computation_outputs = elems.size();
359 
360   // Builds the XLA computation. We *always* form a tuple here to ensure that
361   // the output value is the last thing added into the XLA computation, even
362   // if there is only one output value.
363   xla::XlaOp tuple;
364   if (retval_index_and_sharding.empty() || !is_entry_computation) {
365     tuple = xla::Tuple(builder, elems);
366   } else {
367     std::vector<xla::Shape> elem_shapes;
368     for (const auto& elem : elems) {
369       TF_ASSIGN_OR_RETURN(xla::Shape elem_shape,
370                           elem.builder()->GetShape(elem));
371       elem_shapes.push_back(elem_shape);
372     }
373     xla::Shape shape = xla::ShapeUtil::MakeTupleShape(elem_shapes);
374     // Copy specified sharding from retval_index_and_sharding.
375     std::vector<xla::HloSharding> sharding_elems;
376     for (int i = 0, end = elems.size(); i < end; i++) {
377       const auto& iter = retval_index_and_sharding.find(i);
378       TF_RET_CHECK(iter != retval_index_and_sharding.end());
379       const xla::OpSharding& sub_op_sharding = iter->second;
380       TF_ASSIGN_OR_RETURN(xla::HloSharding sub_sharding,
381                           xla::HloSharding::FromProto(sub_op_sharding));
382       if (elem_shapes[i].IsTuple()) {
383         const std::vector<xla::HloSharding> sub_sharding_elems =
384             sub_sharding.tuple_elements();
385         const int64_t sub_sharding_elems_size = sub_sharding_elems.size();
386         TF_RET_CHECK(sub_sharding_elems_size ==
387                      xla::ShapeUtil::GetLeafCount(elem_shapes[i]));
388         for (const auto& sub_sharding_elem : sub_sharding_elems) {
389           sharding_elems.push_back(sub_sharding_elem);
390         }
391       } else {
392         sharding_elems.push_back(sub_sharding);
393       }
394     }
395     xla::HloSharding modified_sharding =
396         xla::HloSharding::Tuple(shape, sharding_elems);
397     xla::OpSharding op_sharding = modified_sharding.ToProto();
398     // Assign proper sharding to the tuple instruction.
399     xla::XlaScopedShardingAssignment assign_sharding(builder, op_sharding);
400     tuple = xla::Tuple(builder, elems);
401   }
402   bool returns_tuple = always_return_tuple || elems.size() != 1;
403   VLOG(3) << "Computation returns a tuple=" << returns_tuple;
404   if (!returns_tuple) {
405     xla::GetTupleElement(tuple, 0);
406 
407     for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
408       if (alias.output_index == xla::ShapeIndex({0})) {
409         VLOG(3) << "For aliased parameter " << alias.param_number << ": "
410                 << alias.param_index.ToString()
411                 << " normalizing output_index from {0} to {}, as a scalar is "
412                    "returned from the cluster";
413         alias.output_index = xla::ShapeIndex({});
414       }
415     }
416   }
417 
418   for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
419     builder->SetUpAlias(alias.output_index, alias.param_number,
420                         alias.param_index);
421   }
422   TF_ASSIGN_OR_RETURN(*computation, builder->Build());
423 
424   TF_ASSIGN_OR_RETURN(auto program_shape, computation->GetProgramShape());
425   *output_shape = program_shape.result();
426   return OkStatus();
427 }
428 
429 }  // namespace
430 
431 
HumanString() const432 string XlaCompiler::Argument::HumanString() const {
433   string common;
434   if (!name.empty()) {
435     common = absl::StrCat(" name=", name);
436   }
437   absl::StrAppend(&common, " type=", DataTypeString(type),
438                   " shape=", ShapeHumanString());
439   absl::StrAppend(
440       &common, " is_same_data_across_replicas=", is_same_data_across_replicas);
441   switch (kind) {
442     case kInvalid:
443       return "invalid";
444     case kConstant:
445       return absl::StrCat("kind=constant", common,
446                           " value=", constant_value.DebugString());
447     case kConstantResource:
448       return absl::StrCat("kind=constant-resource", common,
449                           " value=", constant_value.DebugString());
450     case kResource: {
451       string output = absl::StrCat(
452           "kind=resource", common,
453           " resource_kind=", XlaResource::KindToString(resource_kind),
454           " initialized=", initialized, " is_fast_mem=", fast_mem);
455       if (max_array_size >= 0) {
456         absl::StrAppend(&output, " max_array_size=", max_array_size);
457       }
458       if (!tensor_array_gradients.empty()) {
459         absl::StrAppend(&output, " tensor_array_gradients=",
460                         absl::StrJoin(tensor_array_gradients, ","));
461       }
462       return output;
463     }
464     case kParameter:
465       return absl::StrCat("kind=parameter", common);
466     case kTensorList:
467       return absl::StrCat("kind=tensorlist", common);
468     case kToken:
469       return absl::StrCat("token", common);
470   }
471 }
472 
DimensionSizes() const473 std::vector<int64_t> XlaCompiler::Argument::DimensionSizes() const {
474   if (absl::holds_alternative<TensorShape>(shape)) {
475     return xla::InlinedVectorToVector(std::get<TensorShape>(shape).dim_sizes());
476   } else {
477     return xla::SpanToVector(std::get<xla::Shape>(shape).dimensions());
478   }
479 }
480 
481 absl::InlinedVector<int64_t, 4>
DimensionSizesAsInlinedVector() const482 XlaCompiler::Argument::DimensionSizesAsInlinedVector() const {
483   if (absl::holds_alternative<TensorShape>(shape)) {
484     return std::get<TensorShape>(shape).dim_sizes();
485   } else {
486     auto v = std::get<xla::Shape>(shape).dimensions();
487     return absl::InlinedVector<int64_t, 4>(v.begin(), v.end());
488   }
489 }
490 
ShapeHumanString() const491 string XlaCompiler::Argument::ShapeHumanString() const {
492   if (absl::holds_alternative<TensorShape>(shape)) {
493     return std::get<TensorShape>(shape).DebugString();
494   } else {
495     return std::get<xla::Shape>(shape).DebugString();
496   }
497 }
498 
XlaCompiler(XlaCompiler::Options options)499 XlaCompiler::XlaCompiler(XlaCompiler::Options options)
500     : options_(options),
501       initialization_status_(OkStatus()),
502       next_step_id_(1),
503       device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
504       device_mgr_(absl::WrapUnique(device_)) {
505   CHECK(!options_.device_type.type_string().empty());
506   if (options_.populate_resource_manager) {
507     initialization_status_ =
508         (*options_.populate_resource_manager)(device_->resource_manager());
509   }
510 
511   local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
512                                                       FunctionDefLibrary{}));
513   local_pflr_.reset(new ProcessFunctionLibraryRuntime(
514       &device_mgr_, Env::Default(), /*config=*/nullptr,
515       options.graph_def_version, local_flib_def_.get(), OptimizerOptions()));
516   pflr_.reset(new ProcessFunctionLibraryRuntime(
517       &device_mgr_, Env::Default(), /*config=*/nullptr,
518       options.graph_def_version, options.flib_def, OptimizerOptions()));
519 
520   local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
521   flib_runtime_ = pflr_->GetFLR(device_->name());
522 
523   // The default layout preference is no preference and the default shape
524   // representation function is the identity.
525   XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns =
526       options_.shape_determination_fns;
527   if (!shape_determination_fns.shape_representation_fn) {
528     shape_determination_fns.shape_representation_fn =
529         IdentityShapeRepresentationFn();
530   }
531   if (!shape_determination_fns.layout_preference_fn) {
532     shape_determination_fns.layout_preference_fn = UseNoPreferenceLayoutFn();
533   }
534 }
535 
536 XlaCompiler::~XlaCompiler() = default;
537 
NextStepId()538 int64_t XlaCompiler::NextStepId() { return next_step_id_++; }
539 
operator ()(const std::pair<string,std::vector<Argument>> & signature) const540 uint64 XlaCompiler::SignatureHash::operator()(
541     const std::pair<string, std::vector<Argument>>& signature) const {
542   return std::hash<string>()(signature.first);
543 }
544 
GetFunctionBody(const NameAttrList & function,FunctionLibraryRuntime * flib_runtime,const FunctionBody ** fbody)545 static Status GetFunctionBody(const NameAttrList& function,
546                               FunctionLibraryRuntime* flib_runtime,
547                               const FunctionBody** fbody) {
548   FunctionLibraryRuntime::Handle handle;
549   TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
550       function.name(), AttrSlice(&function.attr()), &handle));
551 
552   *fbody = flib_runtime->GetFunctionBody(handle);
553   TF_RET_CHECK(*fbody);
554   return OkStatus();
555 }
556 
FindFunctionBody(const NameAttrList & function,const FunctionBody ** fbody,const ConfigProto ** config_proto)557 Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
558                                      const FunctionBody** fbody,
559                                      const ConfigProto** config_proto) {
560   // The function may be in either the local_flib_runtime_ or flib_runtime_.
561   // Look up the function in local first and if it is not found then look up the
562   // function in flib_runtime_.
563   auto status = GetFunctionBody(function, local_flib_runtime_, fbody);
564   if (!status.ok()) {
565     if (!errors::IsNotFound(status)) {
566       return status;
567     }
568     TF_RETURN_WITH_CONTEXT_IF_ERROR(
569         GetFunctionBody(function, flib_runtime_, fbody),
570         "Local lookup failed with: ", status.error_message());
571     if (config_proto) {
572       *config_proto = flib_runtime_->config_proto();
573     }
574     VLOG(4) << "Function " << function.name() << " in flib_runtime_";
575   } else {
576     if (config_proto) {
577       *config_proto = local_flib_runtime_->config_proto();
578     }
579     VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
580   }
581   return OkStatus();
582 }
583 
GetGraph(const FunctionBody * fbody)584 std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
585   std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
586   CopyGraph(*fbody->graph, graph.get());
587 
588   bool is_inside_mustcompile = false;
589   TryGetNodeAttr(AttrSlice(&fbody->fdef.attr()), kXlaMustCompileAttr,
590                  &is_inside_mustcompile);
591 
592   // Performs a first function inlining pass before shape inference, since
593   // otherwise shape inference can't see inside functions and a comprehensive
594   // shape_map, including function ops, is needed to constant-propagate Shape
595   // Ops below.
596   auto flags = GetBuildXlaOpsPassFlags();
597   OptimizerOptions opts;
598   opts.set_opt_level(OptimizerOptions::L0);
599   opts.set_do_common_subexpression_elimination(false);
600   opts.set_do_function_inlining(true);
601   opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
602   GraphOptimizer optimizer(opts);
603   // Do not constant fold nodes that output DT_VARIANT type tensors.
604   // XLA does not support Const nodes of Variant type since it needs
605   // to know the original ops to be able to compile them to the relevant
606   // XLA form.
607   // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
608   // the form:
609   //                          Const
610   //                            |
611   // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
612   //                                                  |
613   //                                        (Discard popped list)
614   //
615   // Would have been reduced to "Const -> Op" without this filter.
616   // However since we are only allowed to specify the filter at the "Node"
617   // level there is no good way to allow the above behavior. So we
618   // disallow any sort of constant folding on Variant nodes for now.
619   //
620   // Also do not consider constant folding Shape ops. When there is a dynamic
621   // dimension in a tensor, TF2XLA currently represent them as the static
622   // upperbound shape, which can be constant folded and then lose the info
623   // that this Shape is dynamic.
624   auto cf_consider_fn = [](const Node* n) {
625     for (const auto& output_arg : n->op_def().output_arg()) {
626       if (output_arg.type() == DT_VARIANT) {
627         return false;
628       }
629     }
630     const auto& ts = n->type_string();
631     // XLA has special logic to handle dynamic shapes, don't constant fold
632     // them.
633     if (ts == "Shape" || ts == "ShapeN" || ts == "Size") {
634       return false;
635     }
636     return true;
637   };
638   GraphOptimizer::Options graph_optimizer_options;
639   graph_optimizer_options.cf_consider_fn = cf_consider_fn;
640   graph_optimizer_options.inline_multi_device_functions = true;
641   graph_optimizer_options.inline_impl_selection_group_functions = true;
642   graph_optimizer_options.inline_with_single_device_body_placer = true;
643   graph_optimizer_options.ignore_noinline = is_inside_mustcompile;
644 
645   {
646     GraphShapeInfo shape_info;
647     InferShapes(graph.get(), /*arg_shapes=*/{},
648                 flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
649         .IgnoreError();
650     auto node_name_index = graph->BuildNodeNameIndex();
651     std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
652     for (const auto& node_shape_info : shape_info) {
653       const string& node_name = node_shape_info.first;
654       const std::vector<InferredShape>& output_shapes = node_shape_info.second;
655       const auto& node_iter = node_name_index.find(node_name);
656       if (node_iter != node_name_index.end()) {
657         auto& partial_shapes = shape_map[node_name];
658         for (const auto& inferred_shape : output_shapes) {
659           partial_shapes.push_back(inferred_shape.shape);
660         }
661       }
662     }
663     graph_optimizer_options.shape_map = &shape_map;
664     optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
665                        /*device=*/nullptr, &graph, graph_optimizer_options);
666   }
667 
668   // Run shape inference on the graph and optimize the graph again.
669   GraphShapeInfo shape_info;
670   InferShapes(graph.get(), /*arg_shapes=*/{},
671               flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
672       .IgnoreError();
673   auto node_name_index = graph->BuildNodeNameIndex();
674   std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
675   for (const auto& node_shape_info : shape_info) {
676     const string& node_name = node_shape_info.first;
677     const std::vector<InferredShape>& output_shapes = node_shape_info.second;
678     const auto& node_iter = node_name_index.find(node_name);
679     if (node_iter != node_name_index.end()) {
680       auto& partial_shapes = shape_map[node_name];
681       for (const auto& inferred_shape : output_shapes) {
682         partial_shapes.push_back(inferred_shape.shape);
683       }
684     }
685   }
686   graph_optimizer_options.shape_map = &shape_map;
687   optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
688                      /*device=*/nullptr, &graph, graph_optimizer_options);
689 
690   return graph;
691 }
692 
693 // Collects all control rets from `orig_control_ret_nodes` that are still valid,
694 // keeping the same order.
GetValidControlRets(absl::Span<Node * const> orig_control_ret_nodes,const Graph & graph)695 std::vector<std::string> GetValidControlRets(
696     absl::Span<Node* const> orig_control_ret_nodes, const Graph& graph) {
697   // Build map from control ret node name to index.
698   // We use Node name instead of Node* here to index into the map as we populate
699   // the map with nodes in FunctionDef control_ret_nodes and later query it
700   // using the nodes in `graph`. The Node pointers would be different but the
701   // Node name is expected to remain the same between the two.
702   absl::flat_hash_map<const string, int> control_ret_nodes_map;
703   for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
704     const Node* n = orig_control_ret_nodes[i];
705     control_ret_nodes_map[n->name()] = i;
706   }
707   // Check which control rets are still valid.
708   std::vector<bool> is_valid_control_ret(orig_control_ret_nodes.size(), false);
709   int num_valid_control_rets = 0;
710   for (const Node* n : graph.nodes()) {
711     auto iter = control_ret_nodes_map.find(n->name());
712     if (iter != control_ret_nodes_map.end()) {
713       ++num_valid_control_rets;
714       is_valid_control_ret[iter->second] = true;
715     }
716   }
717   // Return valid control rets in same order as they appear in
718   // `orig_control_ret_nodes`.
719   std::vector<std::string> valid_control_rets;
720   valid_control_rets.reserve(num_valid_control_rets);
721   for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
722     if (is_valid_control_ret[i]) {
723       valid_control_rets.push_back(orig_control_ret_nodes[i]->name());
724     }
725   }
726   return valid_control_rets;
727 }
728 
CompileFunction(const XlaCompiler::CompileOptions & options,const NameAttrList & fn_name_attrs,absl::Span<const XlaCompiler::Argument> args,XlaCompiler::CompilationResult * result)729 Status XlaCompiler::CompileFunction(
730     const XlaCompiler::CompileOptions& options,
731     const NameAttrList& fn_name_attrs,
732     absl::Span<const XlaCompiler::Argument> args,
733     XlaCompiler::CompilationResult* result) {
734   const string function_id =
735       Canonicalize(fn_name_attrs.name(), AttrSlice(&fn_name_attrs.attr()));
736   VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
737 
738   const std::vector<XlaCompiler::Argument> arg_vector(args.begin(), args.end());
739   auto it = cache_.find({function_id, arg_vector});
740   if (it != cache_.end()) {
741     *result = it->second;
742     return OkStatus();
743   }
744 
745   const FunctionBody* fbody;
746   const ConfigProto* config = nullptr;
747   TF_RETURN_IF_ERROR(FindFunctionBody(fn_name_attrs, &fbody, &config));
748 
749   std::optional<ConfigProto> config_proto;
750   if (config) {
751     config_proto = *config;
752   }
753 
754   TF_RETURN_WITH_CONTEXT_IF_ERROR(
755       CheckSignature(fbody->arg_types, args),
756       "Signature check failure while compiling: ", fn_name_attrs.name());
757 
758   // Set shapes for _Arg nodes. They are useful for constant folding (e.g. an
759   // Xla op requires a compile-time constant input, and that input is shape of
760   // an _Arg node.
761   for (int i = 0, end = args.size(); i < end; i++) {
762     // Skip resource variables and tensor lists.
763     DataType dtype;
764     TF_RETURN_IF_ERROR(GetNodeAttr(fbody->arg_nodes[i]->def(), "T", &dtype));
765     if (dtype == DT_RESOURCE || dtype == DT_VARIANT) {
766       continue;
767     }
768 
769     if (absl::holds_alternative<xla::Shape>(args[i].shape)) {
770       xla::Shape xla_shape = std::get<xla::Shape>(args[i].shape);
771       TensorShape tensor_shape;
772       // If xla_shape is dynamic, prevent constant folding by not setting
773       // output_shapes.
774       if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok() &&
775           xla_shape.is_static()) {
776         fbody->arg_nodes[i]->ClearAttr("_output_shapes");
777         fbody->arg_nodes[i]->AddAttr("_output_shapes",
778                                      std::vector<TensorShape>{tensor_shape});
779       }
780     } else {
781       TensorShape tensor_shape = std::get<TensorShape>(args[i].shape);
782       fbody->arg_nodes[i]->ClearAttr("_output_shapes");
783       fbody->arg_nodes[i]->AddAttr("_output_shapes",
784                                    std::vector<TensorShape>{tensor_shape});
785     }
786   }
787 
788   std::unique_ptr<Graph> graph = GetGraph(fbody);
789 
790   // _Arg and _Retval nodes don't exist in the stored subgraph for the function;
791   // they are added by the function body looked up.  Therefore, they don't have
792   // core assignments here.
793   // Attempt to assign a core to each _Retval and _Arg. Chooses the
794   // lowest-numbered core that consumes the argument. We choose the
795   // lowest-numbered core so the assignment is deterministic.
796   for (Node* n : graph->nodes()) {
797     if (n->IsArg()) {
798       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
799     }
800   }
801   // Do _Retval as a second loop, in case the retval's input is an _Arg (which
802   // may have gotten a device assignment from the first loop).
803   for (Node* n : graph->nodes()) {
804     if (n->IsRetval()) {
805       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
806     }
807   }
808 
809   if (VLOG_IS_ON(2)) {
810     VLOG(2) << "XlaCompiler::CompileFunction: "
811             << DumpGraphToFile(
812                    absl::StrCat("xla_compile_function_", function_id), *graph);
813   }
814 
815   VLOG(1) << "====================================================";
816 
817   auto state = ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
818   if (options.is_entry_computation) {
819     state = GetMlirBridgeRolloutState(config_proto);
820   }
821 
822   if (state == ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
823     GraphDebugInfo debug_info;
824     VLOG(1) << "Using the MLIR bridge to compile the function.";
825     std::vector<std::string> valid_control_rets =
826         GetValidControlRets(fbody->control_ret_nodes, *graph);
827     auto mlir_result = CompileGraphToXlaHlo(
828         std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
829         valid_control_rets, options_.device_type.type_string(),
830         options.use_tuple_arg, /*analyse_graph=*/false, *options_.flib_def,
831         debug_info, options_.shape_determination_fns, result);
832     if (mlir_result.ok()) {
833       VLOG(1) << "MLIR bridge was successfull";
834     } else {
835       VLOG(1) << "MLIR failed, no fallback";
836       return mlir_result;
837     }
838   } else {
839     VLOG(1) << "MLIR bridge off. Using the old bridge to compile the function";
840     TF_RETURN_IF_ERROR(
841         CompileGraph(options, function_id, std::move(graph), args, result));
842   }
843   VLOG(1) << "====================================================";
844 
845   cache_[{function_id, arg_vector}] = *result;
846   return OkStatus();
847 }
848 
849 // Computes the XLA shape for argument 'arg'.
XLAShapeForArgument(const XlaCompiler::Argument & arg,bool is_entry_computation,const std::optional<xla::HloSharding> & arg_sharding,xla::Shape * xla_shape) const850 Status XlaCompiler::XLAShapeForArgument(
851     const XlaCompiler::Argument& arg, bool is_entry_computation,
852     const std::optional<xla::HloSharding>& arg_sharding,
853     xla::Shape* xla_shape) const {
854   switch (arg.kind) {
855     case XlaCompiler::Argument::kConstant:
856       LOG(FATAL) << "Unreachable case";
857     case XlaCompiler::Argument::kParameter: {
858       if (is_entry_computation) {
859         TensorShape shape;
860         if (absl::holds_alternative<TensorShape>(arg.shape)) {
861           shape = std::get<TensorShape>(arg.shape);
862         } else {
863           TF_RETURN_IF_ERROR(
864               XLAShapeToTensorShape(std::get<xla::Shape>(arg.shape), &shape));
865         }
866         auto layout_preference =
867             options_.shape_determination_fns.layout_preference_fn(
868                 shape, arg.type, arg.kind);
869         TF_ASSIGN_OR_RETURN(
870             *xla_shape,
871             options_.shape_determination_fns.shape_representation_fn(
872                 shape, arg.type,
873                 /*use_fast_memory=*/false, layout_preference));
874         TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
875             arg_sharding, /*use_fast_memory=*/false,
876             options_.shape_determination_fns, xla_shape));
877       } else {
878         if (absl::holds_alternative<xla::Shape>(arg.shape)) {
879           *xla_shape = std::get<xla::Shape>(arg.shape);
880         } else {
881           TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
882               arg.type, std::get<TensorShape>(arg.shape), xla_shape));
883         }
884       }
885       return OkStatus();
886     }
887     case XlaCompiler::Argument::kTensorList: {
888       TF_RET_CHECK(absl::holds_alternative<xla::Shape>(arg.shape));
889       *xla_shape = std::get<xla::Shape>(arg.shape);
890       return OkStatus();
891     }
892     case XlaCompiler::Argument::kConstantResource:
893     case XlaCompiler::Argument::kResource: {
894       TF_RET_CHECK(arg.initialized);
895 
896       switch (arg.resource_kind) {
897         case XlaResource::kVariable: {
898           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
899           auto layout_preference =
900               options_.shape_determination_fns.layout_preference_fn(
901                   std::get<TensorShape>(arg.shape), arg.type, arg.kind);
902           TF_ASSIGN_OR_RETURN(
903               *xla_shape,
904               options_.shape_determination_fns.shape_representation_fn(
905                   std::get<TensorShape>(arg.shape), arg.type,
906                   /*use_fast_memory=*/arg.fast_mem, layout_preference));
907           TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
908               arg_sharding, arg.fast_mem, options_.shape_determination_fns,
909               xla_shape));
910           return OkStatus();
911         }
912         case XlaResource::kTensorArray: {
913           if (arg.max_array_size < 0) {
914             return errors::InvalidArgument(
915                 "Negative max_array_size in XLAShapeForArgument");
916           }
917           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
918           TensorShape shape;
919           shape.AddDim(arg.max_array_size);
920           shape.AppendShape(std::get<TensorShape>(arg.shape));
921           TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
922 
923           if (!arg.tensor_array_gradients.empty()) {
924             std::vector<xla::Shape> tuple_shape(
925                 arg.tensor_array_gradients.size() + 1, *xla_shape);
926             *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
927           }
928           return OkStatus();
929         }
930         case XlaResource::kStack: {
931           if (arg.max_array_size < 0) {
932             return errors::InvalidArgument(
933                 "Negative max_array_size in XLAShapeForArgument");
934           }
935           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
936           TensorShape shape;
937           shape.AddDim(arg.max_array_size);
938           shape.AppendShape(std::get<TensorShape>(arg.shape));
939           xla::Shape buffer_shape;
940           TF_RETURN_IF_ERROR(
941               TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
942           *xla_shape = xla::ShapeUtil::MakeTupleShape(
943               {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
944           return OkStatus();
945         }
946 
947         case XlaResource::kInvalid:
948           return errors::Internal(
949               "Invalid resource type in XLAShapeForArgument()");
950       }
951     }
952     case XlaCompiler::Argument::kToken: {
953       *xla_shape = xla::ShapeUtil::MakeTokenShape();
954       return OkStatus();
955     }
956     case XlaCompiler::Argument::kInvalid:
957       return errors::Internal("Invalid argument type in XLAShapeForArgument()");
958   }
959 }
960 
961 /* static */
PopulateArgumentFromResource(const XlaResource & resource,Argument * arg)962 void XlaCompiler::PopulateArgumentFromResource(const XlaResource& resource,
963                                                Argument* arg) {
964   arg->initialized = resource.initialized();
965   arg->kind = XlaCompiler::Argument::kResource;
966   arg->resource_kind = resource.kind();
967 
968   arg->type = resource.type();
969   arg->shape = resource.shape();
970   arg->max_array_size = resource.max_array_size();
971   for (const auto& gradient : resource.tensor_array_gradients()) {
972     arg->tensor_array_gradients.insert(gradient.first);
973   }
974   arg->name = resource.name();
975 }
976 
977 // Builds XLA computations for each of the arguments to the computation.
978 // `args` are the arguments to the computation.
BuildArguments(const Graph & graph,const std::vector<XlaCompiler::Argument> & args,bool use_tuple_arg,xla::XlaBuilder * builder,XlaContext * context,const std::map<int,xla::OpSharding> & arg_shardings,std::vector<XlaExpression> * arg_expressions,std::vector<int> * input_to_args,std::vector<xla::Shape> * input_shapes,bool is_entry_computation)979 Status XlaCompiler::BuildArguments(
980     const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
981     bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
982     const std::map<int, xla::OpSharding>& arg_shardings,
983     std::vector<XlaExpression>* arg_expressions,
984     std::vector<int>* input_to_args, std::vector<xla::Shape>* input_shapes,
985     bool is_entry_computation) {
986   arg_expressions->resize(args.size());
987 
988   // Argument numbers of arguments and resources that are to be passed to the
989   // XLA computation as runtime parameters. `input_to_args[a] = b` means that
990   // the a'th XLA input corresponds to the b'th original arg indexes.
991   input_to_args->clear();
992   input_to_args->reserve(args.size());
993 
994   // Fills in constant arguments, and computes non-constant argument order.
995   for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
996        ++i) {
997     const XlaCompiler::Argument& arg = args[i];
998     XlaExpression& arg_expression = (*arg_expressions)[i];
999     switch (arg.kind) {
1000       case XlaCompiler::Argument::kConstantResource:
1001       case XlaCompiler::Argument::kResource: {
1002         TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
1003         TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
1004         // TODO(phawkins): this code assumes that resource arguments do not
1005         // alias.
1006         XlaResource* resource =
1007             context->AddResource(std::make_unique<XlaResource>(
1008                 arg.resource_kind, i, arg.name, arg.type,
1009                 std::get<TensorShape>(arg.shape), xla::XlaOp(),
1010                 /*max_array_size=*/arg.max_array_size,
1011                 /*tensor_array_gradients=*/arg.tensor_array_gradients,
1012                 /*tensor_array_multiple_writes_aggregate=*/true,
1013                 arg.definition_stack_trace));
1014         arg_expression =
1015             arg.kind == XlaCompiler::Argument::kResource
1016                 ? XlaExpression::Resource(resource)
1017                 : XlaExpression::ConstantResource(arg.constant_value, resource);
1018         if (arg.initialized) {
1019           input_to_args->push_back(i);
1020         }
1021         break;
1022       }
1023       case XlaCompiler::Argument::kParameter:
1024       case XlaCompiler::Argument::kTensorList:
1025       case XlaCompiler::Argument::kToken: {
1026         input_to_args->push_back(i);
1027         break;
1028       }
1029       case XlaCompiler::Argument::kConstant:
1030         arg_expression = XlaExpression::Constant(arg.constant_value);
1031         break;
1032       case XlaCompiler::Argument::kInvalid:
1033         return errors::Internal(
1034             "Unreachable case in BuildArguments() while filling constant args");
1035     }
1036   }
1037 
1038   if (input_to_args->empty() && !use_tuple_arg) {
1039     return OkStatus();
1040   }
1041 
1042   // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds
1043   // to the d'th XLA input. Note that the value -1 corresponds to constants, or
1044   // other args that don't correspond to an input.
1045   std::vector<int> arg_to_inputs(args.size(), -1);
1046   for (int i = 0, end = input_to_args->size(); i < end; i++) {
1047     arg_to_inputs[input_to_args->at(i)] = i;
1048   }
1049 
1050   std::vector<xla::Shape> arg_shapes(input_to_args->size());
1051   for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1052     // Computes the shapes of non-constant arguments.
1053     auto arg_sharding = arg_shardings.find((*input_to_args)[i]);
1054     std::optional<xla::HloSharding> sharding;
1055     if (arg_sharding != arg_shardings.end()) {
1056       TF_ASSIGN_OR_RETURN(auto hlo_sharding,
1057                           xla::HloSharding::FromProto(arg_sharding->second));
1058       sharding = hlo_sharding;
1059     }
1060     TF_RETURN_IF_ERROR(XLAShapeForArgument(args[(*input_to_args)[i]],
1061                                            is_entry_computation, sharding,
1062                                            &arg_shapes[i]));
1063   }
1064 
1065   if (use_tuple_arg) {
1066     input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes));
1067   } else {
1068     *input_shapes = arg_shapes;
1069   }
1070 
1071   // Attach a common operator name as metadata. This has no semantic effect — it
1072   // merely makes the HLO graph more readable when visualized via TensorBoard,
1073   // since TensorBoard forms groups out of operators with similar names.
1074   xla::OpMetadata arg_metadata;
1075   arg_metadata.set_op_name("XLA_Args");
1076   builder->SetOpMetadata(arg_metadata);
1077 
1078   // Build parameter handles for non-constant arguments.
1079   std::vector<xla::XlaOp> arg_handles(input_to_args->size());
1080   if (use_tuple_arg) {
1081     xla::XlaOp tuple;
1082     if (is_entry_computation) {
1083       xla::OpSharding tuple_sharding;
1084       tuple_sharding.set_type(xla::OpSharding::TUPLE);
1085       for (int64_t parameter : *input_to_args) {
1086         auto it = arg_shardings.find(parameter);
1087         *tuple_sharding.add_tuple_shardings() =
1088             it == arg_shardings.end() ? xla::sharding_builder::AssignDevice(0)
1089                                       : it->second;
1090       }
1091       std::vector<bool> is_same_across_replicas;
1092       for (int i = 0, end = input_to_args->size(); i < end; ++i) {
1093         // Add an entry to is_same_across_replicas for every leaf buffer.
1094         is_same_across_replicas.insert(
1095             is_same_across_replicas.end(),
1096             xla::ShapeUtil::GetLeafCount(arg_shapes[i]),
1097             args[input_to_args->at(i)].is_same_data_across_replicas);
1098       }
1099       xla::XlaScopedShardingAssignment assign_tuple_sharding(
1100           builder, input_to_args->empty() ? std::optional<xla::OpSharding>()
1101                                           : tuple_sharding);
1102       tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple",
1103                              is_same_across_replicas);
1104     } else {
1105       tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
1106     }
1107 
1108     for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1109       auto it = arg_shardings.find(i);
1110       xla::XlaScopedShardingAssignment assign_sharding(
1111           builder, it == arg_shardings.end() ? std::optional<xla::OpSharding>()
1112                                              : it->second);
1113       auto& arg = args[input_to_args->at(i)];
1114 
1115       xla::OpMetadata arg_metadata;
1116       arg_metadata.set_op_name(arg.node_name);
1117       builder->SetOneShotOpMetadata(arg_metadata);
1118       arg_handles[i] = xla::GetTupleElement(tuple, i);
1119     }
1120   } else {
1121     for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1122       auto it = arg_shardings.find(i);
1123       xla::XlaScopedShardingAssignment assign_sharding(
1124           builder, it == arg_shardings.end() ? std::optional<xla::OpSharding>()
1125                                              : it->second);
1126       if (is_entry_computation) {
1127         // Add an entry to is_same_across_replicas for every leaf buffer.
1128         std::vector<bool> is_same_across_replicas(
1129             xla::ShapeUtil::GetLeafCount((*input_shapes)[i]),
1130             args[input_to_args->at(i)].is_same_data_across_replicas);
1131         arg_handles[i] =
1132             xla::Parameter(builder, i, (*input_shapes)[i],
1133                            absl::StrCat("arg", i), is_same_across_replicas);
1134       } else {
1135         arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i],
1136                                         absl::StrCat("arg", i));
1137       }
1138     }
1139   }
1140 
1141   builder->ClearOpMetadata();
1142 
1143   // Fill in the handles in non-constant arguments, and reshape parameters
1144   // back to their correct shapes.
1145   VLOG(2) << "XLA computation inputs:";
1146   for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1147     const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
1148     VLOG(2) << "  XLA arg " << i
1149             << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
1150             << " name: " << arg.name << " TF arg " << input_to_args->at(i)
1151             << " node name: " << arg.node_name
1152             << (arg_shardings.find(i) == arg_shardings.end()
1153                     ? ""
1154                     : absl::StrCat(" sharding: ",
1155                                    arg_shardings.at(i).DebugString()));
1156     XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
1157     switch (arg.kind) {
1158       case XlaCompiler::Argument::kConstantResource:
1159       case XlaCompiler::Argument::kResource: {
1160         TF_RET_CHECK(arg.initialized);
1161         XlaResource* resource = arg_expression.resource();
1162         TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
1163                                                  arg_handles[i], builder));
1164         VLOG(2) << "    resource: num_gradients: "
1165                 << arg.tensor_array_gradients.size();
1166         break;
1167       }
1168       case XlaCompiler::Argument::kParameter:
1169         // Reshape parameters back to their correct shapes.
1170         // TODO(b/76097077): propagate device assignments onto arguments and
1171         // return values of functions, and then reshape unconditionally.
1172         if (is_entry_computation) {
1173           arg_expression = XlaExpression::XlaOp(
1174               xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type);
1175         } else {
1176           arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
1177           if (arg.value_bound) {
1178             TF_RET_CHECK(arg.value_dynamism);
1179             // Propagate upper bound and value dynamism to arg_expression.
1180             arg_expression.set_value_bound(arg.value_bound.value());
1181             arg_expression.set_value_dynamism(arg.value_dynamism.value());
1182           }
1183         }
1184         break;
1185       case XlaCompiler::Argument::kTensorList: {
1186         arg_expression = XlaExpression::TensorList(arg_handles[i]);
1187         break;
1188       }
1189       case XlaCompiler::Argument::kToken: {
1190         arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
1191         break;
1192       }
1193       case XlaCompiler::Argument::kConstant:
1194       case XlaCompiler::Argument::kInvalid:
1195         return errors::Internal(
1196             "Unreachable case in BuildArguments() while filling handles");
1197     }
1198   }
1199 
1200   return OkStatus();
1201 }
1202 
1203 namespace {
1204 
1205 // Check that the ops of all non-functional nodes have been registered.
ValidateFunctionDef(const FunctionDef * fdef,const FunctionLibraryDefinition & flib_def)1206 Status ValidateFunctionDef(const FunctionDef* fdef,
1207                            const FunctionLibraryDefinition& flib_def) {
1208   for (const NodeDef& node : fdef->node_def()) {
1209     const string& op = node.op();
1210     if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
1211       continue;
1212     }
1213     const OpDef* op_def;
1214     TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def));
1215   }
1216   return OkStatus();
1217 }
1218 
1219 // If node is PartitionedCall or StatefulPartitionedCall, returns the
1220 // name from the "f" attr, else returns node.def().op().
1221 // Returned pointer points to the internal string either in node's attributes
1222 // or in its NodeDef. This pointer is valid as long as the node has not been
1223 // modified.
GetPotentialFunctionName(const Node & node,const string ** name)1224 Status GetPotentialFunctionName(const Node& node, const string** name) {
1225   if (node.IsPartitionedCall()) {
1226     const AttrValue* attr_value;
1227     TF_RETURN_IF_ERROR(
1228         node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
1229     if (!attr_value->has_func()) {
1230       return errors::InvalidArgument(
1231           "The attribute value for attribute 'f' in node ", node.DebugString(),
1232           " does not have 'func' field set");
1233     }
1234     *name = &attr_value->func().name();
1235     return OkStatus();
1236   }
1237   *name = &node.type_string();
1238   return OkStatus();
1239 }
1240 
1241 // Check that the graph doesn't have any invalid nodes (e.g. incompatible with
1242 // given device_type, invalid data type, missing attributes...)
ValidateGraph(const Graph * graph,const FunctionLibraryDefinition & flib_def,const DeviceType & device_type,const string & name)1243 Status ValidateGraph(const Graph* graph,
1244                      const FunctionLibraryDefinition& flib_def,
1245                      const DeviceType& device_type, const string& name) {
1246   // Make sure the XLA compilation kernels are registered.  This operation is
1247   // idempotent so it is fine if someone called it already.
1248   XlaOpRegistry::RegisterCompilationKernels();
1249 
1250   auto maybe_error = [&](const Node* node, const Status& s) -> Status {
1251     if (!s.ok()) {
1252       std::string errmsg = absl::StrCat(
1253           "Detected unsupported operations when trying to compile graph ", name,
1254           " on ", device_type.type_string(), ": ", node->def().op(), " (",
1255           s.error_message(), ")", FormatNodeForError(*node));
1256       if (absl::StrContains(device_type.type_string(), "TPU")) {
1257         absl::StrAppend(&errmsg,
1258                         "\nOne approach is to outside compile the unsupported "
1259                         "ops to run on CPUs by enabling soft placement "
1260                         "`tf.config.set_soft_device_placement(True)`."
1261                         " This has a potential performance penalty.\n");
1262       }
1263       if (std::shared_ptr<AbstractStackTrace> stack_trace =
1264               node->GetStackTrace()) {
1265         absl::StrAppend(
1266             &errmsg, "\nThe op is created at: \n",
1267             stack_trace->ToString({/*show_line_contents =*/true,
1268                                    /*filter_common_prefix =*/true,
1269                                    /*drop_internal_frames =*/true}));
1270       }
1271 
1272       return errors::InvalidArgument(errmsg);
1273     }
1274     return OkStatus();
1275   };
1276 
1277   for (const Node* node : graph->nodes()) {
1278     if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
1279       continue;
1280     }
1281     const string* function_name;
1282     TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name));
1283     const FunctionDef* fdef = flib_def.Find(*function_name);
1284     Status s;
1285     if (fdef) {
1286       s = ValidateFunctionDef(fdef, flib_def);
1287       TF_RETURN_IF_ERROR(maybe_error(node, s));
1288       continue;
1289     }
1290     const OpDef* op_def;
1291     s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def);
1292     TF_RETURN_IF_ERROR(maybe_error(node, s));
1293     TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
1294     s = FindKernelDef(device_type, node->def(), nullptr, nullptr);
1295     TF_RETURN_IF_ERROR(maybe_error(node, s));
1296   }
1297   return OkStatus();
1298 }
1299 
ConvertConstantsToExpressions(xla::XlaBuilder * builder,absl::Span<XlaExpression> expressions)1300 void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
1301                                    absl::Span<XlaExpression> expressions) {
1302   for (XlaExpression& expression : expressions) {
1303     if (expression.kind() == XlaExpression::Kind::kConstant) {
1304       expression =
1305           XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype());
1306     }
1307   }
1308 }
1309 
1310 }  // namespace
1311 
CompileGraph(const XlaCompiler::CompileOptions & options,string const & name,std::unique_ptr<Graph> graph,absl::Span<const XlaCompiler::Argument> args,CompilationResult * result)1312 Status XlaCompiler::CompileGraph(
1313     const XlaCompiler::CompileOptions& options, string const& name,
1314     std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
1315     CompilationResult* result) {
1316   VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
1317 
1318   TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
1319       graph.get(), options_.flib_def, local_flib_def_.get()));
1320   TF_RETURN_IF_ERROR(RearrangeFunctionArguments(
1321       [this](const NameAttrList& function, const FunctionBody** fbody) {
1322         return FindFunctionBody(function, fbody);
1323       },
1324       graph.get(), local_flib_def_.get(),
1325       pflr_->GetFunctionLibraryDefinition()));
1326 
1327   if (VLOG_IS_ON(2)) {
1328     VLOG(2) << "XlaCompiler::CompileGraph: "
1329             << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph,
1330                                flib_runtime_->GetFunctionLibraryDefinition());
1331   }
1332 
1333   // Report the error here if initialization failed.
1334   TF_RETURN_IF_ERROR(initialization_status_);
1335 
1336   // Detect invalid nodes.
1337   // FunctionalizeControlFlow may remove some nodes from the graph.
1338   TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
1339                                    options_.device_type, name));
1340   xla::XlaBuilder builder(name);
1341   XlaContext* context = new XlaContext(this, &builder, graph.get());
1342   core::ScopedUnref context_unref(context);
1343 
1344   std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
1345   int token_input_index = -1;
1346   std::unique_ptr<xla::XlaOp> token_output;
1347   if (options.add_token_input_output) {
1348     // Add extra token input.
1349     token_input_index = real_args.size();
1350 
1351     XlaCompiler::Argument token_arg;
1352     token_arg.kind = XlaCompiler::Argument::kToken;
1353     real_args.push_back(token_arg);
1354   }
1355 
1356   std::map<int, xla::OpSharding> arg_shardings;
1357   std::map<int, xla::OpSharding> retval_shardings;
1358   TF_ASSIGN_OR_RETURN(std::tie(arg_shardings, retval_shardings),
1359                       ComputeArgAndRetvalShardings(*graph));
1360 
1361   std::vector<XlaExpression> arg_expressions;
1362   TF_RETURN_IF_ERROR(BuildArguments(
1363       *graph, real_args, options.use_tuple_arg, &builder, context,
1364       arg_shardings, &arg_expressions, &result->input_mapping,
1365       &result->xla_input_shapes, options.is_entry_computation));
1366   context->set_args(std::move(arg_expressions));
1367 
1368   PushNodeTokenMapping();
1369   // Use std::set instead of std::unordered_set to ensure determinism.
1370   std::set<std::string> output_node_token_inputs;
1371   if (token_input_index != -1) {
1372     // Original token comes from input.
1373     auto arg_expression = context->args()[token_input_index];
1374     TF_RETURN_IF_ERROR(
1375         SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
1376 
1377     // Calculate token inputs for output token.
1378     output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
1379 
1380     // If there's no side-effecting op in the graph, use token input as token
1381     // output.
1382     if (output_node_token_inputs.empty()) {
1383       output_node_token_inputs.insert(kXlaTokenArgNodeName);
1384     }
1385   } else if (options.is_entry_computation) {
1386     // Original token is manually created.
1387     if (HasSideEffectingNodes(*graph)) {
1388       TF_RETURN_IF_ERROR(
1389           SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
1390     }
1391   }
1392 
1393   Status execute_status = ExecuteGraph(context, std::move(graph), device_,
1394                                        flib_runtime_, NextStepId());
1395   if (!execute_status.ok()) {
1396     VLOG(1) << "Failed executing graph " << name;
1397     return execute_status;
1398   }
1399   if (token_input_index != -1) {
1400     // Add extra token output.
1401     std::vector<xla::XlaOp> token_inputs;
1402     for (const auto& node_name : output_node_token_inputs) {
1403       auto token_or = GetNodeToken(node_name);
1404       TF_RETURN_IF_ERROR(token_or.status());
1405       token_inputs.push_back(token_or.ValueOrDie());
1406     }
1407     token_output.reset(new xla::XlaOp(xla::AfterAll(&builder, token_inputs)));
1408   }
1409   TF_RETURN_IF_ERROR(PopNodeTokenMapping());
1410 
1411   int num_nonconst_outputs;
1412   int num_computation_outputs;
1413   result->computation = std::make_shared<xla::XlaComputation>();
1414   result->outputs.resize(context->retvals().size());
1415   std::vector<XlaExpression> retvals = context->retvals();
1416   ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
1417   XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns{
1418       UseNoPreferenceLayoutFn(), IdentityShapeRepresentationFn()};
1419   TF_RETURN_IF_ERROR(BuildComputation(
1420       real_args, retvals, arg_shardings, retval_shardings, context->resources(),
1421       std::move(token_output),
1422       options.is_entry_computation ? options_.shape_determination_fns
1423                                    : shape_determination_fns,
1424       options.is_entry_computation,
1425       options.return_updated_values_for_all_resources,
1426       options.always_return_tuple, options.use_tuple_arg,
1427       options.alias_resource_update, &builder, result->computation.get(),
1428       &num_computation_outputs, &num_nonconst_outputs, &result->outputs,
1429       &result->resource_updates, &result->xla_output_shape,
1430       result->input_mapping));
1431 
1432   VLOG(2) << "Outputs: total: " << context->retvals().size()
1433           << " nonconstant: " << num_nonconst_outputs;
1434   VLOG(2) << "XLA output shape: "
1435           << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape);
1436   result->collective_info = context->GetCollectiveInfo();
1437   return OkStatus();
1438 }
1439 
GetChannelHandle(const string & key,xla::ChannelHandle * channel)1440 Status XlaCompiler::GetChannelHandle(const string& key,
1441                                      xla::ChannelHandle* channel) {
1442   auto result = channels_.emplace(key, xla::ChannelHandle());
1443   if (result.second) {
1444     TF_ASSIGN_OR_RETURN(result.first->second, client()->CreateChannelHandle());
1445   }
1446   *channel = result.first->second;
1447   VLOG(1) << "Channel: " << key << " " << channel->DebugString();
1448   return OkStatus();
1449 }
1450 
GetHostToDeviceChannelHandle(const string & key,xla::ChannelHandle * channel)1451 Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key,
1452                                                  xla::ChannelHandle* channel) {
1453   auto result = channels_.emplace(key, xla::ChannelHandle());
1454   if (result.second) {
1455     TF_ASSIGN_OR_RETURN(result.first->second,
1456                         client()->CreateHostToDeviceChannelHandle());
1457   }
1458   *channel = result.first->second;
1459   VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString();
1460   return OkStatus();
1461 }
1462 
GetDeviceToHostChannelHandle(const string & key,xla::ChannelHandle * channel)1463 Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
1464                                                  xla::ChannelHandle* channel) {
1465   auto result = channels_.emplace(key, xla::ChannelHandle());
1466   if (result.second) {
1467     TF_ASSIGN_OR_RETURN(result.first->second,
1468                         client()->CreateDeviceToHostChannelHandle());
1469   }
1470   *channel = result.first->second;
1471   VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString();
1472   return OkStatus();
1473 }
1474 
1475 namespace {
1476 
SetTransfer(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes,tf2xla::HostTransferMetadata * transfer)1477 void SetTransfer(const string& key, absl::Span<const DataType> types,
1478                  absl::Span<const TensorShape> shapes,
1479                  tf2xla::HostTransferMetadata* transfer) {
1480   transfer->set_key(key);
1481   CHECK(types.size() == shapes.size());
1482   for (int i = 0, end = types.size(); i < end; ++i) {
1483     tf2xla::TensorMetadata* metadata = transfer->add_metadata();
1484     metadata->set_type(types[i]);
1485     shapes[i].AsProto(metadata->mutable_shape());
1486   }
1487 }
1488 
1489 }  // namespace
1490 
SetDeviceToHostMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1491 Status XlaCompiler::SetDeviceToHostMetadata(
1492     const string& key, absl::Span<const DataType> types,
1493     absl::Span<const TensorShape> shapes) {
1494   if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
1495     tf2xla::HostTransferMetadata& existing_transfer = host_compute_sends_[key];
1496     tf2xla::HostTransferMetadata new_transfer;
1497     SetTransfer(key, types, shapes, &new_transfer);
1498     if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) {
1499       return OkStatus();
1500     } else {
1501       return errors::InvalidArgument(
1502           "Duplicate calls to SetDeviceToHostMetadata with key ", key);
1503     }
1504   }
1505   tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
1506   SetTransfer(key, types, shapes, &transfer);
1507   return OkStatus();
1508 }
1509 
GetDeviceToHostShapes(const string & key,std::vector<TensorShape> * shapes) const1510 Status XlaCompiler::GetDeviceToHostShapes(
1511     const string& key, std::vector<TensorShape>* shapes) const {
1512   const auto iter = host_compute_sends_.find(key);
1513   if (iter == host_compute_sends_.end()) {
1514     return errors::InvalidArgument(
1515         "No host compute send shapes registered for key ", key);
1516   }
1517   shapes->clear();
1518   for (int i = 0; i < iter->second.metadata_size(); ++i) {
1519     TensorShape shape(iter->second.metadata(i).shape());
1520     shapes->push_back(shape);
1521   }
1522   return OkStatus();
1523 }
1524 
SetHostToDeviceMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1525 Status XlaCompiler::SetHostToDeviceMetadata(
1526     const string& key, absl::Span<const DataType> types,
1527     absl::Span<const TensorShape> shapes) {
1528   if (host_compute_recvs_.find(key) != host_compute_recvs_.end()) {
1529     tf2xla::HostTransferMetadata& existing_transfer = host_compute_recvs_[key];
1530     tf2xla::HostTransferMetadata new_transfer;
1531     SetTransfer(key, types, shapes, &new_transfer);
1532     if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) {
1533       return OkStatus();
1534     } else {
1535       return errors::InvalidArgument(
1536           "Duplicate calls to SetHostToDeviceMetadata with key ", key);
1537     }
1538   }
1539   tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
1540   SetTransfer(key, types, shapes, &transfer);
1541   return OkStatus();
1542 }
1543 
GetHostComputeControlDependency(const string & host_compute_name,xla::XlaOp * handle)1544 Status XlaCompiler::GetHostComputeControlDependency(
1545     const string& host_compute_name, xla::XlaOp* handle) {
1546   const auto iter = host_compute_control_output_.find(host_compute_name);
1547   if (iter == host_compute_control_output_.end()) {
1548     return errors::InvalidArgument(
1549         "No registered control handle for host compute Op '", host_compute_name,
1550         "'");
1551   } else {
1552     *handle = iter->second;
1553   }
1554   return OkStatus();
1555 }
1556 
SetHostComputeControlDependency(const string & host_compute_name,const xla::XlaOp & handle)1557 Status XlaCompiler::SetHostComputeControlDependency(
1558     const string& host_compute_name, const xla::XlaOp& handle) {
1559   if (host_compute_control_output_.find(host_compute_name) !=
1560       host_compute_control_output_.end()) {
1561     return errors::InvalidArgument(
1562         "Duplicate control handles registered for host compute Op ",
1563         host_compute_name);
1564   }
1565   host_compute_control_output_[host_compute_name] = handle;
1566   return OkStatus();
1567 }
1568 
PushNodeTokenMapping()1569 void XlaCompiler::PushNodeTokenMapping() {
1570   node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
1571 }
1572 
PopNodeTokenMapping()1573 Status XlaCompiler::PopNodeTokenMapping() {
1574   if (node_token_mapping_stack_.empty()) {
1575     return errors::FailedPrecondition(
1576         "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
1577         "empty.");
1578   }
1579   node_token_mapping_stack_.pop();
1580   return OkStatus();
1581 }
1582 
SetNodeToken(const string & node_name,const xla::XlaOp & op)1583 Status XlaCompiler::SetNodeToken(const string& node_name,
1584                                  const xla::XlaOp& op) {
1585   if (node_token_mapping_stack_.empty()) {
1586     return errors::FailedPrecondition(
1587         "Calling SetNodeToken() when node_token_mapping_stack_ is "
1588         "empty.");
1589   }
1590   auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
1591   if (!insert_result.second) {
1592     return errors::FailedPrecondition("Token mapping already exists for node ",
1593                                       node_name);
1594   }
1595   return OkStatus();
1596 }
1597 
GetNodeToken(const string & node_name)1598 StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
1599   if (node_token_mapping_stack_.empty()) {
1600     return errors::FailedPrecondition(
1601         "Calling GetNodeToken() when node_token_mapping_stack_ is "
1602         "empty.");
1603   }
1604   auto iter = node_token_mapping_stack_.top().find(node_name);
1605   if (iter == node_token_mapping_stack_.top().end()) {
1606     return errors::FailedPrecondition("Cannot find token mapping for node ",
1607                                       node_name);
1608   }
1609   return iter->second;
1610 }
1611 
1612 }  // namespace tensorflow
1613