xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/graph_compiler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
17 
18 #include <deque>
19 #include <numeric>
20 #include <utility>
21 #include <vector>
22 
23 #include "tensorflow/compiler/tf2xla/const_analysis.h"
24 #include "tensorflow/compiler/tf2xla/literal_util.h"
25 #include "tensorflow/compiler/tf2xla/shape_util.h"
26 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
27 #include "tensorflow/compiler/tf2xla/type_util.h"
28 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
29 #include "tensorflow/compiler/tf2xla/xla_context.h"
30 #include "tensorflow/compiler/tf2xla/xla_expression.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
32 #include "tensorflow/compiler/xla/client/client_library.h"
33 #include "tensorflow/compiler/xla/client/xla_builder.h"
34 #include "tensorflow/core/common_runtime/device.h"
35 #include "tensorflow/core/common_runtime/executor.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/common_runtime/graph_constructor.h"
38 #include "tensorflow/core/common_runtime/graph_optimizer.h"
39 #include "tensorflow/core/framework/attr_value.pb.h"
40 #include "tensorflow/core/framework/attr_value_util.h"
41 #include "tensorflow/core/framework/function.h"
42 #include "tensorflow/core/framework/node_def_util.h"
43 #include "tensorflow/core/framework/op_kernel.h"
44 #include "tensorflow/core/graph/algorithm.h"
45 #include "tensorflow/core/graph/node_builder.h"
46 #include "tensorflow/core/graph/validate.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/gtl/cleanup.h"
49 #include "tensorflow/core/lib/hash/hash.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/public/version.h"
52 #include "tensorflow/core/util/dump_graph.h"
53 
54 namespace tensorflow {
55 
56 namespace {
PrepareArguments(XlaOpKernelContext * ctx,Graph * graph,const std::vector<const XlaExpression * > & expressions,const NameAttrList & func,std::vector<XlaCompiler::Argument> * args)57 Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
58                         const std::vector<const XlaExpression*>& expressions,
59                         const NameAttrList& func,
60                         std::vector<XlaCompiler::Argument>* args) {
61   auto client = ctx->compiler()->client();
62   std::vector<bool> arg_must_be_compile_time_constant(expressions.size());
63 
64   TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
65       *graph, &arg_must_be_compile_time_constant,
66       /*compile_time_const_nodes=*/nullptr, ctx->function_library()));
67 
68   args->resize(expressions.size());
69   for (int i = 0, end = args->size(); i < end; ++i) {
70     XlaCompiler::Argument& arg = (*args)[i];
71     arg.type = ctx->input_type(i);
72     arg.shape = ctx->InputShape(i);
73 
74     switch (expressions[i]->kind()) {
75       case XlaExpression::Kind::kConstant:
76         arg.kind = XlaCompiler::Argument::kConstant;
77         arg.constant_value = *expressions[i]->constant_value();
78         break;
79       case XlaExpression::Kind::kXlaOp:
80         if (arg_must_be_compile_time_constant[i]) {
81           TF_ASSIGN_OR_RETURN(std::optional<Tensor> value,
82                               expressions[i]->ResolveConstant(client));
83           if (value.has_value()) {
84             arg.kind = XlaCompiler::Argument::kConstant;
85             arg.constant_value = *value;
86           } else {
87             arg.kind = XlaCompiler::Argument::kParameter;
88           }
89 
90         } else {
91           arg.kind = XlaCompiler::Argument::kParameter;
92         }
93         break;
94       case XlaExpression::Kind::kResource: {
95         XlaResource* resource = expressions[i]->resource();
96         XlaCompiler::PopulateArgumentFromResource(*resource, &arg);
97         break;
98       }
99       case XlaExpression::Kind::kTensorList: {
100         arg.kind = XlaCompiler::Argument::kTensorList;
101         const xla::XlaOp& tensor_list = expressions[i]->handle();
102         arg.shape = tensor_list.builder()->GetShape(tensor_list).ValueOrDie();
103         break;
104       }
105       case XlaExpression::Kind::kInvalid:
106         return errors::InvalidArgument("Invalid function argument");
107     }
108   }
109   return OkStatus();
110 }
111 }  // namespace
Compile()112 Status GraphCompiler::Compile() {
113   // Check that the graph has no illegal cycles.
114   TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_));
115   // Maintain a mapping from node id to node outputs.
116   using NodeOutputs = std::vector<TensorValue>;
117   std::vector<NodeOutputs> output_registry(graph_->num_node_ids());
118   auto output_registry_cleanup = gtl::MakeCleanup([&output_registry] {
119     for (const NodeOutputs& outputs : output_registry) {
120       for (const TensorValue& value : outputs) {
121         CHECK(!value.is_ref());
122         delete value.tensor;
123       }
124     }
125   });
126 
127   // XLA requires determinism, generate a stable ordering from DFS.
128   std::vector<Node*> topo_sorted_nodes;
129   GetReversePostOrder(*graph_, &topo_sorted_nodes,
130                       /*stable_comparator=*/NodeComparatorName());
131 
132   OpKernelContext::Params params;
133   PartiallySetupParams(&params);
134 
135   for (Node* n : topo_sorted_nodes) {
136     OpKernel* op_kernel_raw = nullptr;
137     // The kernel is not actually run for functional ops, we just need it
138     // for metadata.
139     Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw);
140     // Transfer ownership of the kernel to a local smart pointer.
141     std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);
142 
143     if (!s.ok()) {
144       s = AttachDef(s, *n);
145       LOG(ERROR) << "Executor failed to create kernel. " << s;
146       return s;
147     }
148 
149     TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch())
150         << "Not supported node: " << n->DebugString();
151     params.op_kernel = op_kernel.get();
152     absl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
153     params.output_attr_array = output_attr.data();
154 
155     // tensor_inputs_ is a buffer reused across graph traversal. We clean up and
156     // reinitialize the buffer before we visit a new node.
157     tensor_inputs_.clear();
158     tensor_inputs_.resize(n->num_inputs());
159 
160     // Set up inputs from outputs of previous nodes.
161     for (auto* e : n->in_edges()) {
162       if (e->IsControlEdge()) continue;
163       const Node* src = e->src();
164       const int output_registry_size = output_registry.size();
165       TF_RET_CHECK(src->id() < output_registry_size);
166       const NodeOutputs& src_outputs = output_registry[src->id()];
167 
168       tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output());
169     }
170     params.inputs = tensor_inputs_;
171 
172     OpKernelContext op_context(&params, n->num_outputs());
173     VLOG(3) << "Translating " << params.op_kernel->name();
174     if (IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)) {
175       TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
176     } else {
177       device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context);
178       Status s = op_context.status();
179       if (!s.ok()) {
180         return AttachDef(s, n->def());
181       }
182     }
183 
184     // Set up outputs. Also check if outputs from the previous computation is
185     // valid.
186     NodeOutputs& outputs = output_registry[n->id()];
187     outputs.resize(n->num_outputs());
188     for (int o = 0; o < n->num_outputs(); ++o) {
189       outputs[o] = op_context.release_output(o);
190       if (outputs[o].tensor == nullptr) {
191         return errors::Internal("Missing xla_context ", o, "-th output from ",
192                                 FormatNodeForError(*n));
193       }
194     }
195   }
196   return OkStatus();
197 }
198 
199 namespace {
200 
GetFunctionNameAndAttr(const FunctionLibraryRuntime & flib,const Node & node,NameAttrList * func)201 Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib,
202                               const Node& node, NameAttrList* func) {
203   if (node.IsPartitionedCall()) {
204     const AttrValue* attr_value;
205     TF_RETURN_IF_ERROR(
206         node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
207     if (!attr_value->has_func()) {
208       return errors::InvalidArgument(
209           "The attribute value for attribute 'f' in node ", node.DebugString(),
210           " does not have 'func' field set");
211     }
212     *func = attr_value->func();
213     return OkStatus();
214   }
215 
216   if (flib.GetFunctionLibraryDefinition()->Find(node.def().op())) {
217     func->set_name(node.type_string());
218   } else {
219     func->set_name(FunctionLibraryDefinition::kGradientOp);
220   }
221   *func->mutable_attr() = node.def().attr();
222   return OkStatus();
223 }
224 
225 }  // namespace
226 
CompileFunctionalNode(Node * n,OpKernelContext * op_context)227 Status GraphCompiler::CompileFunctionalNode(Node* n,
228                                             OpKernelContext* op_context) {
229   TF_RET_CHECK(IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n));
230   // For functional nodes, compile them using compiler from the context and call
231   // into the functions.
232   XlaOpKernelContext xla_op_context(op_context);
233 
234   XlaContext& context = XlaContext::Get(op_context);
235   auto* b = context.builder();
236 
237   XlaCompiler* compiler = xla_op_context.compiler();
238 
239   NameAttrList func;
240   TF_RETURN_IF_ERROR(GetFunctionNameAndAttr(*flib_, *n, &func));
241 
242   std::vector<const XlaExpression*> expressions;
243 
244   for (auto tensor : tensor_inputs_) {
245     auto expression =
246         reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
247     expressions.push_back(expression);
248   }
249 
250   // Prepare the arguments and compile the function.
251   std::vector<XlaCompiler::Argument> arguments;
252   const FunctionBody* fbody;
253   TF_RETURN_IF_ERROR(compiler->FindFunctionBody(func, &fbody));
254 
255   auto graph = compiler->GetGraph(fbody);
256 
257   TF_RETURN_IF_ERROR(PrepareArguments(&xla_op_context, graph.get(), expressions,
258                                       func, &arguments));
259 
260   bool add_token_input_output =
261       func.attr().find(kXlaTokenInputNodesAttrName) != func.attr().end();
262 
263   XlaCompiler::CompileOptions compile_options;
264   compile_options.is_entry_computation = false;
265   compile_options.add_token_input_output = add_token_input_output;
266   XlaCompiler::CompilationResult result;
267   TF_RETURN_IF_ERROR(
268       compiler->CompileFunction(compile_options, func, arguments, &result));
269 
270   TF_RET_CHECK(arguments.size() == expressions.size());
271 
272   std::vector<xla::XlaOp> handles;
273   for (int64_t i = 0, end = expressions.size(); i < end; ++i) {
274     if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
275       continue;
276     }
277     if (arguments[i].kind == XlaCompiler::Argument::kResource) {
278       handles.push_back(expressions[i]->resource()->value());
279     } else {
280       handles.push_back(expressions[i]->handle());
281     }
282   }
283   if (add_token_input_output) {
284     std::vector<string> token_input_nodes;
285     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()),
286                                    kXlaTokenInputNodesAttrName,
287                                    &token_input_nodes));
288     std::vector<xla::XlaOp> token_inputs;
289     for (const string& node_name : token_input_nodes) {
290       auto token_or = compiler->GetNodeToken(node_name);
291       TF_RETURN_IF_ERROR(token_or.status());
292       token_inputs.push_back(std::move(token_or).value());
293     }
294     xla::XlaOp token_input = xla::AfterAll(b, token_inputs);
295     handles.push_back(token_input);
296   }
297 
298   auto output_handle = xla::Call(b, *result.computation, handles);
299   // The output handle of `Call` computation is a tuple type. Unzip it so
300   // that it can fit into future computations.
301   int computation_output = 0;
302   for (int64_t i = 0; i < n->num_outputs(); ++i) {
303     if (result.outputs[i].is_constant) {
304       xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
305     } else {
306       if (result.outputs[i].is_tensor_list) {
307         xla_op_context.SetTensorListOutput(
308             i, xla::GetTupleElement(output_handle, computation_output));
309       } else {
310         xla_op_context.SetOutput(
311             i, xla::GetTupleElement(output_handle, computation_output));
312       }
313       ++computation_output;
314     }
315   }
316 
317   for (int64_t i = 0, end = result.resource_updates.size(); i < end; i++) {
318     if (result.resource_updates[i].modified) {
319       XlaResource* resource =
320           expressions[result.resource_updates[i].input_index]->resource();
321       xla::XlaOp updated_value =
322           xla::GetTupleElement(output_handle, i + n->num_outputs());
323       TF_RETURN_IF_ERROR(resource->SetValue(updated_value));
324     }
325   }
326 
327   if (add_token_input_output) {
328     std::string node_name;
329     if (!GetNodeAttr(n->attrs(), kXlaOriginalOutsideCompilationNodeName,
330                      &node_name)
331              .ok())
332       node_name = n->name();
333     TF_RETURN_IF_ERROR(compiler->SetNodeToken(
334         node_name, xla::GetTupleElement(output_handle, computation_output)));
335   }
336   return b->first_error();
337 }
338 
PartiallySetupParams(OpKernelContext::Params * params)339 void GraphCompiler::PartiallySetupParams(OpKernelContext::Params* params) {
340   params->device = device_;
341   params->step_container = step_container_;
342   params->resource_manager = device_->resource_manager();
343   params->function_library = flib_;
344 }
345 
346 }  // namespace tensorflow
347