1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
17
18 #include <deque>
19 #include <numeric>
20 #include <utility>
21 #include <vector>
22
23 #include "tensorflow/compiler/tf2xla/const_analysis.h"
24 #include "tensorflow/compiler/tf2xla/literal_util.h"
25 #include "tensorflow/compiler/tf2xla/shape_util.h"
26 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
27 #include "tensorflow/compiler/tf2xla/type_util.h"
28 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
29 #include "tensorflow/compiler/tf2xla/xla_context.h"
30 #include "tensorflow/compiler/tf2xla/xla_expression.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
32 #include "tensorflow/compiler/xla/client/client_library.h"
33 #include "tensorflow/compiler/xla/client/xla_builder.h"
34 #include "tensorflow/core/common_runtime/device.h"
35 #include "tensorflow/core/common_runtime/executor.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/common_runtime/graph_constructor.h"
38 #include "tensorflow/core/common_runtime/graph_optimizer.h"
39 #include "tensorflow/core/framework/attr_value.pb.h"
40 #include "tensorflow/core/framework/attr_value_util.h"
41 #include "tensorflow/core/framework/function.h"
42 #include "tensorflow/core/framework/node_def_util.h"
43 #include "tensorflow/core/framework/op_kernel.h"
44 #include "tensorflow/core/graph/algorithm.h"
45 #include "tensorflow/core/graph/node_builder.h"
46 #include "tensorflow/core/graph/validate.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/gtl/cleanup.h"
49 #include "tensorflow/core/lib/hash/hash.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/public/version.h"
52 #include "tensorflow/core/util/dump_graph.h"
53
54 namespace tensorflow {
55
56 namespace {
PrepareArguments(XlaOpKernelContext * ctx,Graph * graph,const std::vector<const XlaExpression * > & expressions,const NameAttrList & func,std::vector<XlaCompiler::Argument> * args)57 Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
58 const std::vector<const XlaExpression*>& expressions,
59 const NameAttrList& func,
60 std::vector<XlaCompiler::Argument>* args) {
61 auto client = ctx->compiler()->client();
62 std::vector<bool> arg_must_be_compile_time_constant(expressions.size());
63
64 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
65 *graph, &arg_must_be_compile_time_constant,
66 /*compile_time_const_nodes=*/nullptr, ctx->function_library()));
67
68 args->resize(expressions.size());
69 for (int i = 0, end = args->size(); i < end; ++i) {
70 XlaCompiler::Argument& arg = (*args)[i];
71 arg.type = ctx->input_type(i);
72 arg.shape = ctx->InputShape(i);
73
74 switch (expressions[i]->kind()) {
75 case XlaExpression::Kind::kConstant:
76 arg.kind = XlaCompiler::Argument::kConstant;
77 arg.constant_value = *expressions[i]->constant_value();
78 break;
79 case XlaExpression::Kind::kXlaOp:
80 if (arg_must_be_compile_time_constant[i]) {
81 TF_ASSIGN_OR_RETURN(std::optional<Tensor> value,
82 expressions[i]->ResolveConstant(client));
83 if (value.has_value()) {
84 arg.kind = XlaCompiler::Argument::kConstant;
85 arg.constant_value = *value;
86 } else {
87 arg.kind = XlaCompiler::Argument::kParameter;
88 }
89
90 } else {
91 arg.kind = XlaCompiler::Argument::kParameter;
92 }
93 break;
94 case XlaExpression::Kind::kResource: {
95 XlaResource* resource = expressions[i]->resource();
96 XlaCompiler::PopulateArgumentFromResource(*resource, &arg);
97 break;
98 }
99 case XlaExpression::Kind::kTensorList: {
100 arg.kind = XlaCompiler::Argument::kTensorList;
101 const xla::XlaOp& tensor_list = expressions[i]->handle();
102 arg.shape = tensor_list.builder()->GetShape(tensor_list).ValueOrDie();
103 break;
104 }
105 case XlaExpression::Kind::kInvalid:
106 return errors::InvalidArgument("Invalid function argument");
107 }
108 }
109 return OkStatus();
110 }
111 } // namespace
Compile()112 Status GraphCompiler::Compile() {
113 // Check that the graph has no illegal cycles.
114 TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_));
115 // Maintain a mapping from node id to node outputs.
116 using NodeOutputs = std::vector<TensorValue>;
117 std::vector<NodeOutputs> output_registry(graph_->num_node_ids());
118 auto output_registry_cleanup = gtl::MakeCleanup([&output_registry] {
119 for (const NodeOutputs& outputs : output_registry) {
120 for (const TensorValue& value : outputs) {
121 CHECK(!value.is_ref());
122 delete value.tensor;
123 }
124 }
125 });
126
127 // XLA requires determinism, generate a stable ordering from DFS.
128 std::vector<Node*> topo_sorted_nodes;
129 GetReversePostOrder(*graph_, &topo_sorted_nodes,
130 /*stable_comparator=*/NodeComparatorName());
131
132 OpKernelContext::Params params;
133 PartiallySetupParams(¶ms);
134
135 for (Node* n : topo_sorted_nodes) {
136 OpKernel* op_kernel_raw = nullptr;
137 // The kernel is not actually run for functional ops, we just need it
138 // for metadata.
139 Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw);
140 // Transfer ownership of the kernel to a local smart pointer.
141 std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);
142
143 if (!s.ok()) {
144 s = AttachDef(s, *n);
145 LOG(ERROR) << "Executor failed to create kernel. " << s;
146 return s;
147 }
148
149 TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch())
150 << "Not supported node: " << n->DebugString();
151 params.op_kernel = op_kernel.get();
152 absl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
153 params.output_attr_array = output_attr.data();
154
155 // tensor_inputs_ is a buffer reused across graph traversal. We clean up and
156 // reinitialize the buffer before we visit a new node.
157 tensor_inputs_.clear();
158 tensor_inputs_.resize(n->num_inputs());
159
160 // Set up inputs from outputs of previous nodes.
161 for (auto* e : n->in_edges()) {
162 if (e->IsControlEdge()) continue;
163 const Node* src = e->src();
164 const int output_registry_size = output_registry.size();
165 TF_RET_CHECK(src->id() < output_registry_size);
166 const NodeOutputs& src_outputs = output_registry[src->id()];
167
168 tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output());
169 }
170 params.inputs = tensor_inputs_;
171
172 OpKernelContext op_context(¶ms, n->num_outputs());
173 VLOG(3) << "Translating " << params.op_kernel->name();
174 if (IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)) {
175 TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
176 } else {
177 device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context);
178 Status s = op_context.status();
179 if (!s.ok()) {
180 return AttachDef(s, n->def());
181 }
182 }
183
184 // Set up outputs. Also check if outputs from the previous computation is
185 // valid.
186 NodeOutputs& outputs = output_registry[n->id()];
187 outputs.resize(n->num_outputs());
188 for (int o = 0; o < n->num_outputs(); ++o) {
189 outputs[o] = op_context.release_output(o);
190 if (outputs[o].tensor == nullptr) {
191 return errors::Internal("Missing xla_context ", o, "-th output from ",
192 FormatNodeForError(*n));
193 }
194 }
195 }
196 return OkStatus();
197 }
198
199 namespace {
200
GetFunctionNameAndAttr(const FunctionLibraryRuntime & flib,const Node & node,NameAttrList * func)201 Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib,
202 const Node& node, NameAttrList* func) {
203 if (node.IsPartitionedCall()) {
204 const AttrValue* attr_value;
205 TF_RETURN_IF_ERROR(
206 node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
207 if (!attr_value->has_func()) {
208 return errors::InvalidArgument(
209 "The attribute value for attribute 'f' in node ", node.DebugString(),
210 " does not have 'func' field set");
211 }
212 *func = attr_value->func();
213 return OkStatus();
214 }
215
216 if (flib.GetFunctionLibraryDefinition()->Find(node.def().op())) {
217 func->set_name(node.type_string());
218 } else {
219 func->set_name(FunctionLibraryDefinition::kGradientOp);
220 }
221 *func->mutable_attr() = node.def().attr();
222 return OkStatus();
223 }
224
225 } // namespace
226
CompileFunctionalNode(Node * n,OpKernelContext * op_context)227 Status GraphCompiler::CompileFunctionalNode(Node* n,
228 OpKernelContext* op_context) {
229 TF_RET_CHECK(IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n));
230 // For functional nodes, compile them using compiler from the context and call
231 // into the functions.
232 XlaOpKernelContext xla_op_context(op_context);
233
234 XlaContext& context = XlaContext::Get(op_context);
235 auto* b = context.builder();
236
237 XlaCompiler* compiler = xla_op_context.compiler();
238
239 NameAttrList func;
240 TF_RETURN_IF_ERROR(GetFunctionNameAndAttr(*flib_, *n, &func));
241
242 std::vector<const XlaExpression*> expressions;
243
244 for (auto tensor : tensor_inputs_) {
245 auto expression =
246 reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
247 expressions.push_back(expression);
248 }
249
250 // Prepare the arguments and compile the function.
251 std::vector<XlaCompiler::Argument> arguments;
252 const FunctionBody* fbody;
253 TF_RETURN_IF_ERROR(compiler->FindFunctionBody(func, &fbody));
254
255 auto graph = compiler->GetGraph(fbody);
256
257 TF_RETURN_IF_ERROR(PrepareArguments(&xla_op_context, graph.get(), expressions,
258 func, &arguments));
259
260 bool add_token_input_output =
261 func.attr().find(kXlaTokenInputNodesAttrName) != func.attr().end();
262
263 XlaCompiler::CompileOptions compile_options;
264 compile_options.is_entry_computation = false;
265 compile_options.add_token_input_output = add_token_input_output;
266 XlaCompiler::CompilationResult result;
267 TF_RETURN_IF_ERROR(
268 compiler->CompileFunction(compile_options, func, arguments, &result));
269
270 TF_RET_CHECK(arguments.size() == expressions.size());
271
272 std::vector<xla::XlaOp> handles;
273 for (int64_t i = 0, end = expressions.size(); i < end; ++i) {
274 if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
275 continue;
276 }
277 if (arguments[i].kind == XlaCompiler::Argument::kResource) {
278 handles.push_back(expressions[i]->resource()->value());
279 } else {
280 handles.push_back(expressions[i]->handle());
281 }
282 }
283 if (add_token_input_output) {
284 std::vector<string> token_input_nodes;
285 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()),
286 kXlaTokenInputNodesAttrName,
287 &token_input_nodes));
288 std::vector<xla::XlaOp> token_inputs;
289 for (const string& node_name : token_input_nodes) {
290 auto token_or = compiler->GetNodeToken(node_name);
291 TF_RETURN_IF_ERROR(token_or.status());
292 token_inputs.push_back(std::move(token_or).value());
293 }
294 xla::XlaOp token_input = xla::AfterAll(b, token_inputs);
295 handles.push_back(token_input);
296 }
297
298 auto output_handle = xla::Call(b, *result.computation, handles);
299 // The output handle of `Call` computation is a tuple type. Unzip it so
300 // that it can fit into future computations.
301 int computation_output = 0;
302 for (int64_t i = 0; i < n->num_outputs(); ++i) {
303 if (result.outputs[i].is_constant) {
304 xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
305 } else {
306 if (result.outputs[i].is_tensor_list) {
307 xla_op_context.SetTensorListOutput(
308 i, xla::GetTupleElement(output_handle, computation_output));
309 } else {
310 xla_op_context.SetOutput(
311 i, xla::GetTupleElement(output_handle, computation_output));
312 }
313 ++computation_output;
314 }
315 }
316
317 for (int64_t i = 0, end = result.resource_updates.size(); i < end; i++) {
318 if (result.resource_updates[i].modified) {
319 XlaResource* resource =
320 expressions[result.resource_updates[i].input_index]->resource();
321 xla::XlaOp updated_value =
322 xla::GetTupleElement(output_handle, i + n->num_outputs());
323 TF_RETURN_IF_ERROR(resource->SetValue(updated_value));
324 }
325 }
326
327 if (add_token_input_output) {
328 std::string node_name;
329 if (!GetNodeAttr(n->attrs(), kXlaOriginalOutsideCompilationNodeName,
330 &node_name)
331 .ok())
332 node_name = n->name();
333 TF_RETURN_IF_ERROR(compiler->SetNodeToken(
334 node_name, xla::GetTupleElement(output_handle, computation_output)));
335 }
336 return b->first_error();
337 }
338
PartiallySetupParams(OpKernelContext::Params * params)339 void GraphCompiler::PartiallySetupParams(OpKernelContext::Params* params) {
340 params->device = device_;
341 params->step_container = step_container_;
342 params->resource_manager = device_->resource_manager();
343 params->function_library = flib_;
344 }
345
346 } // namespace tensorflow
347