xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Rewrites TPUReplicate nodes into replicated computations on TPU.
17 //
18 // To represent a distributed TPU computation, we use the
19 // TPUReplicate operator, that describes a subgraph (represented as a
20 // Tensorflow function) to replicate across a TPU pod.
21 //
22 // Model parallelism and data parallelism:
23 // ---------------------------------------
24 // We support two different kinds of parallelism on TPU:
25 // * data parallelism (replication), or parallelization across batches, and
26 // * model parallelism, or parallelization within a batch.
27 //
28 // The function passed to a TPUReplicate operator is replicated many
29 // times across a TPU pod (data parallelism). The `num_replicas` attribute
30 // controls how many replicas of the computation to create. Replicas are mostly
31 // independent; replicas can only communicate using the CrossReplicaSum
32 // operator, which is typically used to communicate gradients during training.
33 //
34 // Each replica may optionally use more than one TPU core (model
35 // parallelism). The `num_cores_per_replica` attribute controls how many cores
36 // there are per replica. For each core, there is a virtual TPU_REPLICATED_CORE
37 // device that is only valid within replicated TPU computations (e.g.,
38 // TPU_REPLICATED_CORE:0, TPU_REPLICATED_CORE:1, etc.); each TPU_REPLICATED_CORE
39 // device corresponds to one TPU core in every replica.
40 // Each replica has runs its own copy of the computation assigned to each
41 // TPU_REPLICATED_CORE device.
42 //
43 // The Python code is responsible for providing a device_assignment that
44 // describes how the replicated logical cores map to physical cores on the TPU
45 // topology.
46 //
47 // Inputs to TPUReplicate:
48 // ------------------------------
49 // The TPUReplicate operator takes three kinds of inputs, in the
50 // following order:
51 // * per-replica inputs. If there are three per-replica inputs (A, B, C) and two
52 //   replicas, the first six arguments to TPUReplicate will be:
53 //   A0 B0 C0 A1 B1 C1
54 //   where Ai is the A input to the i-th replica.
55 // * distributed inputs. These inputs follow the per-replica inputs.
56 //   If there are two distributed inputs (E, F) and two replicas, the following
57 //   arguments to TPUReplicate will be: E F.
58 //   But there is local E and F on each replica.
59 // * broadcast inputs. These inputs follow the distributed inputs. All
60 //   replicas receive a copy of each of these inputs.
61 // * variables. Resource variables accessed by the computation follow the
62 //   broadcast inputs.
63 //
64 // For example, for a computation with two replicas, three per-replica inputs
65 // (A, B, C), two distributed inputs(E, F), two broadcast inputs (X, Y), and two
66 // variables (V, W), the arguments to TPUReplicate will be:
67 // A0 B0 C0 A1 B1 C1 E F X Y V W
68 // and each replica will receive the following arguments:
69 // A B C E F X Y V W
70 //
71 // Distributed TPU compilation requires that the shapes of all operators
72 // be known statically at compilation time, before any nodes have executed.
73 // Shapes are determined using shape information emitted by InferShapes. It
74 // is not possible to replicate Tensorflow operators with unknown or dynamic
75 // shapes for TPU at present.
76 //
77 // Graph rewrite:
78 // --------------
79 // Compilation replaces TPUReplicate operators with:
80 // * a single TPUCompile node that compiles the computations,
81 // * one TPUExecute node for each TPU device in the system that
82 //   executes the relevant computation,
83 // * one ReadVariableOp for each variable accessed by the replicated
84 //   computation,
85 // * one AssignVariableOp for each variable accessed by the replicated
86 //   computation. An assignment is built even if a variable is only read by the
87 //   computation. We do not know which variables are written until we apply the
88 //   XlaCompiler to the computation, but that does not happen until after the
89 //   rewrite. Conservatively, we write back the values of all variables after
90 //   the computation completes.
91 //   TODO(phawkins): only write back variables that the computation may write.
92 // * one Shape node for each Tensor or Variable input to the computation whose
93 //   shape is not statically known at rewrite time. The input shapes are fed
94 //   to the TPUCompile node.
95 //
96 // To ensure that the reads and writes seem to happen at the right time in the
97 // graph execution, we add control edges from all predecessors of the original
98 // TPUReplicate operator to each of the ReadVariableOp operators.
99 // Similarly, we add control edges from all of the AssignVariableOp operators to
100 // all of the successors of the TPUReplicate operator.
101 //
102 // The TPUReplicate rewrite must run before placement, since resource
103 // variable inputs will have DT_RESOURCE, which cannot be sent across devices,
104 // leading to objections from the placer. The rewrite rewrites the resource
105 // accesses into explicit ReadVariableOp and AssignVariableOp operators that the
106 // placer is free to colocate with the variables.
107 
108 #ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_
109 #define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_
110 
111 #include <string>
112 #include <vector>
113 
114 #include "absl/container/flat_hash_map.h"
115 #include "absl/container/node_hash_map.h"
116 #include "absl/types/span.h"
117 #include "tensorflow/compiler/jit/shape_inference.h"
118 #include "tensorflow/compiler/xla/service/computation_placer.h"
119 #include "tensorflow/core/common_runtime/optimization_registry.h"
120 #include "tensorflow/core/framework/function.h"
121 #include "tensorflow/core/graph/graph.h"
122 #include "tensorflow/core/platform/env.h"
123 #include "tensorflow/stream_executor/tpu/tpu_topology.h"
124 
125 namespace tensorflow {
126 
127 // Replaces clusters assigned to TPU_SYSTEM devices with
128 // TPUCompile and TPUExecute nodes assigned to the corresponding
129 // TPU devices.
130 class DistributedTPURewritePass : public GraphOptimizationPass {
131  public:
132   static void SetDistributedTpuRewritePassOptions(
133       bool distribute_vars, bool allow_xla_spmd_partition,
134       bool replicate_inputs_outputs_by_default_for_xla_spmd,
135       bool enable_cross_replica_sharding_mirrored_variables,
136       bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast,
137       bool enable_multicore_locking, bool use_nd_sharding_ops);
138 
139   Status Run(const GraphOptimizationPassOptions& options) override;
140 
141   // The following methods are public only for the use of unit tests.
142 
143   // See comment at the top of the file for how the inputs are ordered.
144   // Encapsulates the different TPU replicated node input and output
145   // information, and provide common APIs over them.
146   class ParameterInfo {
147    public:
ParameterInfo()148     ParameterInfo() {}
ParameterInfo(int64_t num_replicas,int64_t num_per_replica_args,int64_t num_distributed_args,int64_t num_broadcast_args,int64_t num_variables,int64_t num_guaranteed_constants,int64_t num_retvals_per_replica)149     ParameterInfo(int64_t num_replicas, int64_t num_per_replica_args,
150                   int64_t num_distributed_args, int64_t num_broadcast_args,
151                   int64_t num_variables, int64_t num_guaranteed_constants,
152                   int64_t num_retvals_per_replica)
153         : num_replicas_(num_replicas),
154           num_per_replica_args_(num_per_replica_args),
155           num_distributed_args_(num_distributed_args),
156           num_broadcast_args_(num_broadcast_args),
157           num_variables_(num_variables),
158           num_guaranteed_constants_(num_guaranteed_constants),
159           num_retvals_per_replica_(num_retvals_per_replica) {}
160 
NumReplicas()161     int64_t NumReplicas() const { return num_replicas_; }
162 
NumPerReplicaArgs()163     int64_t NumPerReplicaArgs() const { return num_per_replica_args_; }
164 
NumDistributedArgs()165     int64_t NumDistributedArgs() const { return num_distributed_args_; }
166 
NumBroadcastArgs()167     int64_t NumBroadcastArgs() const { return num_broadcast_args_; }
168 
NumVariables()169     int64_t NumVariables() const { return num_variables_; }
170 
NumGuaranteedConstants()171     int64_t NumGuaranteedConstants() const { return num_guaranteed_constants_; }
172 
NumRetvalsPerReplica()173     int64_t NumRetvalsPerReplica() const { return num_retvals_per_replica_; }
174 
IsPerReplicaArg(int64_t index)175     bool IsPerReplicaArg(int64_t index) const {
176       return index < num_per_replica_args_;
177     }
178 
IsDistributedArg(int64_t index)179     bool IsDistributedArg(int64_t index) const {
180       return index >= num_per_replica_args_ &&
181              index < (num_per_replica_args_ + num_distributed_args_);
182     }
183 
IsBroadcastArg(int64_t index)184     bool IsBroadcastArg(int64_t index) const {
185       return (index >= num_per_replica_args_ + num_distributed_args_) &&
186              index < (num_per_replica_args_ + num_distributed_args_ +
187                       num_broadcast_args_);
188     }
189 
IsVariableArg(int64_t index)190     bool IsVariableArg(int64_t index) const {
191       return index >= (num_per_replica_args_ + num_distributed_args_ +
192                        num_broadcast_args_) &&
193              index < (num_per_replica_args_ + num_distributed_args_ +
194                       num_broadcast_args_ + num_variables_);
195     }
196 
IsConstantArg(int64_t index)197     bool IsConstantArg(int64_t index) const {
198       return index >= (num_per_replica_args_ + num_distributed_args_ +
199                        num_broadcast_args_ + num_variables_) &&
200              index < (num_per_replica_args_ + num_distributed_args_ +
201                       num_broadcast_args_ + num_variables_ +
202                       num_guaranteed_constants_);
203     }
204 
205     // Returns the number of inputs which has been received by the host.
NumInputsFromHost()206     int64_t NumInputsFromHost() const {
207       return num_replicas_ * num_per_replica_args_ + num_distributed_args_ +
208              num_broadcast_args_ + num_variables_ + num_guaranteed_constants_;
209     }
210 
211     // Returns the number of inputs which will be sent to each replica.
NumInputsToEachReplica()212     int64_t NumInputsToEachReplica() const {
213       return num_per_replica_args_ + num_distributed_args_ +
214              num_broadcast_args_ + num_variables_ + num_guaranteed_constants_;
215     }
216 
217     // Returns the total number of output values returned to the host (for all
218     // replicas).
NumOutputsToHost()219     int64_t NumOutputsToHost() const {
220       return num_replicas_ * num_retvals_per_replica_;
221     }
222 
223     // Returns the position of the first per-replica argument, within the set
224     // of all hosts arguments.
225     // Broadcast arguments follow the distributed arguments.
FirstBroadcastArgFromHost()226     int64_t FirstBroadcastArgFromHost() const {
227       return num_replicas_ * num_per_replica_args_ + num_distributed_args_;
228     }
229 
230     // Indices of mirrored variables across replicas, which should be
231     // categorized as per_replica_args.
mirrored_variable_indices()232     const std::set<int64_t>& mirrored_variable_indices() const {
233       return mirrored_variable_indices_;
234     }
mutable_mirrored_variable_indices()235     std::set<int64_t>* mutable_mirrored_variable_indices() {
236       return &mirrored_variable_indices_;
237     }
238 
239    private:
240     int64_t num_replicas_ = 1;
241     int64_t num_per_replica_args_ = 0;
242     int64_t num_distributed_args_ = 0;
243     int64_t num_broadcast_args_ = 0;
244     int64_t num_variables_ = 0;
245     int64_t num_guaranteed_constants_ = 0;
246     int64_t num_retvals_per_replica_ = 0;
247     std::set<int64_t> mirrored_variable_indices_;
248   };
249 
250   // Mapping from TPUReplicate cluster name to tpu device names. Value is a
251   // mapping from [replica][core] to a TF device name.
252   typedef absl::flat_hash_map<string, std::vector<std::vector<string>>>
253       TPUReplicateDeviceNamesMapping;
254 
255   // Determines which devices to use to run the computation.
256   // Inputs:
257   // * num_tpus_per_task: the number of TPU devices attached to each task
258   // * tpu_devices: a [task][device] collection of TPU devices
259   // * num_replicas: the number of replicas requested
260   // * num_cores_per_replica: the number of cores in each computation instance
261   // * topology_attr: the topology TPUReplicate attribute
262   // * device_assignment_attr: the device_assignment TPUReplicate attribute
263   // Outputs:
264   // * tf_device_assignment: a mapping from [replica][core] to a TF device name
265   // * devices_to_lock: a flat array of integer indices corresponding to devices
266   //   that are used in this computation. They will be locked before the
267   //   TPUExecute kernels are run, to ensure that the kernels from concurrent
268   //   multi-core executions are enqueued consistently, i.e., all kernels from
269   //   computation A before any kernel from computation B, thus preventing
270   //   deadlock.
271   // * xla_device_assignment: a mapping from [replica][core] to a linearized TPU
272   //   coordinate.
273   // TODO(phawkins): change tf_device_assignment to an xla::Array2D.
274   static Status BuildDeviceAssignment(
275       const tpu::TpuTopologyExternal& topology, int num_tpus_per_task,
276       const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas,
277       int num_cores_per_replica, const string& topology_attr,
278       absl::Span<const int> device_assignment_attr,
279       std::vector<std::vector<string>>* tf_device_assignment,
280       std::vector<int>* devices_to_lock,
281       std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment);
282 
283   // Returns the `computation` graph attached to TPUReplicate operator
284   // `node`. `flr` is a FunctionLibraryRuntime to use when
285   // instantiating the function body. Sets `*arg_types` and
286   // `*retval_types` to the argument/return types of the function.
287   static Status GetComputationForTPUReplicateOp(const NameAttrList& function,
288                                                 FunctionLibraryRuntime* flr,
289                                                 Graph* computation,
290                                                 DataTypeVector* arg_types,
291                                                 DataTypeVector* retval_types);
292 
293   // Returns the shapes of the argument tensors and return values of the
294   // TPUReplicate operator `node` using the _output_shapes,
295   // _output_handle_shapes, and _output_handle_types annotations on the input
296   // nodes. Expects inputs in the following order (see comment at top of file):
297   // * num_replicas * num_per_replica_args per-replica inputs,
298   // * num_broadcast_args broadcast inputs,
299   // * num_variables variable inputs.
300   // Returns an error if the input shapes to `node` are not statically known.
301   // Also verifies that all replicas have identical input shapes for their
302   // per-replica inputs.
303   static Status GetArgAndRetvalShapes(
304       const GraphShapeInfo& shape_info, const Node& node,
305       const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes,
306       std::vector<InferredShape>* retval_shapes);
307 
308   // Assigns arguments and return values to cores. The assignment is represented
309   // as an XLA op sharding, so that an argument can be replicated across cores.
310   // `arg_sharding` and `retval_sharding` are vectors of shardings indexed by
311   // argument/retval number.
312   // `arg_fast_mem` is vector of fast_mem indication which is indexed by
313   // argument number.
314   static Status AssignArgsAndRetvalsToCores(
315       int num_cores_per_replica, const ParameterInfo& params_info,
316       const DataTypeVector& arg_types,
317       const std::vector<InferredShape>& arg_shapes,
318       const DataTypeVector& retval_types,
319       const std::vector<InferredShape>& retval_shapes, const Graph& graph,
320       const Node* replicate_node, FunctionLibraryRuntime* flr,
321       bool allow_parameter_replication_for_spmd,
322       std::vector<::xla::OpSharding>* arg_sharding,
323       std::vector<bool>* arg_fast_mem,
324       std::vector<::xla::OpSharding>* retval_sharding,
325       std::vector<std::string>* arg_names);
326 
327   // Populates `*variables` with the "variables" inputs to `index`-th output of
328   // `node`.
329   struct VariableInput {
330     Node* node;
331     int index;
332 
333     // Type of the variable's value. Note that this is different to the type of
334     // the output of 'variable', which is always DT_RESOURCE.
335     DataType dtype;
336   };
337   static Status FindVariableInputs(const Node& node,
338                                    const NameRangeMap& input_range_map,
339                                    std::vector<VariableInput>* variables);
340 
341   // Populates '*guaranteed_constants' with the "guaranteed_constants" inputs
342   // to 'node'.
343   static Status FindGuaranteedConstantInputs(
344       const Node& node, const NameRangeMap& input_range_map,
345       std::vector<Node*>* guaranteed_constants);
346 
347   // Builds Shape nodes that compute the shapes of arguments whose shapes are
348   // not statically known.
349   static Status BuildDynamicShapeNodes(
350       const Node& replicate_node, const std::vector<InferredShape>& arg_shapes,
351       const ParameterInfo& params_info,
352       const std::vector<Node*>& variable_reads, Graph* graph,
353       std::vector<Node*>* dynamic_shape_nodes);
354 
355   // Builds a TPUCompile node that compiles the computation in
356   // `function_names`. calls `nodes`.
357   // TODO(b/33943292): at present, for model parallelism with Send/Recv to work
358   // the `nodes` must correspond to the computations assigned to TPU:0,
359   // TPU:1, ... in order since XLA hard-codes the chip IDs in the generated
360   // executables.
361   static Status BuildCompileNode(
362       const Node* replicate_node, const NameAttrList& function,
363       uint64 library_fingerprint, const ParameterInfo& params_info,
364       const std::vector<InferredShape>& arg_shapes,
365       const DataTypeVector& arg_types,
366       const std::vector<Node*>& guaranteed_constant_nodes,
367       const string& session_handle,
368       const std::vector<::xla::OpSharding>& arg_sharding,
369       const std::vector<bool>& arg_fast_mem,
370       const std::vector<std::string>& arg_names,
371       const std::vector<::xla::OpSharding>& retval_sharding,
372       int num_cores_per_replica, const string& compile_device,
373       const xla::DeviceAssignment* xla_device_assignment,
374       const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
375       Node** compile_node, int64_t autotuner_thresh);
376 
377   // Builds a TPUCompileSucceededAssert node that verifies that compilation
378   // succeeded and replaces the TPUCompilationStatus node in the graph.
379   static Status BuildCompilationStatusReturnNodes(
380       Node* replicate_node, Node* compile_node,
381       absl::Span<const int> devices_to_lock, Node** control_after_compilation,
382       Node** multilock_acquire, Graph* graph);
383 
384   // Builds ReadVariableOp nodes that read `variables`, with a control
385   // edges that ensure they happen after `control_predecessor`.
386   static Status BuildVariableReads(absl::Span<const VariableInput> variables,
387                                    Node* control_predecessor, Graph* graph,
388                                    std::vector<Node*>* variable_reads);
389 
390   // Returns true if graph or functions contain resource write op, otherwise
391   // return false.
392   // TODO(b/137048563): Recognize unused resource rewrite op.
393   static bool ContainsResourceWriteOp(const Graph& graph,
394                                       const FunctionLibraryDefinition& fld);
395   // Struct that describes a variable value to be written back from TPUExecute.
396   struct VariableWrite {
397     // A node:output pair containing a boolean tensor that determines whether
398     // the value should be written back.
399     Node* predicate;
400     int predicate_output;
401 
402     // A node:output pair containing the value to be written back.
403     Node* value;
404     int value_output;
405   };
406 
407   // Builds AssignVariableOp nodes that write `variables` with the values from
408   // `variable_writes`, with control edges that ensure the writes happen before
409   // `control_successor`.
410   static Status BuildVariableWrites(
411       absl::Span<const VariableInput> variables, Node* control_successor,
412       absl::Span<const VariableWrite> variable_writes, Graph* graph);
413 
414   // Builds TPUExecute operators assigned to each TPU device
415   // involved in the computation.
416   // Arguments:
417   // * `params_info` is the structure containing the information about the
418   //    TPUReplicate node inputs and outputs.
419   // * `num_tasks` is the number of TensorFlow tasks in the slice.
420   // * `num_cores_per_replica` is the number of cores which are dedicated to
421   //    each replica.
422   // * `replicate_node` is the original TPUReplicate node.
423   // * `arg_names` are the names of the arguments to the computation function
424   //    passed as argument to TPUReplicate, including per-replica,
425   //    broadcast, and variable arguments.
426   // * `arg_types` are the corresponding types of the arguments.
427   // * `arg_shapes` are the corresponding shapes (and handle types/shapes, if
428   //    applicable).
429   // * `arg_shardings` and `retval_shardings` are mappings from
430   //    arguments/return indices to shardings, as returned by
431   //    `AssignArgsAndRetvalsToCores`.
432   // * `pod_devices` lists the devices to assign to each core of each replica.
433   // * `variable_reads` is a vectors of ReadVariableOp operators, one for each
434   //    variable argument to the computation.
435   // * The execute operators will have a control edge from
436   //   `control_predecessor` and another control edge to `control_successor`.
437   // Populates '*variable_writes' with information about variable values to
438   // write back.
439   static Status BuildExecuteNodes(
440       const ParameterInfo& params_info, int num_tasks,
441       int num_cores_per_replica, const Node& replicate_node,
442       const std::vector<std::string>& arg_names,
443       const DataTypeVector& arg_types,
444       const std::vector<InferredShape>& arg_shapes,
445       const DataTypeVector& retval_types,
446       const std::vector<::xla::OpSharding>& arg_shardings,
447       const std::vector<::xla::OpSharding>& retval_shardings,
448       const std::vector<std::vector<string>>& tpu_device_names,
449       Node* compile_node, const std::vector<Node*>& variable_reads,
450       Node* control_predecessor, Node* control_successor,
451       Node* multilock_acquire, std::vector<VariableWrite>* variable_writes,
452       Graph* graph);
453 
454   // Connects the compile node to all the host transfer nodes, and removes the
455   // key placeholder node that was previously standing in for it.
456   // Arguments:
457   // * `compile_node` is the TPUCompile node that has been added to the graph.
458   // * `key_placeholder_node` is the placeholder node to send the key to all the
459   // host
460   // * transfer nodes in the original graph.
461   // * `graph` is the graph being rewritten.
462   static Status ConnectHostComputeNodes(Node* compile_node,
463                                         Node* key_placeholder_node,
464                                         Graph* graph);
465 
466   // Map from a Node in an outside_compilation cluster in the original graph to
467   // the list of Nodes, one for each replica, that it is expanded into during
468   // replication.
469   typedef absl::node_hash_map<Node*, std::vector<Node*>> NodeToNodeReplicasMap;
470 
471   // Map from the name of an outside_compilation cluster to the model-parallel
472   // core index that the HostCompute Op should be placed on in that cluster.
473   typedef std::map<string, int> HostComputeCoreMap;
474 
475   // Map from the name of an outside_compilation cluster to the list of Nodes
476   // that should run on the host for that cluster.
477   typedef std::map<string, std::vector<Node*>> OutsideCompilationNodeMap;
478 
479   // Copies the outside_compilation nodes in a cluster to create replica
480   // replica_index.
481   static Status CopyOutsideCompilationNodes(
482       int replica_index, const std::vector<Node*>& outside_compilation_nodes,
483       const DeviceNameUtils::ParsedName& tpu_device,
484       const DeviceNameUtils::ParsedName& partial_device,
485       NodeToNodeReplicasMap* node_images, Graph* graph);
486 
487   // Replicates all the nodes in outside_compilation clusters in a compiled
488   // computation.
489   static Status ReplicateOutsideCompilationNodes(
490       const std::vector<std::vector<string>>& tf_device_assignment,
491       const HostComputeCoreMap& host_compute_core,
492       const OutsideCompilationNodeMap& outside_compilation_nodes,
493       NodeToNodeReplicasMap* node_images, Graph* graph);
494 
495   // Lifts the edges between original outside_compilation nodes in a cluster
496   // onto their replicas.
497   static Status CopyOutsideCompilationEdges(
498       const std::vector<Node*>& outside_compilation_nodes,
499       const NodeToNodeReplicasMap& node_images,
500       const std::unordered_map<string, Node*> outside_compilation_inputs,
501       Graph* graph);
502 
503   // Lifts all the edges in outside_compilation clusters in a compiled
504   // computation to their replicas.
505   static Status ReplicateOutsideCompilationEdges(
506       const OutsideCompilationNodeMap& outside_compilation_nodes,
507       const NodeToNodeReplicasMap& node_images,
508       const std::unordered_map<string, Node*> outside_compilation_inputs,
509       Graph* graph);
510 
511   // Removes all the original outside_compilation nodes from the graph,
512   // following replication.
513   static Status RemoveOutsideCompilationNodes(
514       const NodeToNodeReplicasMap& node_images, Graph* graph);
515 
516   // Lowers outside compilation functional nodes (If/While/function call).
517   // Otherwise, when we have multiple workers, device placer will not be able to
518   // place nodes if outside compilation has DT_RESOURCE inputs (e.g. a
519   // DT_RESOURCE input fed into multiple While nodes on different devices).
520   static Status LowerOutsideCompilationFunctionalNodes(
521       Graph* g, FunctionLibraryDefinition& flib_def,
522       const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping);
523 
524   // Parses the 'host_compute_core' attribute on replicate_node to get the
525   // replicated core id of each outside_compilation cluster.
526   static Status ParseHostComputeCores(
527       const Node& replicate_node,
528       const OutsideCompilationNodeMap& outside_compilation_nodes,
529       HostComputeCoreMap* host_compute_core);
530 
531   // Gets the physical topology information about the TPU system.
532   static Status GetDeviceTopology(
533       const DeviceSet& device_set, const Node& replicate_node,
534       int* num_replicas, int* num_cores_per_replica, int* num_tasks,
535       std::vector<std::vector<string>>* tf_device_assignment,
536       std::vector<int>* devices_to_lock,
537       std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment,
538       string* tpu_compilation_device);
539 
540   // Gets the types of args, retvals, and parameters.
541   static Status GetIOTypes(
542       int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr,
543       Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function,
544       std::unique_ptr<Graph>* computation, DataTypeVector* arg_types,
545       DataTypeVector* retval_types, ParameterInfo* params_info);
546 
547   // Find known constants and deals with variable reads.
548   static Status DealWithConstantsAndVariables(
549       const Node& replicate_node, const NameRangeMap& input_name_map,
550       Graph* graph, Node* host_transfer_sequencer, Node* control_before,
551       Node* control_after, absl::Span<const VariableInput> variable_nodes,
552       std::vector<Node*>* guaranteed_constant_nodes,
553       std::vector<Node*>* variable_reads);
554 
555   // Adds NoOp nodes for sequencing computation and variable reads/writes.
556   static Status BuildSequencingNodes(const string& tpu_compilation_device,
557                                      const Node& replicate_node, Graph* graph,
558                                      Node** host_transfer_sequencer,
559                                      Node** control_before,
560                                      Node** control_after);
561 
562   // Performs the pass's rewrite on a TPUReplicate node `node`.
563   static Status RewriteTPUReplicateNode(
564       const string& session_handle, const DeviceSet& device_set,
565       Node* replicate_node, FunctionLibraryDefinition* flib_def,
566       FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node,
567       const OutsideCompilationNodeMap& outside_compilation_nodes,
568       const std::vector<Node*>& head_tail_outside_compilation_nodes,
569       NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph,
570       const GraphShapeInfo& shape_info,
571       TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping,
572       int64_t autotuner_thresh);
573 
574   // Performs host training loop optimization. For example, when TPUExecute
575   // node is inside a while loop, then model weight variables can be sharded
576   // in XLA preferred layout and then unsharded only at the very last iteration
577   // to reduce the number of all_gather.
578   static Status PerformHostTrainingLoopOptimization(
579       Graph* graph, FunctionLibraryDefinition* flib_def,
580       FunctionLibraryRuntime* flr);
581 
582   // Heuristically place some nodes with unassigned devices on TPUs for
583   // performance reasons.
584   static Status PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph* graph);
585 
586   // Updates the head and tail outside compiled nodes so that nodes have the
587   // correct device and removes the replication and outside compilation
588   // attributes so that these nodes do not trigger further graph optimization
589   // passes.
590   static Status UpdateHeadTailOutsideCompilation(
591       const std::vector<std::vector<string>>& tf_device_assignment,
592       const std::vector<Node*>& head_tail_outside_compilation_nodes);
593 
594  private:
595   static bool distribute_vars_;
596   static bool allow_xla_spmd_partition_;
597   static bool replicate_inputs_outputs_by_default_for_xla_spmd_;
598   static bool enable_cross_replica_sharding_mirrored_variables_;
599   static bool enable_automatic_model_parallelism_;
600   static bool enable_xla_param_broadcast_;
601   static bool enable_multicore_locking_;
602   static bool use_nd_sharding_ops_;
603   Status InternalRun(const GraphOptimizationPassOptions& options);
604 };
605 
606 }  // namespace tensorflow
607 
608 #endif  // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_
609