1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Configuration for TPU Embedding.
17 
18 #include "tensorflow/core/tpu/graph_rewrite/configure_tpu_embedding_rewrite_pass.h"
19 
20 #include <string>
21 
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/core/common_runtime/device_set.h"
24 #include "tensorflow/core/common_runtime/optimization_registry.h"
25 #include "tensorflow/core/graph/graph.h"
26 #include "tensorflow/core/graph/graph_node_util.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/statusor.h"
29 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
30 #include "tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite.h"
31 #include "tensorflow/core/util/device_name_utils.h"
32 #include "tensorflow/core/util/dump_graph.h"
33 
34 namespace tensorflow {
35 namespace {
36 
37 constexpr char kNoOp[] = "NoOp";
38 constexpr char kConfigureOp[] = "ConfigureTPUEmbedding";
39 constexpr char kExecutePartitionerOp[] = "ExecuteTPUEmbeddingPartitioner";
40 constexpr char kConfigureMemoryOp[] = "ConfigureTPUEmbeddingMemory";
41 constexpr char kCollateMemoryOp[] = "CollateTPUEmbeddingMemory";
42 constexpr char kConfigureHostOp[] = "ConfigureTPUEmbeddingHost";
43 constexpr char kConnectHostsOp[] = "ConnectTPUEmbeddingHosts";
44 constexpr char kFinalizeOp[] = "FinalizeTPUEmbedding";
45 constexpr char kEmbeddingConfigurationAttr[] = "config";
46 
AddSynchronizationNode(const NodeDef & sync_node_def,const string & device_name,absl::Span<Node * const> end_nodes,absl::Span<const DistributedTPURewriteHelpers::OutputDependency> output_dependencies,Graph * graph)47 Status AddSynchronizationNode(
48     const NodeDef& sync_node_def, const string& device_name,
49     absl::Span<Node* const> end_nodes,
50     absl::Span<const DistributedTPURewriteHelpers::OutputDependency>
51         output_dependencies,
52     Graph* graph) {
53   NodeDef sync_def;
54   sync_def.set_name(sync_node_def.name());
55   sync_def.set_op(kNoOp);
56   sync_def.set_device(device_name);
57   MergeDebugInfo(NodeDebugInfo(sync_node_def), &sync_def);
58 
59   TF_ASSIGN_OR_RETURN(Node * sync_node, graph->AddNode(sync_def));
60   sync_node->set_assigned_device_name(device_name);
61 
62   // Add control edges from the nodes which must complete execution.
63   for (Node* end_node : end_nodes) {
64     graph->AddControlEdge(end_node, sync_node);
65   }
66 
67   // Replace the output edges.
68   for (const DistributedTPURewriteHelpers::OutputDependency& dep :
69        output_dependencies) {
70     if (dep.dst_input == Graph::kControlSlot) {
71       graph->AddControlEdge(sync_node, dep.dst);
72     } else {
73       graph->AddEdge(sync_node, dep.src_output, dep.dst, dep.dst_input);
74     }
75   }
76   return OkStatus();
77 }
78 
AddSetupPropagationEmbeddingNode(const string & device_name,const string & node_name,const string & op_name,absl::Span<Node * const> input_nodes,Graph * graph,Node ** node)79 Status AddSetupPropagationEmbeddingNode(const string& device_name,
80                                         const string& node_name,
81                                         const string& op_name,
82                                         absl::Span<Node* const> input_nodes,
83                                         Graph* graph, Node** node) {
84   NodeDef node_def;
85   node_def.set_name(node_name);
86   node_def.set_op(op_name);
87   node_def.set_device(device_name);
88   AddNodeAttr("N", static_cast<int>(input_nodes.size()), &node_def);
89   if (!input_nodes.empty()) {
90     MergeDebugInfo(NodeDebugInfo(input_nodes[0]->def()), &node_def);
91   }
92 
93   TF_ASSIGN_OR_RETURN(*node, graph->AddNode(node_def));
94   (*node)->set_assigned_device_name(device_name);
95   // Add inputs from the embedding nodes.
96   for (int i = 0; i < input_nodes.size(); ++i) {
97     graph->AddEdge(input_nodes[i], 0, *node, i);
98   }
99   return OkStatus();
100 }
101 
AddExecutePartitionerNode(const string & configuration_device_name,const string & config,absl::Span<Node * const> input_dependencies,Graph * graph,Node ** partitioner_node)102 Status AddExecutePartitionerNode(const string& configuration_device_name,
103                                  const string& config,
104                                  absl::Span<Node* const> input_dependencies,
105                                  Graph* graph, Node** partitioner_node) {
106   NodeDef partitioner_def;
107   partitioner_def.set_name(graph->NewName("execute_embedding_partitioner"));
108   partitioner_def.set_op(kExecutePartitionerOp);
109   partitioner_def.set_device(configuration_device_name);
110   AddNodeAttr("config", config, &partitioner_def);
111 
112   TF_ASSIGN_OR_RETURN(*partitioner_node, graph->AddNode(partitioner_def));
113   (*partitioner_node)->set_assigned_device_name(configuration_device_name);
114   // Replace the input control edges.
115   for (Node* src_node : input_dependencies) {
116     graph->AddControlEdge(src_node, *partitioner_node);
117   }
118 
119   return OkStatus();
120 }
121 
AddConfigureMemoryNode(const string & host_device_name,Node * partitioner_node,Graph * graph,Node ** embedding_node)122 Status AddConfigureMemoryNode(const string& host_device_name,
123                               Node* partitioner_node, Graph* graph,
124                               Node** embedding_node) {
125   NodeDef embedding_def;
126   embedding_def.set_name(graph->NewName("configure_tpu_embedding_memory"));
127   embedding_def.set_op(kConfigureMemoryOp);
128   embedding_def.set_device(host_device_name);
129 
130   TF_ASSIGN_OR_RETURN(*embedding_node, graph->AddNode(embedding_def));
131   (*embedding_node)->set_assigned_device_name(host_device_name);
132   graph->AddEdge(partitioner_node, 0, *embedding_node, 0);
133   return OkStatus();
134 }
135 
AddCollateMemoryNode(const string & configuration_device_name,absl::Span<Node * const> memory_nodes,Graph * graph,Node ** embedding_node)136 Status AddCollateMemoryNode(const string& configuration_device_name,
137                             absl::Span<Node* const> memory_nodes, Graph* graph,
138                             Node** embedding_node) {
139   return AddSetupPropagationEmbeddingNode(
140       /*device_name=*/configuration_device_name,
141       /*node_name=*/graph->NewName("collate_tpu_embedding_memory"),
142       /*op_name=*/kCollateMemoryOp, /*input_nodes=*/memory_nodes,
143       /*graph=*/graph,
144       /*node=*/embedding_node);
145 }
146 
AddConfigureHostNode(const string & host_device_name,const string & config,Node * partitioner_node,Node * memory_node,Graph * graph,Node ** embedding_node)147 Status AddConfigureHostNode(const string& host_device_name,
148                             const string& config, Node* partitioner_node,
149                             Node* memory_node, Graph* graph,
150                             Node** embedding_node) {
151   NodeDef embedding_def;
152   embedding_def.set_name(graph->NewName("configure_tpu_embedding_host"));
153   embedding_def.set_op(kConfigureHostOp);
154   embedding_def.set_device(host_device_name);
155   AddNodeAttr("config", config, &embedding_def);
156 
157   TF_ASSIGN_OR_RETURN(*embedding_node, graph->AddNode(embedding_def));
158   (*embedding_node)->set_assigned_device_name(host_device_name);
159   // Add inputs from the partitioner node and the memory node.
160   graph->AddEdge(partitioner_node, 0, *embedding_node, 0);
161   graph->AddEdge(memory_node, 0, *embedding_node, 1);
162 
163   return OkStatus();
164 }
165 
AddConnectHostsNode(const string & host_device_name,absl::Span<Node * const> configure_host_nodes,Graph * graph,Node ** connect_node)166 Status AddConnectHostsNode(const string& host_device_name,
167                            absl::Span<Node* const> configure_host_nodes,
168                            Graph* graph, Node** connect_node) {
169   return AddSetupPropagationEmbeddingNode(
170       /*device_name=*/host_device_name,
171       /*node_name=*/graph->NewName("connect_tpu_embedding_hosts"),
172       /*op_name=*/kConnectHostsOp, /*input_nodes=*/configure_host_nodes,
173       /*graph=*/graph,
174       /*node=*/connect_node);
175 }
176 
AddFinalizeNode(const string & configuration_device_name,Node * partitioner_node,Node * memory_node,Graph * graph,Node ** finalize_node)177 Status AddFinalizeNode(const string& configuration_device_name,
178                        Node* partitioner_node, Node* memory_node, Graph* graph,
179                        Node** finalize_node) {
180   NodeDef finalize_def;
181   finalize_def.set_name(graph->NewName("finalize_tpu_embedding"));
182   finalize_def.set_op(kFinalizeOp);
183   finalize_def.set_device(configuration_device_name);
184 
185   TF_ASSIGN_OR_RETURN(*finalize_node, graph->AddNode(finalize_def));
186   (*finalize_node)->set_assigned_device_name(configuration_device_name);
187   // Add inputs from the partitioner node and the memory node.
188   graph->AddEdge(partitioner_node, 0, *finalize_node, 0);
189   graph->AddEdge(memory_node, 0, *finalize_node, 1);
190 
191   return OkStatus();
192 }
193 
194 }  // namespace
195 
Run(const GraphOptimizationPassOptions & options)196 Status ConfigureTPUEmbeddingRewritePass::Run(
197     const GraphOptimizationPassOptions& options) {
198   VLOG(1) << "ConfigureTPUEmbeddingRewritePass::Run";
199 
200   Graph* graph = options.graph->get();
201 
202   if (VLOG_IS_ON(1)) {
203     DumpGraphToFile("configure_tpu_embedding_before", *graph, options.flib_def);
204   }
205 
206   // This pass can only run in the session master, which should fill
207   // in the device_set field to the options.
208   TF_RET_CHECK(options.device_set != nullptr);
209 
210   TF_RETURN_IF_ERROR(
211       DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
212           kConfigureOp, graph, *options.device_set,
213           [](const NodeDef& configuration_node_def,
214              const std::string& configuration_device_name,
215              const std::vector<Device*>& host_devices,
216              const std::vector<Node*>& input_dependencies,
217              const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
218                  output_dependencies,
219              Graph* graph) -> Status {
220             if (host_devices.empty()) {
221               return errors::InvalidArgument("TPU job contains no CPU devices");
222             }
223             TF_RET_CHECK(!host_devices.empty());
224 
225             auto get_updated_device_name =
226                 [](absl::string_view initial_device_name)
227                 -> xla::StatusOr<std::string> {
228               DeviceNameUtils::ParsedName device_spec;
229               TF_RET_CHECK(DeviceNameUtils::ParseFullName(initial_device_name,
230                                                           &device_spec));
231               // Keep job, replica, and task information, but change the
232               // '/device:TPU_SYSTEM:0' specification to '/device:CPU:0'.
233               device_spec.type = "CPU";
234               return DeviceNameUtils::ParsedNameToString(device_spec);
235             };
236 
237             // Must not use embedding_attr_string beyond the lifetime of
238             // configuration_node_def.
239             const std::string& embedding_attr_string = GetNodeAttrString(
240                 AttrSlice(configuration_node_def), kEmbeddingConfigurationAttr);
241             if (embedding_attr_string.empty()) {
242               return errors::InvalidArgument("TPU embedding config is empty.");
243             } else {
244               // Auto populate the feature descriptor so that we can make use
245               // of these fields later.
246               std::string updated_embedding_attr_string;
247               tpu::TPUEmbeddingConfiguration tpu_embedding_config;
248               tpu_embedding_config.ParseFromString(embedding_attr_string);
249               TF_RETURN_IF_ERROR(PopulateMissingFieldsInTPUEmbeddingConfig(
250                   &tpu_embedding_config));
251               tpu_embedding_config.SerializeToString(
252                   &updated_embedding_attr_string);
253 
254               // Execute the TPU embedding partitioner if configured to do so.
255               Node* partitioner_node;
256               TF_ASSIGN_OR_RETURN(
257                   const std::string configuration_device_string,
258                   get_updated_device_name(configuration_device_name));
259               TF_RETURN_IF_ERROR(AddExecutePartitionerNode(
260                   configuration_device_string, updated_embedding_attr_string,
261                   input_dependencies, graph, &partitioner_node));
262 
263               // Obtain the device strings for configuring the TPU embedding
264               // core on each host.
265               std::vector<std::string> host_device_strings(host_devices.size());
266               for (int i = 0; i < host_devices.size(); ++i) {
267                 TF_ASSIGN_OR_RETURN(
268                     host_device_strings[i],
269                     get_updated_device_name(host_devices[i]->name()));
270               }
271 
272               // Add nodes that configure the HBM memory at each host.
273               std::vector<Node*> memory_nodes;
274               memory_nodes.reserve(host_devices.size());
275               for (int i = 0; i < host_devices.size(); ++i) {
276                 Node* memory_node;
277                 TF_RETURN_IF_ERROR(AddConfigureMemoryNode(
278                     host_device_strings[i], partitioner_node, graph,
279                     &memory_node));
280                 memory_nodes.push_back(memory_node);
281               }
282 
283               // Add node to merge the HBM memory configurations.
284               Node* merged_memory_node;
285               TF_RETURN_IF_ERROR(AddCollateMemoryNode(
286                   configuration_device_string, memory_nodes, graph,
287                   &merged_memory_node));
288 
289               // Add the nodes to configure the embeddings at each host.
290               std::vector<Node*> host_embedding_nodes;
291               host_embedding_nodes.reserve(host_devices.size());
292               for (int i = 0; i < host_devices.size(); ++i) {
293                 Node* host_embedding_node;
294                 TF_RETURN_IF_ERROR(AddConfigureHostNode(
295                     host_device_strings[i], updated_embedding_attr_string,
296                     partitioner_node, merged_memory_node, graph,
297                     &host_embedding_node));
298                 host_embedding_nodes.push_back(host_embedding_node);
299               }
300 
301               // Add the nodes to specify the ports to connect to on each each.
302               // Note that each TPU worker needs to know how to connect to all
303               // other TPU workers in the system, so these are all-to-all
304               // communication links.
305               std::vector<Node*> connect_embedding_nodes;
306               connect_embedding_nodes.reserve(host_devices.size());
307               for (int i = 0; i < host_devices.size(); ++i) {
308                 Node* connect_embedding_node;
309                 TF_RETURN_IF_ERROR(AddConnectHostsNode(
310                     host_device_strings[i], host_embedding_nodes, graph,
311                     &connect_embedding_node));
312                 connect_embedding_nodes.push_back(connect_embedding_node);
313               }
314 
315               // Add the finalize node that checks that the HBM base addresses
316               // allocated are the same across all TPU worker tasks.
317               Node* finalize_node;
318               TF_RETURN_IF_ERROR(
319                   AddFinalizeNode(configuration_device_string, partitioner_node,
320                                   merged_memory_node, graph, &finalize_node));
321 
322               // Wait for the connect and finalize nodes to complete execution.
323               std::vector<Node*> end_nodes(connect_embedding_nodes);
324               end_nodes.push_back(finalize_node);
325 
326               TF_RETURN_IF_ERROR(AddSynchronizationNode(
327                   configuration_node_def, configuration_device_string,
328                   end_nodes, output_dependencies, graph));
329             }
330 
331             return OkStatus();
332           }));
333 
334   if (VLOG_IS_ON(1)) {
335     DumpGraphToFile("configure_tpu_embedding_after", *graph, options.flib_def);
336   }
337 
338   VLOG(1) << "ConfigureTPUEmbeddingRewritePass::Run() finished";
339   return OkStatus();
340 }
341 
342 }  // namespace tensorflow
343