xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_functional_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/kernels/tpu_functional_ops.h"
17 
18 #include <algorithm>
19 #include <memory>
20 
21 #include "absl/strings/match.h"
22 #include "tensorflow/core/framework/cancellation.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
25 #include "tensorflow/stream_executor/tpu/c_api_decl.h"
26 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
27 
28 #define EIGEN_USE_THREADS
29 
30 #include "absl/base/call_once.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/synchronization/mutex.h"
33 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
34 #include "tensorflow/compiler/tf2xla/sharding_util.h"
35 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/common_runtime/function_body.h"
38 #include "tensorflow/core/common_runtime/graph_constructor.h"
39 #include "tensorflow/core/common_runtime/placer.h"
40 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
41 #include "tensorflow/core/framework/graph_to_functiondef.h"
42 #include "tensorflow/core/framework/metrics.h"
43 #include "tensorflow/core/framework/node_def.pb.h"
44 #include "tensorflow/core/framework/node_def_util.h"
45 #include "tensorflow/core/framework/resource_mgr.h"
46 #include "tensorflow/core/framework/resource_var.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/framework/tensor.pb.h"
49 #include "tensorflow/core/framework/tensor_shape.h"
50 #include "tensorflow/core/graph/graph_partition.h"
51 #include "tensorflow/core/graph/node_builder.h"
52 #include "tensorflow/core/lib/core/errors.h"
53 #include "tensorflow/core/lib/hash/hash.h"
54 #include "tensorflow/core/lib/strings/str_util.h"
55 #include "tensorflow/core/platform/blocking_counter.h"
56 #include "tensorflow/core/platform/errors.h"
57 #include "tensorflow/core/platform/fingerprint.h"
58 #include "tensorflow/core/platform/refcount.h"
59 #include "tensorflow/core/profiler/lib/traceme.h"
60 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
61 #include "tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h"
62 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
63 #include "tensorflow/core/tpu/kernels/tpu_op_util.h"
64 #include "tensorflow/core/tpu/kernels/tpu_util.h"
65 #include "tensorflow/core/tpu/tpu_configuration.h"
66 #include "tensorflow/core/tpu/tpu_defs.h"
67 #include "tensorflow/core/util/dump_graph.h"
68 
69 namespace tensorflow {
70 namespace {
71 
72 constexpr char kTpuReplicateAttr[] = "_tpu_replicate";
73 constexpr int kLastDimOfTpuInputFastPath = 128;
74 constexpr int kOtherDimOfTpuInputFastPath = 8;
75 
76 constexpr char kXLAShardingAttrName[] = "sharding";
77 constexpr char kXLAShardingAttrAltName[] = "_XlaSharding";
78 
GetTPUTopology()79 tpu::TopologyProto GetTPUTopology() {
80   const tpu::TpuTopologyExternal& topology =
81       tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
82 
83   tpu::TopologyProto topology_proto;
84   topology_proto.set_num_tasks(topology.HostCount());
85   topology_proto.set_num_tpu_devices_per_task(
86       topology.LogicalDevicesPerHost(TpuCoreTypeEnum::kTensorCore));
87 
88   // mesh shape.
89   int devices_per_chip =
90       topology.LogicalDevicesPerChip(TpuCoreTypeEnum::kTensorCore);
91   topology_proto.add_mesh_shape(topology.chip_bounds().x);
92   topology_proto.add_mesh_shape(topology.chip_bounds().y);
93   topology_proto.add_mesh_shape(topology.chip_bounds().z);
94   topology_proto.add_mesh_shape(devices_per_chip);
95 
96   // device coordinates.
97   for (const tpu::TpuCoreLocationExternal& core :
98        topology.cores(TpuCoreTypeEnum::kTensorCore)) {
99     const tpu::TpuDimensionsExternal coords = core.chip_coordinates();
100     topology_proto.add_device_coordinates(coords.x);
101     topology_proto.add_device_coordinates(coords.y);
102     topology_proto.add_device_coordinates(coords.z);
103     topology_proto.add_device_coordinates(core.index());
104   }
105 
106   return topology_proto;
107 }
108 
109 struct TPUVariableInfo {
TPUVariableInfotensorflow::__anon1d85d1de0111::TPUVariableInfo110   TPUVariableInfo(int device_ordinal_id, bool use_fast_mem)
111       : device_ordinal(device_ordinal_id), fast_mem(use_fast_mem) {}
112   // The TPU core which the variable will be placed on.
113   int device_ordinal;
114   // If true, try to place the variable on fast memory space if hardware
115   // support.
116   bool fast_mem;
117 };
118 
119 // Check the descendants to parse the placement information for the input node.
120 // num_cores_per_replica descriables how many cores the single model uses.
ParseTPUVariableInfor(const Node * node,const int num_cores_per_replica,TPUVariableInfo * var_info)121 Status ParseTPUVariableInfor(const Node* node, const int num_cores_per_replica,
122                              TPUVariableInfo* var_info) {
123   int core = 0;
124   bool use_fast_mem = false;
125   VLOG(3) << "Parse tpu variable information for " << node->name();
126   for (const Edge* edge : node->out_edges()) {
127     if (edge->IsControlEdge()) continue;
128     Node* next = edge->dst();
129     VLOG(3) << "Neighbor node " << next->name();
130     // Looking through Enter/Switch/ReadVariableOp nodes.
131     while (next->IsEnter() || next->IsSwitch() ||
132            next->type_string() == "ReadVariableOp") {
133       Node* new_node = nullptr;
134       for (const Edge* e : next->out_edges()) {
135         if (!e->IsControlEdge()) {
136           new_node = e->dst();
137           break;
138         }
139       }
140       if (new_node == nullptr) break;
141       next = new_node;
142     }
143     if (next != edge->dst()) {
144       VLOG(3) << "Looked through Enter/Switch node " << next->DebugString();
145     }
146     TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
147                         ParseShardingFromDevice(*next, num_cores_per_replica,
148                                                 /*add_metadata=*/false));
149     if (sharding.has_value() && sharding->tile_assignment_devices_size() > 0) {
150       core = sharding->tile_assignment_devices(0);
151       VLOG(3) << next->name() << " is placed on core " << core;
152     }
153     if (next->attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
154       use_fast_mem = true;
155       VLOG(3) << next->name() << " has " << TPU_FAST_MEM_ATTR << " attribute";
156     }
157   }
158   VLOG(1) << "Place " << node->name() << " to core: " << core
159           << " fast_mem: " << use_fast_mem;
160   var_info->device_ordinal = core;
161   var_info->fast_mem = use_fast_mem;
162 
163   return OkStatus();
164 }
165 
166 // Helper to instantiate function "func" in the library "lib".
Instantiate(FunctionLibraryRuntime * lib,const NameAttrList & func,FunctionLibraryRuntime::Handle * handle)167 Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
168                    FunctionLibraryRuntime::Handle* handle) {
169   return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
170 }
171 
172 static constexpr const char* const kDeviceOrdinalAttr = "device_ordinal";
173 
174 static constexpr const char* const kTPUExecuteOp = "TPUExecute";
175 static constexpr const char* const kInfeedEnqueueOp = "InfeedEnqueue";
176 static constexpr const char* const kInfeedEnqueueTupleOp = "InfeedEnqueueTuple";
177 static constexpr const char* const kOutfeedDequeueOp = "OutfeedDequeue";
178 static constexpr const char* const kOutfeedDequeueTupleOp =
179     "OutfeedDequeueTuple";
180 static constexpr const char* const kOutfeedDequeueV2Op = "OutfeedDequeueV2";
181 static constexpr const char* const kOutfeedDequeueTupleV2Op =
182     "OutfeedDequeueTupleV2";
183 static constexpr const char* const kVarHandleOp = "VarHandleOp";
184 
185 static constexpr const char* const kTPUDeviceNamePrefix = "/device:TPU:";
186 static constexpr const int kTPUDefaultDeviceOrdinal = 0;
187 
IsSupportedTPUOp(const string & op_name)188 bool IsSupportedTPUOp(const string& op_name) {
189   return op_name == kTPUExecuteOp || op_name == kInfeedEnqueueOp ||
190          op_name == kInfeedEnqueueTupleOp || op_name == kOutfeedDequeueOp ||
191          op_name == kOutfeedDequeueTupleOp || op_name == kOutfeedDequeueV2Op ||
192          op_name == kOutfeedDequeueTupleV2Op;
193 }
194 
195 // Sets the sharding attributes for an XlaSharding node.
SetXlaShardingNodeAttr(Node * xla_sharding_node,int num_cores_per_replica,int rank,int shard_dim)196 void SetXlaShardingNodeAttr(Node* xla_sharding_node, int num_cores_per_replica,
197                             int rank, int shard_dim) {
198   auto sharding = absl::make_optional<xla::OpSharding>();
199   sharding->set_type(xla::OpSharding::OTHER);
200 
201   std::vector<int64_t> dims(rank, 1LL);
202   dims[shard_dim] = num_cores_per_replica;
203   for (auto dim : dims) {
204     sharding->add_tile_assignment_dimensions(dim);
205   }
206 
207   // Sets up tile_assignment_devices.
208   for (int d = 0; d < num_cores_per_replica; ++d) {
209     sharding->add_tile_assignment_devices(d);
210   }
211 
212   xla_sharding_node->ClearAttr(kXLAShardingAttrName);
213   xla_sharding_node->ClearAttr(kXLAShardingAttrAltName);
214   xla_sharding_node->AddAttr(kXLAShardingAttrName,
215                              sharding->SerializeAsString());
216   xla_sharding_node->AddAttr(kXLAShardingAttrAltName,
217                              sharding->SerializeAsString());
218 }
219 
220 // If 'device_name' is a TPU device, set its device_ordinal to 'device_ordinal'
221 // and set '*rewritten' to true. Otherwise, do nothing.
UpdateTPUDeviceOrdinal(int device_ordinal,string * device_name,bool * rewritten)222 Status UpdateTPUDeviceOrdinal(int device_ordinal, string* device_name,
223                               bool* rewritten) {
224   DeviceNameUtils::ParsedName device;
225   if (!DeviceNameUtils::ParseFullName(*device_name, &device)) {
226     return errors::InvalidArgument("Unable to parse device name ",
227                                    *device_name);
228   }
229   if (device.type == DEVICE_TPU_NODE) {
230     device.id = device_ordinal;
231     *rewritten = true;
232   }
233   *device_name = DeviceNameUtils::ParsedNameToString(device);
234   return OkStatus();
235 }
236 
FindHostToDeviceEdge(Node * arg_node)237 const Edge* FindHostToDeviceEdge(Node* arg_node) {
238   const Edge* candidate_edge = nullptr;
239   for (const Edge* edge : arg_node->out_edges())
240     if (!edge->IsControlEdge()) {
241       // Find CPU -> TPU input edge.
242       const Edge* original_edge;
243       while (edge->src()->attrs().Find(kTpuReplicateAttr) != nullptr ||
244              edge->dst()->attrs().Find(kTpuReplicateAttr) == nullptr) {
245         const Node* new_src = edge->dst();
246         original_edge = edge;
247         for (const Edge* new_edge : new_src->out_edges())
248           if (!new_edge->IsControlEdge()) {
249             original_edge = edge;
250             edge = new_edge;
251             break;
252           }
253         if (original_edge == edge) break;
254       }
255       // TPU input edge: src is on CPU and dest is on TPU.
256       if (edge->src()->attrs().Find(kTpuReplicateAttr) != nullptr ||
257           edge->dst()->attrs().Find(kTpuReplicateAttr) == nullptr)
258         continue;
259       // Won't work with GuaranteeConst.
260       if (edge->src()->type_string() == "GuaranteeConst") break;
261       candidate_edge = edge;
262     }
263   return candidate_edge;
264 }
265 
CreateInputProxy(Graph * graph,const Edge * candidate_edge,const Edge ** tpu_input_edge)266 Status CreateInputProxy(Graph* graph, const Edge* candidate_edge,
267                         const Edge** tpu_input_edge) {
268   std::vector<const Edge*> edges_to_replace;
269   for (const Edge* input_edge : candidate_edge->src()->out_edges()) {
270     if (!input_edge->IsControlEdge() &&
271         input_edge->dst()->attrs().Find(kTpuReplicateAttr) != nullptr)
272       edges_to_replace.push_back(input_edge);
273   }
274   // Build an Identity node as the proxy of the original edge source.
275   Node* input_identity_node = nullptr;
276   TF_RETURN_IF_ERROR(
277       NodeBuilder(strings::StrCat(candidate_edge->src()->name(), "/proxy"),
278                   "Identity")
279           .Input(candidate_edge->src())
280           .Attr("T", candidate_edge->src()->output_type(0))
281           .Attr(kTpuReplicateAttr,
282                 candidate_edge->dst()->attrs().Find(kTpuReplicateAttr)->s())
283           .Finalize(graph, &input_identity_node));
284   // Find the tpu input edge from original source to proxy identity.
285   for (const Edge* input_edge : input_identity_node->in_edges())
286     if (input_edge->src() == candidate_edge->src()) {
287       *tpu_input_edge = input_edge;
288       break;
289     }
290   // Replace original input edges with proxy's output.
291   for (const Edge* input_edge : edges_to_replace) {
292     graph->RemoveEdge(input_edge);
293     graph->AddEdge(input_identity_node, 0, input_edge->dst(),
294                    input_edge->dst_input());
295   }
296   return OkStatus();
297 }
298 
GetClusterName(Graph * graph,string * cluster_name)299 Status GetClusterName(Graph* graph, string* cluster_name) {
300   *cluster_name = "";
301   for (const Node* node : graph->nodes()) {
302     if (node->attrs().Find(kTpuReplicateAttr) == nullptr) continue;
303     if (cluster_name->empty())
304       *cluster_name = node->attrs().Find(kTpuReplicateAttr)->s();
305     // When optimization is turned on, the graph should only have one TPU
306     // cluster.
307     if (*cluster_name != node->attrs().Find(kTpuReplicateAttr)->s())
308       return errors::FailedPrecondition(
309           "Only one cluster is allowed when optimization is turned on for "
310           "TPUPartitionedCall. Found ",
311           node->attrs().Find(kTpuReplicateAttr)->s(), " and ", *cluster_name);
312   }
313   return OkStatus();
314 }
315 
316 // Removes nodes that has no effect that directly descends from _Arg node.
317 //
318 // This is currently used for removing TPUReplicatedInput and XlaSharding node
319 // are always descendants of _Arg node. During optimization, we try to insert
320 // nodes in between _Arg and _Arg's children, where some of the nodes inserted
321 // are TPU nodes. We will add the TPUReplicatedInput and XlaSharding op nodes
322 // back where necessary.
323 //
324 // Returns the number of nodes that were removed.
RemoveDescendantNodeOfArg(Graph * graph,const std::string & node_type_to_remove,const std::set<std::string> & must_be_child_of)325 int64_t RemoveDescendantNodeOfArg(
326     Graph* graph, const std::string& node_type_to_remove,
327     const std::set<std::string>& must_be_child_of) {
328   int64_t nodes_removed = 0;
329   std::vector<std::pair<const Edge*, std::vector<const Edge*>>> edges_to_remove;
330 
331   for (Node* node : graph->nodes()) {
332     if (node_type_to_remove != node->type_string()) continue;
333     if (!must_be_child_of.empty()) {
334       bool has_arg_parent = false;
335       for (const Edge* edge : node->in_edges()) {
336         if (must_be_child_of.count(edge->src()->type_string()) > 0) {
337           has_arg_parent = true;
338         }
339       }
340       if (!has_arg_parent) continue;
341     }
342     nodes_removed++;
343 
344     const Edge* input_edge = nullptr;
345     std::vector<const Edge*> output_edges;
346     for (const Edge* edge : node->in_edges())
347       if (!edge->IsControlEdge()) {
348         input_edge = edge;
349         break;
350       }
351     for (const Edge* edge : node->out_edges())
352       if (!edge->IsControlEdge()) {
353         output_edges.push_back(edge);
354       }
355     if (input_edge != nullptr && !output_edges.empty())
356       edges_to_remove.push_back(std::make_pair(input_edge, output_edges));
357   }
358   for (const auto& it : edges_to_remove) {
359     for (const Edge* output_edge : it.second) {
360       graph->RemoveEdge(output_edge);
361       graph->AddEdge(it.first->src(), it.first->src_output(),
362                      output_edge->dst(), output_edge->dst_input());
363     }
364     graph->RemoveNode(it.first->dst());
365   }
366   return nodes_removed;
367 }
368 
GetInputHash(OpKernelContext * ctx)369 uint64 GetInputHash(OpKernelContext* ctx) {
370   uint64 input_hash = 0;  // initialization for determinism.
371   // Use the number of elements to compute hash.
372   // TODO(chiachenc): use fhe full shape to compute the hash.
373   for (int i = 0; i < ctx->num_inputs(); ++i) {
374     VLOG(4) << "InputHash, combine input " << i
375             << ", NumElements: " << ctx->input(i).NumElements();
376     input_hash = Hash64Combine(input_hash, ctx->input(i).NumElements());
377   }
378   return input_hash;
379 }
380 
HashShapeAndType(const string prefix,const std::vector<int> & input_dims,const DataType & dtype,const bool input_shape_opt)381 string HashShapeAndType(const string prefix, const std::vector<int>& input_dims,
382                         const DataType& dtype, const bool input_shape_opt) {
383   string hash = strings::StrCat(prefix, dtype, "_dims");
384   // We will concat at the last dimension.
385   for (int d = 0; d < input_dims.size() - 1; ++d) {
386     strings::StrAppend(&hash, "_", input_dims.at(d));
387   }
388 
389   if (input_shape_opt) {
390     if (input_dims.back() % kLastDimOfTpuInputFastPath == 0) {
391       strings::StrAppend(&hash, "_last_", kLastDimOfTpuInputFastPath, "n");
392     } else {
393       strings::StrAppend(&hash, "_last_other");
394     }
395   }
396   return hash;
397 }
398 
399 // Get the information for input and output tensors (shapes, dtypes, etc).
GetInputOutputInfo(Graph * graph,GraphShapeInfo & tpu_inferred_info,std::map<int,InferredShape> & arg_shapes,EdgeShapes & tpu_input_shapes,absl::flat_hash_map<const Edge *,DataType> & tpu_input_dtypes,OpKernelContext * ctx)400 Status GetInputOutputInfo(
401     Graph* graph, GraphShapeInfo& tpu_inferred_info,
402     std::map<int, InferredShape>& arg_shapes, EdgeShapes& tpu_input_shapes,
403     absl::flat_hash_map<const Edge*, DataType>& tpu_input_dtypes,
404     OpKernelContext* ctx) {
405   // Search for the device-to-host or tpu-to-cpu edges.
406   for (Node* node : graph->op_nodes()) {
407     if (!node->IsArg()) continue;
408     const DataType dtype = node->attrs().Find("T")->type();
409     const int arg_index = node->attrs().Find("index")->i();
410     if (dtype != DT_INT32 && dtype != DT_BFLOAT16 && dtype != DT_FLOAT &&
411         dtype != DT_BOOL && dtype != DT_QINT8 && dtype != DT_QUINT8)
412       continue;
413     VLOG(3) << "Argnode: " << node->DebugString();
414     const Tensor& tensor = ctx->input(arg_index);
415 
416     // Search for the cross-device edge from arg node.
417     const Edge* candidate_edge = FindHostToDeviceEdge(node);
418     if (candidate_edge == nullptr) continue;
419 
420     // Make proxy and get the sole tpu_input_edge for transfer the input tensor
421     // corresponding to the current _Arg node.
422     const Edge* tpu_input_edge = nullptr;
423     TF_RETURN_IF_ERROR(
424         CreateInputProxy(graph, candidate_edge, &tpu_input_edge));
425     if (tpu_input_edge == nullptr)
426       return errors::NotFound("Couldn't find TPU input edge for", node->name());
427 
428     // Optimize edge: original source to proxy identity.
429     VLOG(3) << "Input: " << tpu_input_edge->src()->name();
430     std::vector<int>& input_shapes = tpu_input_shapes[tpu_input_edge];
431     input_shapes.clear();
432     for (int d = 0; d < tensor.dims(); ++d) {
433       input_shapes.push_back(tensor.dim_size(d));
434       VLOG(3) << "Input Tensor: Dim[" << d << "] = " << tensor.dim_size(d);
435     }
436     tpu_input_dtypes[tpu_input_edge] = tensor.dtype();
437 
438     // Collect shapes for non-resource-variable args.
439     PartialTensorShape partial_tensor_shape;
440     auto partial_shape = PartialTensorShape::MakePartialShape(
441         input_shapes.data(), input_shapes.size(), &partial_tensor_shape);
442     InferredShape inferred_shape = {partial_tensor_shape};
443     arg_shapes[arg_index] = inferred_shape;
444   }
445   return OkStatus();
446 }
447 
448 // Converts a integer vector that represents the shapes to a Tensorshape.
ConvertEdgeShapesToTensorShapes(const std::map<std::string,std::vector<int>> & named_input_shapes,std::vector<TensorShape> * shapes)449 Status ConvertEdgeShapesToTensorShapes(
450     const std::map<std::string, std::vector<int>>& named_input_shapes,
451     std::vector<TensorShape>* shapes) {
452   shapes->resize(named_input_shapes.size());
453   int32_t i = 0;
454   // keys in tpu_input_shapes may be stale.
455   for (const auto& iter : named_input_shapes) {
456     VLOG(2) << iter.first << ", rank: " << iter.second.size();
457     const int64_t rank = iter.second.size();
458     std::vector<int64_t> dims(rank);
459     for (int64_t d = 0; d < rank; ++d) {
460       VLOG(2) << " dim[" << d << "]: " << iter.second.at(d);
461       dims[d] = iter.second.at(d);
462     }
463     TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(dims, &(*shapes)[i]));
464     i++;
465   }
466   return OkStatus();
467 }
468 
469 // Get the TF fingerprint with the information from the TPUCompileOp or
470 // _TPUCompileMlirOp.
MaybeRegisterFingerprint(Graph * graph,const std::map<std::string,std::vector<int>> & named_input_shapes,uint64 input_hash)471 Status MaybeRegisterFingerprint(
472     Graph* graph,
473     const std::map<std::string, std::vector<int>>& named_input_shapes,
474     uint64 input_hash) {
475   // Find the compiler metadata.
476   tpu::TPUCompileMetadataProto metadata_proto;
477   std::map<std::string, std::vector<int>> inputs_to_keep;
478   int num_dynamic_shapes = -1;
479   tensorflow::uint64 fingerprint = 0;
480 
481   for (Node* node : graph->op_nodes()) {
482     if (node->type_string() == "TPUCompile" ||
483         node->type_string() == "_TPUCompileMlir") {
484       num_dynamic_shapes = node->attrs().Find("NumDynamicShapes")->i();
485       if (num_dynamic_shapes <= 0) {
486         break;
487       }
488       int visited = 0;
489       // TPUCompileOp/_TPUCompileMlirOp take Shape nodes as inputs.
490       // The number of Shape nodes matches the number of dynamic shaped inputs.
491       // The Shape nodes come from the input nodes:
492       //   [TPU Input] --> [Input Shape] --> [TPUCompileOp]
493       for (auto in_node : node->in_nodes()) {
494         if (in_node->type_string() != "Shape") {
495           continue;
496         }
497         for (auto input_node : in_node->in_nodes()) {
498           auto iter = named_input_shapes.find(input_node->name());
499           if (iter != named_input_shapes.end()) {
500             inputs_to_keep[iter->first] = iter->second;
501           }
502         }
503         visited++;
504         if (visited == num_dynamic_shapes) {
505           break;
506         }
507       }
508       std::string metadata = node->attrs().Find("metadata")->s();
509       metadata_proto.ParseFromString(metadata);
510 
511       if (node->type_string() == "_TPUCompileMlir") {
512         std::string mlir_module = node->attrs().Find("mlir_module")->s();
513         fingerprint = tensorflow::Fingerprint64(mlir_module);
514       } else {
515         fingerprint = metadata_proto.function_library_fingerprint();
516       }
517 
518       break;
519     }
520   }
521   VLOG(2) << "inputs_to_keep size: " << inputs_to_keep.size();
522   if (inputs_to_keep.size() != num_dynamic_shapes) {
523     VLOG(2) << "Cannot match all inputs shapes. Skip fingerprint registration.";
524     return OkStatus();
525   }
526 
527   std::vector<TensorShape> input_shapes;
528   TF_RETURN_IF_ERROR(
529       ConvertEdgeShapesToTensorShapes(inputs_to_keep, &input_shapes));
530 
531   std::vector<TensorShape> arg_shapes;
532   auto status =
533       tpu::ComputeArgumentShapes(metadata_proto, input_shapes, &arg_shapes);
534   if (!status.ok()) {
535     VLOG(2) << status.error_message();
536     return OkStatus();
537   }
538   uint64 tf_fingerprint =
539       tpu::CreateFingerprintWithNameAndShapes(fingerprint, arg_shapes);
540   VLOG(2) << "fingerprint: " << fingerprint;
541   VLOG(2) << "TF fingerprint: " << tf_fingerprint;
542 
543   ResourceMgr* rm = GetTPUConfigResourceMgr();
544   tpu::TpuFingerprintLookup* fingerprint_lookup;
545   TF_RETURN_IF_ERROR(rm->Lookup<tpu::TpuFingerprintLookup>(
546       rm->default_container(), tpu::kFingerprintLookupResourceName,
547       &fingerprint_lookup));
548   fingerprint_lookup->RegisterKeyAndIntermediatePair(input_hash,
549                                                      tf_fingerprint);
550   return OkStatus();
551 }
552 
FindTpuReplicatedInputAndXlaSharding(const Graph * graph,XlaShardingInfoMap & xla_sharding_ops,TpuReplicatedInputInfoMap & tpu_replicated_input_ops)553 bool FindTpuReplicatedInputAndXlaSharding(
554     const Graph* graph, XlaShardingInfoMap& xla_sharding_ops,
555     TpuReplicatedInputInfoMap& tpu_replicated_input_ops) {
556   bool xla_spmd_input_sharded = false;
557   // Detect whether there are XLA Sharding on the inputs, if there are, then
558   // we cannot remove the replicated inputs or the xla sharding ops.
559   for (Node* xla_sharding_node : graph->nodes()) {
560     if (xla_sharding_node->type_string() == "XlaSharding") {
561       for (const Edge* edge : xla_sharding_node->in_edges()) {
562         if (edge->src()->type_string() == "TPUReplicatedInput") {
563           Node* tpu_replicated_input_node = edge->src();
564           Node* tpu_replicated_metadata_node = nullptr;
565           for (const Edge* input_edge : tpu_replicated_input_node->in_edges()) {
566             if (input_edge->IsControlEdge()) {
567               tpu_replicated_metadata_node = input_edge->src();
568             }
569           }
570 
571           for (const Edge* input_edge : tpu_replicated_input_node->in_edges()) {
572             if (input_edge->src()->type_string() == "_Arg") {
573               Node* arg_node = input_edge->src();
574 
575               xla_sharding_ops[arg_node->name()] = std::make_tuple(
576                   xla_sharding_node->attrs().Find("T")->type(),
577                   xla_sharding_node->attrs().Find("sharding")->s(),
578                   xla_sharding_node->attrs().Find("_tpu_replicate")->s());
579 
580               tpu_replicated_input_ops[arg_node->name()] = std::make_tuple(
581                   tpu_replicated_input_node->attrs().Find("T")->type(),
582                   tpu_replicated_metadata_node);
583 
584               VLOG(2) << "Detected input is sharded. XlaSharding node: "
585                       << xla_sharding_node->DebugString()
586                       << ", TPUReplicatedInput node: "
587                       << edge->src()->DebugString()
588                       << ", _Arg node: " << arg_node->DebugString();
589               xla_spmd_input_sharded = true;
590               break;
591             }
592           }
593         }
594       }
595     }
596   }
597   return xla_spmd_input_sharded;
598 }
599 
600 // Returns the name of the framework that rewrote the graph to support
601 // inference on TPUs. This name is accessed later during metric collection.
GetProducerName(const string & function_name)602 string GetProducerName(const string& function_name) {
603   if (absl::StrContains(function_name, "tpu_fn_icv2_")) {
604     if (absl::StrContains(function_name, "_tf_quant")) {
605       return "TPU_INFERENCE_CONVERTER_V2_TF_QUANTIZER";
606     }
607     return "TPU_INFERENCE_CONVERTER_V2";
608   }
609   if (absl::StrContains(function_name, "tpu_func_0") ||
610       absl::StrContains(function_name, "_with_batch") ||
611       absl::StrContains(function_name, "_optim")) {
612     if (absl::StrContains(function_name, "_tf_quant")) {
613       return "TPU_INFERENCE_CONVERTER_TF_QUANTIZER";
614     }
615     return "TPU_INFERENCE_CONVERTER";
616   }
617   return "UNKNOWN";
618 }
619 
620 // Gets the proper tensor dimension from XLA OpSharding.
621 // "replicate_on_last_tile_dim" and "last_tile_dims" should be deducted from the
622 // real Tensor dimensions when tiled.
623 // For example:
624 // f32[8,512](sharding={devices=[1,1,2]0,1 last_tile_dims={REPLICATED})
625 // also means a replicated tensor over all devices.
626 //
627 // See xla_data.proto for detailed explanations on the fields.
GetDimsFromXLAShardingTiled(const xla::OpSharding xla_sharding)628 int GetDimsFromXLAShardingTiled(const xla::OpSharding xla_sharding) {
629   return xla_sharding.tile_assignment_dimensions_size() -
630          (xla_sharding.replicate_on_last_tile_dim() ? 1 : 0) -
631          xla_sharding.last_tile_dims_size();
632 }
633 
634 }  // end namespace
635 
636 namespace tpu_functional_internal {
637 
638 // An optimization pass that separates tensors to leverage the fast path in
639 // TPU input preparation. The algorithm is as follows:
640 // (1) Group all tensors that have same dimensions except the last dimension. A
641 // group of tensors will be concatenated by the last dimension in a later pass.
642 // (2) Check all groups of tensors and find groups whose dimensions after concat
643 // cannot leverage the fast path.
644 // (3) For groups of tensors that don't leverage the fast path, split tensors
645 // into two sub-groups such that one sub-group of tensors can leverage the fast
646 // path.
647 // Exception in (2) is that concated tensors are small, which means separating
648 // tensors would introduce overheads of data transfer to device.
649 // This optimization takes effect when both --input_shape_opt and
650 // --group_tensors_for_packing are true.
GroupTensorsForInputPacking(const EdgeShapes & tpu_input_shapes,const absl::flat_hash_map<const Edge *,DataType> & tpu_input_dtypes,bool input_shape_opt,bool group_tensors_for_packing)651 GroupedEdges GroupTensorsForInputPacking(
652     const EdgeShapes& tpu_input_shapes,
653     const absl::flat_hash_map<const Edge*, DataType>& tpu_input_dtypes,
654     bool input_shape_opt, bool group_tensors_for_packing) {
655   GroupedEdges grouped_input_edges;
656   for (const auto& iter : tpu_input_shapes) {
657     if (iter.second.empty()) continue;
658     DataType dtype = tpu_input_dtypes.find(iter.first)->second;
659     string hash_key = HashShapeAndType("input_tensors_dtype_", iter.second,
660                                        dtype, /*input_shape_opt*/ false);
661     grouped_input_edges[hash_key].push_back(iter.first);
662   }
663   // Apply grouping when both are true.
664   if (!input_shape_opt || !group_tensors_for_packing)
665     return grouped_input_edges;
666 
667   GroupedEdges grouped_input_edges_opt;
668   for (const auto& iter : grouped_input_edges) {
669     int sum_last_dim = 0;
670     int product_other_dims = 0;
671     VLOG(3) << "group name: " << iter.first;
672     for (const auto& edge : iter.second) {
673       const std::vector<int>& input_shapes =
674           tpu_input_shapes.find(edge)->second;
675       sum_last_dim += input_shapes.back();
676       if (product_other_dims == 0) {
677         product_other_dims = 1;
678         for (int d = 0; d < input_shapes.size() - 1; ++d)
679           product_other_dims *= input_shapes.at(d);
680       }
681     }
682     VLOG(3) << "sum_last_dim: " << sum_last_dim;
683     VLOG(3) << "product_other_dims: " << product_other_dims;
684     // Already uses fast path, skip further grouping.
685     if ((sum_last_dim % kLastDimOfTpuInputFastPath) == 0 &&
686         (product_other_dims % kOtherDimOfTpuInputFastPath) == 0) {
687       grouped_input_edges_opt[iter.first] = iter.second;
688       continue;
689     }
690     // Tensors are small, skip further grouping.
691     if ((sum_last_dim * product_other_dims) <
692         (kLastDimOfTpuInputFastPath * kOtherDimOfTpuInputFastPath)) {
693       grouped_input_edges_opt[iter.first] = iter.second;
694       continue;
695     }
696     VLOG(3) << "Splitting tensors.";
697     for (const auto& edge : iter.second) {
698       auto tpu_input_shape = tpu_input_shapes.find(edge)->second;
699       string hash_key =
700           HashShapeAndType("input_tensors_dtype_", tpu_input_shape,
701                            tpu_input_dtypes.find(edge)->second,
702                            /*input_shape_opt*/ true);
703       grouped_input_edges_opt[hash_key].push_back(edge);
704     }
705   }
706   return grouped_input_edges_opt;
707 }
708 
GroupTensorsForOutputPacking(Graph * graph,EdgeShapes & tpu_output_shapes,GraphShapeInfo * shape_info)709 GroupedEdges GroupTensorsForOutputPacking(Graph* graph,
710                                           EdgeShapes& tpu_output_shapes,
711                                           GraphShapeInfo* shape_info) {
712   GroupedEdges shape_to_output;
713   for (const Edge* edge : graph->edges()) {
714     if (edge->IsControlEdge()) continue;
715 
716     // TPU input edge: src is on TPU and dest is on CPU.
717     if (edge->dst()->type_string() != "TPUReplicatedOutput") continue;
718     if (!shape_info->count(edge->src()->name())) continue;
719 
720     // output shapes for hashing
721     std::vector<int>& output_shapes = tpu_output_shapes[edge];
722     output_shapes.clear();
723 
724     int output_id = edge->src_output();
725     auto inferred_shape_vec = shape_info->at(edge->src()->name());
726 
727     for (int d : inferred_shape_vec.at(output_id).shape.dim_sizes()) {
728       output_shapes.push_back(d);
729     }
730 
731     // Hash Shape and Types.
732     DataType dtype = edge->src()->input_type(output_id);
733     string hash_key =
734         HashShapeAndType("output_tensors_dtype_", output_shapes, dtype, false);
735 
736     shape_to_output[hash_key].push_back(edge);
737   }
738   return shape_to_output;
739 }
740 
741 // Concatenates input tensors on CPU along the last dimension if all other
742 // dimensions are the same, and split them on TPU to reduce input overhead.
743 // `tpu_input_shapes` maps an edge to the shape of its output tensor.
744 // `grouped_input_edges` maps tensor name to all edges output from this tensor.
CreateConcatAndSplitNodesForInputTensor(Graph * graph,const string & cluster_name,EdgeShapes * tpu_input_shapes,const absl::flat_hash_map<std::string,std::vector<const Edge * >> & grouped_input_edges,int32_t minimum_input_tensors_packing,bool xla_spmd_input_sharded,const XlaShardingInfoMap & xla_sharding_info,const TpuReplicatedInputInfoMap & tpu_replicated_input_info)745 Status CreateConcatAndSplitNodesForInputTensor(
746     Graph* graph, const string& cluster_name, EdgeShapes* tpu_input_shapes,
747     const absl::flat_hash_map<std::string, std::vector<const Edge*>>&
748         grouped_input_edges,
749     int32_t minimum_input_tensors_packing, bool xla_spmd_input_sharded,
750     const XlaShardingInfoMap& xla_sharding_info,
751     const TpuReplicatedInputInfoMap& tpu_replicated_input_info) {
752   for (const auto& iter : grouped_input_edges) {
753     std::vector<int> last_dim_vec;
754     std::vector<NodeBuilder::NodeOut> concat_nodeouts;
755     absl::flat_hash_map<std::string, int> tensor_to_split_output;
756     int rank;
757     DataType dtype = DT_INVALID;
758     std::string src_name;
759     for (const Edge* edge : iter.second) {
760       src_name = edge->src()->name();
761       string tensor_name =
762           absl::StrCat(edge->src()->name(), ":", edge->src_output());
763       // Create Concat / Split pair for a tensor if not exist yet.
764       if (tensor_to_split_output.contains(tensor_name)) continue;
765       tensor_to_split_output[tensor_name] = concat_nodeouts.size();
766       concat_nodeouts.push_back(
767           NodeBuilder::NodeOut(edge->src(), edge->src_output()));
768       dtype = edge->src()->output_type(edge->src_output());
769       rank = tpu_input_shapes->at(edge).size();
770       last_dim_vec.push_back(tpu_input_shapes->at(edge).back());
771     }
772 
773     const int num_tensors = tensor_to_split_output.size();
774     VLOG(3) << iter.first << " num_tensors: " << num_tensors;
775     if (num_tensors < minimum_input_tensors_packing) {
776       VLOG(3) << "skip concat/split " << iter.first;
777       continue;
778     }
779 
780     Node* concat_axis_node = nullptr;
781     TensorShape t_shape;
782     Tensor dim_tensor(DT_INT32, t_shape);
783     // Concat and Split at the last dim.
784     dim_tensor.flat<int>()(0) = rank - 1;
785     TF_RETURN_IF_ERROR(
786         NodeBuilder(strings::StrCat(iter.first, "/concat/axis"), "Const")
787             .Attr("dtype", DT_INT32)
788             .Attr("value", dim_tensor)
789             .Finalize(graph, &concat_axis_node));
790 
791     Node* concat_node = nullptr;
792     TF_RETURN_IF_ERROR(
793         NodeBuilder(strings::StrCat(iter.first, "/concat"), "ConcatV2")
794             .Input(concat_nodeouts)
795             .Input(concat_axis_node)
796             .Attr("T", dtype)
797             .Attr("Tidx", DT_INT32)
798             .Attr("N", num_tensors)
799             .Finalize(graph, &concat_node));
800 
801     Node* split_dim_node = nullptr;
802     TF_RETURN_IF_ERROR(
803         NodeBuilder(strings::StrCat(iter.first, "/split/split_dim"), "Const")
804             .Attr("dtype", DT_INT32)
805             .Attr("value", dim_tensor)
806             .Attr(kTpuReplicateAttr, cluster_name)
807             .Finalize(graph, &split_dim_node));
808 
809     Node* split_vec_node = nullptr;
810     TensorShape split_vec_shape;
811     split_vec_shape.AddDim(1);
812     split_vec_shape.set_dim(0, last_dim_vec.size());
813 
814     Tensor split_vec_tensor(DT_INT32, split_vec_shape);
815     for (int i = 0; i < last_dim_vec.size(); ++i) {
816       split_vec_tensor.flat<int>()(i) = last_dim_vec[i];
817     }
818     VLOG(3) << "split_vec_tensor: " << split_vec_tensor.DebugString();
819 
820     TF_RETURN_IF_ERROR(
821         NodeBuilder(strings::StrCat(iter.first, "/split/vec"), "Const")
822             .Attr("dtype", DT_INT32)
823             .Attr("value", split_vec_tensor)
824             .Attr(kTpuReplicateAttr, cluster_name)
825             .Finalize(graph, &split_vec_node));
826 
827     Node* split_node = nullptr;
828     Node* input_to_split_node = concat_node;
829     Node* output_from_concat_node = nullptr;
830     if (xla_spmd_input_sharded &&
831         tpu_replicated_input_info.count(src_name) > 0 &&
832         xla_sharding_info.count(src_name) > 0) {
833       // Create new TPUReplicatedInput and XLAShardingOp nodes
834       //
835       // Rewrite the graph from:
836       //   Concat -> Split
837       // to
838       //   Concat -> TPUReplicatedInput -> XlaSharding -> Split
839       Node* tpu_replicated_input = nullptr;
840       Node* xla_sharding_op = nullptr;
841 
842       std::vector<NodeBuilder::NodeOut> replicated_input;
843       replicated_input.push_back(NodeBuilder::NodeOut(concat_node));
844 
845       // TODO(b/183060455): Add TPUReplicatedInput to all graphs.
846       TF_RETURN_IF_ERROR(
847           NodeBuilder(strings::StrCat(iter.first, "/TPUReplicatedInput"),
848                       "TPUReplicatedInput")
849               .Input(replicated_input)
850               .ControlInput(std::get<1>(tpu_replicated_input_info.at(src_name)))
851               .Attr("N", 1)
852               .Attr("T", std::get<0>(tpu_replicated_input_info.at(src_name)))
853               .Attr("index", -1)
854               .Attr("is_mirrored_variable", false)
855               .Attr("is_packed", false)
856               .Finalize(graph, &tpu_replicated_input));
857       VLOG(2) << "Created new TPUReplicatedInput node "
858               << tpu_replicated_input->DebugString();
859 
860       TF_RETURN_IF_ERROR(
861           NodeBuilder(strings::StrCat(iter.first, "/XlaSharding"),
862                       "XlaSharding")
863               .Input(tpu_replicated_input)
864               .Attr("T", std::get<0>(xla_sharding_info.at(src_name)))
865               .Attr("sharding", std::get<1>(xla_sharding_info.at(src_name)))
866               .Attr("_XlaSharding", std::get<1>(xla_sharding_info.at(src_name)))
867               .Attr("_tpu_replicate",
868                     std::get<2>(xla_sharding_info.at(src_name)))
869               .Finalize(graph, &xla_sharding_op));
870       VLOG(2) << "Created new XLA sharding node "
871               << xla_sharding_op->DebugString();
872 
873       input_to_split_node = xla_sharding_op;
874       output_from_concat_node = tpu_replicated_input;
875     }
876     // Update the `tpu_input_shapes` mapping: Add the new edge
877     // from concat to split.
878     TF_RETURN_IF_ERROR(
879         NodeBuilder(strings::StrCat(iter.first, "/split"), "SplitV")
880             .Input(input_to_split_node)
881             .Input(split_vec_node)
882             .Input(split_dim_node)
883             .Attr("T", dtype)
884             .Attr("num_split", num_tensors)
885             .Attr(kTpuReplicateAttr, cluster_name)
886             .Finalize(graph, &split_node));
887 
888     if (output_from_concat_node == nullptr)
889       output_from_concat_node = split_node;
890 
891     const Edge* concat_to_split;
892     for (const Edge* edge : concat_node->out_edges())
893       if (edge->dst() == output_from_concat_node) {
894         concat_to_split = edge;
895         break;
896       }
897     if (rank > 1) {
898       for (int d = 0; d < rank - 1; ++d)
899         (*tpu_input_shapes)[concat_to_split].push_back(
900             tpu_input_shapes->at(iter.second.back()).at(d));
901     }
902     (*tpu_input_shapes)[concat_to_split].push_back(
903         std::accumulate(last_dim_vec.begin(), last_dim_vec.end(), 0));
904 
905     // Connect split node to original tensor output.
906     for (const Edge* edge : iter.second) {
907       string tensor_name =
908           absl::StrCat(edge->src()->name(), ":", edge->src_output());
909       int output_index = tensor_to_split_output.at(tensor_name);
910       graph->RemoveEdge(edge);
911       graph->AddEdge(split_node, output_index, edge->dst(), edge->dst_input());
912       // Update the `tpu_input_shapes` mapping: Remove old edges.
913       tpu_input_shapes->erase(edge);
914     }
915     VLOG(3) << "Concat node: " << concat_node->DebugString();
916   }
917   return OkStatus();
918 }
919 
920 // Concatenates input tensors on TPU along the last dimension if all other
921 // dimensions are the same, and split them on CPU to reduce outfeed overhead.
922 // `tpu_inferred_info` maps an edge to the inferred shape of its output tensor.
923 // `shape_to_output` maps tensor name to all edges output from this tensor.
CreateConcatAndSplitNodesForOutputTensor(Graph * graph,const string & cluster_name,EdgeShapes * tpu_output_shapes,GraphShapeInfo * tpu_inferred_info,GroupedEdges shape_to_output,int32_t minimum_output_tensors_packing)924 Status CreateConcatAndSplitNodesForOutputTensor(
925     Graph* graph, const string& cluster_name, EdgeShapes* tpu_output_shapes,
926     GraphShapeInfo* tpu_inferred_info, GroupedEdges shape_to_output,
927     int32_t minimum_output_tensors_packing) {
928   for (const auto& iter : shape_to_output) {
929     std::vector<int> last_dim_vec;
930     std::vector<NodeBuilder::NodeOut> concat_nodeouts;
931     absl::flat_hash_map<std::string, int> tensor_to_split_output;
932     int rank;
933     DataType dtype = DT_INVALID;
934     for (const Edge* edge : iter.second) {
935       string tensor_name =
936           absl::StrCat(edge->src()->name(), ":", edge->src_output());
937 
938       // Create Concat / Split pair for a tensor if not exist yet.
939       if (tensor_to_split_output.contains(tensor_name)) continue;
940       tensor_to_split_output[tensor_name] = concat_nodeouts.size();
941 
942       concat_nodeouts.push_back(
943           NodeBuilder::NodeOut(edge->src(), edge->src_output()));
944       dtype = edge->src()->output_type(edge->src_output());
945       rank = tpu_output_shapes->at(edge).size();
946       last_dim_vec.push_back(tpu_output_shapes->at(edge).back());
947     }
948 
949     const int num_tensors = tensor_to_split_output.size();
950     if (num_tensors < minimum_output_tensors_packing) {
951       VLOG(3) << "skip concat/split " << iter.first;
952       continue;
953     }
954 
955     Node* concat_axis_node = nullptr;
956     TensorShape t_shape;
957     Tensor dim_tensor(DT_INT32, t_shape);
958     // Concat and Split at the last dim.
959     dim_tensor.flat<int>()(0) = rank - 1;
960     TF_RETURN_IF_ERROR(
961         NodeBuilder(strings::StrCat(iter.first, "/concat/axis"), "Const")
962             .Attr("dtype", DT_INT32)
963             .Attr("value", dim_tensor)
964             .Attr(kTpuReplicateAttr, cluster_name)
965             .Finalize(graph, &concat_axis_node));
966 
967     Node* concat_node = nullptr;
968     TF_RETURN_IF_ERROR(
969         NodeBuilder(strings::StrCat(iter.first, "/concat"), "ConcatV2")
970             .Input(concat_nodeouts)
971             .Input(concat_axis_node)
972             .Attr("T", dtype)
973             .Attr("Tidx", DT_INT32)
974             .Attr("N", num_tensors)
975             .Attr(kTpuReplicateAttr, cluster_name)
976             .Finalize(graph, &concat_node));
977 
978     Node* tpu_replicated_output_node = nullptr;
979     TF_RETURN_IF_ERROR(
980         NodeBuilder(strings::StrCat(iter.first, "/tpu_replicated_output"),
981                     "TPUReplicatedOutput")
982             .Input(concat_node)
983             .Attr("T", dtype)
984             .Attr("num_replicas", 1)
985             .Finalize(graph, &tpu_replicated_output_node));
986 
987     Node* split_dim_node = nullptr;
988     TF_RETURN_IF_ERROR(
989         NodeBuilder(strings::StrCat(iter.first, "/split/split_dim"), "Const")
990             .Attr("dtype", DT_INT32)
991             .Attr("value", dim_tensor)
992             .Finalize(graph, &split_dim_node));
993 
994     Node* split_vec_node = nullptr;
995     TensorShape split_vec_shape;
996     split_vec_shape.AddDim(1);
997     split_vec_shape.set_dim(0, last_dim_vec.size());
998 
999     Tensor split_vec_tensor(DT_INT32, split_vec_shape);
1000     for (int i = 0; i < last_dim_vec.size(); ++i) {
1001       split_vec_tensor.flat<int>()(i) = last_dim_vec[i];
1002     }
1003     VLOG(3) << "split_vec_tensor: " << split_vec_tensor.DebugString();
1004 
1005     TF_RETURN_IF_ERROR(
1006         NodeBuilder(strings::StrCat(iter.first, "/split/vec"), "Const")
1007             .Attr("dtype", DT_INT32)
1008             .Attr("value", split_vec_tensor)
1009             .Finalize(graph, &split_vec_node));
1010 
1011     Node* split_node = nullptr;
1012     TF_RETURN_IF_ERROR(
1013         NodeBuilder(strings::StrCat(iter.first, "/split"), "SplitV")
1014             .Input(tpu_replicated_output_node)
1015             .Input(split_vec_node)
1016             .Input(split_dim_node)
1017             .Attr("T", dtype)
1018             .Attr("num_split", num_tensors)
1019             .Finalize(graph, &split_node));
1020 
1021     // Update the `tpu_out_shapes` mapping: Add the new edge
1022     // from concat to split.
1023     const Edge* concat_to_split;
1024     for (const Edge* edge : concat_node->out_edges())
1025       if (edge->dst() == split_node) {
1026         concat_to_split = edge;
1027         break;
1028       }
1029 
1030     if (rank > 1) (*tpu_output_shapes)[concat_to_split].push_back(-1);
1031     for (int d = 1; d < rank - 1; ++d)
1032       (*tpu_output_shapes)[concat_to_split].push_back(
1033           tpu_output_shapes->at(iter.second.back()).at(d));
1034     (*tpu_output_shapes)[concat_to_split].push_back(
1035         std::accumulate(last_dim_vec.begin(), last_dim_vec.end(), 0));
1036 
1037     for (const Edge* edge : iter.second) {
1038       // 1. Find old TPURelicatedOutput output edges
1039       std::vector<const Edge*> output_edge_vec;
1040       for (const Edge* output_edge : edge->dst()->out_edges())
1041         output_edge_vec.push_back(output_edge);
1042 
1043       string tensor_name =
1044           absl::StrCat(edge->src()->name(), ":", edge->src_output());
1045       int output_index = tensor_to_split_output.at(tensor_name);
1046       VLOG(3) << "output_index: " << output_index;
1047 
1048       // Connect split node to original tensor output.
1049       for (const Edge* output_edge : output_edge_vec) {
1050         VLOG(3) << "output_edge" << output_edge->DebugString();
1051         graph->RemoveEdge(output_edge);
1052         graph->AddEdge(split_node, output_index, output_edge->dst(),
1053                        output_edge->dst_input());
1054         // Update the `tpu_output_shapes` mapping: Remove old edges.
1055         tpu_output_shapes->erase(output_edge);
1056       }
1057       graph->RemoveNode(edge->dst());
1058     }
1059     VLOG(3) << "Concat node: " << concat_node->DebugString();
1060   }
1061   return OkStatus();
1062 }
1063 
InsertReshapeNodePairs(Graph * graph,const string & cluster_name,EdgeShapes * tpu_input_shapes,int num_cores_per_replica)1064 Status InsertReshapeNodePairs(Graph* graph, const string& cluster_name,
1065                               EdgeShapes* tpu_input_shapes,
1066                               int num_cores_per_replica) {
1067   std::vector<const Edge*> tpu_input_edges_original;
1068   for (const auto& it : *tpu_input_shapes)
1069     if (!it.second.empty()) tpu_input_edges_original.push_back(it.first);
1070   for (const Edge* edge : tpu_input_edges_original) {
1071     VLOG(3) << "Reshape input: " << edge->DebugString();
1072 
1073     // Check if there is a TPUReplicatedInput and XlaSharding in the middle
1074     bool xla_sharded_input = false;
1075     Node* xla_sharding_node = nullptr;
1076     if (edge->dst()->type_string() == "TPUReplicatedInput" &&
1077         edge->dst()->out_nodes().begin()->type_string() == "XlaSharding") {
1078       VLOG(3) << "Detected TPUReplicatedInput " << edge->dst()->DebugString()
1079               << " and XlaSharding "
1080               << edge->dst()->out_nodes().begin()->DebugString()
1081               << ", setting xla_sharded_input = true";
1082       xla_sharded_input = true;
1083       xla_sharding_node = *(edge->dst()->out_nodes().begin());
1084     }
1085 
1086     // 1. Build Reshape node for flatten.
1087 
1088     // 1.1 Build Const node for shape
1089     Node* flatten_reshape_shape_node = nullptr;
1090     Tensor flattened_input_shape_tensor;
1091     flattened_input_shape_tensor =
1092         Tensor(DT_INT32, TensorShape({static_cast<int64_t>(1)}));
1093     flattened_input_shape_tensor.flat<int>()(0) = -1;
1094     TF_RETURN_IF_ERROR(
1095         NodeBuilder(absl::StrCat(edge->src()->name(), "/flatten/Reshape/shape"),
1096                     "Const")
1097             .Attr("dtype", DT_INT32)
1098             .Attr("value", flattened_input_shape_tensor)
1099             .Finalize(graph, &flatten_reshape_shape_node));
1100 
1101     // 1.2 Build Reshape node for flatten.
1102     Node* flatten_reshape_node = nullptr;
1103     TF_RETURN_IF_ERROR(
1104         NodeBuilder(absl::StrCat(edge->src()->name(), "/flatten/Reshape"),
1105                     "Reshape")
1106             .Input(edge->src(), edge->src_output())
1107             .Input(flatten_reshape_shape_node)
1108             .Attr("T", edge->src()->output_type(edge->src_output()))
1109             .Attr("Tshape", DT_INT32)
1110             .Finalize(graph, &flatten_reshape_node));
1111 
1112     // 2. Build Reshape node for recover.
1113 
1114     // 2.1 Build Const node for shape.
1115     Node* recover_reshape_shape_node = nullptr;
1116     Tensor original_input_shape_tensor(
1117         DT_INT32,
1118         TensorShape({static_cast<int64_t>(tpu_input_shapes->at(edge).size())}));
1119     original_input_shape_tensor.flat<int>()(0) = -1;
1120     for (int d = 1; d < tpu_input_shapes->at(edge).size(); ++d)
1121       original_input_shape_tensor.flat<int>()(d) =
1122           tpu_input_shapes->at(edge).at(d);
1123     TF_RETURN_IF_ERROR(
1124         NodeBuilder(absl::StrCat(edge->src()->name(), "/recover/Reshape/shape"),
1125                     "Const")
1126             .Attr("dtype", DT_INT32)
1127             .Attr("value", original_input_shape_tensor)
1128             .Attr(kTpuReplicateAttr, cluster_name)  // This node is on TPU.
1129             .Finalize(graph, &recover_reshape_shape_node));
1130 
1131     // 2.2 Build Reshape node for recover.
1132     Node* recover_reshape_input_node = flatten_reshape_node;
1133     const Edge* original_recover_reshape_input_edge = nullptr;
1134     if (xla_sharded_input) {
1135       // We want to find the node after the XlaSharding node
1136       original_recover_reshape_input_edge =
1137           *(edge->dst()->out_nodes().begin()->out_edges().begin());
1138       recover_reshape_input_node = *(edge->dst()->out_nodes().begin());
1139       VLOG(3) << "Recover reshape input node: "
1140               << recover_reshape_input_node->DebugString()
1141               << ", recover reshape input edge: "
1142               << original_recover_reshape_input_edge->DebugString();
1143     }
1144 
1145     Node* recover_reshape_node = nullptr;
1146     TF_RETURN_IF_ERROR(
1147         NodeBuilder(absl::StrCat(edge->src()->name(), "/recover/Reshape"),
1148                     "Reshape")
1149             .Input(recover_reshape_input_node)
1150             .Input(recover_reshape_shape_node)
1151             .Attr("T", edge->src()->output_type(edge->src_output()))
1152             .Attr("Tshape", DT_INT32)
1153             .Attr(kTpuReplicateAttr, cluster_name)  // This node is on TPU.
1154             .Finalize(graph, &recover_reshape_node));
1155 
1156     // 3. Rewrite XlaSharding attribute if necessary
1157     if (xla_sharding_node != nullptr) {
1158       // The flattened tensor always has rank = 1 and we want to shard the only
1159       // dimension (0).
1160       SetXlaShardingNodeAttr(xla_sharding_node, num_cores_per_replica, 1, 0);
1161     }
1162 
1163     // 4. Connect / disconnect nodes.
1164     if (xla_sharded_input) {
1165       graph->AddEdge(flatten_reshape_node, 0, edge->dst(), edge->dst_input());
1166     }
1167 
1168     if (original_recover_reshape_input_edge != nullptr) {
1169       graph->AddEdge(recover_reshape_node, 0,
1170                      original_recover_reshape_input_edge->dst(),
1171                      original_recover_reshape_input_edge->dst_input());
1172     } else {
1173       graph->AddEdge(recover_reshape_node, 0, edge->dst(), edge->dst_input());
1174     }
1175 
1176     graph->RemoveEdge(edge);
1177     if (original_recover_reshape_input_edge != nullptr) {
1178       graph->RemoveEdge(original_recover_reshape_input_edge);
1179     }
1180 
1181     // 4. Update EdgeShapes.
1182     int dimension = 1;
1183     for (auto& it : (*tpu_input_shapes)[edge]) {
1184       dimension *= it;
1185     }
1186     VLOG(3) << "Dimension after reshape: " << dimension;
1187     for (const Edge* out_edge : flatten_reshape_node->out_edges()) {
1188       if (out_edge->dst() == recover_reshape_node) {
1189         (*tpu_input_shapes)[out_edge].push_back(dimension);
1190         tpu_input_shapes->erase(edge);
1191         break;
1192       }
1193     }
1194     VLOG(3) << "Reshape optimization done for " << edge->src()->name();
1195   }
1196   return OkStatus();
1197 }
1198 }  // namespace tpu_functional_internal
1199 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)1200 void TPUPartitionedCallOp::ComputeAsync(OpKernelContext* ctx,
1201                                         DoneCallback done) {
1202   Status init_status;
1203   absl::call_once(once_, [&]() {
1204     library_runtime_ = ctx->function_library();
1205     if (library_runtime_ == nullptr) {
1206       init_status = errors::Internal("No function library is provided.");
1207       return;
1208     }
1209     flib_def_ = std::make_unique<FunctionLibraryDefinition>(
1210         *library_runtime_->GetFunctionLibraryDefinition());
1211     device_mgr_ = library_runtime_->device_mgr();
1212     for (auto d : device_mgr_->ListDevices()) {
1213       device_set_.AddDevice(d);
1214     }
1215 
1216     DeviceNameUtils::ParsedName tpu_device_name;
1217     tpu_device_name.has_type = true;
1218     tpu_device_name.type = "TPU";
1219     std::vector<Device*> tpu_devices;
1220     device_set_.FindMatchingDevices(tpu_device_name, &tpu_devices_);
1221   });
1222   OP_REQUIRES_OK_ASYNC(ctx, init_status, done);
1223 
1224   // Initialize the ordinal selector with information from the graph if it is
1225   // the first time we are running this op.
1226   absl::call_once(ordinal_selector_once_, [&]() {
1227     std::unique_ptr<Graph> graph(new Graph(flib_def_.get()));
1228     bool enable_spmd_xla_partitioning = false;
1229     TPUMetadata tpu_metadata;
1230     {
1231       absl::MutexLock l(&mu_);
1232       OP_REQUIRES_OK_ASYNC(
1233           ctx,
1234           GetGraphFromFunction(graph.get(), /*device_ordinal=*/0,
1235                                &enable_spmd_xla_partitioning, &tpu_metadata),
1236           done);
1237     }
1238     if (enable_spmd_xla_partitioning) {
1239       ordinal_selector_ = std::make_shared<tpu::TPUOrdinalSelector>(
1240           tpu_metadata.num_cores_per_replica);
1241     } else {
1242       ordinal_selector_ = std::make_shared<tpu::TPUOrdinalSelector>();
1243     }
1244 
1245     metrics::RecordTPUXlaSpmdCoresPerReplica(
1246         tpu_metadata.num_cores_per_replica);
1247   });
1248   OP_REQUIRES_ASYNC(
1249       ctx, ordinal_selector_ != nullptr,
1250       errors::Internal("The TPUOrdinalSelector is not initialized."), done);
1251 
1252   uint64 input_hash = GetInputHash(ctx);
1253   int64_t ordinal_selector_req_id = -1;
1254   // Select a TPU core.
1255   int32_t device_ordinal = 0;
1256   OP_REQUIRES_OK_ASYNC(
1257       ctx,
1258       GetTpuCoreOrdinal(ctx, input_hash, &ordinal_selector_req_id,
1259                         &device_ordinal),
1260       done);
1261   uint64 cache_hash = Hash64Combine(input_hash, device_ordinal);
1262   absl::ReleasableMutexLock lock(&mu_);
1263 
1264   const std::vector<DeviceAndFHandle>* functions;
1265 
1266   bool cache_miss = !partition_cache_.count(cache_hash);
1267   if (cache_miss) {
1268     VLOG(3) << "Cache Miss: partitioning function " << func_.name()
1269             << " cache_hash: " << cache_hash
1270             << " device_ordinal: " << device_ordinal;
1271 
1272     profiler::TraceMe trace_me(
1273         "TPUPartitionedCallOp-RewriteAndInstantiateFunctions");
1274     std::unique_ptr<Graph> graph(new Graph(flib_def_.get()));
1275     bool enable_spmd_xla_partitioning = false;
1276     TPUMetadata tpu_metadata;
1277     OP_REQUIRES_OK_ASYNC(
1278         ctx,
1279         GetGraphFromFunction(graph.get(), device_ordinal,
1280                              &enable_spmd_xla_partitioning, &tpu_metadata),
1281         done);
1282 
1283     VLOG(1) << DumpGraphToFile("before_input_output_optimizations", *graph,
1284                                flib_def_.get());
1285 
1286     std::map<std::string, std::vector<int>> named_input_shapes;
1287     OP_REQUIRES_OK_ASYNC(
1288         ctx,
1289         OptimizeTpuInputOutputTensors(graph.get(), enable_spmd_xla_partitioning,
1290                                       tpu_metadata.num_cores_per_replica,
1291                                       named_input_shapes, ctx),
1292         done);
1293 
1294     VLOG(1) << DumpGraphToFile(
1295         "before_replace_resource_args_with_var_handle_ops", *graph,
1296         flib_def_.get());
1297     OP_REQUIRES_OK_ASYNC(ctx,
1298                          ReplaceResourceArgsWithVarHandleOps(
1299                              graph.get(), ctx, device_ordinal,
1300                              enable_spmd_xla_partitioning, tpu_metadata),
1301                          done);
1302 
1303     VLOG(1) << DumpGraphToFile(
1304         "after_replace_resource_args_with_var_handle_ops", *graph,
1305         flib_def_.get());
1306 
1307     // Graph rewrite passes.
1308     GraphOptimizationPassOptions optimization_options;
1309     // TODO(akshayka): Thread the SessionOptions into this kernel, or make
1310     // it possible to specify the relevant options via attributes.
1311     SessionOptions session_options;
1312     session_options.config.mutable_experimental()
1313         ->set_xla_fusion_autotuner_thresh(autotuner_thresh_);
1314 
1315     session_options.env = ctx->env();
1316     optimization_options.session_handle = ctx->session_handle();
1317     optimization_options.session_options = &session_options;
1318     optimization_options.graph = &graph;
1319     optimization_options.flib_def = flib_def_.get();
1320     optimization_options.device_set = &device_set_;
1321     OP_REQUIRES_OK_ASYNC(
1322         ctx, PlacementHelper(device_set_, optimization_options, func_.name()),
1323         done);
1324 
1325     if (!enable_spmd_xla_partitioning ||
1326         tpu_metadata.num_cores_per_replica == 1) {
1327       OP_REQUIRES_OK_ASYNC(
1328           ctx,
1329           MaybeRegisterFingerprint(graph.get(), named_input_shapes, input_hash),
1330           done);
1331     }
1332     // `subgraphs` maps from device names to functions.
1333     std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs;
1334     optimization_options.graph = nullptr;
1335     optimization_options.device_set = nullptr;
1336     optimization_options.partition_graphs = &subgraphs;
1337     VLOG(1) << DumpGraphToFile("before_partition_helper.pbtxt", *graph,
1338                                flib_def_.get());
1339     OP_REQUIRES_OK_ASYNC(ctx,
1340                          PartitionHelper(device_set_, optimization_options,
1341                                          graph.get(), &subgraphs),
1342                          done);
1343     OP_REQUIRES_OK_ASYNC(
1344         ctx,
1345         InstantiateFunctionsFromSubgraphs(
1346             device_set_, device_ordinal, cache_hash,
1347             tpu_metadata.num_cores_per_replica, std::move(subgraphs)),
1348         done);
1349   }
1350   functions = &partition_cache_[cache_hash];
1351   lock.Release();
1352 
1353   ExecuteFunctions(*functions, ctx, device_ordinal, ordinal_selector_req_id,
1354                    std::move(done));
1355 }
1356 
GetTpuCoreOrdinal(OpKernelContext * ctx,uint64 input_hash,int64_t * ordinal_selector_req_id,int32_t * core_ordinal)1357 Status TPUPartitionedCallOp::GetTpuCoreOrdinal(OpKernelContext* ctx,
1358                                                uint64 input_hash,
1359                                                int64_t* ordinal_selector_req_id,
1360                                                int32_t* core_ordinal) {
1361   profiler::TraceMe trace_me("TPUPartitionedCallOp-GetTpuCoreOrdinal");
1362   const Tensor* device_ordinal_t;
1363   TF_RETURN_IF_ERROR(ctx->input(kDeviceOrdinalAttr, &device_ordinal_t));
1364   int device_ordinal = device_ordinal_t->scalar<int>()();
1365   if (device_ordinal == tpu::kDeferredCoreSelectionReserved) {
1366     device_ordinal =
1367         ordinal_selector_->GetOrdinal(input_hash, ordinal_selector_req_id);
1368   }
1369   *core_ordinal = device_ordinal;
1370   return OkStatus();
1371 }
1372 
InitializeVarOnTPU(OpKernelContext * ctx,const core::RefCountPtr<Var> & var,NodeDef * ndef,int device_ordinal,bool fast_mem)1373 Status TPUPartitionedCallOp::InitializeVarOnTPU(
1374     OpKernelContext* ctx, const core::RefCountPtr<Var>& var, NodeDef* ndef,
1375     int device_ordinal, bool fast_mem) {
1376   const string device = strings::StrCat(kTPUDeviceNamePrefix, device_ordinal);
1377   Status status;
1378   std::unique_ptr<Graph> init_graph(new Graph(OpRegistry::Global()));
1379   TF_ASSIGN_OR_RETURN(Node * init_handle, init_graph->AddNode(*ndef));
1380   init_handle->set_assigned_device_name(device);
1381 
1382   NodeDef init_const_ndef;
1383   init_const_ndef.set_name("initial_value");
1384 #if defined(LIBTPU_ON_GCE)  // TODO(b/217559071) - Remove once _TPUConst is OSS
1385   init_const_ndef.set_op("Const");
1386 #else
1387   init_const_ndef.set_op("_TPUConst");
1388   AddNodeAttr("memory_space", "HBM", &init_const_ndef);
1389 #endif
1390   init_const_ndef.set_device(device);
1391   AddNodeAttr("dtype", var->tensor()->dtype(), &init_const_ndef);
1392   AddNodeAttr("value", *var->tensor(), &init_const_ndef);
1393 
1394   TF_ASSIGN_OR_RETURN(Node * init_const, init_graph->AddNode(init_const_ndef));
1395 
1396   NodeDef assign_node_def;
1397   assign_node_def.set_name("Assign");
1398   assign_node_def.set_op("AssignVariableOp");
1399   assign_node_def.set_device(device);
1400   AddNodeAttr("dtype", var->tensor()->dtype(), &assign_node_def);
1401   TF_ASSIGN_OR_RETURN(Node * init_assign, init_graph->AddNode(assign_node_def));
1402 
1403   init_graph->AddEdge(init_handle, 0, init_assign, 0);
1404   init_graph->AddEdge(init_const, 0, init_assign, 1);
1405   FHandle fhandle;
1406   const string fname =
1407       strings::StrCat(ndef->name(), "_init_ord_", device_ordinal);
1408 
1409   TF_RETURN_IF_ERROR(
1410       InstantiatePartition(*init_graph, fname, device, &fhandle, nullptr));
1411 
1412   FunctionLibraryRuntime::Options opts;
1413   opts.step_container = ctx->step_container();
1414   opts.cancellation_manager = ctx->cancellation_manager();
1415   opts.stats_collector = ctx->stats_collector();
1416 
1417   // Blocking on threads in the same thread pool is disallowed because
1418   // concurrent warm-up requests can exhaust the default thread pool.
1419   // Create a new thread pool to initialize variables on TPU.
1420   std::function<void(std::function<void()>)> runner =
1421       [this](std::function<void()> fn) { pool_.Schedule(fn); };
1422   opts.runner = &runner;
1423 
1424   opts.source_device = local_device_name_;
1425   PrivateIntraProcessRendezvous rendez(device_mgr_);
1426   opts.rendezvous = &rendez;
1427   opts.remote_execution = true;
1428 
1429   std::vector<Tensor> dummy_args;
1430   std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
1431   Notification done;
1432   profiler::TraceMe trace_me("TPUPartitionedCallOp-InitializeVarOnTPU");
1433   library_runtime_->Run(opts, fhandle, dummy_args, dummy_rets,
1434                         [dummy_rets, &done, ctx](const Status& status) {
1435                           if (!status.ok()) {
1436                             ctx->SetStatus(status);
1437                           }
1438                           delete dummy_rets;
1439                           done.Notify();
1440                         });
1441   done.WaitForNotification();
1442   // We don't actually want the variable initialization functions
1443   // in the function library definition and the function library
1444   // runtime, because flib_def_ is used for the graph rewrite passes.
1445   // The TPU distributed rewrite pass computes a fingerprint for
1446   // flib_def_, which will throw an length error if there are
1447   // many variables whose initialization functions are added
1448   // to the library definition.
1449   TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(fname));
1450   TF_RETURN_IF_ERROR(library_runtime_->ReleaseHandle(fhandle));
1451   return OkStatus();
1452 }
1453 
InitializeShardedVarOnTPU(OpKernelContext * ctx,const core::RefCountPtr<Var> & var,std::vector<NodeDef> & ndefs,int split_dim,const std::vector<string> & tpu_devices)1454 Status TPUPartitionedCallOp::InitializeShardedVarOnTPU(
1455     OpKernelContext* ctx, const core::RefCountPtr<Var>& var,
1456     std::vector<NodeDef>& ndefs, int split_dim,
1457     const std::vector<string>& tpu_devices) {
1458   std::unique_ptr<Graph> init_graph(new Graph(OpRegistry::Global()));
1459   int num_cores = ndefs.size();
1460   string cpu_device = "/device:CPU:0";
1461 
1462   Status status;
1463   std::vector<std::string> devices;
1464   std::vector<Node*> init_handles;
1465   for (int i = 0; i < num_cores; i++) {
1466     TF_ASSIGN_OR_RETURN(Node * init_handle, init_graph->AddNode(ndefs[i]));
1467     string device = tpu_devices[i];
1468     init_handle->set_assigned_device_name(device);
1469     devices.push_back(device);
1470     init_handles.push_back(init_handle);
1471   }
1472 
1473   NodeDef init_const_ndef;
1474   init_const_ndef.set_name("initial_value");
1475   init_const_ndef.set_op("Const");
1476   init_const_ndef.set_device(cpu_device);
1477   AddNodeAttr("dtype", var->tensor()->dtype(), &init_const_ndef);
1478   AddNodeAttr("value", *var->tensor(), &init_const_ndef);
1479   TF_ASSIGN_OR_RETURN(Node * init_const, init_graph->AddNode(init_const_ndef));
1480   init_const->set_assigned_device_name(cpu_device);
1481 
1482   Node* assign_value_node = init_const;
1483   // If the variable is sharded, we will insert "Split" node between the initial
1484   // value and AssignVariableOp, so the variables on each TPU device get
1485   // assigned to the splitted value.
1486   //
1487   // initial_value--Split--AssignVariableOp ("/device:TPU:0")
1488   //                  |
1489   //            AssignVariableOp ("/device:TPU:1")
1490   if (split_dim >= 0) {
1491     // Add a split dimension node.
1492     NodeDef split_dim_def;
1493     split_dim_def.set_name("initial_value_split_dim");
1494     split_dim_def.set_op("Const");
1495     split_dim_def.set_device(cpu_device);
1496     AddNodeAttr("dtype", DT_INT32, &split_dim_def);
1497     TensorProto tensor_proto;
1498     tensor_proto.set_dtype(DT_INT32);
1499     tensor_proto.add_int_val(split_dim);
1500     TensorShape shape({});
1501     shape.AsProto(tensor_proto.mutable_tensor_shape());
1502     AddNodeAttr("value", tensor_proto, &split_dim_def);
1503     TF_ASSIGN_OR_RETURN(Node * split_dim_node,
1504                         init_graph->AddNode(split_dim_def));
1505     split_dim_node->set_assigned_device_name(cpu_device);
1506 
1507     // Add a split node.
1508     NodeDef split_def;
1509     int split_num = ndefs.size();
1510     split_def.set_name("initial_value_split");
1511     split_def.set_op("Split");
1512     split_def.set_device(cpu_device);
1513     AddNodeAttr("num_split", split_num, &split_def);
1514     AddNodeAttr("T", var->tensor()->dtype(), &split_def);
1515     split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
1516     split_def.add_input(absl::StrCat(init_const->name(), ":0"));
1517     TF_ASSIGN_OR_RETURN(Node * split_node, init_graph->AddNode(split_def));
1518     split_node->set_assigned_device_name(cpu_device);
1519 
1520     init_graph->AddEdge(split_dim_node, 0, split_node, 0);
1521     init_graph->AddEdge(init_const, 0, split_node, 1);
1522 
1523     assign_value_node = split_node;
1524   }
1525 
1526   for (int i = 0; i < num_cores; i++) {
1527     NodeDef assign_node_def;
1528     assign_node_def.set_name(absl::StrCat("Assign_", i));
1529     assign_node_def.set_op("AssignVariableOp");
1530     assign_node_def.set_device(devices[i]);
1531     AddNodeAttr("dtype", var->tensor()->dtype(), &assign_node_def);
1532     TF_ASSIGN_OR_RETURN(Node * init_assign,
1533                         init_graph->AddNode(assign_node_def));
1534     init_assign->set_assigned_device_name(devices[i]);
1535 
1536     init_graph->AddEdge(init_handles[i], 0, init_assign, 0);
1537     if (split_dim >= 0) {
1538       init_graph->AddEdge(assign_value_node, i, init_assign, 1);
1539     } else {
1540       init_graph->AddEdge(assign_value_node, 0, init_assign, 1);
1541     }
1542   }
1543 
1544   GraphOptimizationPassOptions optimization_options;
1545   SessionOptions session_options;
1546   session_options.env = ctx->env();
1547   optimization_options.session_handle = ctx->session_handle();
1548   optimization_options.session_options = &session_options;
1549   optimization_options.flib_def = flib_def_.get();
1550   optimization_options.graph = nullptr;
1551   optimization_options.device_set = nullptr;
1552   std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs;
1553   optimization_options.partition_graphs = &subgraphs;
1554   TF_RETURN_IF_ERROR(PartitionHelper(device_set_, optimization_options,
1555                                      init_graph.get(), &subgraphs));
1556 
1557   std::vector<DeviceAndFHandle> functions;
1558   std::vector<std::string> function_names;
1559   for (auto& pair : subgraphs) {
1560     string target = pair.first;
1561     Device* device;
1562     TF_RETURN_IF_ERROR(
1563         library_runtime_->device_mgr()->LookupDevice(target, &device));
1564     Graph* subgraph = pair.second.get();
1565     string function_name = flib_def_->UniqueFunctionName(
1566         strings::StrCat(func_.name(), "_hash_", pair.first));
1567     function_names.push_back(function_name);
1568     FHandle handle;
1569     TF_RETURN_IF_ERROR(InstantiatePartition(*subgraph, function_name, target,
1570                                             &handle, nullptr));
1571     functions.push_back(DeviceAndFHandle{.device = target, .handle = handle});
1572   }
1573 
1574   FunctionLibraryRuntime::Options opts;
1575 
1576   // Blocking on threads in the same thread pool is disallowed because
1577   // concurrent warm-up requests can exhaust the default thread pool.
1578   // Create a new thread pool to initialize variables on TPU.
1579   std::function<void(std::function<void()>)> runner =
1580       [this](std::function<void()> fn) { pool_.Schedule(fn); };
1581   opts.runner = &runner;
1582 
1583   opts.step_container = ctx->step_container();
1584   opts.cancellation_manager = ctx->cancellation_manager();
1585   opts.stats_collector = ctx->stats_collector();
1586   opts.source_device = local_device_name_;
1587   opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
1588 
1589   OpInputList arguments;
1590   TF_RETURN_IF_ERROR(ctx->input_list("args", &arguments));
1591 
1592   PrivateIntraProcessRendezvous rendez(device_mgr_);
1593   opts.rendezvous = &rendez;
1594 
1595   BlockingCounter bcount(functions.size());
1596   for (const DeviceAndFHandle& entry : functions) {
1597     const string& target_device = entry.device;
1598     FHandle handle = entry.handle;
1599 
1600     TF_RETURN_IF_ERROR(
1601         ShouldUseRemoteExecutionForFn(target_device, &(opts.remote_execution)));
1602     std::vector<Tensor> dummy_args;
1603     std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
1604 
1605     profiler::TraceMe trace_me(
1606         "TPUPartitionedCallOp-InitializeShardedVarOnTPU");
1607     library_runtime_->Run(opts, handle, dummy_args, dummy_rets,
1608                           [dummy_rets, &bcount, ctx](const Status& status) {
1609                             if (!status.ok()) {
1610                               ctx->SetStatus(status);
1611                             }
1612                             delete dummy_rets;
1613                             bcount.DecrementCount();
1614                           });
1615   }
1616   bcount.Wait();
1617 
1618   for (int i = 0; i < functions.size(); i++) {
1619     TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(function_names[i]));
1620     TF_RETURN_IF_ERROR(library_runtime_->ReleaseHandle(functions[i].handle));
1621   }
1622   return OkStatus();
1623 }
1624 
IsInputToTPUReplicate(Node * node)1625 bool TPUPartitionedCallOp::IsInputToTPUReplicate(Node* node) {
1626   for (Node* successor : node->out_nodes()) {
1627     if (successor->attrs().Find(kTpuReplicateAttr) != nullptr) {
1628       return true;
1629     }
1630   }
1631   return false;
1632 }
1633 
ReplaceResourceArgsWithVarHandleOps(Graph * graph,OpKernelContext * ctx,int device_ordinal,bool enable_spmd_xla_partitioning,const TPUMetadata & tpu_metadata)1634 Status TPUPartitionedCallOp::ReplaceResourceArgsWithVarHandleOps(
1635     Graph* graph, OpKernelContext* ctx, int device_ordinal,
1636     bool enable_spmd_xla_partitioning, const TPUMetadata& tpu_metadata) {
1637   // Currently variable deduplication is not supported for XLA SPMD
1638   // partitioning. It is possible that it could be supported in the future.
1639   bool enable_variable_deduplication =
1640       runtime_params_.enable_variable_deduplication;
1641   if (enable_spmd_xla_partitioning && tpu_metadata.num_cores_per_replica > 1) {
1642     // If enable_spmd_xla_partitioning is true, the user set the
1643     // enable_auto_xla_input_sharding flag. Warn them that only one of the flags
1644     // can be set safely when num_cores_per_replica > 1. If
1645     // num_cores_per_replica==1, enable_spmd_xla_partitioning is effectively a
1646     // no-op so we can skip this check.
1647     LOG(WARNING) << "Disabling variable deduplication because it is not "
1648                     "compatible with enable_auto_xla_input_sharding.";
1649     enable_variable_deduplication = false;
1650   }
1651   std::vector<Node*> tpu_resource_args;
1652   std::vector<int> arg_indices;
1653   absl::flat_hash_map<const Node*, xla::OpSharding> variable_to_xla_sharding;
1654   for (Node* node : graph->op_nodes()) {
1655     if (node->IsArg()) {
1656       const AttrValue* attr_value;
1657       TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value));
1658       DataType dtype = attr_value->type();
1659       if (dtype == DT_RESOURCE && IsInputToTPUReplicate(node)) {
1660         // If this VarHandleOp is used by a TPU computation,
1661         // we need to create a TPU version of the variable,
1662         TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
1663         int index = attr_value->i();
1664         tpu_resource_args.push_back(node);
1665         arg_indices.push_back(index);
1666         replaced_input_indices_[index] = true;
1667       }
1668     }
1669   }
1670 
1671   VLOG(3) << "tpu_resource_args.size(): " << tpu_resource_args.size();
1672   // Create a mapping from ResourceHandle to variable node. When a
1673   // ResourceHandle backs several variable nodes, the variable nodes refer to
1674   // the same underlying resource. In that case, only one variable node needs
1675   // to be mirrored to the TPU for that resource.
1676   absl::flat_hash_map<uint64, Node*> tpu_variables;
1677   for (int i = 0; i < tpu_resource_args.size(); i++) {
1678     Node* node = tpu_resource_args[i];
1679     ResourceHandle handle = HandleFromInput(ctx, arg_indices[i]);
1680 
1681     if (tpu_metadata.num_cores_per_replica > 1 &&
1682         enable_spmd_xla_partitioning) {
1683       TF_RETURN_IF_ERROR(ReplaceAndPartitionXLAShardingVariable(
1684           graph, ctx, device_ordinal, handle, node, tpu_metadata));
1685       continue;
1686     }
1687     TPUVariableInfo var_info(/*device_ordinal_id=*/0, /*use_fast_mem=*/false);
1688     TF_RETURN_IF_ERROR(ParseTPUVariableInfor(
1689         node, tpu_metadata.num_cores_per_replica, &var_info));
1690     // Only respect graph's placement when model parallelism enabled.
1691     if (tpu_metadata.num_cores_per_replica > 1)
1692       device_ordinal = var_info.device_ordinal;
1693 
1694     const uint64 handle_fp =
1695         Fingerprint64(strings::StrCat(handle.container(), handle.name()));
1696     if (enable_variable_deduplication && tpu_variables.contains(handle_fp) &&
1697         tpu_metadata.num_cores_per_replica == 1) {
1698       Node* tpu_variable = tpu_variables.at(handle_fp);
1699       std::vector<Node*> dst_nodes;
1700       std::vector<int> src_indices;
1701       std::vector<int> dst_indices;
1702       for (const Edge* edge : node->out_edges()) {
1703         dst_nodes.push_back(edge->dst());
1704         src_indices.push_back(edge->src_output());
1705         dst_indices.push_back(edge->dst_input());
1706       }
1707       graph->RemoveNode(node);
1708       for (int i = 0; i < dst_nodes.size(); i++) {
1709         graph->AddEdge(tpu_variable, src_indices[i], dst_nodes[i],
1710                        dst_indices[i]);
1711       }
1712     } else {
1713       uint64 fp =
1714           Fingerprint64(strings::StrCat(handle.container(), handle.name(), i));
1715       NodeDef ndef;
1716       ndef.set_name(strings::StrCat(handle.name(), fp));
1717       ndef.set_op(kVarHandleOp);
1718       if (tpu_metadata.num_cores_per_replica > 1) {
1719         ndef.set_device(strings::StrCat(kTPUDeviceNamePrefix, device_ordinal));
1720       } else {
1721         // Assign this new VarHandleOp to TPU:0 so the partitioner only
1722         // partiitons the graph into two subgraphs, one on CPU and one on TPU.
1723         // The actual device ordinal on which this VarHandleOp runs is assigned
1724         // after partitioning (in SetDeviceOrdinal).
1725         ndef.set_device(
1726             strings::StrCat(kTPUDeviceNamePrefix, kTPUDefaultDeviceOrdinal));
1727       }
1728 
1729       // Replace each _Arg node of type DT_RESOURCE that goes into a TPU node
1730       // by a VarHandleOp on TPU with shared_name "v_tpu_x" where "v" is the
1731       // shared_name of the variable on CPU and "x" is the rewritten device
1732       // ordinal.
1733       const string sname =
1734           strings::StrCat(handle.name(), "_tpu_", device_ordinal);
1735       AddNodeAttr("shared_name", sname, &ndef);
1736       const string cname = ctx->resource_manager()->default_container();
1737       AddNodeAttr("container", cname, &ndef);
1738       core::RefCountPtr<Var> var;
1739       TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &var));
1740       AddNodeAttr("dtype", var->tensor()->dtype(), &ndef);
1741       TensorShapeProto proto;
1742       var->tensor()->shape().AsProto(&proto);
1743       AddNodeAttr("shape", proto, &ndef);
1744       TF_ASSIGN_OR_RETURN(Node * new_node, graph->AddNode(ndef));
1745       std::vector<const Edge*> in_edges(node->in_edges().begin(),
1746                                         node->in_edges().end());
1747       for (const Edge* edge : in_edges) {
1748         graph->AddEdge(edge->src(), edge->src_output(), new_node,
1749                        edge->dst_input());
1750       }
1751       std::vector<Node*> dst_nodes;
1752       std::vector<int> src_indices;
1753       std::vector<int> dst_indices;
1754       for (const Edge* edge : node->out_edges()) {
1755         dst_nodes.push_back(edge->dst());
1756         src_indices.push_back(edge->src_output());
1757         dst_indices.push_back(edge->dst_input());
1758       }
1759       graph->RemoveNode(node);
1760       for (int i = 0; i < dst_nodes.size(); i++) {
1761         graph->AddEdge(new_node, src_indices[i], dst_nodes[i], dst_indices[i]);
1762       }
1763       // Don't initialize variables on TPU if it is done for the ordinal
1764       // already.
1765       if (seen_ordinals_.contains(device_ordinal)) continue;
1766 
1767       Device* d;
1768       TF_RETURN_IF_ERROR(library_runtime_->device_mgr()->LookupDevice(
1769           strings::StrCat(kTPUDeviceNamePrefix, device_ordinal), &d));
1770       Var* tpu_var;
1771       Status status = d->resource_manager()->Lookup(cname, sname, &tpu_var);
1772       if (!status.ok()) {
1773         TF_RETURN_IF_ERROR(InitializeVarOnTPU(ctx, var, &ndef, device_ordinal,
1774                                               var_info.fast_mem));
1775         VLOG(3) << "Initialized variable on TPU: " << sname
1776                 << " device_ordinal: " << device_ordinal;
1777       }
1778       tpu_variables[handle_fp] = new_node;
1779     }
1780   }
1781 
1782   // adjust the index attr of other non-resource arg nodes
1783   int new_index = 0;
1784   for (Node* node : graph->op_nodes()) {
1785     if (node->IsArg()) {
1786       node->ClearAttr("index");
1787       node->AddAttr("index", new_index);
1788       new_index++;
1789     }
1790   }
1791 
1792   seen_ordinals_.insert(device_ordinal);
1793 
1794   return OkStatus();
1795 }
1796 
ReplaceAndPartitionXLAShardingVariable(Graph * graph,OpKernelContext * ctx,int device_ordinal,ResourceHandle & handle,Node * variable,const TPUMetadata & tpu_metadata)1797 Status TPUPartitionedCallOp::ReplaceAndPartitionXLAShardingVariable(
1798     Graph* graph, OpKernelContext* ctx, int device_ordinal,
1799     ResourceHandle& handle, Node* variable, const TPUMetadata& tpu_metadata) {
1800   TF_ASSIGN_OR_RETURN(
1801       auto sharding,
1802       GetShardingFromNodeDef(variable->def(), /*add_metadata=*/false));
1803   xla::OpSharding xla_sharding;
1804   bool is_var_sharded = false;
1805   if (sharding.has_value() &&
1806       sharding.value().type() == xla::OpSharding::OTHER) {
1807     xla_sharding = sharding.value();
1808     for (int dim = 0; dim < GetDimsFromXLAShardingTiled(xla_sharding); dim++) {
1809       is_var_sharded |= xla_sharding.tile_assignment_dimensions(dim) > 1;
1810     }
1811   } else {
1812     xla_sharding.set_type(xla::OpSharding::REPLICATED);
1813     is_var_sharded = false;
1814   }
1815   VLOG(3) << "Replace and partition variable " << variable->name()
1816           << " with xla_sharding: " << xla_sharding.DebugString();
1817 
1818   core::RefCountPtr<Var> var;
1819   TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &var));
1820 
1821   int split_dim = -1;
1822   int split_size = 0;
1823 
1824   if (is_var_sharded) {
1825     for (int dim = 0; dim < GetDimsFromXLAShardingTiled(xla_sharding); dim++) {
1826       if (xla_sharding.tile_assignment_dimensions(dim) > 1) {
1827         if (split_dim != -1) {
1828           return errors::InvalidArgument(
1829               "Currently we only support inference with one split dimension, "
1830               "however got sharding: ",
1831               xla_sharding.DebugString());
1832         }
1833         split_dim = dim;
1834         split_size = xla_sharding.tile_assignment_dimensions(dim);
1835       }
1836     }
1837     if (split_dim == -1 || split_dim >= var->tensor()->dims()) {
1838       return errors::InvalidArgument(
1839           "sharding split_dim ", split_dim, " for variable: ", variable->name(),
1840           " is -1 or large than the number of dimensions ",
1841           var->tensor()->dims());
1842     }
1843   }
1844 
1845   const auto& topology = tpu_metadata.topology;
1846   int num_cores_per_replica = tpu_metadata.num_cores_per_replica;
1847   xla::Array4D<int> mapping(topology.mesh_shape(0), topology.mesh_shape(1),
1848                             topology.mesh_shape(2), topology.mesh_shape(3), -1);
1849   int pos = 0;
1850   // The topology should only have one task.
1851   for (int device = 0; device < topology.num_tpu_devices_per_task(); device++) {
1852     int32_t x = topology.device_coordinates(pos++);
1853     int32_t y = topology.device_coordinates(pos++);
1854     int32_t z = topology.device_coordinates(pos++);
1855     int32_t core = topology.device_coordinates(pos++);
1856     mapping(x, y, z, core) = device;
1857   }
1858 
1859   const string cname = ctx->resource_manager()->default_container();
1860   std::vector<Node*> per_core_vars;
1861   std::vector<string> tpu_devices;
1862   for (int i = 0; i < num_cores_per_replica; i++) {
1863     int offset = i * 4;
1864     int device_index = mapping(tpu_metadata.device_assignment[offset],
1865                                tpu_metadata.device_assignment[offset + 1],
1866                                tpu_metadata.device_assignment[offset + 2],
1867                                tpu_metadata.device_assignment[offset + 3]);
1868 
1869     NodeDef ndef;
1870     uint64 fp = Fingerprint64(
1871         strings::StrCat(handle.container(), handle.name(), "_", device_index));
1872     ndef.set_name(strings::StrCat(handle.name(), fp));
1873     ndef.set_op(kVarHandleOp);
1874     string tpu_device = strings::StrCat(kTPUDeviceNamePrefix, device_index);
1875     ndef.set_device(tpu_device);
1876     tpu_devices.push_back(tpu_device);
1877 
1878     // Replace each _Arg node of type DT_RESOURCE that goes into a TPU node
1879     // by a VarHandleOp on TPU with shared_name "v_tpu_x" where "v" is the
1880     // shared_name of the variable on CPU and "x" is the rewritten device
1881     // ordinal.
1882     const string sname = strings::StrCat(handle.name(), "_tpu_", device_index);
1883     AddNodeAttr("shared_name", sname, &ndef);
1884     AddNodeAttr("container", cname, &ndef);
1885     AddNodeAttr("dtype", var->tensor()->dtype(), &ndef);
1886 
1887     TensorShapeProto proto;
1888     var->tensor()->shape().AsProto(&proto);
1889 
1890     if (is_var_sharded) {
1891       int dim_size = proto.dim(split_dim).size();
1892       if (dim_size % split_size != 0) {
1893         return errors::InvalidArgument("dimension size ", dim_size,
1894                                        " cannot be divisible by split size ",
1895                                        split_size);
1896       }
1897       proto.mutable_dim(split_dim)->set_size(dim_size / split_size);
1898     }
1899     AddNodeAttr("shape", proto, &ndef);
1900 
1901     TF_ASSIGN_OR_RETURN(Node * new_node, graph->AddNode(ndef));
1902     per_core_vars.push_back(new_node);
1903   }
1904 
1905   // Insert TPUPartitionedInput op.
1906   NodeDefBuilder builder(absl::StrCat(handle.name(), "/tpu_partitioned_input"),
1907                          "TPUPartitionedInput");
1908   builder.Attr("N", num_cores_per_replica);
1909   builder.Attr("T", DT_RESOURCE);
1910   builder.Attr("partition_dim", split_dim);
1911   builder.Attr("_XlaSharding", xla_sharding.SerializeAsString());
1912   std::vector<NodeDefBuilder::NodeOut> inputs;
1913   inputs.reserve(num_cores_per_replica);
1914   for (int i = 0; i < num_cores_per_replica; i++) {
1915     inputs.push_back({per_core_vars[i]->name(), 0, DT_RESOURCE});
1916   }
1917   builder.Input(inputs);
1918   NodeDef node_def;
1919   TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
1920   TF_ASSIGN_OR_RETURN(Node * tpu_partitioned_input_node,
1921                       graph->AddNode(node_def));
1922 
1923   for (int i = 0; i < num_cores_per_replica; i++) {
1924     graph->AddEdge(per_core_vars[i], 0, tpu_partitioned_input_node, i);
1925   }
1926 
1927   // Insert TPUReplicatedInput op.
1928   NodeDefBuilder replicated_builder(
1929       absl::StrCat(handle.name(), "/tpu_replicated_input"),
1930       "TPUReplicatedInput");
1931   replicated_builder.Attr("N", 1);
1932   replicated_builder.Attr("T", DT_RESOURCE);
1933   replicated_builder.Attr("is_mirrored_variable", true);
1934   std::vector<NodeDefBuilder::NodeOut> replicated_inputs;
1935   replicated_inputs.push_back(
1936       {tpu_partitioned_input_node->name(), 0, DT_RESOURCE});
1937   replicated_builder.Input(replicated_inputs);
1938   NodeDef replicated_node_def;
1939   TF_RETURN_IF_ERROR(replicated_builder.Finalize(&replicated_node_def));
1940   Status replicated_s;
1941   Node* tpu_replicated_input_node =
1942       graph->AddNode(replicated_node_def, &replicated_s);
1943   if (!replicated_s.ok()) {
1944     return replicated_s;
1945   }
1946   graph->AddEdge(tpu_partitioned_input_node, 0, tpu_replicated_input_node, 0);
1947 
1948   // Connect the TPUReplicatedInput node to the previous output nodes of the
1949   // variable, and remove the variable node.
1950   std::vector<Node*> dst_nodes;
1951   std::vector<int> src_indices;
1952   std::vector<int> dst_indices;
1953   for (const Edge* edge : variable->out_edges()) {
1954     dst_nodes.push_back(edge->dst());
1955     src_indices.push_back(edge->src_output());
1956     dst_indices.push_back(edge->dst_input());
1957   }
1958   for (int i = 0; i < dst_nodes.size(); i++) {
1959     graph->AddEdge(tpu_replicated_input_node, src_indices[i], dst_nodes[i],
1960                    dst_indices[i]);
1961   }
1962 
1963   graph->RemoveNode(variable);
1964 
1965   std::vector<NodeDef> ndefs;
1966   Status status;
1967   for (int i = 0; i < num_cores_per_replica; i++) {
1968     Device* d;
1969     TF_RETURN_IF_ERROR(
1970         library_runtime_->device_mgr()->LookupDevice(tpu_devices[i], &d));
1971     string sname;
1972     const NodeDef& ndef = per_core_vars[i]->def();
1973     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "shared_name", &sname));
1974     ndefs.push_back(ndef);
1975     Var* tpu_var;
1976     status = d->resource_manager()->Lookup(cname, sname, &tpu_var);
1977   }
1978 
1979   if (!status.ok()) {
1980     TF_RETURN_IF_ERROR(
1981         InitializeShardedVarOnTPU(ctx, var, ndefs, split_dim, tpu_devices));
1982     if (VLOG_IS_ON(4)) {
1983       for (int i = 0; i < num_cores_per_replica; i++) {
1984         string sname;
1985         TF_RETURN_IF_ERROR(GetNodeAttr(ndefs[i], "shared_name", &sname));
1986         LOG(INFO) << "Initialized sharded variable on TPU: " << sname
1987                   << " device: " << tpu_devices[i];
1988       }
1989     }
1990   }
1991 
1992   return OkStatus();
1993 }
1994 
InferShapesWithResourceVar(Graph * graph,OpKernelContext * ctx,std::map<int,InferredShape> & arg_shapes,GraphShapeInfo * tpu_inferred_info)1995 Status TPUPartitionedCallOp::InferShapesWithResourceVar(
1996     Graph* graph, OpKernelContext* ctx,
1997     std::map<int, InferredShape>& arg_shapes,
1998     GraphShapeInfo* tpu_inferred_info) {
1999   auto shape_inference_graph_interim =
2000       absl::make_unique<Graph>(graph->flib_def());
2001   CopyGraph(*graph, shape_inference_graph_interim.get());
2002 
2003   for (Node* node : shape_inference_graph_interim->nodes()) {
2004     if (node->type_string() != "_Arg" ||
2005         node->attrs().Find("T")->type() != DT_RESOURCE)
2006       continue;
2007 
2008     std::vector<std::function<void()>> to_remove;
2009 
2010     for (const Edge* out_edge : node->out_edges()) {
2011       Node* read_node = out_edge->dst();
2012       if (read_node->type_string() != "ReadVariableOp") continue;
2013 
2014       for (const Edge* variable_edge : read_node->out_edges()) {
2015         // We are delaying these modifications as we cannot do in-place
2016         // modification of EdgeSets.
2017         to_remove.push_back(
2018             [variable_edge, graph = shape_inference_graph_interim.get(), node] {
2019               Node* dst = variable_edge->dst();
2020               graph->RemoveEdge(variable_edge);
2021               graph->AddEdge(node, variable_edge->src_output(), dst,
2022                              variable_edge->dst_input());
2023             });
2024       }
2025       to_remove.push_back(
2026           [graph = shape_inference_graph_interim.get(), out_edge, read_node] {
2027             graph->RemoveEdge(out_edge);
2028             graph->RemoveNode(read_node);
2029           });
2030     }
2031 
2032     for (auto& func : to_remove) {
2033       func();
2034     }
2035 
2036     int resource_arg_index = node->attrs().Find("index")->i();
2037 
2038     // Get resource variable tensor
2039     core::RefCountPtr<Var> variable;
2040     const ResourceHandle& handle = HandleFromInput(ctx, resource_arg_index);
2041     TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &variable));
2042 
2043     const Tensor* variable_tensor = variable->tensor();
2044     std::vector<int> variable_tensor_vec;
2045 
2046     variable_tensor_vec.reserve(variable_tensor->dims());
2047     for (int d = 0; d < variable_tensor->dims(); ++d) {
2048       variable_tensor_vec.push_back(variable_tensor->dim_size(d));
2049     }
2050 
2051     PartialTensorShape partial_tensor_shape;
2052     auto partial_shape = PartialTensorShape::MakePartialShape(
2053         variable_tensor_vec.data(), variable_tensor_vec.size(),
2054         &partial_tensor_shape);
2055     InferredShape inferred_shape = {partial_tensor_shape};
2056     arg_shapes.emplace(resource_arg_index, inferred_shape);
2057   }
2058 
2059   TF_RETURN_IF_ERROR(tensorflow::InferShapes(
2060       shape_inference_graph_interim.get(), arg_shapes,
2061       &shape_inference_graph_interim->flib_def(), tpu_inferred_info));
2062   return OkStatus();
2063 }
2064 
ShardInputsWithXlaSharding(Graph * graph,const std::string & cluster_name,int num_cores_per_replica,OpKernelContext * ctx)2065 Status TPUPartitionedCallOp::ShardInputsWithXlaSharding(
2066     Graph* graph, const std::string& cluster_name, int num_cores_per_replica,
2067     OpKernelContext* ctx) {
2068   for (Node* replicated_input_node : graph->nodes()) {
2069     if (replicated_input_node->type_string() != "TPUReplicatedInput") continue;
2070 
2071     Node* arg_node;
2072     auto input_node_status = replicated_input_node->input_node(0, &arg_node);
2073     if (!input_node_status.ok()) {
2074       VLOG(2) << "Skip because cannot retrieve input node 0 of "
2075               << replicated_input_node->name() << " because "
2076               << input_node_status.ToString();
2077       continue;
2078     }
2079 
2080     // Check if this TPUReplicatedInput can qualify because it has _Arg
2081     // as input and doesn't have XlaSharding already as an output, then
2082     // try to shard inputs automatically.
2083     //
2084     // In short, we want to see the following graph:
2085     //    _Arg -> TPUReplicatedInput -> (not XlaSharding op)
2086     // and transform it to:
2087     //    _Arg -> TPUReplicatedInput -> XlaSharding -> (not XlaSharding op)
2088     if (arg_node->IsArg() &&
2089         replicated_input_node->out_nodes().begin()->type_string() !=
2090             "XlaSharding") {
2091       int arg_id;
2092       if (!absl::SimpleAtoi(absl::StripPrefix(arg_node->name(), "arg_"),
2093                             &arg_id)) {
2094         VLOG(3) << "Skip auto-sharding because we are unable to extract "
2095                    "argument number from "
2096                 << arg_node->name();
2097         continue;
2098       }
2099 
2100       auto shape = ctx->input(arg_id).shape();
2101 
2102       VLOG(3) << "Identified arg node " << arg_node->DebugString()
2103               << " for TPUReplicatedInput "
2104               << replicated_input_node->DebugString();
2105       VLOG(3) << "Shape within TPUReplicatedInput is: " << shape.DebugString();
2106 
2107       int rank = shape.dims();
2108       int shard_dim =
2109           (runtime_params_.auto_xla_input_sharding_dim + rank) % rank;
2110 
2111       if (shape.dim_size(shard_dim) % num_cores_per_replica != 0) {
2112         VLOG(3) << "Skip auto-sharding " << replicated_input_node->name()
2113                 << " because the specified sharding dimension " << shard_dim
2114                 << " cannot be evenly split by " << num_cores_per_replica;
2115         continue;
2116       }
2117 
2118       auto sharding = absl::make_optional<xla::OpSharding>();
2119       sharding->set_type(xla::OpSharding::OTHER);
2120 
2121       // Sets up tile_assignment_dimensions.
2122       std::vector<int64_t> dims(rank, 1LL);
2123       dims[shard_dim] = num_cores_per_replica;
2124       for (auto dim : dims) {
2125         sharding->add_tile_assignment_dimensions(dim);
2126       }
2127 
2128       // Sets up tile_assignment_devices.
2129       for (int d = 0; d < num_cores_per_replica; ++d) {
2130         sharding->add_tile_assignment_devices(d);
2131       }
2132 
2133       std::vector<const Edge*> edges_to_remove;
2134       for (const Edge* edge : replicated_input_node->out_edges()) {
2135         if (edge->IsControlEdge()) continue;
2136         edges_to_remove.push_back(edge);
2137       }
2138 
2139       // Create XlaSharding Op.
2140       Node* sharding_op = nullptr;
2141       TF_RETURN_IF_ERROR(
2142           NodeBuilder(absl::StrCat(replicated_input_node->name(), "/sharding"),
2143                       "XlaSharding")
2144               .Input(replicated_input_node, 0)
2145               .Attr("T", replicated_input_node->output_type(0))
2146               .Attr(kXLAShardingAttrName, sharding->SerializeAsString())
2147               .Attr(kXLAShardingAttrAltName, sharding->SerializeAsString())
2148               .Attr("_tpu_replicate", cluster_name)
2149               .Finalize(graph, &sharding_op));
2150       for (const Edge* edge : edges_to_remove) {
2151         VLOG(3) << "XlaSharding op creation output edge "
2152                 << edge->DebugString();
2153         graph->RemoveEdge(edge);
2154         graph->AddEdge(sharding_op, 0, edge->dst(), edge->dst_input());
2155       }
2156 
2157       VLOG(3) << "Auto shard " << replicated_input_node->name() << " by dim "
2158               << shard_dim << " into " << num_cores_per_replica << " slices";
2159 
2160       VLOG(3) << "Created XlaSharding Op " << sharding_op->DebugString();
2161     }
2162   }
2163 
2164   return OkStatus();
2165 }
2166 
2167 // OptimizeTpuInputOutputTensors does the following things;
2168 //  (1) Detect input arguments, and add XlaSharding op to the arguments if the
2169 //  enable_auto_xla_input_sharding is turned on
2170 //  (2) Pack multiple input tensors into one tensor by a concat to avoid PCIe
2171 //  transfer overheads for small tensors.
2172 //  (3) Reshape input tensors to R1 to leverage the fast path in TPU input
2173 //  preparation done by runtime.
2174 //  (4) Pack multiple output tensors into one tensor by a concat.
2175 //
2176 // (1) is controlled by --enable_auto_xla_input_sharding and
2177 // --auto_xla_input_sharding_dim
2178 // (2) and (3) are controlled by flags --minimum_input_tensors_packing
2179 // and --input_shape_opt, respectively, while (4) is controlled by
2180 // --minimum_output_tensors_packing.
OptimizeTpuInputOutputTensors(Graph * graph,bool enable_spmd_xla_partitioning,int num_cores_per_replica,std::map<std::string,std::vector<int>> & named_input_shapes,OpKernelContext * ctx)2181 Status TPUPartitionedCallOp::OptimizeTpuInputOutputTensors(
2182     Graph* graph, bool enable_spmd_xla_partitioning, int num_cores_per_replica,
2183     std::map<std::string, std::vector<int>>& named_input_shapes,
2184     OpKernelContext* ctx) {
2185   std::string cluster_name;
2186   TF_RETURN_IF_ERROR(GetClusterName(graph, &cluster_name));
2187 
2188   if (runtime_params_.enable_auto_xla_input_sharding) {
2189     VLOG(2) << DumpGraphToFile("before_enable_auto_xla_input_sharding", *graph,
2190                                flib_def_.get());
2191 
2192     TF_RETURN_IF_ERROR(ShardInputsWithXlaSharding(graph, cluster_name,
2193                                                   num_cores_per_replica, ctx));
2194   }
2195 
2196   GraphShapeInfo tpu_inferred_info;
2197   std::map<int, InferredShape> arg_shapes;
2198   EdgeShapes tpu_input_shapes;
2199   absl::flat_hash_map<const Edge*, DataType> tpu_input_dtypes;
2200 
2201   // Contains attrs "T", "sharding", "_tpu_replicate" for each XlaSharding op.
2202   XlaShardingInfoMap xla_sharding_ops;
2203 
2204   // Contains attrs "T", and a pointer to tpu_replicated_metadata for ctrl dep
2205   TpuReplicatedInputInfoMap tpu_replicated_input_ops;
2206 
2207   bool xla_spmd_input_sharded = false;
2208 
2209   if (enable_spmd_xla_partitioning) {
2210     xla_spmd_input_sharded = FindTpuReplicatedInputAndXlaSharding(
2211         graph, xla_sharding_ops, tpu_replicated_input_ops);
2212   }
2213 
2214   VLOG(1) << "xla_spmd_input_sharded: " << xla_spmd_input_sharded;
2215   VLOG(2) << DumpGraphToFile("before_remove_descendant_nodes", *graph,
2216                              flib_def_.get());
2217 
2218   if (!xla_spmd_input_sharded ||
2219       runtime_params_.minimum_input_tensors_packing > 1 ||
2220       runtime_params_.enable_auto_xla_input_sharding) {
2221     // Currently we remove `TPUReplicatedInput` nodes when the input tensors are
2222     // not sharded, input tensors packing optimization is enabled or when
2223     // auto xla input sharding is there, or else downstream rewrites will be
2224     // confused.
2225     RemoveDescendantNodeOfArg(graph, "TPUReplicatedInput",
2226                               /*must_be_child_of=*/{});
2227   }
2228 
2229   if (xla_spmd_input_sharded) {
2230     // We are setting must_be_child_of to {"Arg"} because we do not want
2231     // to remove other XlaSharding ops that might be in the graph. We only
2232     // want the XlaSharding ops that are directly attached to the input
2233     // arguments to be removed.
2234     RemoveDescendantNodeOfArg(graph, "XlaSharding",
2235                               /*must_be_child_of=*/{"_Arg"});
2236   }
2237 
2238   VLOG(2) << DumpGraphToFile("before_get_input_output_info", *graph,
2239                              flib_def_.get());
2240 
2241   TF_RETURN_IF_ERROR(GetInputOutputInfo(graph, tpu_inferred_info, arg_shapes,
2242                                         tpu_input_shapes, tpu_input_dtypes,
2243                                         ctx));
2244 
2245   VLOG(2) << DumpGraphToFile("before_optimize_tpu_input_output_tensors", *graph,
2246                              flib_def_.get());
2247 
2248   if (runtime_params_.minimum_output_tensors_packing > 1) {
2249     // Copy graph to shape_inference_graph
2250     EdgeShapes tpu_output_shapes;
2251     TF_RETURN_IF_ERROR(
2252         InferShapesWithResourceVar(graph, ctx, arg_shapes, &tpu_inferred_info));
2253 
2254     // Find TPU -> CPU output edges.
2255     GroupedEdges shape_to_output =
2256         tpu_functional_internal::GroupTensorsForOutputPacking(
2257             graph, tpu_output_shapes, &tpu_inferred_info);
2258 
2259     TF_RETURN_IF_ERROR(
2260         tpu_functional_internal::CreateConcatAndSplitNodesForOutputTensor(
2261             graph, cluster_name, &tpu_output_shapes, &tpu_inferred_info,
2262             shape_to_output, runtime_params_.minimum_output_tensors_packing));
2263   }
2264 
2265   if (runtime_params_.minimum_input_tensors_packing > 1) {
2266     GroupedEdges grouped_input_edges =
2267         tpu_functional_internal::GroupTensorsForInputPacking(
2268             tpu_input_shapes, tpu_input_dtypes, runtime_params_.input_shape_opt,
2269             runtime_params_.group_tensors_for_packing);
2270     TF_RETURN_IF_ERROR(
2271         tpu_functional_internal::CreateConcatAndSplitNodesForInputTensor(
2272             graph, cluster_name, &tpu_input_shapes, grouped_input_edges,
2273             runtime_params_.minimum_input_tensors_packing,
2274             xla_spmd_input_sharded, xla_sharding_ops,
2275             tpu_replicated_input_ops));
2276   }
2277   if (runtime_params_.input_shape_opt) {
2278     TF_RETURN_IF_ERROR(tpu_functional_internal::InsertReshapeNodePairs(
2279         graph, cluster_name, &tpu_input_shapes, num_cores_per_replica));
2280   }
2281   VLOG(1) << DumpGraphToFile("optim_result", *graph);
2282 
2283   // With or without optimizations, collect the input names and shapes.
2284   for (const auto& iter : tpu_input_shapes) {
2285     std::string name = iter.first->src()->name();
2286     named_input_shapes[name] = iter.second;
2287   }
2288   return OkStatus();
2289 }
2290 
GetGraphFromFunction(Graph * graph,int device_ordinal,bool * use_spmd_for_xla_partitioning,TPUMetadata * tpu_metadata)2291 Status TPUPartitionedCallOp::GetGraphFromFunction(
2292     Graph* graph, int device_ordinal, bool* use_spmd_for_xla_partitioning,
2293     TPUMetadata* tpu_metadata) {
2294   FunctionLibraryRuntime::InstantiateOptions opts;
2295   FHandle handle;
2296   TF_RETURN_IF_ERROR(library_runtime_->Instantiate(
2297       func_.name(), AttrSlice(&func_.attr()), opts, &handle));
2298   const FunctionBody* fbody = library_runtime_->GetFunctionBody(handle);
2299   if (fbody == nullptr) {
2300     return errors::Internal("Could not find handle ", handle);
2301   }
2302   CopyGraph(*fbody->graph, graph);
2303 
2304   // Pin the inputs and outputs to the local device to simplify the
2305   // function-dispatching logic.
2306   local_device_name_ = library_runtime_->device()->name();
2307   replaced_input_indices_.resize(fbody->arg_nodes.size(), false);
2308   for (Node* node : graph->op_nodes()) {
2309     if (node->IsArg() || node->IsRetval()) {
2310       node->set_assigned_device_name(local_device_name_);
2311     } else if (node->type_string() == "TPUReplicateMetadata") {
2312       // Record the producer name so it can be accessed later during metric
2313       // collection.
2314       string producer_name = GetProducerName(func_.name());
2315       node->AddAttr("_producer_name", producer_name);
2316 
2317       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "num_cores_per_replica",
2318                                      &tpu_metadata->num_cores_per_replica));
2319       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(),
2320                                      "use_spmd_for_xla_partitioning",
2321                                      use_spmd_for_xla_partitioning));
2322       VLOG(1) << "num_core_per_replica = "
2323               << tpu_metadata->num_cores_per_replica
2324               << ", use_spmd_for_xla_partitioning = "
2325               << *use_spmd_for_xla_partitioning;
2326 
2327       if (tpu_metadata->num_cores_per_replica > 1) {
2328         int num_replicas;
2329         TF_RETURN_IF_ERROR(
2330             GetNodeAttr(node->attrs(), "num_replicas", &num_replicas));
2331         if (num_replicas > 1) {
2332           return errors::InvalidArgument(
2333               "num_replicas shouldn't be large than 1, however it is: ",
2334               num_replicas);
2335         }
2336 
2337         TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "device_assignment",
2338                                        &tpu_metadata->device_assignment));
2339 
2340         if (!tpu_metadata->device_assignment.empty() && device_ordinal > 0) {
2341           return errors::InvalidArgument(
2342               "`device_assignment` shouldn't be set manually in the graph when "
2343               "round-robin core selection is enabled.");
2344         }
2345 
2346         tpu_metadata->topology = GetTPUTopology();
2347         VLOG(1) << "TPU topology: " << tpu_metadata->topology.DebugString();
2348         std::string topology_str;
2349         TF_RETURN_IF_ERROR(
2350             GetNodeAttr(node->attrs(), "topology", &topology_str));
2351         if (!topology_str.empty()) {
2352           LOG(WARNING)
2353               << "Ignore the `topology` value set in TPUReplicateMetadata "
2354                  "node, the TPU topology is queried in the runtime.";
2355         }
2356         node->ClearAttr("topology");
2357         node->AddAttr("topology", tpu_metadata->topology.SerializeAsString());
2358 
2359         if (tpu_metadata->topology.num_tasks() > 1) {
2360           return errors::InvalidArgument(
2361               "TPUPartitionedCallOp is only supported in single-host setup, "
2362               "however num_task is: ",
2363               tpu_metadata->topology.num_tasks());
2364         }
2365 
2366         if (tpu_metadata->device_assignment.empty()) {
2367           VLOG(1) << "Auto assigning device assignment";
2368 
2369           // The auto generated device assignment should be the same as or a
2370           // slice of TPU topology device_coordinates. This guarantees the
2371           // logical device IDs order the same as the physical device IDs order.
2372           // It is important for round-robin core selection, as we assume
2373           // the TPU device group for one inference request is
2374           // [TPU:device_ordinal, TPU:device_ordinal + num_cores_per_replica].
2375 
2376           auto coordinates_start =
2377               tpu_metadata->topology.device_coordinates().begin() +
2378               device_ordinal * 4;
2379           auto coordinates_end =
2380               tpu_metadata->topology.device_coordinates().begin() +
2381               (device_ordinal + tpu_metadata->num_cores_per_replica) * 4;
2382 
2383           node->ClearAttr("device_assignment");
2384           tpu_metadata->device_assignment.insert(
2385               tpu_metadata->device_assignment.begin(), coordinates_start,
2386               coordinates_end);
2387           node->AddAttr("device_assignment", tpu_metadata->device_assignment);
2388         }
2389       }
2390     }
2391   }
2392   return OkStatus();
2393 }
2394 
PlacementHelper(const DeviceSet & device_set,const GraphOptimizationPassOptions & optimization_options,const string & function_name)2395 Status TPUPartitionedCallOp::PlacementHelper(
2396     const DeviceSet& device_set,
2397     const GraphOptimizationPassOptions& optimization_options,
2398     const string& function_name) {
2399   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2400       OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
2401   Placer placer(optimization_options.graph->get(), function_name,
2402                 optimization_options.flib_def, &device_set);
2403   TF_RETURN_IF_ERROR(placer.Run());
2404   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2405       OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
2406   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2407       OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
2408   return OkStatus();
2409 }
2410 
PartitionHelper(const DeviceSet & device_set,const GraphOptimizationPassOptions & optimization_options,Graph * graph,std::unordered_map<std::string,std::unique_ptr<Graph>> * subgraphs)2411 Status TPUPartitionedCallOp::PartitionHelper(
2412     const DeviceSet& device_set,
2413     const GraphOptimizationPassOptions& optimization_options, Graph* graph,
2414     std::unordered_map<std::string, std::unique_ptr<Graph>>* subgraphs) {
2415   PartitionOptions partition_options;
2416   partition_options.node_to_loc = [](const Node* node) {
2417     // TODO(akshayka): To better support the distributed case, first split
2418     // the graph by worker (e.g,. using the master session's
2419     // `SplitByWorker` policy), and then recursively partition the
2420     // per-worker shards at the remote worker(s).
2421     return node->assigned_device_name();
2422   };
2423   int64_t edge_name_counter = 0;
2424   partition_options.new_name = [&edge_name_counter](const string& prefix) {
2425     return strings::StrCat(prefix, "/_", ++edge_name_counter);
2426   };
2427   partition_options.get_incarnation = [&device_set](const string& name) {
2428     const Device* d = device_set.FindDeviceByName(name);
2429     if (d == nullptr) {
2430       return PartitionOptions::kIllegalIncarnation;
2431     } else {
2432       return d->attributes().incarnation();
2433     }
2434   };
2435   partition_options.control_flow_added = false;
2436   std::unordered_map<std::string, GraphDef> partitions;
2437   TF_RETURN_IF_ERROR(Partition(partition_options, graph, &partitions));
2438 
2439   VLOG(3) << "Partitioned function '" << func_.name() << "', yielding "
2440           << partitions.size() << " shards.";
2441 
2442   const FunctionLibraryDefinition* flib_def = &graph->flib_def();
2443   for (auto& partition : partitions) {
2444     std::unique_ptr<Graph> subgraph(new Graph(flib_def));
2445     GraphConstructorOptions opts;
2446     opts.allow_internal_ops = true;
2447     opts.expect_device_spec = true;
2448     const string& device = partition.first;
2449     GraphDef& graph_def = partition.second;
2450     TF_RETURN_IF_ERROR(
2451         ConvertGraphDefToGraph(opts, std::move(graph_def), subgraph.get()));
2452     subgraphs->emplace(device, std::move(subgraph));
2453   }
2454 
2455   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2456       OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
2457 
2458   return OkStatus();
2459 }
2460 
InstantiatePartition(const Graph & graph,const string & function_name,const string & target_device,FHandle * handle,std::unique_ptr<FunctionLibraryDefinition> * out_flib_def)2461 Status TPUPartitionedCallOp::InstantiatePartition(
2462     const Graph& graph, const string& function_name,
2463     const string& target_device, FHandle* handle,
2464     std::unique_ptr<FunctionLibraryDefinition>* out_flib_def) {
2465   FunctionDef shard;
2466   TF_RETURN_IF_ERROR(GraphToFunctionDef(graph, function_name, &shard));
2467   TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(shard));
2468   FunctionLibraryRuntime::InstantiateOptions opts;
2469   opts.target = target_device;
2470   if (out_flib_def) {
2471     *out_flib_def = std::make_unique<FunctionLibraryDefinition>(*flib_def_);
2472     opts.lib_def = out_flib_def->get();
2473   } else {
2474     opts.lib_def = flib_def_.get();
2475   }
2476   return library_runtime_->Instantiate(function_name, AttrSlice(&shard.attr()),
2477                                        opts, handle);
2478 }
2479 
SetDeviceOrdinal(const DeviceSet & device_set,int device_ordinal,Graph * graph,bool * modified)2480 Status TPUPartitionedCallOp::SetDeviceOrdinal(const DeviceSet& device_set,
2481                                               int device_ordinal, Graph* graph,
2482                                               bool* modified) {
2483   int ordinal = -1;
2484   for (Node* node : graph->op_nodes()) {
2485     if (node->type_string() == kVarHandleOp) {
2486       if (IsInputToTPUReplicate(node)) {
2487         // If this VarHandleOp is going to a TPU computation,
2488         // it refers to the TPU variable that we created when replacing the
2489         // resource arguments with VarHandleOps.
2490         node->set_assigned_device_name(
2491             strings::StrCat(kTPUDeviceNamePrefix, device_ordinal));
2492       }
2493       continue;
2494     }
2495     if (HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
2496       // Outside compilation related node.
2497       TF_RETURN_IF_ERROR(
2498           SetDeviceOrdinalAttributeForNode(node, device_ordinal));
2499       *modified = true;
2500       continue;
2501     }
2502     const AttrValue* attr = node->attrs().Find(kDeviceOrdinalAttr);
2503     if (attr != nullptr) {
2504       if (!IsSupportedTPUOp(node->type_string())) {
2505         return errors::InvalidArgument("Node ", node->type_string(),
2506                                        " is not yet supported.");
2507       }
2508       if (ordinal == -1) {
2509         ordinal = attr->i();
2510       } else {
2511         if (ordinal != attr->i()) {
2512           return errors::InvalidArgument(
2513               "Can only partition graphs that use a single device ordinal.");
2514         }
2515       }
2516       node->ClearAttr(kDeviceOrdinalAttr);
2517       node->AddAttr(kDeviceOrdinalAttr, device_ordinal);
2518       VLOG(3) << "Set device ordinal of " << node->type_string() << " to "
2519               << device_ordinal;
2520       *modified = true;
2521     }
2522     if (node->IsSend() || node->IsRecv()) {
2523       static const char* kSendDevice = "send_device";
2524       static const char* kSendDeviceIncarnation = "send_device_incarnation";
2525       static const char* kRecvDevice = "recv_device";
2526       const AttrValue* attr = node->attrs().Find(kSendDevice);
2527       if (attr != nullptr) {
2528         string device = attr->s();
2529         TF_RETURN_IF_ERROR(
2530             UpdateTPUDeviceOrdinal(device_ordinal, &device, modified));
2531         node->ClearAttr(kSendDevice);
2532         node->AddAttr(kSendDevice, device);
2533         node->ClearAttr(kSendDeviceIncarnation);
2534         const Device* d = device_set.FindDeviceByName(device);
2535         int64_t send_incarnation = (d == nullptr)
2536                                        ? PartitionOptions::kIllegalIncarnation
2537                                        : d->attributes().incarnation();
2538         node->AddAttr(kSendDeviceIncarnation, send_incarnation);
2539       }
2540       attr = node->attrs().Find(kRecvDevice);
2541       if (attr != nullptr) {
2542         string device = attr->s();
2543         TF_RETURN_IF_ERROR(
2544             UpdateTPUDeviceOrdinal(device_ordinal, &device, modified));
2545         node->ClearAttr(kRecvDevice);
2546         node->AddAttr(kRecvDevice, device);
2547       }
2548     }
2549   }
2550   return OkStatus();
2551 }
2552 
InstantiateFunctionsFromSubgraphs(const DeviceSet & device_set,int replica_id,uint64 cache_hash,int num_cores_per_replica,std::unordered_map<std::string,std::unique_ptr<Graph>> subgraphs)2553 Status TPUPartitionedCallOp::InstantiateFunctionsFromSubgraphs(
2554     const DeviceSet& device_set, int replica_id, uint64 cache_hash,
2555     int num_cores_per_replica,
2556     std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs) {
2557   const Device* reference_device = nullptr;
2558   auto entry =
2559       partition_cache_.emplace(cache_hash, std::vector<DeviceAndFHandle>());
2560 
2561   bool rewritten = false;
2562   for (auto& pair : subgraphs) {
2563     string target = pair.first;
2564     int device_ordinal = replica_id;
2565     if (num_cores_per_replica > 1) {
2566       DeviceNameUtils::ParsedName parsed_device;
2567       if (!DeviceNameUtils::ParseFullName(target, &parsed_device)) {
2568         return errors::InvalidArgument("Malformed assigned device '", target,
2569                                        "'");
2570       }
2571       device_ordinal = parsed_device.id;
2572     }
2573     Device* device;
2574     TF_RETURN_IF_ERROR(
2575         library_runtime_->device_mgr()->LookupDevice(target, &device));
2576     if (reference_device == nullptr) {
2577       reference_device = device;
2578     } else {
2579       if (!DeviceNameUtils::IsSameAddressSpace(
2580               device->parsed_name(), reference_device->parsed_name())) {
2581         return errors::InvalidArgument(
2582             "TPUPartitionedCallOp does not yet support inter-process"
2583             "execution.");
2584       }
2585     }
2586     TF_RETURN_IF_ERROR(device->MaybeRewriteGraph(&pair.second));
2587     Graph* subgraph = pair.second.get();
2588     // For model paralleism inference, we only support num_replica == 1, thus
2589     // there is no need to update the device_ordinal anymore.
2590     if (num_cores_per_replica == 1) {
2591       TF_RETURN_IF_ERROR(
2592           SetDeviceOrdinal(device_set, device_ordinal, subgraph, &rewritten));
2593     } else {
2594       VLOG(1) << "Skip SetDeviceOrdinal()";
2595     }
2596     string function_name = flib_def_->UniqueFunctionName(
2597         strings::StrCat(func_.name(), "_hash_", cache_hash));
2598     TF_RETURN_IF_ERROR(
2599         UpdateTPUDeviceOrdinal(device_ordinal, &target, &rewritten));
2600     FHandle handle;
2601     // Use a copy of the current `flib_def_` to instantiate the function to
2602     // avoid races.
2603     std::unique_ptr<FunctionLibraryDefinition> sub_flib_def;
2604     TF_RETURN_IF_ERROR(InstantiatePartition(*subgraph, function_name, target,
2605                                             &handle, &sub_flib_def));
2606     // Add handle to the cache entry.
2607     entry.first->second.push_back(
2608         DeviceAndFHandle{.device = target,
2609                          .handle = handle,
2610                          .flib_def = std::move(sub_flib_def)});
2611   }
2612 
2613   if (!rewritten) {
2614     // For regular use cases, TPUPartitionedCallOp only works when the
2615     // function being called in rewritten for TPU. If we don't see any signs
2616     // of this rewriting, warn the user about it.
2617     // We don't raise an error because we want to support the use case of
2618     // running tpu.initialize_system eagerly. In this case, we can't use
2619     // tpu.rewrite because it will add compilation ops that require TPU
2620     // to be initialized, i.e. there is a chicken and egg problem.
2621     // We run tpu.initialize_system through TPUPartitionedCallOp because it
2622     // invokes graph rewrite passes that are necessary for initialization to
2623     // work.
2624     LOG(INFO) << "Function body was not rewritten for TPU. "
2625               << "This is probably a bug unless you are initializing "
2626               << "TPUs eagerly.";
2627   }
2628   return OkStatus();
2629 }
2630 
ExecuteRemoteFunction(const FunctionLibraryRuntime::Options & opts,FHandle handle,OpKernelContext * ctx,ReffedStatusCallback * done)2631 void TPUPartitionedCallOp::ExecuteRemoteFunction(
2632     const FunctionLibraryRuntime::Options& opts, FHandle handle,
2633     OpKernelContext* ctx, ReffedStatusCallback* done) {
2634   std::vector<Tensor> dummy_args;
2635   std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
2636 
2637   profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteRemote");
2638   library_runtime_->Run(opts, handle, dummy_args, dummy_rets,
2639                         [dummy_rets, done, ctx](const Status& status) {
2640                           if (!status.ok()) {
2641                             done->UpdateStatus(status);
2642                           }
2643                           delete dummy_rets;
2644                           done->Unref();
2645                         });
2646 }
2647 
ExecuteLocalFunction(const FunctionLibraryRuntime::Options & opts,const OpInputList & arguments,FHandle handle,OpKernelContext * ctx,ReffedStatusCallback * done)2648 void TPUPartitionedCallOp::ExecuteLocalFunction(
2649     const FunctionLibraryRuntime::Options& opts, const OpInputList& arguments,
2650     FHandle handle, OpKernelContext* ctx, ReffedStatusCallback* done) {
2651   std::vector<Tensor> args;
2652 
2653   for (int i = 0; i < arguments.size(); ++i) {
2654     if (!replaced_input_indices_[i]) {
2655       // _Arg nodes of type DT_RESOURCE that go into a TPU node have been
2656       // replaced by TPU VarHandleOp nodes. No longer need to pass them as
2657       // inputs.
2658       args.push_back(arguments[i]);
2659     }
2660   }
2661   auto* rets = new std::vector<Tensor>;
2662 
2663   profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteLocal");
2664   library_runtime_->Run(opts, handle, args, rets,
2665                         [rets, done, ctx](const Status& status) {
2666                           if (!status.ok()) {
2667                             done->UpdateStatus(status);
2668                           } else {
2669                             for (int i = 0; i < rets->size(); ++i) {
2670                               ctx->set_output(i, (*rets)[i]);
2671                             }
2672                           }
2673                           delete rets;
2674                           done->Unref();
2675                         });
2676 }
2677 
ExecuteFunctions(const std::vector<DeviceAndFHandle> & functions,OpKernelContext * ctx,int device_ordinal,int64_t ordinal_selector_req_id,DoneCallback done)2678 void TPUPartitionedCallOp::ExecuteFunctions(
2679     const std::vector<DeviceAndFHandle>& functions, OpKernelContext* ctx,
2680     int device_ordinal, int64_t ordinal_selector_req_id, DoneCallback done) {
2681   profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteFunctions");
2682   FunctionLibraryRuntime::Options opts;
2683   opts.step_container = ctx->step_container();
2684   opts.stats_collector = ctx->stats_collector();
2685   // TODO(akshayka): Consider selecting a runner on a per-device basis,
2686   // i.e., using device-specific threadpools when available.
2687   opts.runner = ctx->runner();
2688   opts.source_device = local_device_name_;
2689   opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
2690 
2691   OpInputList arguments;
2692   OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
2693 
2694   auto* local_cm = new CancellationManager(ctx->cancellation_manager());
2695   auto* rendez = new RefCountedIntraProcessRendezvous(device_mgr_);
2696   opts.cancellation_manager = local_cm;
2697   opts.rendezvous = rendez;
2698 
2699   StatusCallback callback(
2700       [rendez = rendez, local_cm, done = std::move(done),
2701        device_ordinal = device_ordinal, req_id = ordinal_selector_req_id, ctx,
2702        ordinal_selector = ordinal_selector_](const Status& status) {
2703         delete local_cm;
2704         rendez->Unref();
2705         if (!status.ok()) {
2706           ctx->SetStatus(status);
2707         }
2708         done();
2709         if (req_id >= 0) {
2710           ordinal_selector->DequeueFromCoreSelector(device_ordinal, req_id);
2711         }
2712       });
2713 
2714   auto* refcounted_done = new ReffedStatusCallback(std::move(callback));
2715   for (int i = 1; i < functions.size(); ++i) {
2716     refcounted_done->Ref();
2717   }
2718   for (const DeviceAndFHandle& entry : functions) {
2719     const string& target_device = entry.device;
2720     FHandle handle = entry.handle;
2721     VLOG(3) << "Running function shard on device " << target_device
2722             << " with local device name " << local_device_name_;
2723     if (target_device == local_device_name_) {
2724       opts.remote_execution = false;
2725       ExecuteLocalFunction(opts, arguments, handle, ctx, refcounted_done);
2726     } else {
2727       opts.remote_execution = true;
2728       ExecuteRemoteFunction(opts, handle, ctx, refcounted_done);
2729     }
2730   }
2731 }
2732 
2733 REGISTER_KERNEL_BUILDER(Name("TPUPartitionedCall").Device(DEVICE_CPU),
2734                         TPUPartitionedCallOp);
2735 
2736 }  // end namespace tensorflow
2737