1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
17
18 #include <deque>
19 #include <map>
20 #include <unordered_map>
21 #include <unordered_set>
22
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/container/node_hash_set.h"
25 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/gtl/cleanup.h"
31 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
32 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
33
34 namespace tensorflow {
35 namespace tpu {
36
37 namespace {
38
39 constexpr char kDefaultShardingValue[] = "";
40
FindEdgeConnecting(const Node * src,const Node * dst)41 const Edge* FindEdgeConnecting(const Node* src, const Node* dst) {
42 for (const auto e : src->out_edges()) {
43 if (e->dst()->name() == dst->name()) return &(*e);
44 }
45 return nullptr;
46 }
47
48 // Contains TPUExecute node and its DT_RESOURCE input nodes that
49 // correspond to model weights.
50 struct ExecuteNodeInfo {
51 Node* execute_node;
52 std::vector<const Edge*> var_inputs;
53 };
54
55 // Returns whether `node` is in `execute_nodes` or `(identity) -> execute`.
IsExecuteNodeOrIdentityToExecuteNode(const Graph & graph,const std::unordered_set<Node * > & loop_nodes,const absl::flat_hash_set<Node * > & execute_nodes,Node * node)56 bool IsExecuteNodeOrIdentityToExecuteNode(
57 const Graph& graph, const std::unordered_set<Node*>& loop_nodes, // NOLINT
58 const absl::flat_hash_set<Node*>& execute_nodes, Node* node) {
59 if (execute_nodes.find(node) != execute_nodes.end()) return true;
60 if (loop_nodes.find(node) == loop_nodes.end()) return false;
61 if (node->IsNextIteration()) return true;
62 if (!node->IsIdentity()) return false;
63
64 for (const Edge* e : node->out_edges()) {
65 if (e->IsControlEdge()) continue;
66
67 Node* node = e->dst();
68 if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
69 node)) {
70 return false;
71 }
72 }
73
74 return true;
75 }
76
77 // From input node to the TPUExecute op, finds the corresponding Enter node
78 // by searching/traversing nodes in below pattern of nodes:
79 // Enter ----> (identity) ---> While body input
80 // Returns nullptr if the Enter node is not found.
FindEnterNodeFromTPUExecuteNodeInput(Node * input_node)81 xla::StatusOr<Node*> FindEnterNodeFromTPUExecuteNodeInput(Node* input_node) {
82 Node* node = input_node;
83 while (node->IsIdentity()) {
84 TF_RETURN_IF_ERROR(node->input_node(0, &node));
85 }
86
87 if (node->IsEnter()) {
88 return node;
89 }
90 return nullptr;
91 }
92
ResourceOnlyUsedForTPUExecuteInLoop(const Graph & graph,const std::unordered_set<Node * > & loop_nodes,const Node * enter_node,const absl::flat_hash_set<Node * > execute_nodes)93 xla::StatusOr<bool> ResourceOnlyUsedForTPUExecuteInLoop(
94 const Graph& graph, const std::unordered_set<Node*>& loop_nodes, // NOLINT
95 const Node* enter_node, const absl::flat_hash_set<Node*> execute_nodes) {
96 for (const Edge* output_edge : enter_node->out_edges()) {
97 Node* output_node = output_edge->dst();
98 if (output_edge->IsControlEdge() || output_node->IsExit()) continue;
99
100 // If output node is not execute node, it must be output node
101 // to the while loop body.
102 if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
103 output_node)) {
104 return false;
105 }
106 }
107 return true;
108 }
109
110 // Given a TPUCompile node, find all TPUExecute nodes that executes the compiled
111 // program and its model weight variable inputs as well.
112 // TPUCompileMetadataProto of TPUCompile node must be reset to `new_metadata`
113 // if new reshard ops are added.
ExtractExecuteNodeInfo(const Node * compile_node,const Graph & graph,const std::unordered_set<Node * > & loop_nodes,std::vector<ExecuteNodeInfo> * execute_node_info,TPUCompileMetadataProto * new_metadata)114 Status ExtractExecuteNodeInfo(const Node* compile_node, const Graph& graph,
115 const std::unordered_set<Node*>& loop_nodes, // NOLINT
116 std::vector<ExecuteNodeInfo>* execute_node_info,
117 TPUCompileMetadataProto* new_metadata) {
118 string metadata_string;
119 TF_RETURN_IF_ERROR(
120 GetNodeAttr(compile_node->attrs(), "metadata", &metadata_string));
121 new_metadata->ParsePartialFromString(metadata_string);
122 if (new_metadata->num_cores_per_replica() != 1) {
123 // We do not support model parallelism yet.
124 return OkStatus();
125 }
126
127 execute_node_info->clear();
128 for (Node* node : compile_node->out_nodes()) {
129 if (node->type_string() == "TPUExecute") {
130 execute_node_info->push_back({node});
131 }
132 }
133 if (execute_node_info->empty()) {
134 return OkStatus();
135 }
136 TF_RET_CHECK(execute_node_info->size() == new_metadata->num_replicas())
137 << "Number of replicas does not equal number of execute nodes: "
138 << new_metadata->num_replicas() << " vs " << execute_node_info->size();
139 DataTypeVector arg_types;
140 TF_RETURN_IF_ERROR(GetNodeAttr((*execute_node_info)[0].execute_node->attrs(),
141 "Targs", &arg_types));
142 for (int64_t i = 0; i < arg_types.size(); ++i) {
143 if (arg_types[i] != DT_RESOURCE) {
144 continue;
145 }
146 const auto sharding_config = new_metadata->args(i).enable_xla_sharding();
147 if (sharding_config != TPUCompileMetadataProto::Arg::TENTATIVE &&
148 sharding_config != TPUCompileMetadataProto::Arg::ALLOWED) {
149 continue;
150 }
151 std::vector<const Edge*> edges(execute_node_info->size());
152 bool is_supported = true;
153 std::unordered_map<Node*, absl::flat_hash_set<Node*>>
154 enter_to_execute_nodes;
155 for (int64_t j = 0; j < edges.size(); ++j) {
156 auto execute = (*execute_node_info)[j].execute_node;
157 TF_RETURN_IF_ERROR(execute->input_edge(i, &edges[j]));
158 TF_RET_CHECK(edges[j]->src()->output_type(edges[j]->src_output()) ==
159 arg_types[i])
160 << "Execute op has an unexpected input type.";
161 // Traverse backwards to find the Enter node from which the input is
162 // passed.
163 // This makes sure that we are checking the usages of all potential
164 // aliases of the input node as well.
165 TF_ASSIGN_OR_RETURN(auto enter_node, FindEnterNodeFromTPUExecuteNodeInput(
166 edges[j]->src()));
167 if (enter_node == nullptr) {
168 is_supported = false;
169 enter_to_execute_nodes.clear();
170 break;
171 }
172 enter_to_execute_nodes[enter_node].insert(edges[j]->dst());
173 }
174
175 for (const auto& it : enter_to_execute_nodes) {
176 // Size of execute nodes should be either 1 (per-replica variables) or
177 // num_replicas (distributed variables).
178 if ((it.second.size() != 1) &&
179 (it.second.size() != new_metadata->num_replicas())) {
180 is_supported = false;
181 break;
182 }
183 TF_ASSIGN_OR_RETURN(bool no_other_use,
184 ResourceOnlyUsedForTPUExecuteInLoop(
185 graph, loop_nodes, it.first, it.second));
186 if (!no_other_use) {
187 is_supported = false;
188 break;
189 }
190 }
191
192 // Add the variable input edges only when they are supported for all
193 // executes.
194 if (is_supported) {
195 for (int64_t j = 0; j < edges.size(); ++j) {
196 (*execute_node_info)[j].var_inputs.push_back(edges[j]);
197 }
198 new_metadata->mutable_args(i)->set_enable_xla_sharding(
199 TPUCompileMetadataProto::Arg::ALLOWED);
200 }
201 }
202
203 int64_t total = 0;
204 for (const auto& a : new_metadata->args()) {
205 if (a.enable_xla_sharding() == TPUCompileMetadataProto::Arg::ALLOWED) {
206 total++;
207 }
208 }
209 TF_RET_CHECK(total == (*execute_node_info)[0].var_inputs.size())
210 << " total " << total << " var_inputs "
211 << (*execute_node_info)[0].var_inputs.size();
212 if (total == 0) {
213 // We don't need to process anything if no input is added.
214 execute_node_info->clear();
215 }
216 return OkStatus();
217 }
218
IsTPUCompileOp(const Node & n)219 bool IsTPUCompileOp(const Node& n) { return n.type_string() == "TPUCompile"; }
220
FindTPUCompileNodes(const std::string * current_function_name,const AttrValueMap * current_function_attr,const std::unordered_map<string,WhileLoopFrame> & frames,std::vector<HostTrainingLoopInfo> * host_training_loops_info)221 void FindTPUCompileNodes(
222 const std::string* current_function_name,
223 const AttrValueMap* current_function_attr,
224 const std::unordered_map<string, WhileLoopFrame>& frames,
225 std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
226 // Adds frames with no children (i.e., the innermost frames) to a worklist.
227 std::deque<const WhileLoopFrame*> worklist;
228
229 for (auto& frame : frames) {
230 if (frame.second.num_children == 0) {
231 worklist.push_back(&frame.second);
232 }
233 }
234
235 // Check TPUCompile node from the innermost while loop to the outermost
236 // while loop.
237 while (!worklist.empty()) {
238 const WhileLoopFrame* frame = worklist.front();
239 worklist.pop_front();
240
241 for (const auto& n : frame->nodes) {
242 if (!IsTPUCompileOp(*n)) continue;
243
244 HostTrainingLoopInfo host_training_loop_info;
245 host_training_loop_info.compile_node_name = n->name();
246 host_training_loop_info.loop_cond_node_name = frame->loop_cond->name();
247 host_training_loop_info.while_loop_name = frame->name;
248
249 for (const auto arg : frame->args) {
250 LoopArgInfo arg_info;
251 arg_info.enter_node_name = arg.enter->name();
252 if (arg.exit) arg_info.exit_node_name = arg.exit->name();
253
254 host_training_loop_info.loop_arguments.push_back(std::move(arg_info));
255 }
256 host_training_loop_info.loop_nodes = frame->nodes;
257
258 if (current_function_name) {
259 host_training_loop_info.encapsulating_function_name =
260 *current_function_name;
261 }
262 if (current_function_attr) {
263 host_training_loop_info.encapsulating_function_attrs =
264 *current_function_attr;
265 }
266
267 host_training_loops_info->emplace_back(
268 std::move(host_training_loop_info));
269 }
270
271 // If the parent has no remaining children, add it to the worklist.
272 --frame->parent->num_children;
273 if (frame->parent->num_children == 0) {
274 worklist.push_back(frame->parent);
275 }
276 }
277 }
278
279 // From while loop cond node, finds all loop exit nodes by searching/traversing
280 // nodes in below pattern of nodes:
281 // LoopCond -----> Switch -----> Exit
FindLoopExitNodes(const Node & loop_cond)282 std::vector<Node*> FindLoopExitNodes(const Node& loop_cond) {
283 std::vector<Node*> loop_exit_nodes;
284 for (const auto e_cond : loop_cond.out_edges()) {
285 if (e_cond->IsControlEdge() || !e_cond->dst()->IsSwitch()) continue;
286 auto switch_node = e_cond->dst();
287
288 for (const auto e_switch : switch_node->out_edges()) {
289 if (e_switch->IsControlEdge() || !e_switch->dst()->IsExit()) continue;
290
291 loop_exit_nodes.push_back(e_switch->dst());
292 }
293 }
294 return loop_exit_nodes;
295 }
296
297 // Returns or creates a node in that is executed before each loop iteration
298 // in the while loop.
299 // TODO(mdan): Inject this node between the Enter and Merge nodes instead.
300 // See AddNoOpAfterLastIteration for an example.
GetOrCreateBeforeEachIterationNode(const Node & loop_cond_node,Graph * graph,Node ** node_out)301 Status GetOrCreateBeforeEachIterationNode(const Node& loop_cond_node,
302 Graph* graph, Node** node_out) {
303 Node* loop_switch_node = nullptr;
304 for (auto n : loop_cond_node.out_nodes()) {
305 if (n->IsSwitch()) {
306 loop_switch_node = n;
307 break;
308 }
309 }
310 TF_RET_CHECK(loop_switch_node != nullptr)
311 << "Unable to find any switch nodes.";
312
313 // If while loop switch node already has a outgoing data to true brach
314 // of the switch op, then reuse that node.
315 for (const auto out_edge : loop_switch_node->out_edges()) {
316 if (out_edge->src_output() == 1) {
317 *node_out = out_edge->dst();
318 return OkStatus();
319 }
320 }
321
322 // Create Identity node that represents execution at every loop iteration.
323 NodeDef at_loop_iteration_nodedef;
324 at_loop_iteration_nodedef.set_op("Identity");
325 DataType dtype;
326 TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
327
328 AddNodeAttr("T", dtype, &at_loop_iteration_nodedef);
329 at_loop_iteration_nodedef.set_name(graph->NewName(strings::StrCat(
330 "TPUVariableReshard/before_iteration", "/_", internal::GetNodeId())));
331
332 Status status;
333 TF_ASSIGN_OR_RETURN(Node * at_loop_iteration_node,
334 graph->AddNode(at_loop_iteration_nodedef));
335 TF_RETURN_IF_ERROR(status);
336
337 graph->AddEdge(loop_switch_node, 1, at_loop_iteration_node, 0);
338 *node_out = at_loop_iteration_node;
339 return OkStatus();
340 }
341
342 // Injects a NoOp node in that is executed after the very last iteration
343 // of the while loop but before the while loop exit node.
344 // This node is positioned between the False output of all Switch nodes (
345 // meaning, it executes after the loop ended all its iterations) and their
346 // corresponding Exit nodes (meaning, before the loop finally completed).
347 // See the white paper for details:
348 // http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
AddNoOpAfterLastIteration(const Node & loop_cond_node,Graph * graph,Node ** node_out)349 Status AddNoOpAfterLastIteration(const Node& loop_cond_node, Graph* graph,
350 Node** node_out) {
351 NodeDef after_last_iteration;
352 after_last_iteration.set_op("NoOp");
353
354 after_last_iteration.set_name(graph->NewName(strings::StrCat(
355 "TPUVariableReshard/after_last_iteration", "/_", internal::GetNodeId())));
356
357 Status status;
358 Node* after_last_iteration_node =
359 graph->AddNode(after_last_iteration, &status);
360 TF_RETURN_IF_ERROR(status);
361
362 for (auto switch_node : loop_cond_node.out_nodes()) {
363 if (!switch_node->IsSwitch()) {
364 continue;
365 }
366
367 NodeDef switch_exit;
368 switch_exit.set_op("Identity");
369
370 DataType dtype;
371 TF_RETURN_IF_ERROR(GetNodeAttr(switch_node->def(), "T", &dtype));
372 AddNodeAttr("T", dtype, &switch_exit);
373 auto name = strings::StrCat("TPUVariableReshard/switch_exit/", "/_",
374 internal::GetNodeId());
375 switch_exit.set_name(graph->NewName(name));
376 // Introducing identity nodes risks a device copy, which isn't guaranteed
377 // to be available for all types. Hence the colocation constraint.
378 AddNodeAttr(kColocationAttrName,
379 std::vector<string>{
380 absl::StrCat(kColocationGroupPrefix, switch_node->name())},
381 &switch_exit);
382
383 TF_ASSIGN_OR_RETURN(Node * after_switch_node, graph->AddNode(switch_exit));
384
385 graph->AddEdge(switch_node, 0, after_switch_node, 0);
386 graph->AddControlEdge(after_switch_node, after_last_iteration_node);
387
388 for (const auto out_node : switch_node->out_nodes()) {
389 if (out_node->IsExit()) {
390 graph->AddControlEdge(after_last_iteration_node, out_node);
391 }
392 }
393 }
394
395 *node_out = after_last_iteration_node;
396 return OkStatus();
397 }
398
399 } // namespace
400
DetectHostTrainingLoop(const std::string * current_function_name,const AttrValueMap * current_function_attr,const FunctionLibraryDefinition * library,Graph * graph,FunctionLibraryRuntime * flr,std::vector<HostTrainingLoopInfo> * host_training_loops_info)401 Status DetectHostTrainingLoop(
402 const std::string* current_function_name,
403 const AttrValueMap* current_function_attr,
404 const FunctionLibraryDefinition* library, Graph* graph,
405 FunctionLibraryRuntime* flr,
406 std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
407 std::vector<AssociatedFunctionInfo> associated_function_list;
408 for (const auto* n : graph->nodes()) {
409 const auto associated_functions = GetAssociatedFunctions(*n, library);
410 if (associated_functions.empty()) continue;
411
412 associated_function_list.insert(associated_function_list.end(),
413 associated_functions.begin(),
414 associated_functions.end());
415 }
416
417 Status ret_status = OkStatus();
418 for (const auto& function : associated_function_list) {
419 if (function.type() != AssociatedFunctionInfo::kFunctionAttr) continue;
420
421 // Convert the function to Graph.
422 FunctionLibraryRuntime::Handle handle;
423 TF_RETURN_IF_ERROR(flr->Instantiate(function.func_name(),
424 AttrSlice(&function.attrs()), &handle));
425 auto cleanup_handle = gtl::MakeCleanup([&]() {
426 auto s = flr->ReleaseHandle(handle);
427 if (!s.ok()) {
428 ret_status.Update(s);
429 }
430 });
431 const FunctionBody* body = flr->GetFunctionBody(handle);
432 Graph* function_graph = body->graph;
433 TF_RETURN_IF_ERROR(DetectHostTrainingLoop(
434 &function.func_name(), &function.attrs(), library, function_graph, flr,
435 host_training_loops_info));
436 }
437
438 // BuildControlFlowInfo() requires that the graph's source node is connected
439 // to all source nodes in the graph. Many graphs violate this invariant.
440 // As so, add edges to source/sink nodes so that this invariant is kept.
441 FixupSourceAndSinkEdges(graph);
442 std::vector<ControlFlowInfo> cf_info;
443 TF_RETURN_IF_ERROR(
444 BuildControlFlowInfo(graph, &cf_info, /*unreachable_nodes=*/nullptr));
445
446 std::unordered_map<string, WhileLoopFrame> frames;
447 TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames));
448 FindTPUCompileNodes(current_function_name, current_function_attr, frames,
449 host_training_loops_info);
450 return ret_status;
451 }
452
AddReshardOp(Graph * graph,const HostTrainingLoopInfo & host_loop_info)453 Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info) {
454 const auto& compile_node_name = host_loop_info.compile_node_name;
455 const auto node_name_map = graph->BuildNodeNameIndex();
456 const auto node_it = node_name_map.find(compile_node_name);
457 TF_RET_CHECK(node_it != node_name_map.end())
458 << "Unable to find compile node : " << compile_node_name;
459
460 const auto compile_node = node_it->second;
461 std::vector<ExecuteNodeInfo> execute_nodes_info;
462
463 Status status;
464 TPUCompileMetadataProto metadata;
465 status =
466 ExtractExecuteNodeInfo(compile_node, *graph, host_loop_info.loop_nodes,
467 &execute_nodes_info, &metadata);
468 if (!status.ok()) {
469 LOG(ERROR) << "Encountered error when trying to extract execute nodes, "
470 "skipping host loop optimization. Status: "
471 << status.ToString();
472 return OkStatus();
473 }
474
475 if (execute_nodes_info.empty()) {
476 return OkStatus();
477 }
478
479 // Update the TPUCompileMetadata such that sharding config of the
480 // sharded resource variable inputs is set to ALLOWED instead of
481 // TENTATIVE.
482 string new_metadata_string;
483 metadata.SerializeToString(&new_metadata_string);
484 compile_node->ClearAttr("metadata");
485 compile_node->AddAttr("metadata", new_metadata_string);
486
487 // Unsharding of the model weight variables must happen only at the very
488 // last loop iteration. As so, add while loop condition predicate as an
489 // input to the sharding switch node. If loop condition is true, we do not
490 // unshard.
491 const auto& cond_node_name = host_loop_info.loop_cond_node_name;
492 auto loop_cond_node_it = node_name_map.find(cond_node_name);
493 TF_RET_CHECK(loop_cond_node_it != node_name_map.end())
494 << "Cannot find loop condition node : " << cond_node_name;
495 auto* loop_condition_node = loop_cond_node_it->second;
496
497 // In order to make sure that shard/unshard operations are invoked
498 // at the start of every loop body and at the end of last iteration
499 // of the loop, respectively, create a pair of guiding nodes, which
500 // guaranteed to execute before each iteration and respectively after
501 // all iterations.
502
503 Node* after_last_iteration_node;
504 TF_RETURN_IF_ERROR(AddNoOpAfterLastIteration(*loop_condition_node, graph,
505 &after_last_iteration_node));
506
507 Node* before_loop_iteration_node;
508 TF_RETURN_IF_ERROR(GetOrCreateBeforeEachIterationNode(
509 *loop_condition_node, graph, &before_loop_iteration_node));
510
511 // Create const op that represents default sharding value
512 // (i.e. no-op sharding).
513 NodeDef default_sharding;
514 default_sharding.set_op("Const");
515 default_sharding.set_name(graph->NewName(strings::StrCat(
516 "TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
517 AddNodeAttr("dtype", DT_STRING, &default_sharding);
518
519 Tensor t(DT_STRING, {3});
520 t.vec<tstring>()(0) = kDefaultShardingValue;
521 t.vec<tstring>()(1) = kDefaultShardingValue;
522 t.vec<tstring>()(2) = kDefaultShardingValue;
523 t.AsProtoTensorContent(
524 (*default_sharding.mutable_attr())["value"].mutable_tensor());
525
526 TF_ASSIGN_OR_RETURN(Node * default_sharding_node,
527 graph->AddNode(default_sharding));
528 TF_RETURN_IF_ERROR(status);
529 // Add control edge between loop condition to make sure that
530 // default_sharding_node node is inside the while loop frame.
531 graph->AddControlEdge(loop_condition_node, default_sharding_node);
532
533 // Build a no-op node used to add control edges after unshard nodes.
534 NodeDef after_unshard;
535 after_unshard.set_op("NoOp");
536 after_unshard.set_name(graph->NewName(strings::StrCat(
537 "TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
538 TF_ASSIGN_OR_RETURN(auto after_unshard_node, graph->AddNode(after_unshard));
539
540 for (auto info : execute_nodes_info) {
541 auto execute_node = info.execute_node;
542 // Create Reshard op that optionally shards model weight variables
543 // prior to program execution.
544 NodeDef reshard_node_def;
545 reshard_node_def.set_name(graph->NewName(strings::StrCat(
546 "TPUVariableReshard/reshard", "/_", internal::GetNodeId())));
547 reshard_node_def.set_op("TPUReshardVariables");
548 AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
549 &reshard_node_def);
550 TF_ASSIGN_OR_RETURN(Node * reshard_op_node,
551 graph->AddNode(reshard_node_def));
552
553 reshard_op_node->set_assigned_device_name(
554 execute_node->assigned_device_name());
555
556 // Reshard op must execute at every loop iteration prior to
557 // TPUExecute node.
558 graph->AddControlEdge(before_loop_iteration_node, reshard_op_node);
559 graph->AddControlEdge(reshard_op_node, execute_node);
560
561 for (int i = 0; i < info.var_inputs.size(); ++i) {
562 const auto variable_edge = info.var_inputs[i];
563 graph->AddEdge(variable_edge->src(), variable_edge->src_output(),
564 reshard_op_node, i);
565 }
566
567 const int new_key_input = info.var_inputs.size();
568 // Add program input edge from the compiler(i.e. compilation key).
569 const auto compilation_key_edge =
570 FindEdgeConnecting(compile_node, execute_node);
571 graph->AddEdge(compile_node, compilation_key_edge->src_output(),
572 reshard_op_node, new_key_input);
573
574 // Create VarHandleOp to store sharding state. Sharding state holds string
575 // compilation key that identifies whether the graph is re-compiled and the
576 // variables need to be sharded again.
577 NodeDef var_handle_def;
578 var_handle_def.set_op("VarHandleOp");
579 var_handle_def.set_name(graph->NewName(strings::StrCat(
580 "TPUVariableReshard/reshard_state", "/_", internal::GetNodeId())));
581 AddNodeAttr("dtype", DT_STRING, &var_handle_def);
582 AddNodeAttr("shape", TensorShape({}), &var_handle_def);
583 TF_ASSIGN_OR_RETURN(Node * var_handle_node, graph->AddNode(var_handle_def));
584
585 // Add control edge between `var_handle_def` node and while loop
586 // loop condition so that `var_handle_def` is inside the same while loop
587 // frame.
588 // TODO(hongjunchoi): Consider adding control edge from another node--such
589 // as input control node.
590 graph->AddControlEdge(loop_condition_node, var_handle_node);
591
592 // Connect data edge between var handle op and reshard op.
593 const int format_state_input = new_key_input + 1;
594 graph->AddEdge(var_handle_node, 0, reshard_op_node, format_state_input);
595
596 // Create Reshard op that represents unsharding after TPUExecute.
597 NodeDef unshard_node_def;
598 unshard_node_def.set_name(graph->NewName(strings::StrCat(
599 "TPUVariableReshard/unshard", "/_", internal::GetNodeId())));
600 unshard_node_def.set_op("TPUReshardVariables");
601 AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
602 &unshard_node_def);
603 TF_ASSIGN_OR_RETURN(Node * unshard_op_node,
604 graph->AddNode(unshard_node_def));
605
606 unshard_op_node->set_assigned_device_name(
607 execute_node->assigned_device_name());
608
609 for (int i = 0; i < info.var_inputs.size(); ++i) {
610 const auto variable_edge = info.var_inputs[i];
611 // Connect model weight resource variables to unshard op. Since unshard op
612 // must be only invoked after the very last loop iteration, for each while
613 // loop inputs, we traverse backwards to find the switch node of the host
614 // training loop and connect `output_false` field of the switch node with
615 // unshard op.
616 TF_ASSIGN_OR_RETURN(
617 Node * enter_node,
618 FindEnterNodeFromTPUExecuteNodeInput(variable_edge->src()));
619 graph->AddEdge(enter_node, 0, unshard_op_node, i);
620 }
621
622 // Add control dependency before/after unshard node and the control nodes.
623 graph->AddControlEdge(after_last_iteration_node, unshard_op_node);
624 graph->AddControlEdge(unshard_op_node, after_unshard_node);
625
626 graph->AddEdge(default_sharding_node, 0, unshard_op_node, new_key_input);
627
628 // Add data edge from sharding state var handle op to unshard op.
629 graph->AddEdge(var_handle_node, 0, unshard_op_node, format_state_input);
630 }
631 // Add control dependency from after_unshard_node to all exits nodes. This is
632 // to make sure that the unshard ops will be executed as long as any of the
633 // exits are used.
634 for (auto exit : FindLoopExitNodes(*loop_condition_node)) {
635 graph->AddControlEdge(after_unshard_node, exit);
636 }
637 return OkStatus();
638 }
639
640 } // namespace tpu
641 } // namespace tensorflow
642