1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Configuration for distributed TPU jobs
17 
18 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
19 
20 #include <unordered_map>
21 
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/core/common_runtime/device_set.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/graph_constructor.h"
26 #include "tensorflow/core/common_runtime/optimization_registry.h"
27 #include "tensorflow/core/framework/node_def_builder.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/partial_tensor_shape.h"
30 #include "tensorflow/core/graph/graph.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/public/session_options.h"
35 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
36 #include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
37 #include "tensorflow/core/tpu/tpu_init_mode.h"
38 #include "tensorflow/core/util/device_name_utils.h"
39 #include "tensorflow/core/util/dump_graph.h"
40 
41 namespace tensorflow {
42 namespace {
43 
44 constexpr char kIdentityOp[] = "Identity";
45 constexpr char kConfigureOp[] = "ConfigureDistributedTPU";
46 constexpr char kInternalConfigureOp[] = "_ConfigureDistributedTPU";
47 constexpr char kWaitOp[] = "_WaitForDistributedTPU";
48 constexpr char kHostConfigureOp[] = "_InitializeHostForDistributedTPU";
49 constexpr char kGlobalTPUArrayOp[] = "_SetGlobalTPUArray";
50 constexpr char kShutdownOp[] = "ShutdownDistributedTPU";
51 constexpr char kInternalShutdownOp[] = "_ShutdownDistributedTPU";
52 constexpr char kHostDisconnectOp[] = "_DisconnectHostFromDistributedTPUSystem";
53 constexpr char kEmbeddingConfigurationAttr[] = "embedding_config";
54 constexpr char kTpuCancellationClosesChipsAttr[] =
55     "tpu_cancellation_closes_chips";
56 constexpr int kDefaultStartupTimeout = 20;
57 
AddConfigurationNode(const string & configuration_device_name,int number_of_hosts,Graph * graph,bool enable_whole_mesh_compilations,Node ** configuration_node)58 Status AddConfigurationNode(const string& configuration_device_name,
59                             int number_of_hosts, Graph* graph,
60                             bool enable_whole_mesh_compilations,
61                             Node** configuration_node) {
62   NodeDef config_def;
63   config_def.set_name(graph->NewName("configure_distributed_tpu"));
64   config_def.set_op(kInternalConfigureOp);
65   config_def.set_device(configuration_device_name);
66   AddNodeAttr("N", number_of_hosts, &config_def);
67   AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
68               &config_def);
69   // TODO(shikharagarwal): Fill with appropriate original node debug info.
70 
71   TF_ASSIGN_OR_RETURN(*configuration_node, graph->AddNode(config_def));
72   (*configuration_node)->set_assigned_device_name(configuration_device_name);
73   return OkStatus();
74 }
75 
AddHostConfigNode(const string & host_device_name,Node * configuration_node,Graph * graph,bool enable_whole_mesh_compilations,int tpu_cancellation_closes_chips,Node ** host_configuration_node)76 Status AddHostConfigNode(const string& host_device_name,
77                          Node* configuration_node, Graph* graph,
78                          bool enable_whole_mesh_compilations,
79                          int tpu_cancellation_closes_chips,
80                          Node** host_configuration_node) {
81   NodeDef host_config_def;
82   host_config_def.set_name(graph->NewName("configure_tpu_host"));
83   host_config_def.set_op(kHostConfigureOp);
84   host_config_def.set_device(host_device_name);
85   AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
86               &host_config_def);
87   AddNodeAttr(kTpuCancellationClosesChipsAttr, tpu_cancellation_closes_chips,
88               &host_config_def);
89   MergeDebugInfo(NodeDebugInfo(configuration_node->def()), &host_config_def);
90 
91   TF_ASSIGN_OR_RETURN(*host_configuration_node,
92                       graph->AddNode(host_config_def));
93   (*host_configuration_node)->set_assigned_device_name(host_device_name);
94   graph->AddEdge(configuration_node, 0, *host_configuration_node, 0);
95   return OkStatus();
96 }
97 
AddWaitNode(const string & configuration_device_name,const std::vector<Node * > & host_configuration_nodes,Graph * graph,Node ** wait_node)98 Status AddWaitNode(const string& configuration_device_name,
99                    const std::vector<Node*>& host_configuration_nodes,
100                    Graph* graph, Node** wait_node) {
101   NodeDef wait_def;
102   wait_def.set_name(graph->NewName("wait_for_distributed_tpu_system"));
103   wait_def.set_op(kWaitOp);
104   wait_def.set_device(configuration_device_name);
105   AddNodeAttr("N", static_cast<int32>(host_configuration_nodes.size()),
106               &wait_def);
107   AddNodeAttr("startup_timeout_sec", kDefaultStartupTimeout, &wait_def);
108   if (!host_configuration_nodes.empty()) {
109     MergeDebugInfo(NodeDebugInfo(host_configuration_nodes[0]->def()),
110                    &wait_def);
111   }
112 
113   TF_ASSIGN_OR_RETURN(*wait_node, graph->AddNode(wait_def));
114   (*wait_node)->set_assigned_device_name(configuration_device_name);
115   // Get the inputs from the host configuration nodes.
116   for (int i = 0; i < host_configuration_nodes.size(); ++i) {
117     graph->AddEdge(host_configuration_nodes[i], 0, *wait_node, i);
118   }
119   return OkStatus();
120 }
121 
AddGlobalTPUArrayNode(const string & host_device_name,Node * wait_node,Graph * graph,Node ** global_tpu_array_node)122 Status AddGlobalTPUArrayNode(const string& host_device_name, Node* wait_node,
123                              Graph* graph, Node** global_tpu_array_node) {
124   NodeDef global_tpu_array_def;
125   global_tpu_array_def.set_name(graph->NewName("set_global_tpu_array"));
126   global_tpu_array_def.set_op(kGlobalTPUArrayOp);
127   global_tpu_array_def.set_device(host_device_name);
128   MergeDebugInfo(NodeDebugInfo(wait_node->def()), &global_tpu_array_def);
129 
130   TF_ASSIGN_OR_RETURN(*global_tpu_array_node,
131                       graph->AddNode(global_tpu_array_def));
132   (*global_tpu_array_node)->set_assigned_device_name(host_device_name);
133   graph->AddEdge(wait_node, 0, *global_tpu_array_node, 0);
134   return OkStatus();
135 }
136 
AddSynchronizationNode(const NodeDef & sync_node_def,const string & device_name,const std::vector<Node * > & global_array_id_nodes,Node * wait_node,const std::vector<DistributedTPURewriteHelpers::OutputDependency> & output_dependencies,Graph * graph)137 Status AddSynchronizationNode(
138     const NodeDef& sync_node_def, const string& device_name,
139     const std::vector<Node*>& global_array_id_nodes, Node* wait_node,
140     const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
141         output_dependencies,
142     Graph* graph) {
143   NodeDef sync_def;
144   sync_def.set_name(sync_node_def.name());
145   sync_def.set_op(kIdentityOp);
146   sync_def.set_device(device_name);
147   AddNodeAttr("T", DT_STRING, &sync_def);
148   MergeDebugInfo(NodeDebugInfo(sync_node_def), &sync_def);
149 
150   TF_ASSIGN_OR_RETURN(Node * sync_node, graph->AddNode(sync_def));
151   sync_node->set_assigned_device_name(device_name);
152   // Add control edges from the global array id nodes.
153   for (auto node : global_array_id_nodes) {
154     graph->AddControlEdge(node, sync_node);
155   }
156   // Forward the data from the wait node.
157   graph->AddEdge(wait_node, 0, sync_node, 0);
158   // Replace the output edges.
159   for (const DistributedTPURewriteHelpers::OutputDependency& dep :
160        output_dependencies) {
161     if (dep.dst_input == Graph::kControlSlot) {
162       graph->AddControlEdge(sync_node, dep.dst);
163     } else {
164       graph->AddEdge(sync_node, dep.src_output, dep.dst, dep.dst_input);
165     }
166   }
167   return OkStatus();
168 }
169 
170 
AddShutdownNode(const NodeDef & shutdown_node_def,const string & shutdown_device_name,const std::vector<DistributedTPURewriteHelpers::OutputDependency> & output_dependencies,Graph * graph,Node ** shutdown_node)171 Status AddShutdownNode(
172     const NodeDef& shutdown_node_def, const string& shutdown_device_name,
173     const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
174         output_dependencies,
175     Graph* graph, Node** shutdown_node) {
176   NodeDef shutdown_def;
177   shutdown_def.set_name(shutdown_node_def.name());
178   shutdown_def.set_op(kInternalShutdownOp);
179   shutdown_def.set_device(shutdown_device_name);
180   MergeDebugInfo(NodeDebugInfo(shutdown_node_def), &shutdown_def);
181 
182   TF_ASSIGN_OR_RETURN(*shutdown_node, graph->AddNode(shutdown_def));
183   (*shutdown_node)->set_assigned_device_name(shutdown_device_name);
184   // Replace the output control edges.
185   for (const DistributedTPURewriteHelpers::OutputDependency& dep :
186        output_dependencies) {
187     if (dep.dst_input != Graph::kControlSlot) {
188       return errors::Internal("Shutdown node had non-control edge output");
189     }
190     graph->AddControlEdge(*shutdown_node, dep.dst);
191   }
192   return OkStatus();
193 }
194 
AddHostDisconnectNode(const string & host_device_name,const std::vector<Node * > & input_dependencies,Node * post_disconnect_node,int output_index,Graph * graph)195 Status AddHostDisconnectNode(const string& host_device_name,
196                              const std::vector<Node*>& input_dependencies,
197                              Node* post_disconnect_node, int output_index,
198                              Graph* graph) {
199   NodeDef host_disconnect_def;
200   host_disconnect_def.set_name(graph->NewName("disconnect_tpu_host"));
201   host_disconnect_def.set_op(kHostDisconnectOp);
202   host_disconnect_def.set_device(host_device_name);
203   MergeDebugInfo(NodeDebugInfo(post_disconnect_node->def()),
204                  &host_disconnect_def);
205 
206   TF_ASSIGN_OR_RETURN(Node * host_disconnect_node,
207                       graph->AddNode(host_disconnect_def));
208   host_disconnect_node->set_assigned_device_name(host_device_name);
209   // Replace the input control edges.
210   for (Node* src_node : input_dependencies) {
211     graph->AddControlEdge(src_node, host_disconnect_node);
212   }
213   if (output_index == -1) {
214     graph->AddControlEdge(host_disconnect_node, post_disconnect_node);
215   } else {
216     graph->AddEdge(host_disconnect_node, 0, post_disconnect_node, output_index);
217   }
218   return OkStatus();
219 }
220 
221 }  // namespace
222 
Run(const GraphOptimizationPassOptions & options)223 Status DistributedTPUConfigurationRewritePass::Run(
224     const GraphOptimizationPassOptions& options) {
225   VLOG(1) << "DistributedTPUConfigurationRewritePass::Run";
226 
227   Graph* graph = options.graph->get();
228 
229   if (VLOG_IS_ON(1)) {
230     DumpGraphToFile("distributed_tpu_configuration_before", *graph,
231                     options.flib_def);
232   }
233 
234   // This pass can only run in the session master, which should fill
235   // in the device_set field to the options.
236   TF_RET_CHECK(options.device_set != nullptr);
237 
238   TF_RETURN_IF_ERROR(
239       DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
240           kConfigureOp, graph, *options.device_set,
241           [](const NodeDef& configuration_node_def,
242              const string& configuration_device_name,
243              const std::vector<Device*>& host_devices,
244              const std::vector<Node*>& input_dependencies,
245              const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
246                  output_dependencies,
247              Graph* graph) -> Status {
248             const std::string& embedding_attr_string = GetNodeAttrString(
249                 AttrSlice(configuration_node_def), kEmbeddingConfigurationAttr);
250 
251             if (!embedding_attr_string.empty()) {
252               return errors::InvalidArgument("embedding_config must be empty.");
253             }
254 
255             bool is_global_init = false;
256             bool enable_whole_mesh_compilations = false;
257             TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
258                                            "is_global_init", &is_global_init));
259             TryGetNodeAttr(configuration_node_def,
260                            "enable_whole_mesh_compilations",
261                            &enable_whole_mesh_compilations);
262             TF_RETURN_IF_ERROR(SetTPUInitMode(
263                 is_global_init ? TPUInitMode::kGlobal : TPUInitMode::kRegular));
264 
265             bool compilation_failure_closes_chips;
266             TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
267                                            "compilation_failure_closes_chips",
268                                            &compilation_failure_closes_chips));
269             internal::SetTpuCompilationFailureClosesChips(
270                 compilation_failure_closes_chips);
271 
272             int tpu_cancellation_closes_chips;
273             TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
274                                            kTpuCancellationClosesChipsAttr,
275                                            &tpu_cancellation_closes_chips));
276 
277             // Add the global TPU system configuration node.
278             Node* configuration_node;
279             TF_RETURN_IF_ERROR(AddConfigurationNode(
280                 configuration_device_name, host_devices.size(), graph,
281                 enable_whole_mesh_compilations, &configuration_node));
282 
283             // Add the host disconnect nodes.
284             for (int i = 0; i < host_devices.size(); ++i) {
285               const auto host_device = host_devices[i];
286               TF_RETURN_IF_ERROR(
287                   AddHostDisconnectNode(host_device->name(), input_dependencies,
288                                         configuration_node, i, graph));
289             }
290 
291             // Add the host configuration nodes.
292             std::vector<Node*> host_configuration_nodes;
293             for (const auto host_device : host_devices) {
294               Node* host_configuration_node;
295               TF_RETURN_IF_ERROR(AddHostConfigNode(
296                   host_device->name(), configuration_node, graph,
297                   enable_whole_mesh_compilations, tpu_cancellation_closes_chips,
298                   &host_configuration_node));
299               host_configuration_nodes.push_back(host_configuration_node);
300             }
301 
302             // Add the node to wait for the system configuration to
303             // stabilize. Use the name of the original dummy Op in case it was
304             // the target of a Session::Run call.
305             Node* wait_node;
306             TF_RETURN_IF_ERROR(AddWaitNode(configuration_device_name,
307                                            host_configuration_nodes, graph,
308                                            &wait_node));
309 
310             // Add the nodes to set the global TPU ids at each host.
311             std::vector<Node*> global_array_id_nodes;
312             for (const auto host_device : host_devices) {
313               Node* global_array_id_node;
314               TF_RETURN_IF_ERROR(AddGlobalTPUArrayNode(host_device->name(),
315                                                        wait_node, graph,
316                                                        &global_array_id_node));
317               global_array_id_nodes.push_back(global_array_id_node);
318             }
319 
320             if (host_devices.empty()) {
321               return errors::InvalidArgument("TPU job contains no CPU devices");
322             }
323             TF_RET_CHECK(!host_devices.empty());
324 
325             TF_RETURN_IF_ERROR(AddSynchronizationNode(
326                 configuration_node_def, host_devices.front()->name(),
327                 global_array_id_nodes, wait_node, output_dependencies, graph));
328 
329             return OkStatus();
330           }));
331 
332   if (VLOG_IS_ON(1)) {
333     DumpGraphToFile("distributed_tpu_configuration_after", *graph,
334                     options.flib_def);
335   }
336 
337   VLOG(1) << "DistributedTPUConfigurationRewritePass::Run() finished";
338   return OkStatus();
339 }
340 
Run(const GraphOptimizationPassOptions & options)341 Status DistributedTPUShutdownRewritePass::Run(
342     const GraphOptimizationPassOptions& options) {
343   VLOG(1) << "DistributedTPUShutdownRewritePass::Run";
344 
345   Graph* graph = options.graph->get();
346 
347   if (VLOG_IS_ON(1)) {
348     DumpGraphToFile("distributed_tpu_shutdown_before", *graph,
349                     options.flib_def);
350   }
351 
352   // This pass can only run in the session master, which should fill
353   // in the device_set field to the options.
354   TF_RET_CHECK(options.device_set != nullptr);
355 
356   TF_RETURN_IF_ERROR(
357       DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
358           kShutdownOp, graph, *options.device_set,
359           [](const NodeDef& shutdown_node_def,
360              const string& shutdown_device_name,
361              const std::vector<Device*>& host_devices,
362              const std::vector<Node*>& input_dependencies,
363              const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
364                  output_dependencies,
365              Graph* graph) -> Status {
366             Node* shutdown_node;
367             TF_RETURN_IF_ERROR(
368                 AddShutdownNode(shutdown_node_def, shutdown_device_name,
369                                 output_dependencies, graph, &shutdown_node));
370 
371             // Add the host disconnect nodes.
372             for (const auto host_device : host_devices) {
373               TF_RETURN_IF_ERROR(
374                   AddHostDisconnectNode(host_device->name(), input_dependencies,
375                                         shutdown_node, -1, graph));
376             }
377 
378             return OkStatus();
379           }));
380 
381   if (VLOG_IS_ON(1)) {
382     DumpGraphToFile("distributed_tpu_shutdown_after", *graph, options.flib_def);
383   }
384 
385   VLOG(1) << "DistributedTPUShutdownRewritePass::Run() finished";
386   return OkStatus();
387 }
388 
389 }  // namespace tensorflow
390