1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Configuration for distributed TPU jobs
17
18 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
19
20 #include <unordered_map>
21
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/core/common_runtime/device_set.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/graph_constructor.h"
26 #include "tensorflow/core/common_runtime/optimization_registry.h"
27 #include "tensorflow/core/framework/node_def_builder.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/partial_tensor_shape.h"
30 #include "tensorflow/core/graph/graph.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/public/session_options.h"
35 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
36 #include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
37 #include "tensorflow/core/tpu/tpu_init_mode.h"
38 #include "tensorflow/core/util/device_name_utils.h"
39 #include "tensorflow/core/util/dump_graph.h"
40
41 namespace tensorflow {
42 namespace {
43
44 constexpr char kIdentityOp[] = "Identity";
45 constexpr char kConfigureOp[] = "ConfigureDistributedTPU";
46 constexpr char kInternalConfigureOp[] = "_ConfigureDistributedTPU";
47 constexpr char kWaitOp[] = "_WaitForDistributedTPU";
48 constexpr char kHostConfigureOp[] = "_InitializeHostForDistributedTPU";
49 constexpr char kGlobalTPUArrayOp[] = "_SetGlobalTPUArray";
50 constexpr char kShutdownOp[] = "ShutdownDistributedTPU";
51 constexpr char kInternalShutdownOp[] = "_ShutdownDistributedTPU";
52 constexpr char kHostDisconnectOp[] = "_DisconnectHostFromDistributedTPUSystem";
53 constexpr char kEmbeddingConfigurationAttr[] = "embedding_config";
54 constexpr char kTpuCancellationClosesChipsAttr[] =
55 "tpu_cancellation_closes_chips";
56 constexpr int kDefaultStartupTimeout = 20;
57
AddConfigurationNode(const string & configuration_device_name,int number_of_hosts,Graph * graph,bool enable_whole_mesh_compilations,Node ** configuration_node)58 Status AddConfigurationNode(const string& configuration_device_name,
59 int number_of_hosts, Graph* graph,
60 bool enable_whole_mesh_compilations,
61 Node** configuration_node) {
62 NodeDef config_def;
63 config_def.set_name(graph->NewName("configure_distributed_tpu"));
64 config_def.set_op(kInternalConfigureOp);
65 config_def.set_device(configuration_device_name);
66 AddNodeAttr("N", number_of_hosts, &config_def);
67 AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
68 &config_def);
69 // TODO(shikharagarwal): Fill with appropriate original node debug info.
70
71 TF_ASSIGN_OR_RETURN(*configuration_node, graph->AddNode(config_def));
72 (*configuration_node)->set_assigned_device_name(configuration_device_name);
73 return OkStatus();
74 }
75
AddHostConfigNode(const string & host_device_name,Node * configuration_node,Graph * graph,bool enable_whole_mesh_compilations,int tpu_cancellation_closes_chips,Node ** host_configuration_node)76 Status AddHostConfigNode(const string& host_device_name,
77 Node* configuration_node, Graph* graph,
78 bool enable_whole_mesh_compilations,
79 int tpu_cancellation_closes_chips,
80 Node** host_configuration_node) {
81 NodeDef host_config_def;
82 host_config_def.set_name(graph->NewName("configure_tpu_host"));
83 host_config_def.set_op(kHostConfigureOp);
84 host_config_def.set_device(host_device_name);
85 AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
86 &host_config_def);
87 AddNodeAttr(kTpuCancellationClosesChipsAttr, tpu_cancellation_closes_chips,
88 &host_config_def);
89 MergeDebugInfo(NodeDebugInfo(configuration_node->def()), &host_config_def);
90
91 TF_ASSIGN_OR_RETURN(*host_configuration_node,
92 graph->AddNode(host_config_def));
93 (*host_configuration_node)->set_assigned_device_name(host_device_name);
94 graph->AddEdge(configuration_node, 0, *host_configuration_node, 0);
95 return OkStatus();
96 }
97
AddWaitNode(const string & configuration_device_name,const std::vector<Node * > & host_configuration_nodes,Graph * graph,Node ** wait_node)98 Status AddWaitNode(const string& configuration_device_name,
99 const std::vector<Node*>& host_configuration_nodes,
100 Graph* graph, Node** wait_node) {
101 NodeDef wait_def;
102 wait_def.set_name(graph->NewName("wait_for_distributed_tpu_system"));
103 wait_def.set_op(kWaitOp);
104 wait_def.set_device(configuration_device_name);
105 AddNodeAttr("N", static_cast<int32>(host_configuration_nodes.size()),
106 &wait_def);
107 AddNodeAttr("startup_timeout_sec", kDefaultStartupTimeout, &wait_def);
108 if (!host_configuration_nodes.empty()) {
109 MergeDebugInfo(NodeDebugInfo(host_configuration_nodes[0]->def()),
110 &wait_def);
111 }
112
113 TF_ASSIGN_OR_RETURN(*wait_node, graph->AddNode(wait_def));
114 (*wait_node)->set_assigned_device_name(configuration_device_name);
115 // Get the inputs from the host configuration nodes.
116 for (int i = 0; i < host_configuration_nodes.size(); ++i) {
117 graph->AddEdge(host_configuration_nodes[i], 0, *wait_node, i);
118 }
119 return OkStatus();
120 }
121
AddGlobalTPUArrayNode(const string & host_device_name,Node * wait_node,Graph * graph,Node ** global_tpu_array_node)122 Status AddGlobalTPUArrayNode(const string& host_device_name, Node* wait_node,
123 Graph* graph, Node** global_tpu_array_node) {
124 NodeDef global_tpu_array_def;
125 global_tpu_array_def.set_name(graph->NewName("set_global_tpu_array"));
126 global_tpu_array_def.set_op(kGlobalTPUArrayOp);
127 global_tpu_array_def.set_device(host_device_name);
128 MergeDebugInfo(NodeDebugInfo(wait_node->def()), &global_tpu_array_def);
129
130 TF_ASSIGN_OR_RETURN(*global_tpu_array_node,
131 graph->AddNode(global_tpu_array_def));
132 (*global_tpu_array_node)->set_assigned_device_name(host_device_name);
133 graph->AddEdge(wait_node, 0, *global_tpu_array_node, 0);
134 return OkStatus();
135 }
136
AddSynchronizationNode(const NodeDef & sync_node_def,const string & device_name,const std::vector<Node * > & global_array_id_nodes,Node * wait_node,const std::vector<DistributedTPURewriteHelpers::OutputDependency> & output_dependencies,Graph * graph)137 Status AddSynchronizationNode(
138 const NodeDef& sync_node_def, const string& device_name,
139 const std::vector<Node*>& global_array_id_nodes, Node* wait_node,
140 const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
141 output_dependencies,
142 Graph* graph) {
143 NodeDef sync_def;
144 sync_def.set_name(sync_node_def.name());
145 sync_def.set_op(kIdentityOp);
146 sync_def.set_device(device_name);
147 AddNodeAttr("T", DT_STRING, &sync_def);
148 MergeDebugInfo(NodeDebugInfo(sync_node_def), &sync_def);
149
150 TF_ASSIGN_OR_RETURN(Node * sync_node, graph->AddNode(sync_def));
151 sync_node->set_assigned_device_name(device_name);
152 // Add control edges from the global array id nodes.
153 for (auto node : global_array_id_nodes) {
154 graph->AddControlEdge(node, sync_node);
155 }
156 // Forward the data from the wait node.
157 graph->AddEdge(wait_node, 0, sync_node, 0);
158 // Replace the output edges.
159 for (const DistributedTPURewriteHelpers::OutputDependency& dep :
160 output_dependencies) {
161 if (dep.dst_input == Graph::kControlSlot) {
162 graph->AddControlEdge(sync_node, dep.dst);
163 } else {
164 graph->AddEdge(sync_node, dep.src_output, dep.dst, dep.dst_input);
165 }
166 }
167 return OkStatus();
168 }
169
170
AddShutdownNode(const NodeDef & shutdown_node_def,const string & shutdown_device_name,const std::vector<DistributedTPURewriteHelpers::OutputDependency> & output_dependencies,Graph * graph,Node ** shutdown_node)171 Status AddShutdownNode(
172 const NodeDef& shutdown_node_def, const string& shutdown_device_name,
173 const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
174 output_dependencies,
175 Graph* graph, Node** shutdown_node) {
176 NodeDef shutdown_def;
177 shutdown_def.set_name(shutdown_node_def.name());
178 shutdown_def.set_op(kInternalShutdownOp);
179 shutdown_def.set_device(shutdown_device_name);
180 MergeDebugInfo(NodeDebugInfo(shutdown_node_def), &shutdown_def);
181
182 TF_ASSIGN_OR_RETURN(*shutdown_node, graph->AddNode(shutdown_def));
183 (*shutdown_node)->set_assigned_device_name(shutdown_device_name);
184 // Replace the output control edges.
185 for (const DistributedTPURewriteHelpers::OutputDependency& dep :
186 output_dependencies) {
187 if (dep.dst_input != Graph::kControlSlot) {
188 return errors::Internal("Shutdown node had non-control edge output");
189 }
190 graph->AddControlEdge(*shutdown_node, dep.dst);
191 }
192 return OkStatus();
193 }
194
AddHostDisconnectNode(const string & host_device_name,const std::vector<Node * > & input_dependencies,Node * post_disconnect_node,int output_index,Graph * graph)195 Status AddHostDisconnectNode(const string& host_device_name,
196 const std::vector<Node*>& input_dependencies,
197 Node* post_disconnect_node, int output_index,
198 Graph* graph) {
199 NodeDef host_disconnect_def;
200 host_disconnect_def.set_name(graph->NewName("disconnect_tpu_host"));
201 host_disconnect_def.set_op(kHostDisconnectOp);
202 host_disconnect_def.set_device(host_device_name);
203 MergeDebugInfo(NodeDebugInfo(post_disconnect_node->def()),
204 &host_disconnect_def);
205
206 TF_ASSIGN_OR_RETURN(Node * host_disconnect_node,
207 graph->AddNode(host_disconnect_def));
208 host_disconnect_node->set_assigned_device_name(host_device_name);
209 // Replace the input control edges.
210 for (Node* src_node : input_dependencies) {
211 graph->AddControlEdge(src_node, host_disconnect_node);
212 }
213 if (output_index == -1) {
214 graph->AddControlEdge(host_disconnect_node, post_disconnect_node);
215 } else {
216 graph->AddEdge(host_disconnect_node, 0, post_disconnect_node, output_index);
217 }
218 return OkStatus();
219 }
220
221 } // namespace
222
Run(const GraphOptimizationPassOptions & options)223 Status DistributedTPUConfigurationRewritePass::Run(
224 const GraphOptimizationPassOptions& options) {
225 VLOG(1) << "DistributedTPUConfigurationRewritePass::Run";
226
227 Graph* graph = options.graph->get();
228
229 if (VLOG_IS_ON(1)) {
230 DumpGraphToFile("distributed_tpu_configuration_before", *graph,
231 options.flib_def);
232 }
233
234 // This pass can only run in the session master, which should fill
235 // in the device_set field to the options.
236 TF_RET_CHECK(options.device_set != nullptr);
237
238 TF_RETURN_IF_ERROR(
239 DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
240 kConfigureOp, graph, *options.device_set,
241 [](const NodeDef& configuration_node_def,
242 const string& configuration_device_name,
243 const std::vector<Device*>& host_devices,
244 const std::vector<Node*>& input_dependencies,
245 const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
246 output_dependencies,
247 Graph* graph) -> Status {
248 const std::string& embedding_attr_string = GetNodeAttrString(
249 AttrSlice(configuration_node_def), kEmbeddingConfigurationAttr);
250
251 if (!embedding_attr_string.empty()) {
252 return errors::InvalidArgument("embedding_config must be empty.");
253 }
254
255 bool is_global_init = false;
256 bool enable_whole_mesh_compilations = false;
257 TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
258 "is_global_init", &is_global_init));
259 TryGetNodeAttr(configuration_node_def,
260 "enable_whole_mesh_compilations",
261 &enable_whole_mesh_compilations);
262 TF_RETURN_IF_ERROR(SetTPUInitMode(
263 is_global_init ? TPUInitMode::kGlobal : TPUInitMode::kRegular));
264
265 bool compilation_failure_closes_chips;
266 TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
267 "compilation_failure_closes_chips",
268 &compilation_failure_closes_chips));
269 internal::SetTpuCompilationFailureClosesChips(
270 compilation_failure_closes_chips);
271
272 int tpu_cancellation_closes_chips;
273 TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
274 kTpuCancellationClosesChipsAttr,
275 &tpu_cancellation_closes_chips));
276
277 // Add the global TPU system configuration node.
278 Node* configuration_node;
279 TF_RETURN_IF_ERROR(AddConfigurationNode(
280 configuration_device_name, host_devices.size(), graph,
281 enable_whole_mesh_compilations, &configuration_node));
282
283 // Add the host disconnect nodes.
284 for (int i = 0; i < host_devices.size(); ++i) {
285 const auto host_device = host_devices[i];
286 TF_RETURN_IF_ERROR(
287 AddHostDisconnectNode(host_device->name(), input_dependencies,
288 configuration_node, i, graph));
289 }
290
291 // Add the host configuration nodes.
292 std::vector<Node*> host_configuration_nodes;
293 for (const auto host_device : host_devices) {
294 Node* host_configuration_node;
295 TF_RETURN_IF_ERROR(AddHostConfigNode(
296 host_device->name(), configuration_node, graph,
297 enable_whole_mesh_compilations, tpu_cancellation_closes_chips,
298 &host_configuration_node));
299 host_configuration_nodes.push_back(host_configuration_node);
300 }
301
302 // Add the node to wait for the system configuration to
303 // stabilize. Use the name of the original dummy Op in case it was
304 // the target of a Session::Run call.
305 Node* wait_node;
306 TF_RETURN_IF_ERROR(AddWaitNode(configuration_device_name,
307 host_configuration_nodes, graph,
308 &wait_node));
309
310 // Add the nodes to set the global TPU ids at each host.
311 std::vector<Node*> global_array_id_nodes;
312 for (const auto host_device : host_devices) {
313 Node* global_array_id_node;
314 TF_RETURN_IF_ERROR(AddGlobalTPUArrayNode(host_device->name(),
315 wait_node, graph,
316 &global_array_id_node));
317 global_array_id_nodes.push_back(global_array_id_node);
318 }
319
320 if (host_devices.empty()) {
321 return errors::InvalidArgument("TPU job contains no CPU devices");
322 }
323 TF_RET_CHECK(!host_devices.empty());
324
325 TF_RETURN_IF_ERROR(AddSynchronizationNode(
326 configuration_node_def, host_devices.front()->name(),
327 global_array_id_nodes, wait_node, output_dependencies, graph));
328
329 return OkStatus();
330 }));
331
332 if (VLOG_IS_ON(1)) {
333 DumpGraphToFile("distributed_tpu_configuration_after", *graph,
334 options.flib_def);
335 }
336
337 VLOG(1) << "DistributedTPUConfigurationRewritePass::Run() finished";
338 return OkStatus();
339 }
340
Run(const GraphOptimizationPassOptions & options)341 Status DistributedTPUShutdownRewritePass::Run(
342 const GraphOptimizationPassOptions& options) {
343 VLOG(1) << "DistributedTPUShutdownRewritePass::Run";
344
345 Graph* graph = options.graph->get();
346
347 if (VLOG_IS_ON(1)) {
348 DumpGraphToFile("distributed_tpu_shutdown_before", *graph,
349 options.flib_def);
350 }
351
352 // This pass can only run in the session master, which should fill
353 // in the device_set field to the options.
354 TF_RET_CHECK(options.device_set != nullptr);
355
356 TF_RETURN_IF_ERROR(
357 DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
358 kShutdownOp, graph, *options.device_set,
359 [](const NodeDef& shutdown_node_def,
360 const string& shutdown_device_name,
361 const std::vector<Device*>& host_devices,
362 const std::vector<Node*>& input_dependencies,
363 const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
364 output_dependencies,
365 Graph* graph) -> Status {
366 Node* shutdown_node;
367 TF_RETURN_IF_ERROR(
368 AddShutdownNode(shutdown_node_def, shutdown_device_name,
369 output_dependencies, graph, &shutdown_node));
370
371 // Add the host disconnect nodes.
372 for (const auto host_device : host_devices) {
373 TF_RETURN_IF_ERROR(
374 AddHostDisconnectNode(host_device->name(), input_dependencies,
375 shutdown_node, -1, graph));
376 }
377
378 return OkStatus();
379 }));
380
381 if (VLOG_IS_ON(1)) {
382 DumpGraphToFile("distributed_tpu_shutdown_after", *graph, options.flib_def);
383 }
384
385 VLOG(1) << "DistributedTPUShutdownRewritePass::Run() finished";
386 return OkStatus();
387 }
388
389 } // namespace tensorflow
390