1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Configuration for TPU Embedding.
17
18 #include "tensorflow/core/tpu/graph_rewrite/configure_tpu_embedding_rewrite_pass.h"
19
20 #include <string>
21
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/core/common_runtime/device_set.h"
24 #include "tensorflow/core/common_runtime/optimization_registry.h"
25 #include "tensorflow/core/graph/graph.h"
26 #include "tensorflow/core/graph/graph_node_util.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/statusor.h"
29 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
30 #include "tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite.h"
31 #include "tensorflow/core/util/device_name_utils.h"
32 #include "tensorflow/core/util/dump_graph.h"
33
34 namespace tensorflow {
35 namespace {
36
37 constexpr char kNoOp[] = "NoOp";
38 constexpr char kConfigureOp[] = "ConfigureTPUEmbedding";
39 constexpr char kExecutePartitionerOp[] = "ExecuteTPUEmbeddingPartitioner";
40 constexpr char kConfigureMemoryOp[] = "ConfigureTPUEmbeddingMemory";
41 constexpr char kCollateMemoryOp[] = "CollateTPUEmbeddingMemory";
42 constexpr char kConfigureHostOp[] = "ConfigureTPUEmbeddingHost";
43 constexpr char kConnectHostsOp[] = "ConnectTPUEmbeddingHosts";
44 constexpr char kFinalizeOp[] = "FinalizeTPUEmbedding";
45 constexpr char kEmbeddingConfigurationAttr[] = "config";
46
AddSynchronizationNode(const NodeDef & sync_node_def,const string & device_name,absl::Span<Node * const> end_nodes,absl::Span<const DistributedTPURewriteHelpers::OutputDependency> output_dependencies,Graph * graph)47 Status AddSynchronizationNode(
48 const NodeDef& sync_node_def, const string& device_name,
49 absl::Span<Node* const> end_nodes,
50 absl::Span<const DistributedTPURewriteHelpers::OutputDependency>
51 output_dependencies,
52 Graph* graph) {
53 NodeDef sync_def;
54 sync_def.set_name(sync_node_def.name());
55 sync_def.set_op(kNoOp);
56 sync_def.set_device(device_name);
57 MergeDebugInfo(NodeDebugInfo(sync_node_def), &sync_def);
58
59 TF_ASSIGN_OR_RETURN(Node * sync_node, graph->AddNode(sync_def));
60 sync_node->set_assigned_device_name(device_name);
61
62 // Add control edges from the nodes which must complete execution.
63 for (Node* end_node : end_nodes) {
64 graph->AddControlEdge(end_node, sync_node);
65 }
66
67 // Replace the output edges.
68 for (const DistributedTPURewriteHelpers::OutputDependency& dep :
69 output_dependencies) {
70 if (dep.dst_input == Graph::kControlSlot) {
71 graph->AddControlEdge(sync_node, dep.dst);
72 } else {
73 graph->AddEdge(sync_node, dep.src_output, dep.dst, dep.dst_input);
74 }
75 }
76 return OkStatus();
77 }
78
AddSetupPropagationEmbeddingNode(const string & device_name,const string & node_name,const string & op_name,absl::Span<Node * const> input_nodes,Graph * graph,Node ** node)79 Status AddSetupPropagationEmbeddingNode(const string& device_name,
80 const string& node_name,
81 const string& op_name,
82 absl::Span<Node* const> input_nodes,
83 Graph* graph, Node** node) {
84 NodeDef node_def;
85 node_def.set_name(node_name);
86 node_def.set_op(op_name);
87 node_def.set_device(device_name);
88 AddNodeAttr("N", static_cast<int>(input_nodes.size()), &node_def);
89 if (!input_nodes.empty()) {
90 MergeDebugInfo(NodeDebugInfo(input_nodes[0]->def()), &node_def);
91 }
92
93 TF_ASSIGN_OR_RETURN(*node, graph->AddNode(node_def));
94 (*node)->set_assigned_device_name(device_name);
95 // Add inputs from the embedding nodes.
96 for (int i = 0; i < input_nodes.size(); ++i) {
97 graph->AddEdge(input_nodes[i], 0, *node, i);
98 }
99 return OkStatus();
100 }
101
AddExecutePartitionerNode(const string & configuration_device_name,const string & config,absl::Span<Node * const> input_dependencies,Graph * graph,Node ** partitioner_node)102 Status AddExecutePartitionerNode(const string& configuration_device_name,
103 const string& config,
104 absl::Span<Node* const> input_dependencies,
105 Graph* graph, Node** partitioner_node) {
106 NodeDef partitioner_def;
107 partitioner_def.set_name(graph->NewName("execute_embedding_partitioner"));
108 partitioner_def.set_op(kExecutePartitionerOp);
109 partitioner_def.set_device(configuration_device_name);
110 AddNodeAttr("config", config, &partitioner_def);
111
112 TF_ASSIGN_OR_RETURN(*partitioner_node, graph->AddNode(partitioner_def));
113 (*partitioner_node)->set_assigned_device_name(configuration_device_name);
114 // Replace the input control edges.
115 for (Node* src_node : input_dependencies) {
116 graph->AddControlEdge(src_node, *partitioner_node);
117 }
118
119 return OkStatus();
120 }
121
AddConfigureMemoryNode(const string & host_device_name,Node * partitioner_node,Graph * graph,Node ** embedding_node)122 Status AddConfigureMemoryNode(const string& host_device_name,
123 Node* partitioner_node, Graph* graph,
124 Node** embedding_node) {
125 NodeDef embedding_def;
126 embedding_def.set_name(graph->NewName("configure_tpu_embedding_memory"));
127 embedding_def.set_op(kConfigureMemoryOp);
128 embedding_def.set_device(host_device_name);
129
130 TF_ASSIGN_OR_RETURN(*embedding_node, graph->AddNode(embedding_def));
131 (*embedding_node)->set_assigned_device_name(host_device_name);
132 graph->AddEdge(partitioner_node, 0, *embedding_node, 0);
133 return OkStatus();
134 }
135
AddCollateMemoryNode(const string & configuration_device_name,absl::Span<Node * const> memory_nodes,Graph * graph,Node ** embedding_node)136 Status AddCollateMemoryNode(const string& configuration_device_name,
137 absl::Span<Node* const> memory_nodes, Graph* graph,
138 Node** embedding_node) {
139 return AddSetupPropagationEmbeddingNode(
140 /*device_name=*/configuration_device_name,
141 /*node_name=*/graph->NewName("collate_tpu_embedding_memory"),
142 /*op_name=*/kCollateMemoryOp, /*input_nodes=*/memory_nodes,
143 /*graph=*/graph,
144 /*node=*/embedding_node);
145 }
146
AddConfigureHostNode(const string & host_device_name,const string & config,Node * partitioner_node,Node * memory_node,Graph * graph,Node ** embedding_node)147 Status AddConfigureHostNode(const string& host_device_name,
148 const string& config, Node* partitioner_node,
149 Node* memory_node, Graph* graph,
150 Node** embedding_node) {
151 NodeDef embedding_def;
152 embedding_def.set_name(graph->NewName("configure_tpu_embedding_host"));
153 embedding_def.set_op(kConfigureHostOp);
154 embedding_def.set_device(host_device_name);
155 AddNodeAttr("config", config, &embedding_def);
156
157 TF_ASSIGN_OR_RETURN(*embedding_node, graph->AddNode(embedding_def));
158 (*embedding_node)->set_assigned_device_name(host_device_name);
159 // Add inputs from the partitioner node and the memory node.
160 graph->AddEdge(partitioner_node, 0, *embedding_node, 0);
161 graph->AddEdge(memory_node, 0, *embedding_node, 1);
162
163 return OkStatus();
164 }
165
AddConnectHostsNode(const string & host_device_name,absl::Span<Node * const> configure_host_nodes,Graph * graph,Node ** connect_node)166 Status AddConnectHostsNode(const string& host_device_name,
167 absl::Span<Node* const> configure_host_nodes,
168 Graph* graph, Node** connect_node) {
169 return AddSetupPropagationEmbeddingNode(
170 /*device_name=*/host_device_name,
171 /*node_name=*/graph->NewName("connect_tpu_embedding_hosts"),
172 /*op_name=*/kConnectHostsOp, /*input_nodes=*/configure_host_nodes,
173 /*graph=*/graph,
174 /*node=*/connect_node);
175 }
176
AddFinalizeNode(const string & configuration_device_name,Node * partitioner_node,Node * memory_node,Graph * graph,Node ** finalize_node)177 Status AddFinalizeNode(const string& configuration_device_name,
178 Node* partitioner_node, Node* memory_node, Graph* graph,
179 Node** finalize_node) {
180 NodeDef finalize_def;
181 finalize_def.set_name(graph->NewName("finalize_tpu_embedding"));
182 finalize_def.set_op(kFinalizeOp);
183 finalize_def.set_device(configuration_device_name);
184
185 TF_ASSIGN_OR_RETURN(*finalize_node, graph->AddNode(finalize_def));
186 (*finalize_node)->set_assigned_device_name(configuration_device_name);
187 // Add inputs from the partitioner node and the memory node.
188 graph->AddEdge(partitioner_node, 0, *finalize_node, 0);
189 graph->AddEdge(memory_node, 0, *finalize_node, 1);
190
191 return OkStatus();
192 }
193
194 } // namespace
195
Run(const GraphOptimizationPassOptions & options)196 Status ConfigureTPUEmbeddingRewritePass::Run(
197 const GraphOptimizationPassOptions& options) {
198 VLOG(1) << "ConfigureTPUEmbeddingRewritePass::Run";
199
200 Graph* graph = options.graph->get();
201
202 if (VLOG_IS_ON(1)) {
203 DumpGraphToFile("configure_tpu_embedding_before", *graph, options.flib_def);
204 }
205
206 // This pass can only run in the session master, which should fill
207 // in the device_set field to the options.
208 TF_RET_CHECK(options.device_set != nullptr);
209
210 TF_RETURN_IF_ERROR(
211 DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
212 kConfigureOp, graph, *options.device_set,
213 [](const NodeDef& configuration_node_def,
214 const std::string& configuration_device_name,
215 const std::vector<Device*>& host_devices,
216 const std::vector<Node*>& input_dependencies,
217 const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
218 output_dependencies,
219 Graph* graph) -> Status {
220 if (host_devices.empty()) {
221 return errors::InvalidArgument("TPU job contains no CPU devices");
222 }
223 TF_RET_CHECK(!host_devices.empty());
224
225 auto get_updated_device_name =
226 [](absl::string_view initial_device_name)
227 -> xla::StatusOr<std::string> {
228 DeviceNameUtils::ParsedName device_spec;
229 TF_RET_CHECK(DeviceNameUtils::ParseFullName(initial_device_name,
230 &device_spec));
231 // Keep job, replica, and task information, but change the
232 // '/device:TPU_SYSTEM:0' specification to '/device:CPU:0'.
233 device_spec.type = "CPU";
234 return DeviceNameUtils::ParsedNameToString(device_spec);
235 };
236
237 // Must not use embedding_attr_string beyond the lifetime of
238 // configuration_node_def.
239 const std::string& embedding_attr_string = GetNodeAttrString(
240 AttrSlice(configuration_node_def), kEmbeddingConfigurationAttr);
241 if (embedding_attr_string.empty()) {
242 return errors::InvalidArgument("TPU embedding config is empty.");
243 } else {
244 // Auto populate the feature descriptor so that we can make use
245 // of these fields later.
246 std::string updated_embedding_attr_string;
247 tpu::TPUEmbeddingConfiguration tpu_embedding_config;
248 tpu_embedding_config.ParseFromString(embedding_attr_string);
249 TF_RETURN_IF_ERROR(PopulateMissingFieldsInTPUEmbeddingConfig(
250 &tpu_embedding_config));
251 tpu_embedding_config.SerializeToString(
252 &updated_embedding_attr_string);
253
254 // Execute the TPU embedding partitioner if configured to do so.
255 Node* partitioner_node;
256 TF_ASSIGN_OR_RETURN(
257 const std::string configuration_device_string,
258 get_updated_device_name(configuration_device_name));
259 TF_RETURN_IF_ERROR(AddExecutePartitionerNode(
260 configuration_device_string, updated_embedding_attr_string,
261 input_dependencies, graph, &partitioner_node));
262
263 // Obtain the device strings for configuring the TPU embedding
264 // core on each host.
265 std::vector<std::string> host_device_strings(host_devices.size());
266 for (int i = 0; i < host_devices.size(); ++i) {
267 TF_ASSIGN_OR_RETURN(
268 host_device_strings[i],
269 get_updated_device_name(host_devices[i]->name()));
270 }
271
272 // Add nodes that configure the HBM memory at each host.
273 std::vector<Node*> memory_nodes;
274 memory_nodes.reserve(host_devices.size());
275 for (int i = 0; i < host_devices.size(); ++i) {
276 Node* memory_node;
277 TF_RETURN_IF_ERROR(AddConfigureMemoryNode(
278 host_device_strings[i], partitioner_node, graph,
279 &memory_node));
280 memory_nodes.push_back(memory_node);
281 }
282
283 // Add node to merge the HBM memory configurations.
284 Node* merged_memory_node;
285 TF_RETURN_IF_ERROR(AddCollateMemoryNode(
286 configuration_device_string, memory_nodes, graph,
287 &merged_memory_node));
288
289 // Add the nodes to configure the embeddings at each host.
290 std::vector<Node*> host_embedding_nodes;
291 host_embedding_nodes.reserve(host_devices.size());
292 for (int i = 0; i < host_devices.size(); ++i) {
293 Node* host_embedding_node;
294 TF_RETURN_IF_ERROR(AddConfigureHostNode(
295 host_device_strings[i], updated_embedding_attr_string,
296 partitioner_node, merged_memory_node, graph,
297 &host_embedding_node));
298 host_embedding_nodes.push_back(host_embedding_node);
299 }
300
301 // Add the nodes to specify the ports to connect to on each each.
302 // Note that each TPU worker needs to know how to connect to all
303 // other TPU workers in the system, so these are all-to-all
304 // communication links.
305 std::vector<Node*> connect_embedding_nodes;
306 connect_embedding_nodes.reserve(host_devices.size());
307 for (int i = 0; i < host_devices.size(); ++i) {
308 Node* connect_embedding_node;
309 TF_RETURN_IF_ERROR(AddConnectHostsNode(
310 host_device_strings[i], host_embedding_nodes, graph,
311 &connect_embedding_node));
312 connect_embedding_nodes.push_back(connect_embedding_node);
313 }
314
315 // Add the finalize node that checks that the HBM base addresses
316 // allocated are the same across all TPU worker tasks.
317 Node* finalize_node;
318 TF_RETURN_IF_ERROR(
319 AddFinalizeNode(configuration_device_string, partitioner_node,
320 merged_memory_node, graph, &finalize_node));
321
322 // Wait for the connect and finalize nodes to complete execution.
323 std::vector<Node*> end_nodes(connect_embedding_nodes);
324 end_nodes.push_back(finalize_node);
325
326 TF_RETURN_IF_ERROR(AddSynchronizationNode(
327 configuration_node_def, configuration_device_string,
328 end_nodes, output_dependencies, graph));
329 }
330
331 return OkStatus();
332 }));
333
334 if (VLOG_IS_ON(1)) {
335 DumpGraphToFile("configure_tpu_embedding_after", *graph, options.flib_def);
336 }
337
338 VLOG(1) << "ConfigureTPUEmbeddingRewritePass::Run() finished";
339 return OkStatus();
340 }
341
342 } // namespace tensorflow
343