1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Rewrites TPUReplicate nodes into replicated computations on TPU. 17 // 18 // To represent a distributed TPU computation, we use the 19 // TPUReplicate operator, that describes a subgraph (represented as a 20 // Tensorflow function) to replicate across a TPU pod. 21 // 22 // Model parallelism and data parallelism: 23 // --------------------------------------- 24 // We support two different kinds of parallelism on TPU: 25 // * data parallelism (replication), or parallelization across batches, and 26 // * model parallelism, or parallelization within a batch. 27 // 28 // The function passed to a TPUReplicate operator is replicated many 29 // times across a TPU pod (data parallelism). The `num_replicas` attribute 30 // controls how many replicas of the computation to create. Replicas are mostly 31 // independent; replicas can only communicate using the CrossReplicaSum 32 // operator, which is typically used to communicate gradients during training. 33 // 34 // Each replica may optionally use more than one TPU core (model 35 // parallelism). The `num_cores_per_replica` attribute controls how many cores 36 // there are per replica. For each core, there is a virtual TPU_REPLICATED_CORE 37 // device that is only valid within replicated TPU computations (e.g., 38 // TPU_REPLICATED_CORE:0, TPU_REPLICATED_CORE:1, etc.); each TPU_REPLICATED_CORE 39 // device corresponds to one TPU core in every replica. 40 // Each replica has runs its own copy of the computation assigned to each 41 // TPU_REPLICATED_CORE device. 42 // 43 // The Python code is responsible for providing a device_assignment that 44 // describes how the replicated logical cores map to physical cores on the TPU 45 // topology. 46 // 47 // Inputs to TPUReplicate: 48 // ------------------------------ 49 // The TPUReplicate operator takes three kinds of inputs, in the 50 // following order: 51 // * per-replica inputs. If there are three per-replica inputs (A, B, C) and two 52 // replicas, the first six arguments to TPUReplicate will be: 53 // A0 B0 C0 A1 B1 C1 54 // where Ai is the A input to the i-th replica. 55 // * distributed inputs. These inputs follow the per-replica inputs. 56 // If there are two distributed inputs (E, F) and two replicas, the following 57 // arguments to TPUReplicate will be: E F. 58 // But there is local E and F on each replica. 59 // * broadcast inputs. These inputs follow the distributed inputs. All 60 // replicas receive a copy of each of these inputs. 61 // * variables. Resource variables accessed by the computation follow the 62 // broadcast inputs. 63 // 64 // For example, for a computation with two replicas, three per-replica inputs 65 // (A, B, C), two distributed inputs(E, F), two broadcast inputs (X, Y), and two 66 // variables (V, W), the arguments to TPUReplicate will be: 67 // A0 B0 C0 A1 B1 C1 E F X Y V W 68 // and each replica will receive the following arguments: 69 // A B C E F X Y V W 70 // 71 // Distributed TPU compilation requires that the shapes of all operators 72 // be known statically at compilation time, before any nodes have executed. 73 // Shapes are determined using shape information emitted by InferShapes. It 74 // is not possible to replicate Tensorflow operators with unknown or dynamic 75 // shapes for TPU at present. 76 // 77 // Graph rewrite: 78 // -------------- 79 // Compilation replaces TPUReplicate operators with: 80 // * a single TPUCompile node that compiles the computations, 81 // * one TPUExecute node for each TPU device in the system that 82 // executes the relevant computation, 83 // * one ReadVariableOp for each variable accessed by the replicated 84 // computation, 85 // * one AssignVariableOp for each variable accessed by the replicated 86 // computation. An assignment is built even if a variable is only read by the 87 // computation. We do not know which variables are written until we apply the 88 // XlaCompiler to the computation, but that does not happen until after the 89 // rewrite. Conservatively, we write back the values of all variables after 90 // the computation completes. 91 // TODO(phawkins): only write back variables that the computation may write. 92 // * one Shape node for each Tensor or Variable input to the computation whose 93 // shape is not statically known at rewrite time. The input shapes are fed 94 // to the TPUCompile node. 95 // 96 // To ensure that the reads and writes seem to happen at the right time in the 97 // graph execution, we add control edges from all predecessors of the original 98 // TPUReplicate operator to each of the ReadVariableOp operators. 99 // Similarly, we add control edges from all of the AssignVariableOp operators to 100 // all of the successors of the TPUReplicate operator. 101 // 102 // The TPUReplicate rewrite must run before placement, since resource 103 // variable inputs will have DT_RESOURCE, which cannot be sent across devices, 104 // leading to objections from the placer. The rewrite rewrites the resource 105 // accesses into explicit ReadVariableOp and AssignVariableOp operators that the 106 // placer is free to colocate with the variables. 107 108 #ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_ 109 #define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_ 110 111 #include <string> 112 #include <vector> 113 114 #include "absl/container/flat_hash_map.h" 115 #include "absl/container/node_hash_map.h" 116 #include "absl/types/span.h" 117 #include "tensorflow/compiler/jit/shape_inference.h" 118 #include "tensorflow/compiler/xla/service/computation_placer.h" 119 #include "tensorflow/core/common_runtime/optimization_registry.h" 120 #include "tensorflow/core/framework/function.h" 121 #include "tensorflow/core/graph/graph.h" 122 #include "tensorflow/core/platform/env.h" 123 #include "tensorflow/stream_executor/tpu/tpu_topology.h" 124 125 namespace tensorflow { 126 127 // Replaces clusters assigned to TPU_SYSTEM devices with 128 // TPUCompile and TPUExecute nodes assigned to the corresponding 129 // TPU devices. 130 class DistributedTPURewritePass : public GraphOptimizationPass { 131 public: 132 static void SetDistributedTpuRewritePassOptions( 133 bool distribute_vars, bool allow_xla_spmd_partition, 134 bool replicate_inputs_outputs_by_default_for_xla_spmd, 135 bool enable_cross_replica_sharding_mirrored_variables, 136 bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast, 137 bool enable_multicore_locking, bool use_nd_sharding_ops); 138 139 Status Run(const GraphOptimizationPassOptions& options) override; 140 141 // The following methods are public only for the use of unit tests. 142 143 // See comment at the top of the file for how the inputs are ordered. 144 // Encapsulates the different TPU replicated node input and output 145 // information, and provide common APIs over them. 146 class ParameterInfo { 147 public: ParameterInfo()148 ParameterInfo() {} ParameterInfo(int64_t num_replicas,int64_t num_per_replica_args,int64_t num_distributed_args,int64_t num_broadcast_args,int64_t num_variables,int64_t num_guaranteed_constants,int64_t num_retvals_per_replica)149 ParameterInfo(int64_t num_replicas, int64_t num_per_replica_args, 150 int64_t num_distributed_args, int64_t num_broadcast_args, 151 int64_t num_variables, int64_t num_guaranteed_constants, 152 int64_t num_retvals_per_replica) 153 : num_replicas_(num_replicas), 154 num_per_replica_args_(num_per_replica_args), 155 num_distributed_args_(num_distributed_args), 156 num_broadcast_args_(num_broadcast_args), 157 num_variables_(num_variables), 158 num_guaranteed_constants_(num_guaranteed_constants), 159 num_retvals_per_replica_(num_retvals_per_replica) {} 160 NumReplicas()161 int64_t NumReplicas() const { return num_replicas_; } 162 NumPerReplicaArgs()163 int64_t NumPerReplicaArgs() const { return num_per_replica_args_; } 164 NumDistributedArgs()165 int64_t NumDistributedArgs() const { return num_distributed_args_; } 166 NumBroadcastArgs()167 int64_t NumBroadcastArgs() const { return num_broadcast_args_; } 168 NumVariables()169 int64_t NumVariables() const { return num_variables_; } 170 NumGuaranteedConstants()171 int64_t NumGuaranteedConstants() const { return num_guaranteed_constants_; } 172 NumRetvalsPerReplica()173 int64_t NumRetvalsPerReplica() const { return num_retvals_per_replica_; } 174 IsPerReplicaArg(int64_t index)175 bool IsPerReplicaArg(int64_t index) const { 176 return index < num_per_replica_args_; 177 } 178 IsDistributedArg(int64_t index)179 bool IsDistributedArg(int64_t index) const { 180 return index >= num_per_replica_args_ && 181 index < (num_per_replica_args_ + num_distributed_args_); 182 } 183 IsBroadcastArg(int64_t index)184 bool IsBroadcastArg(int64_t index) const { 185 return (index >= num_per_replica_args_ + num_distributed_args_) && 186 index < (num_per_replica_args_ + num_distributed_args_ + 187 num_broadcast_args_); 188 } 189 IsVariableArg(int64_t index)190 bool IsVariableArg(int64_t index) const { 191 return index >= (num_per_replica_args_ + num_distributed_args_ + 192 num_broadcast_args_) && 193 index < (num_per_replica_args_ + num_distributed_args_ + 194 num_broadcast_args_ + num_variables_); 195 } 196 IsConstantArg(int64_t index)197 bool IsConstantArg(int64_t index) const { 198 return index >= (num_per_replica_args_ + num_distributed_args_ + 199 num_broadcast_args_ + num_variables_) && 200 index < (num_per_replica_args_ + num_distributed_args_ + 201 num_broadcast_args_ + num_variables_ + 202 num_guaranteed_constants_); 203 } 204 205 // Returns the number of inputs which has been received by the host. NumInputsFromHost()206 int64_t NumInputsFromHost() const { 207 return num_replicas_ * num_per_replica_args_ + num_distributed_args_ + 208 num_broadcast_args_ + num_variables_ + num_guaranteed_constants_; 209 } 210 211 // Returns the number of inputs which will be sent to each replica. NumInputsToEachReplica()212 int64_t NumInputsToEachReplica() const { 213 return num_per_replica_args_ + num_distributed_args_ + 214 num_broadcast_args_ + num_variables_ + num_guaranteed_constants_; 215 } 216 217 // Returns the total number of output values returned to the host (for all 218 // replicas). NumOutputsToHost()219 int64_t NumOutputsToHost() const { 220 return num_replicas_ * num_retvals_per_replica_; 221 } 222 223 // Returns the position of the first per-replica argument, within the set 224 // of all hosts arguments. 225 // Broadcast arguments follow the distributed arguments. FirstBroadcastArgFromHost()226 int64_t FirstBroadcastArgFromHost() const { 227 return num_replicas_ * num_per_replica_args_ + num_distributed_args_; 228 } 229 230 // Indices of mirrored variables across replicas, which should be 231 // categorized as per_replica_args. mirrored_variable_indices()232 const std::set<int64_t>& mirrored_variable_indices() const { 233 return mirrored_variable_indices_; 234 } mutable_mirrored_variable_indices()235 std::set<int64_t>* mutable_mirrored_variable_indices() { 236 return &mirrored_variable_indices_; 237 } 238 239 private: 240 int64_t num_replicas_ = 1; 241 int64_t num_per_replica_args_ = 0; 242 int64_t num_distributed_args_ = 0; 243 int64_t num_broadcast_args_ = 0; 244 int64_t num_variables_ = 0; 245 int64_t num_guaranteed_constants_ = 0; 246 int64_t num_retvals_per_replica_ = 0; 247 std::set<int64_t> mirrored_variable_indices_; 248 }; 249 250 // Mapping from TPUReplicate cluster name to tpu device names. Value is a 251 // mapping from [replica][core] to a TF device name. 252 typedef absl::flat_hash_map<string, std::vector<std::vector<string>>> 253 TPUReplicateDeviceNamesMapping; 254 255 // Determines which devices to use to run the computation. 256 // Inputs: 257 // * num_tpus_per_task: the number of TPU devices attached to each task 258 // * tpu_devices: a [task][device] collection of TPU devices 259 // * num_replicas: the number of replicas requested 260 // * num_cores_per_replica: the number of cores in each computation instance 261 // * topology_attr: the topology TPUReplicate attribute 262 // * device_assignment_attr: the device_assignment TPUReplicate attribute 263 // Outputs: 264 // * tf_device_assignment: a mapping from [replica][core] to a TF device name 265 // * devices_to_lock: a flat array of integer indices corresponding to devices 266 // that are used in this computation. They will be locked before the 267 // TPUExecute kernels are run, to ensure that the kernels from concurrent 268 // multi-core executions are enqueued consistently, i.e., all kernels from 269 // computation A before any kernel from computation B, thus preventing 270 // deadlock. 271 // * xla_device_assignment: a mapping from [replica][core] to a linearized TPU 272 // coordinate. 273 // TODO(phawkins): change tf_device_assignment to an xla::Array2D. 274 static Status BuildDeviceAssignment( 275 const tpu::TpuTopologyExternal& topology, int num_tpus_per_task, 276 const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas, 277 int num_cores_per_replica, const string& topology_attr, 278 absl::Span<const int> device_assignment_attr, 279 std::vector<std::vector<string>>* tf_device_assignment, 280 std::vector<int>* devices_to_lock, 281 std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment); 282 283 // Returns the `computation` graph attached to TPUReplicate operator 284 // `node`. `flr` is a FunctionLibraryRuntime to use when 285 // instantiating the function body. Sets `*arg_types` and 286 // `*retval_types` to the argument/return types of the function. 287 static Status GetComputationForTPUReplicateOp(const NameAttrList& function, 288 FunctionLibraryRuntime* flr, 289 Graph* computation, 290 DataTypeVector* arg_types, 291 DataTypeVector* retval_types); 292 293 // Returns the shapes of the argument tensors and return values of the 294 // TPUReplicate operator `node` using the _output_shapes, 295 // _output_handle_shapes, and _output_handle_types annotations on the input 296 // nodes. Expects inputs in the following order (see comment at top of file): 297 // * num_replicas * num_per_replica_args per-replica inputs, 298 // * num_broadcast_args broadcast inputs, 299 // * num_variables variable inputs. 300 // Returns an error if the input shapes to `node` are not statically known. 301 // Also verifies that all replicas have identical input shapes for their 302 // per-replica inputs. 303 static Status GetArgAndRetvalShapes( 304 const GraphShapeInfo& shape_info, const Node& node, 305 const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes, 306 std::vector<InferredShape>* retval_shapes); 307 308 // Assigns arguments and return values to cores. The assignment is represented 309 // as an XLA op sharding, so that an argument can be replicated across cores. 310 // `arg_sharding` and `retval_sharding` are vectors of shardings indexed by 311 // argument/retval number. 312 // `arg_fast_mem` is vector of fast_mem indication which is indexed by 313 // argument number. 314 static Status AssignArgsAndRetvalsToCores( 315 int num_cores_per_replica, const ParameterInfo& params_info, 316 const DataTypeVector& arg_types, 317 const std::vector<InferredShape>& arg_shapes, 318 const DataTypeVector& retval_types, 319 const std::vector<InferredShape>& retval_shapes, const Graph& graph, 320 const Node* replicate_node, FunctionLibraryRuntime* flr, 321 bool allow_parameter_replication_for_spmd, 322 std::vector<::xla::OpSharding>* arg_sharding, 323 std::vector<bool>* arg_fast_mem, 324 std::vector<::xla::OpSharding>* retval_sharding, 325 std::vector<std::string>* arg_names); 326 327 // Populates `*variables` with the "variables" inputs to `index`-th output of 328 // `node`. 329 struct VariableInput { 330 Node* node; 331 int index; 332 333 // Type of the variable's value. Note that this is different to the type of 334 // the output of 'variable', which is always DT_RESOURCE. 335 DataType dtype; 336 }; 337 static Status FindVariableInputs(const Node& node, 338 const NameRangeMap& input_range_map, 339 std::vector<VariableInput>* variables); 340 341 // Populates '*guaranteed_constants' with the "guaranteed_constants" inputs 342 // to 'node'. 343 static Status FindGuaranteedConstantInputs( 344 const Node& node, const NameRangeMap& input_range_map, 345 std::vector<Node*>* guaranteed_constants); 346 347 // Builds Shape nodes that compute the shapes of arguments whose shapes are 348 // not statically known. 349 static Status BuildDynamicShapeNodes( 350 const Node& replicate_node, const std::vector<InferredShape>& arg_shapes, 351 const ParameterInfo& params_info, 352 const std::vector<Node*>& variable_reads, Graph* graph, 353 std::vector<Node*>* dynamic_shape_nodes); 354 355 // Builds a TPUCompile node that compiles the computation in 356 // `function_names`. calls `nodes`. 357 // TODO(b/33943292): at present, for model parallelism with Send/Recv to work 358 // the `nodes` must correspond to the computations assigned to TPU:0, 359 // TPU:1, ... in order since XLA hard-codes the chip IDs in the generated 360 // executables. 361 static Status BuildCompileNode( 362 const Node* replicate_node, const NameAttrList& function, 363 uint64 library_fingerprint, const ParameterInfo& params_info, 364 const std::vector<InferredShape>& arg_shapes, 365 const DataTypeVector& arg_types, 366 const std::vector<Node*>& guaranteed_constant_nodes, 367 const string& session_handle, 368 const std::vector<::xla::OpSharding>& arg_sharding, 369 const std::vector<bool>& arg_fast_mem, 370 const std::vector<std::string>& arg_names, 371 const std::vector<::xla::OpSharding>& retval_sharding, 372 int num_cores_per_replica, const string& compile_device, 373 const xla::DeviceAssignment* xla_device_assignment, 374 const std::vector<Node*>& dynamic_shape_nodes, Graph* graph, 375 Node** compile_node, int64_t autotuner_thresh); 376 377 // Builds a TPUCompileSucceededAssert node that verifies that compilation 378 // succeeded and replaces the TPUCompilationStatus node in the graph. 379 static Status BuildCompilationStatusReturnNodes( 380 Node* replicate_node, Node* compile_node, 381 absl::Span<const int> devices_to_lock, Node** control_after_compilation, 382 Node** multilock_acquire, Graph* graph); 383 384 // Builds ReadVariableOp nodes that read `variables`, with a control 385 // edges that ensure they happen after `control_predecessor`. 386 static Status BuildVariableReads(absl::Span<const VariableInput> variables, 387 Node* control_predecessor, Graph* graph, 388 std::vector<Node*>* variable_reads); 389 390 // Returns true if graph or functions contain resource write op, otherwise 391 // return false. 392 // TODO(b/137048563): Recognize unused resource rewrite op. 393 static bool ContainsResourceWriteOp(const Graph& graph, 394 const FunctionLibraryDefinition& fld); 395 // Struct that describes a variable value to be written back from TPUExecute. 396 struct VariableWrite { 397 // A node:output pair containing a boolean tensor that determines whether 398 // the value should be written back. 399 Node* predicate; 400 int predicate_output; 401 402 // A node:output pair containing the value to be written back. 403 Node* value; 404 int value_output; 405 }; 406 407 // Builds AssignVariableOp nodes that write `variables` with the values from 408 // `variable_writes`, with control edges that ensure the writes happen before 409 // `control_successor`. 410 static Status BuildVariableWrites( 411 absl::Span<const VariableInput> variables, Node* control_successor, 412 absl::Span<const VariableWrite> variable_writes, Graph* graph); 413 414 // Builds TPUExecute operators assigned to each TPU device 415 // involved in the computation. 416 // Arguments: 417 // * `params_info` is the structure containing the information about the 418 // TPUReplicate node inputs and outputs. 419 // * `num_tasks` is the number of TensorFlow tasks in the slice. 420 // * `num_cores_per_replica` is the number of cores which are dedicated to 421 // each replica. 422 // * `replicate_node` is the original TPUReplicate node. 423 // * `arg_names` are the names of the arguments to the computation function 424 // passed as argument to TPUReplicate, including per-replica, 425 // broadcast, and variable arguments. 426 // * `arg_types` are the corresponding types of the arguments. 427 // * `arg_shapes` are the corresponding shapes (and handle types/shapes, if 428 // applicable). 429 // * `arg_shardings` and `retval_shardings` are mappings from 430 // arguments/return indices to shardings, as returned by 431 // `AssignArgsAndRetvalsToCores`. 432 // * `pod_devices` lists the devices to assign to each core of each replica. 433 // * `variable_reads` is a vectors of ReadVariableOp operators, one for each 434 // variable argument to the computation. 435 // * The execute operators will have a control edge from 436 // `control_predecessor` and another control edge to `control_successor`. 437 // Populates '*variable_writes' with information about variable values to 438 // write back. 439 static Status BuildExecuteNodes( 440 const ParameterInfo& params_info, int num_tasks, 441 int num_cores_per_replica, const Node& replicate_node, 442 const std::vector<std::string>& arg_names, 443 const DataTypeVector& arg_types, 444 const std::vector<InferredShape>& arg_shapes, 445 const DataTypeVector& retval_types, 446 const std::vector<::xla::OpSharding>& arg_shardings, 447 const std::vector<::xla::OpSharding>& retval_shardings, 448 const std::vector<std::vector<string>>& tpu_device_names, 449 Node* compile_node, const std::vector<Node*>& variable_reads, 450 Node* control_predecessor, Node* control_successor, 451 Node* multilock_acquire, std::vector<VariableWrite>* variable_writes, 452 Graph* graph); 453 454 // Connects the compile node to all the host transfer nodes, and removes the 455 // key placeholder node that was previously standing in for it. 456 // Arguments: 457 // * `compile_node` is the TPUCompile node that has been added to the graph. 458 // * `key_placeholder_node` is the placeholder node to send the key to all the 459 // host 460 // * transfer nodes in the original graph. 461 // * `graph` is the graph being rewritten. 462 static Status ConnectHostComputeNodes(Node* compile_node, 463 Node* key_placeholder_node, 464 Graph* graph); 465 466 // Map from a Node in an outside_compilation cluster in the original graph to 467 // the list of Nodes, one for each replica, that it is expanded into during 468 // replication. 469 typedef absl::node_hash_map<Node*, std::vector<Node*>> NodeToNodeReplicasMap; 470 471 // Map from the name of an outside_compilation cluster to the model-parallel 472 // core index that the HostCompute Op should be placed on in that cluster. 473 typedef std::map<string, int> HostComputeCoreMap; 474 475 // Map from the name of an outside_compilation cluster to the list of Nodes 476 // that should run on the host for that cluster. 477 typedef std::map<string, std::vector<Node*>> OutsideCompilationNodeMap; 478 479 // Copies the outside_compilation nodes in a cluster to create replica 480 // replica_index. 481 static Status CopyOutsideCompilationNodes( 482 int replica_index, const std::vector<Node*>& outside_compilation_nodes, 483 const DeviceNameUtils::ParsedName& tpu_device, 484 const DeviceNameUtils::ParsedName& partial_device, 485 NodeToNodeReplicasMap* node_images, Graph* graph); 486 487 // Replicates all the nodes in outside_compilation clusters in a compiled 488 // computation. 489 static Status ReplicateOutsideCompilationNodes( 490 const std::vector<std::vector<string>>& tf_device_assignment, 491 const HostComputeCoreMap& host_compute_core, 492 const OutsideCompilationNodeMap& outside_compilation_nodes, 493 NodeToNodeReplicasMap* node_images, Graph* graph); 494 495 // Lifts the edges between original outside_compilation nodes in a cluster 496 // onto their replicas. 497 static Status CopyOutsideCompilationEdges( 498 const std::vector<Node*>& outside_compilation_nodes, 499 const NodeToNodeReplicasMap& node_images, 500 const std::unordered_map<string, Node*> outside_compilation_inputs, 501 Graph* graph); 502 503 // Lifts all the edges in outside_compilation clusters in a compiled 504 // computation to their replicas. 505 static Status ReplicateOutsideCompilationEdges( 506 const OutsideCompilationNodeMap& outside_compilation_nodes, 507 const NodeToNodeReplicasMap& node_images, 508 const std::unordered_map<string, Node*> outside_compilation_inputs, 509 Graph* graph); 510 511 // Removes all the original outside_compilation nodes from the graph, 512 // following replication. 513 static Status RemoveOutsideCompilationNodes( 514 const NodeToNodeReplicasMap& node_images, Graph* graph); 515 516 // Lowers outside compilation functional nodes (If/While/function call). 517 // Otherwise, when we have multiple workers, device placer will not be able to 518 // place nodes if outside compilation has DT_RESOURCE inputs (e.g. a 519 // DT_RESOURCE input fed into multiple While nodes on different devices). 520 static Status LowerOutsideCompilationFunctionalNodes( 521 Graph* g, FunctionLibraryDefinition& flib_def, 522 const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping); 523 524 // Parses the 'host_compute_core' attribute on replicate_node to get the 525 // replicated core id of each outside_compilation cluster. 526 static Status ParseHostComputeCores( 527 const Node& replicate_node, 528 const OutsideCompilationNodeMap& outside_compilation_nodes, 529 HostComputeCoreMap* host_compute_core); 530 531 // Gets the physical topology information about the TPU system. 532 static Status GetDeviceTopology( 533 const DeviceSet& device_set, const Node& replicate_node, 534 int* num_replicas, int* num_cores_per_replica, int* num_tasks, 535 std::vector<std::vector<string>>* tf_device_assignment, 536 std::vector<int>* devices_to_lock, 537 std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment, 538 string* tpu_compilation_device); 539 540 // Gets the types of args, retvals, and parameters. 541 static Status GetIOTypes( 542 int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr, 543 Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function, 544 std::unique_ptr<Graph>* computation, DataTypeVector* arg_types, 545 DataTypeVector* retval_types, ParameterInfo* params_info); 546 547 // Find known constants and deals with variable reads. 548 static Status DealWithConstantsAndVariables( 549 const Node& replicate_node, const NameRangeMap& input_name_map, 550 Graph* graph, Node* host_transfer_sequencer, Node* control_before, 551 Node* control_after, absl::Span<const VariableInput> variable_nodes, 552 std::vector<Node*>* guaranteed_constant_nodes, 553 std::vector<Node*>* variable_reads); 554 555 // Adds NoOp nodes for sequencing computation and variable reads/writes. 556 static Status BuildSequencingNodes(const string& tpu_compilation_device, 557 const Node& replicate_node, Graph* graph, 558 Node** host_transfer_sequencer, 559 Node** control_before, 560 Node** control_after); 561 562 // Performs the pass's rewrite on a TPUReplicate node `node`. 563 static Status RewriteTPUReplicateNode( 564 const string& session_handle, const DeviceSet& device_set, 565 Node* replicate_node, FunctionLibraryDefinition* flib_def, 566 FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node, 567 const OutsideCompilationNodeMap& outside_compilation_nodes, 568 const std::vector<Node*>& head_tail_outside_compilation_nodes, 569 NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph, 570 const GraphShapeInfo& shape_info, 571 TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping, 572 int64_t autotuner_thresh); 573 574 // Performs host training loop optimization. For example, when TPUExecute 575 // node is inside a while loop, then model weight variables can be sharded 576 // in XLA preferred layout and then unsharded only at the very last iteration 577 // to reduce the number of all_gather. 578 static Status PerformHostTrainingLoopOptimization( 579 Graph* graph, FunctionLibraryDefinition* flib_def, 580 FunctionLibraryRuntime* flr); 581 582 // Heuristically place some nodes with unassigned devices on TPUs for 583 // performance reasons. 584 static Status PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph* graph); 585 586 // Updates the head and tail outside compiled nodes so that nodes have the 587 // correct device and removes the replication and outside compilation 588 // attributes so that these nodes do not trigger further graph optimization 589 // passes. 590 static Status UpdateHeadTailOutsideCompilation( 591 const std::vector<std::vector<string>>& tf_device_assignment, 592 const std::vector<Node*>& head_tail_outside_compilation_nodes); 593 594 private: 595 static bool distribute_vars_; 596 static bool allow_xla_spmd_partition_; 597 static bool replicate_inputs_outputs_by_default_for_xla_spmd_; 598 static bool enable_cross_replica_sharding_mirrored_variables_; 599 static bool enable_automatic_model_parallelism_; 600 static bool enable_xla_param_broadcast_; 601 static bool enable_multicore_locking_; 602 static bool use_nd_sharding_ops_; 603 Status InternalRun(const GraphOptimizationPassOptions& options); 604 }; 605 606 } // namespace tensorflow 607 608 #endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_ 609