1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
17 
18 #include <deque>
19 #include <map>
20 #include <unordered_map>
21 #include <unordered_set>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/container/node_hash_set.h"
25 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/gtl/cleanup.h"
31 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
32 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
33 
34 namespace tensorflow {
35 namespace tpu {
36 
37 namespace {
38 
39 constexpr char kDefaultShardingValue[] = "";
40 
FindEdgeConnecting(const Node * src,const Node * dst)41 const Edge* FindEdgeConnecting(const Node* src, const Node* dst) {
42   for (const auto e : src->out_edges()) {
43     if (e->dst()->name() == dst->name()) return &(*e);
44   }
45   return nullptr;
46 }
47 
48 // Contains TPUExecute node and its DT_RESOURCE input nodes that
49 // correspond to model weights.
50 struct ExecuteNodeInfo {
51   Node* execute_node;
52   std::vector<const Edge*> var_inputs;
53 };
54 
55 // Returns whether `node` is in `execute_nodes` or `(identity) -> execute`.
IsExecuteNodeOrIdentityToExecuteNode(const Graph & graph,const std::unordered_set<Node * > & loop_nodes,const absl::flat_hash_set<Node * > & execute_nodes,Node * node)56 bool IsExecuteNodeOrIdentityToExecuteNode(
57     const Graph& graph, const std::unordered_set<Node*>& loop_nodes,  // NOLINT
58     const absl::flat_hash_set<Node*>& execute_nodes, Node* node) {
59   if (execute_nodes.find(node) != execute_nodes.end()) return true;
60   if (loop_nodes.find(node) == loop_nodes.end()) return false;
61   if (node->IsNextIteration()) return true;
62   if (!node->IsIdentity()) return false;
63 
64   for (const Edge* e : node->out_edges()) {
65     if (e->IsControlEdge()) continue;
66 
67     Node* node = e->dst();
68     if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
69                                               node)) {
70       return false;
71     }
72   }
73 
74   return true;
75 }
76 
77 // From input node to the TPUExecute op, finds the corresponding Enter node
78 // by searching/traversing nodes in below pattern of nodes:
79 // Enter ----> (identity) --->  While body input
80 // Returns nullptr if the Enter node is not found.
FindEnterNodeFromTPUExecuteNodeInput(Node * input_node)81 xla::StatusOr<Node*> FindEnterNodeFromTPUExecuteNodeInput(Node* input_node) {
82   Node* node = input_node;
83   while (node->IsIdentity()) {
84     TF_RETURN_IF_ERROR(node->input_node(0, &node));
85   }
86 
87   if (node->IsEnter()) {
88     return node;
89   }
90   return nullptr;
91 }
92 
ResourceOnlyUsedForTPUExecuteInLoop(const Graph & graph,const std::unordered_set<Node * > & loop_nodes,const Node * enter_node,const absl::flat_hash_set<Node * > execute_nodes)93 xla::StatusOr<bool> ResourceOnlyUsedForTPUExecuteInLoop(
94     const Graph& graph, const std::unordered_set<Node*>& loop_nodes,  // NOLINT
95     const Node* enter_node, const absl::flat_hash_set<Node*> execute_nodes) {
96   for (const Edge* output_edge : enter_node->out_edges()) {
97     Node* output_node = output_edge->dst();
98     if (output_edge->IsControlEdge() || output_node->IsExit()) continue;
99 
100     // If output node is not execute node, it must be output node
101     // to the while loop body.
102     if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
103                                               output_node)) {
104       return false;
105     }
106   }
107   return true;
108 }
109 
110 // Given a TPUCompile node, find all TPUExecute nodes that executes the compiled
111 // program and its model weight variable inputs as well.
112 // TPUCompileMetadataProto of TPUCompile node must be reset to `new_metadata`
113 // if new reshard ops are added.
ExtractExecuteNodeInfo(const Node * compile_node,const Graph & graph,const std::unordered_set<Node * > & loop_nodes,std::vector<ExecuteNodeInfo> * execute_node_info,TPUCompileMetadataProto * new_metadata)114 Status ExtractExecuteNodeInfo(const Node* compile_node, const Graph& graph,
115                               const std::unordered_set<Node*>& loop_nodes,  // NOLINT
116                               std::vector<ExecuteNodeInfo>* execute_node_info,
117                               TPUCompileMetadataProto* new_metadata) {
118   string metadata_string;
119   TF_RETURN_IF_ERROR(
120       GetNodeAttr(compile_node->attrs(), "metadata", &metadata_string));
121   new_metadata->ParsePartialFromString(metadata_string);
122   if (new_metadata->num_cores_per_replica() != 1) {
123     // We do not support model parallelism yet.
124     return OkStatus();
125   }
126 
127   execute_node_info->clear();
128   for (Node* node : compile_node->out_nodes()) {
129     if (node->type_string() == "TPUExecute") {
130       execute_node_info->push_back({node});
131     }
132   }
133   if (execute_node_info->empty()) {
134     return OkStatus();
135   }
136   TF_RET_CHECK(execute_node_info->size() == new_metadata->num_replicas())
137       << "Number of replicas does not equal number of execute nodes: "
138       << new_metadata->num_replicas() << " vs " << execute_node_info->size();
139   DataTypeVector arg_types;
140   TF_RETURN_IF_ERROR(GetNodeAttr((*execute_node_info)[0].execute_node->attrs(),
141                                  "Targs", &arg_types));
142   for (int64_t i = 0; i < arg_types.size(); ++i) {
143     if (arg_types[i] != DT_RESOURCE) {
144       continue;
145     }
146     const auto sharding_config = new_metadata->args(i).enable_xla_sharding();
147     if (sharding_config != TPUCompileMetadataProto::Arg::TENTATIVE &&
148         sharding_config != TPUCompileMetadataProto::Arg::ALLOWED) {
149       continue;
150     }
151     std::vector<const Edge*> edges(execute_node_info->size());
152     bool is_supported = true;
153     std::unordered_map<Node*, absl::flat_hash_set<Node*>>
154         enter_to_execute_nodes;
155     for (int64_t j = 0; j < edges.size(); ++j) {
156       auto execute = (*execute_node_info)[j].execute_node;
157       TF_RETURN_IF_ERROR(execute->input_edge(i, &edges[j]));
158       TF_RET_CHECK(edges[j]->src()->output_type(edges[j]->src_output()) ==
159                    arg_types[i])
160           << "Execute op has an unexpected input type.";
161       // Traverse backwards to find the Enter node from which the input is
162       // passed.
163       // This makes sure that we are checking the usages of all potential
164       // aliases of the input node as well.
165       TF_ASSIGN_OR_RETURN(auto enter_node, FindEnterNodeFromTPUExecuteNodeInput(
166                                                edges[j]->src()));
167       if (enter_node == nullptr) {
168         is_supported = false;
169         enter_to_execute_nodes.clear();
170         break;
171       }
172       enter_to_execute_nodes[enter_node].insert(edges[j]->dst());
173     }
174 
175     for (const auto& it : enter_to_execute_nodes) {
176       // Size of execute nodes should be either 1 (per-replica variables) or
177       // num_replicas (distributed variables).
178       if ((it.second.size() != 1) &&
179           (it.second.size() != new_metadata->num_replicas())) {
180         is_supported = false;
181         break;
182       }
183       TF_ASSIGN_OR_RETURN(bool no_other_use,
184                           ResourceOnlyUsedForTPUExecuteInLoop(
185                               graph, loop_nodes, it.first, it.second));
186       if (!no_other_use) {
187         is_supported = false;
188         break;
189       }
190     }
191 
192     // Add the variable input edges only when they are supported for all
193     // executes.
194     if (is_supported) {
195       for (int64_t j = 0; j < edges.size(); ++j) {
196         (*execute_node_info)[j].var_inputs.push_back(edges[j]);
197       }
198       new_metadata->mutable_args(i)->set_enable_xla_sharding(
199           TPUCompileMetadataProto::Arg::ALLOWED);
200     }
201   }
202 
203   int64_t total = 0;
204   for (const auto& a : new_metadata->args()) {
205     if (a.enable_xla_sharding() == TPUCompileMetadataProto::Arg::ALLOWED) {
206       total++;
207     }
208   }
209   TF_RET_CHECK(total == (*execute_node_info)[0].var_inputs.size())
210       << " total " << total << " var_inputs "
211       << (*execute_node_info)[0].var_inputs.size();
212   if (total == 0) {
213     // We don't need to process anything if no input is added.
214     execute_node_info->clear();
215   }
216   return OkStatus();
217 }
218 
IsTPUCompileOp(const Node & n)219 bool IsTPUCompileOp(const Node& n) { return n.type_string() == "TPUCompile"; }
220 
FindTPUCompileNodes(const std::string * current_function_name,const AttrValueMap * current_function_attr,const std::unordered_map<string,WhileLoopFrame> & frames,std::vector<HostTrainingLoopInfo> * host_training_loops_info)221 void FindTPUCompileNodes(
222     const std::string* current_function_name,
223     const AttrValueMap* current_function_attr,
224     const std::unordered_map<string, WhileLoopFrame>& frames,
225     std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
226   // Adds frames with no children (i.e., the innermost frames) to a worklist.
227   std::deque<const WhileLoopFrame*> worklist;
228 
229   for (auto& frame : frames) {
230     if (frame.second.num_children == 0) {
231       worklist.push_back(&frame.second);
232     }
233   }
234 
235   // Check TPUCompile node from the innermost while loop to the outermost
236   // while loop.
237   while (!worklist.empty()) {
238     const WhileLoopFrame* frame = worklist.front();
239     worklist.pop_front();
240 
241     for (const auto& n : frame->nodes) {
242       if (!IsTPUCompileOp(*n)) continue;
243 
244       HostTrainingLoopInfo host_training_loop_info;
245       host_training_loop_info.compile_node_name = n->name();
246       host_training_loop_info.loop_cond_node_name = frame->loop_cond->name();
247       host_training_loop_info.while_loop_name = frame->name;
248 
249       for (const auto arg : frame->args) {
250         LoopArgInfo arg_info;
251         arg_info.enter_node_name = arg.enter->name();
252         if (arg.exit) arg_info.exit_node_name = arg.exit->name();
253 
254         host_training_loop_info.loop_arguments.push_back(std::move(arg_info));
255       }
256       host_training_loop_info.loop_nodes = frame->nodes;
257 
258       if (current_function_name) {
259         host_training_loop_info.encapsulating_function_name =
260             *current_function_name;
261       }
262       if (current_function_attr) {
263         host_training_loop_info.encapsulating_function_attrs =
264             *current_function_attr;
265       }
266 
267       host_training_loops_info->emplace_back(
268           std::move(host_training_loop_info));
269     }
270 
271     // If the parent has no remaining children, add it to the worklist.
272     --frame->parent->num_children;
273     if (frame->parent->num_children == 0) {
274       worklist.push_back(frame->parent);
275     }
276   }
277 }
278 
279 // From while loop cond node, finds all loop exit nodes by searching/traversing
280 // nodes in below pattern of nodes:
281 // LoopCond -----> Switch -----> Exit
FindLoopExitNodes(const Node & loop_cond)282 std::vector<Node*> FindLoopExitNodes(const Node& loop_cond) {
283   std::vector<Node*> loop_exit_nodes;
284   for (const auto e_cond : loop_cond.out_edges()) {
285     if (e_cond->IsControlEdge() || !e_cond->dst()->IsSwitch()) continue;
286     auto switch_node = e_cond->dst();
287 
288     for (const auto e_switch : switch_node->out_edges()) {
289       if (e_switch->IsControlEdge() || !e_switch->dst()->IsExit()) continue;
290 
291       loop_exit_nodes.push_back(e_switch->dst());
292     }
293   }
294   return loop_exit_nodes;
295 }
296 
297 // Returns or creates a node in that is executed before each loop iteration
298 // in the while loop.
299 // TODO(mdan): Inject this node between the Enter and Merge nodes instead.
300 // See AddNoOpAfterLastIteration for an example.
GetOrCreateBeforeEachIterationNode(const Node & loop_cond_node,Graph * graph,Node ** node_out)301 Status GetOrCreateBeforeEachIterationNode(const Node& loop_cond_node,
302                                           Graph* graph, Node** node_out) {
303   Node* loop_switch_node = nullptr;
304   for (auto n : loop_cond_node.out_nodes()) {
305     if (n->IsSwitch()) {
306       loop_switch_node = n;
307       break;
308     }
309   }
310   TF_RET_CHECK(loop_switch_node != nullptr)
311       << "Unable to find any switch nodes.";
312 
313   // If while loop switch node already has a outgoing data to true brach
314   // of the switch op, then reuse that node.
315   for (const auto out_edge : loop_switch_node->out_edges()) {
316     if (out_edge->src_output() == 1) {
317       *node_out = out_edge->dst();
318       return OkStatus();
319     }
320   }
321 
322   // Create Identity node that represents execution at every loop iteration.
323   NodeDef at_loop_iteration_nodedef;
324   at_loop_iteration_nodedef.set_op("Identity");
325   DataType dtype;
326   TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
327 
328   AddNodeAttr("T", dtype, &at_loop_iteration_nodedef);
329   at_loop_iteration_nodedef.set_name(graph->NewName(strings::StrCat(
330       "TPUVariableReshard/before_iteration", "/_", internal::GetNodeId())));
331 
332   Status status;
333   TF_ASSIGN_OR_RETURN(Node * at_loop_iteration_node,
334                       graph->AddNode(at_loop_iteration_nodedef));
335   TF_RETURN_IF_ERROR(status);
336 
337   graph->AddEdge(loop_switch_node, 1, at_loop_iteration_node, 0);
338   *node_out = at_loop_iteration_node;
339   return OkStatus();
340 }
341 
342 // Injects a NoOp node in that is executed after the very last iteration
343 // of the while loop but before the while loop exit node.
344 // This node is positioned between the False output of all Switch nodes (
345 // meaning, it executes after the loop ended all its iterations) and their
346 // corresponding Exit nodes (meaning, before the loop finally completed).
347 // See the white paper for details:
348 // http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
AddNoOpAfterLastIteration(const Node & loop_cond_node,Graph * graph,Node ** node_out)349 Status AddNoOpAfterLastIteration(const Node& loop_cond_node, Graph* graph,
350                                  Node** node_out) {
351   NodeDef after_last_iteration;
352   after_last_iteration.set_op("NoOp");
353 
354   after_last_iteration.set_name(graph->NewName(strings::StrCat(
355       "TPUVariableReshard/after_last_iteration", "/_", internal::GetNodeId())));
356 
357   Status status;
358   Node* after_last_iteration_node =
359       graph->AddNode(after_last_iteration, &status);
360   TF_RETURN_IF_ERROR(status);
361 
362   for (auto switch_node : loop_cond_node.out_nodes()) {
363     if (!switch_node->IsSwitch()) {
364       continue;
365     }
366 
367     NodeDef switch_exit;
368     switch_exit.set_op("Identity");
369 
370     DataType dtype;
371     TF_RETURN_IF_ERROR(GetNodeAttr(switch_node->def(), "T", &dtype));
372     AddNodeAttr("T", dtype, &switch_exit);
373     auto name = strings::StrCat("TPUVariableReshard/switch_exit/", "/_",
374                                 internal::GetNodeId());
375     switch_exit.set_name(graph->NewName(name));
376     // Introducing identity nodes risks a device copy, which isn't guaranteed
377     // to be available for all types. Hence the colocation constraint.
378     AddNodeAttr(kColocationAttrName,
379                 std::vector<string>{
380                     absl::StrCat(kColocationGroupPrefix, switch_node->name())},
381                 &switch_exit);
382 
383     TF_ASSIGN_OR_RETURN(Node * after_switch_node, graph->AddNode(switch_exit));
384 
385     graph->AddEdge(switch_node, 0, after_switch_node, 0);
386     graph->AddControlEdge(after_switch_node, after_last_iteration_node);
387 
388     for (const auto out_node : switch_node->out_nodes()) {
389       if (out_node->IsExit()) {
390         graph->AddControlEdge(after_last_iteration_node, out_node);
391       }
392     }
393   }
394 
395   *node_out = after_last_iteration_node;
396   return OkStatus();
397 }
398 
399 }  // namespace
400 
DetectHostTrainingLoop(const std::string * current_function_name,const AttrValueMap * current_function_attr,const FunctionLibraryDefinition * library,Graph * graph,FunctionLibraryRuntime * flr,std::vector<HostTrainingLoopInfo> * host_training_loops_info)401 Status DetectHostTrainingLoop(
402     const std::string* current_function_name,
403     const AttrValueMap* current_function_attr,
404     const FunctionLibraryDefinition* library, Graph* graph,
405     FunctionLibraryRuntime* flr,
406     std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
407   std::vector<AssociatedFunctionInfo> associated_function_list;
408   for (const auto* n : graph->nodes()) {
409     const auto associated_functions = GetAssociatedFunctions(*n, library);
410     if (associated_functions.empty()) continue;
411 
412     associated_function_list.insert(associated_function_list.end(),
413                                     associated_functions.begin(),
414                                     associated_functions.end());
415   }
416 
417   Status ret_status = OkStatus();
418   for (const auto& function : associated_function_list) {
419     if (function.type() != AssociatedFunctionInfo::kFunctionAttr) continue;
420 
421     // Convert the function to Graph.
422     FunctionLibraryRuntime::Handle handle;
423     TF_RETURN_IF_ERROR(flr->Instantiate(function.func_name(),
424                                         AttrSlice(&function.attrs()), &handle));
425     auto cleanup_handle = gtl::MakeCleanup([&]() {
426       auto s = flr->ReleaseHandle(handle);
427       if (!s.ok()) {
428         ret_status.Update(s);
429       }
430     });
431     const FunctionBody* body = flr->GetFunctionBody(handle);
432     Graph* function_graph = body->graph;
433     TF_RETURN_IF_ERROR(DetectHostTrainingLoop(
434         &function.func_name(), &function.attrs(), library, function_graph, flr,
435         host_training_loops_info));
436   }
437 
438   // BuildControlFlowInfo() requires that the graph's source node is connected
439   // to all source nodes in the graph. Many graphs violate this invariant.
440   // As so, add edges to source/sink nodes so that this invariant is kept.
441   FixupSourceAndSinkEdges(graph);
442   std::vector<ControlFlowInfo> cf_info;
443   TF_RETURN_IF_ERROR(
444       BuildControlFlowInfo(graph, &cf_info, /*unreachable_nodes=*/nullptr));
445 
446   std::unordered_map<string, WhileLoopFrame> frames;
447   TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames));
448   FindTPUCompileNodes(current_function_name, current_function_attr, frames,
449                       host_training_loops_info);
450   return ret_status;
451 }
452 
AddReshardOp(Graph * graph,const HostTrainingLoopInfo & host_loop_info)453 Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info) {
454   const auto& compile_node_name = host_loop_info.compile_node_name;
455   const auto node_name_map = graph->BuildNodeNameIndex();
456   const auto node_it = node_name_map.find(compile_node_name);
457   TF_RET_CHECK(node_it != node_name_map.end())
458       << "Unable to find compile node : " << compile_node_name;
459 
460   const auto compile_node = node_it->second;
461   std::vector<ExecuteNodeInfo> execute_nodes_info;
462 
463   Status status;
464   TPUCompileMetadataProto metadata;
465   status =
466       ExtractExecuteNodeInfo(compile_node, *graph, host_loop_info.loop_nodes,
467                              &execute_nodes_info, &metadata);
468   if (!status.ok()) {
469     LOG(ERROR) << "Encountered error when trying to extract execute nodes, "
470                   "skipping host loop optimization. Status: "
471                << status.ToString();
472     return OkStatus();
473   }
474 
475   if (execute_nodes_info.empty()) {
476     return OkStatus();
477   }
478 
479   // Update the TPUCompileMetadata such that sharding config of the
480   // sharded resource variable inputs is set to ALLOWED instead of
481   // TENTATIVE.
482   string new_metadata_string;
483   metadata.SerializeToString(&new_metadata_string);
484   compile_node->ClearAttr("metadata");
485   compile_node->AddAttr("metadata", new_metadata_string);
486 
487   // Unsharding of the model weight variables must happen only at the very
488   // last loop iteration. As so, add while loop condition predicate as an
489   // input to the sharding switch node. If loop condition is true, we do not
490   // unshard.
491   const auto& cond_node_name = host_loop_info.loop_cond_node_name;
492   auto loop_cond_node_it = node_name_map.find(cond_node_name);
493   TF_RET_CHECK(loop_cond_node_it != node_name_map.end())
494       << "Cannot find loop condition node : " << cond_node_name;
495   auto* loop_condition_node = loop_cond_node_it->second;
496 
497   // In order to make sure that shard/unshard operations are invoked
498   // at the start of every loop body and at the end of last iteration
499   // of the loop, respectively, create a pair of guiding nodes, which
500   // guaranteed to execute before each iteration and respectively after
501   // all iterations.
502 
503   Node* after_last_iteration_node;
504   TF_RETURN_IF_ERROR(AddNoOpAfterLastIteration(*loop_condition_node, graph,
505                                                &after_last_iteration_node));
506 
507   Node* before_loop_iteration_node;
508   TF_RETURN_IF_ERROR(GetOrCreateBeforeEachIterationNode(
509       *loop_condition_node, graph, &before_loop_iteration_node));
510 
511   // Create const op that represents default sharding value
512   // (i.e. no-op sharding).
513   NodeDef default_sharding;
514   default_sharding.set_op("Const");
515   default_sharding.set_name(graph->NewName(strings::StrCat(
516       "TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
517   AddNodeAttr("dtype", DT_STRING, &default_sharding);
518 
519   Tensor t(DT_STRING, {3});
520   t.vec<tstring>()(0) = kDefaultShardingValue;
521   t.vec<tstring>()(1) = kDefaultShardingValue;
522   t.vec<tstring>()(2) = kDefaultShardingValue;
523   t.AsProtoTensorContent(
524       (*default_sharding.mutable_attr())["value"].mutable_tensor());
525 
526   TF_ASSIGN_OR_RETURN(Node * default_sharding_node,
527                       graph->AddNode(default_sharding));
528   TF_RETURN_IF_ERROR(status);
529   // Add control edge between loop condition to make sure that
530   // default_sharding_node node is inside the while loop frame.
531   graph->AddControlEdge(loop_condition_node, default_sharding_node);
532 
533   // Build a no-op node used to add control edges after unshard nodes.
534   NodeDef after_unshard;
535   after_unshard.set_op("NoOp");
536   after_unshard.set_name(graph->NewName(strings::StrCat(
537       "TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
538   TF_ASSIGN_OR_RETURN(auto after_unshard_node, graph->AddNode(after_unshard));
539 
540   for (auto info : execute_nodes_info) {
541     auto execute_node = info.execute_node;
542     // Create Reshard op that optionally shards model weight variables
543     // prior to program execution.
544     NodeDef reshard_node_def;
545     reshard_node_def.set_name(graph->NewName(strings::StrCat(
546         "TPUVariableReshard/reshard", "/_", internal::GetNodeId())));
547     reshard_node_def.set_op("TPUReshardVariables");
548     AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
549                 &reshard_node_def);
550     TF_ASSIGN_OR_RETURN(Node * reshard_op_node,
551                         graph->AddNode(reshard_node_def));
552 
553     reshard_op_node->set_assigned_device_name(
554         execute_node->assigned_device_name());
555 
556     // Reshard op must execute at every loop iteration prior to
557     // TPUExecute node.
558     graph->AddControlEdge(before_loop_iteration_node, reshard_op_node);
559     graph->AddControlEdge(reshard_op_node, execute_node);
560 
561     for (int i = 0; i < info.var_inputs.size(); ++i) {
562       const auto variable_edge = info.var_inputs[i];
563       graph->AddEdge(variable_edge->src(), variable_edge->src_output(),
564                      reshard_op_node, i);
565     }
566 
567     const int new_key_input = info.var_inputs.size();
568     // Add program input edge from the compiler(i.e. compilation key).
569     const auto compilation_key_edge =
570         FindEdgeConnecting(compile_node, execute_node);
571     graph->AddEdge(compile_node, compilation_key_edge->src_output(),
572                    reshard_op_node, new_key_input);
573 
574     // Create VarHandleOp to store sharding state. Sharding state holds string
575     // compilation key that identifies whether the graph is re-compiled and the
576     // variables need to be sharded again.
577     NodeDef var_handle_def;
578     var_handle_def.set_op("VarHandleOp");
579     var_handle_def.set_name(graph->NewName(strings::StrCat(
580         "TPUVariableReshard/reshard_state", "/_", internal::GetNodeId())));
581     AddNodeAttr("dtype", DT_STRING, &var_handle_def);
582     AddNodeAttr("shape", TensorShape({}), &var_handle_def);
583     TF_ASSIGN_OR_RETURN(Node * var_handle_node, graph->AddNode(var_handle_def));
584 
585     // Add control edge between `var_handle_def` node and while loop
586     // loop condition so that `var_handle_def` is inside the same while loop
587     // frame.
588     // TODO(hongjunchoi): Consider adding control edge from another node--such
589     // as input control node.
590     graph->AddControlEdge(loop_condition_node, var_handle_node);
591 
592     // Connect data edge between var handle op and reshard op.
593     const int format_state_input = new_key_input + 1;
594     graph->AddEdge(var_handle_node, 0, reshard_op_node, format_state_input);
595 
596     // Create Reshard op that represents unsharding after TPUExecute.
597     NodeDef unshard_node_def;
598     unshard_node_def.set_name(graph->NewName(strings::StrCat(
599         "TPUVariableReshard/unshard", "/_", internal::GetNodeId())));
600     unshard_node_def.set_op("TPUReshardVariables");
601     AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
602                 &unshard_node_def);
603     TF_ASSIGN_OR_RETURN(Node * unshard_op_node,
604                         graph->AddNode(unshard_node_def));
605 
606     unshard_op_node->set_assigned_device_name(
607         execute_node->assigned_device_name());
608 
609     for (int i = 0; i < info.var_inputs.size(); ++i) {
610       const auto variable_edge = info.var_inputs[i];
611       // Connect model weight resource variables to unshard op. Since unshard op
612       // must be only invoked after the very last loop iteration, for each while
613       // loop inputs, we traverse backwards to find the switch node of the host
614       // training loop and connect `output_false` field of the switch node with
615       // unshard op.
616       TF_ASSIGN_OR_RETURN(
617           Node * enter_node,
618           FindEnterNodeFromTPUExecuteNodeInput(variable_edge->src()));
619       graph->AddEdge(enter_node, 0, unshard_op_node, i);
620     }
621 
622     // Add control dependency before/after unshard node and the control nodes.
623     graph->AddControlEdge(after_last_iteration_node, unshard_op_node);
624     graph->AddControlEdge(unshard_op_node, after_unshard_node);
625 
626     graph->AddEdge(default_sharding_node, 0, unshard_op_node, new_key_input);
627 
628     // Add data edge from sharding state var handle op to unshard op.
629     graph->AddEdge(var_handle_node, 0, unshard_op_node, format_state_input);
630   }
631   // Add control dependency from after_unshard_node to all exits nodes. This is
632   // to make sure that the unshard ops will be executed as long as any of the
633   // exits are used.
634   for (auto exit : FindLoopExitNodes(*loop_condition_node)) {
635     graph->AddControlEdge(after_unshard_node, exit);
636   }
637   return OkStatus();
638 }
639 
640 }  // namespace tpu
641 }  // namespace tensorflow
642