1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/kernels/tpu_functional_ops.h"
17
18 #include <algorithm>
19 #include <memory>
20
21 #include "absl/strings/match.h"
22 #include "tensorflow/core/framework/cancellation.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
25 #include "tensorflow/stream_executor/tpu/c_api_decl.h"
26 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
27
28 #define EIGEN_USE_THREADS
29
30 #include "absl/base/call_once.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/synchronization/mutex.h"
33 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
34 #include "tensorflow/compiler/tf2xla/sharding_util.h"
35 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/common_runtime/function_body.h"
38 #include "tensorflow/core/common_runtime/graph_constructor.h"
39 #include "tensorflow/core/common_runtime/placer.h"
40 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
41 #include "tensorflow/core/framework/graph_to_functiondef.h"
42 #include "tensorflow/core/framework/metrics.h"
43 #include "tensorflow/core/framework/node_def.pb.h"
44 #include "tensorflow/core/framework/node_def_util.h"
45 #include "tensorflow/core/framework/resource_mgr.h"
46 #include "tensorflow/core/framework/resource_var.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/framework/tensor.pb.h"
49 #include "tensorflow/core/framework/tensor_shape.h"
50 #include "tensorflow/core/graph/graph_partition.h"
51 #include "tensorflow/core/graph/node_builder.h"
52 #include "tensorflow/core/lib/core/errors.h"
53 #include "tensorflow/core/lib/hash/hash.h"
54 #include "tensorflow/core/lib/strings/str_util.h"
55 #include "tensorflow/core/platform/blocking_counter.h"
56 #include "tensorflow/core/platform/errors.h"
57 #include "tensorflow/core/platform/fingerprint.h"
58 #include "tensorflow/core/platform/refcount.h"
59 #include "tensorflow/core/profiler/lib/traceme.h"
60 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
61 #include "tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h"
62 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
63 #include "tensorflow/core/tpu/kernels/tpu_op_util.h"
64 #include "tensorflow/core/tpu/kernels/tpu_util.h"
65 #include "tensorflow/core/tpu/tpu_configuration.h"
66 #include "tensorflow/core/tpu/tpu_defs.h"
67 #include "tensorflow/core/util/dump_graph.h"
68
69 namespace tensorflow {
70 namespace {
71
72 constexpr char kTpuReplicateAttr[] = "_tpu_replicate";
73 constexpr int kLastDimOfTpuInputFastPath = 128;
74 constexpr int kOtherDimOfTpuInputFastPath = 8;
75
76 constexpr char kXLAShardingAttrName[] = "sharding";
77 constexpr char kXLAShardingAttrAltName[] = "_XlaSharding";
78
GetTPUTopology()79 tpu::TopologyProto GetTPUTopology() {
80 const tpu::TpuTopologyExternal& topology =
81 tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
82
83 tpu::TopologyProto topology_proto;
84 topology_proto.set_num_tasks(topology.HostCount());
85 topology_proto.set_num_tpu_devices_per_task(
86 topology.LogicalDevicesPerHost(TpuCoreTypeEnum::kTensorCore));
87
88 // mesh shape.
89 int devices_per_chip =
90 topology.LogicalDevicesPerChip(TpuCoreTypeEnum::kTensorCore);
91 topology_proto.add_mesh_shape(topology.chip_bounds().x);
92 topology_proto.add_mesh_shape(topology.chip_bounds().y);
93 topology_proto.add_mesh_shape(topology.chip_bounds().z);
94 topology_proto.add_mesh_shape(devices_per_chip);
95
96 // device coordinates.
97 for (const tpu::TpuCoreLocationExternal& core :
98 topology.cores(TpuCoreTypeEnum::kTensorCore)) {
99 const tpu::TpuDimensionsExternal coords = core.chip_coordinates();
100 topology_proto.add_device_coordinates(coords.x);
101 topology_proto.add_device_coordinates(coords.y);
102 topology_proto.add_device_coordinates(coords.z);
103 topology_proto.add_device_coordinates(core.index());
104 }
105
106 return topology_proto;
107 }
108
109 struct TPUVariableInfo {
TPUVariableInfotensorflow::__anon1d85d1de0111::TPUVariableInfo110 TPUVariableInfo(int device_ordinal_id, bool use_fast_mem)
111 : device_ordinal(device_ordinal_id), fast_mem(use_fast_mem) {}
112 // The TPU core which the variable will be placed on.
113 int device_ordinal;
114 // If true, try to place the variable on fast memory space if hardware
115 // support.
116 bool fast_mem;
117 };
118
119 // Check the descendants to parse the placement information for the input node.
120 // num_cores_per_replica descriables how many cores the single model uses.
ParseTPUVariableInfor(const Node * node,const int num_cores_per_replica,TPUVariableInfo * var_info)121 Status ParseTPUVariableInfor(const Node* node, const int num_cores_per_replica,
122 TPUVariableInfo* var_info) {
123 int core = 0;
124 bool use_fast_mem = false;
125 VLOG(3) << "Parse tpu variable information for " << node->name();
126 for (const Edge* edge : node->out_edges()) {
127 if (edge->IsControlEdge()) continue;
128 Node* next = edge->dst();
129 VLOG(3) << "Neighbor node " << next->name();
130 // Looking through Enter/Switch/ReadVariableOp nodes.
131 while (next->IsEnter() || next->IsSwitch() ||
132 next->type_string() == "ReadVariableOp") {
133 Node* new_node = nullptr;
134 for (const Edge* e : next->out_edges()) {
135 if (!e->IsControlEdge()) {
136 new_node = e->dst();
137 break;
138 }
139 }
140 if (new_node == nullptr) break;
141 next = new_node;
142 }
143 if (next != edge->dst()) {
144 VLOG(3) << "Looked through Enter/Switch node " << next->DebugString();
145 }
146 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
147 ParseShardingFromDevice(*next, num_cores_per_replica,
148 /*add_metadata=*/false));
149 if (sharding.has_value() && sharding->tile_assignment_devices_size() > 0) {
150 core = sharding->tile_assignment_devices(0);
151 VLOG(3) << next->name() << " is placed on core " << core;
152 }
153 if (next->attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
154 use_fast_mem = true;
155 VLOG(3) << next->name() << " has " << TPU_FAST_MEM_ATTR << " attribute";
156 }
157 }
158 VLOG(1) << "Place " << node->name() << " to core: " << core
159 << " fast_mem: " << use_fast_mem;
160 var_info->device_ordinal = core;
161 var_info->fast_mem = use_fast_mem;
162
163 return OkStatus();
164 }
165
166 // Helper to instantiate function "func" in the library "lib".
Instantiate(FunctionLibraryRuntime * lib,const NameAttrList & func,FunctionLibraryRuntime::Handle * handle)167 Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
168 FunctionLibraryRuntime::Handle* handle) {
169 return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
170 }
171
172 static constexpr const char* const kDeviceOrdinalAttr = "device_ordinal";
173
174 static constexpr const char* const kTPUExecuteOp = "TPUExecute";
175 static constexpr const char* const kInfeedEnqueueOp = "InfeedEnqueue";
176 static constexpr const char* const kInfeedEnqueueTupleOp = "InfeedEnqueueTuple";
177 static constexpr const char* const kOutfeedDequeueOp = "OutfeedDequeue";
178 static constexpr const char* const kOutfeedDequeueTupleOp =
179 "OutfeedDequeueTuple";
180 static constexpr const char* const kOutfeedDequeueV2Op = "OutfeedDequeueV2";
181 static constexpr const char* const kOutfeedDequeueTupleV2Op =
182 "OutfeedDequeueTupleV2";
183 static constexpr const char* const kVarHandleOp = "VarHandleOp";
184
185 static constexpr const char* const kTPUDeviceNamePrefix = "/device:TPU:";
186 static constexpr const int kTPUDefaultDeviceOrdinal = 0;
187
IsSupportedTPUOp(const string & op_name)188 bool IsSupportedTPUOp(const string& op_name) {
189 return op_name == kTPUExecuteOp || op_name == kInfeedEnqueueOp ||
190 op_name == kInfeedEnqueueTupleOp || op_name == kOutfeedDequeueOp ||
191 op_name == kOutfeedDequeueTupleOp || op_name == kOutfeedDequeueV2Op ||
192 op_name == kOutfeedDequeueTupleV2Op;
193 }
194
195 // Sets the sharding attributes for an XlaSharding node.
SetXlaShardingNodeAttr(Node * xla_sharding_node,int num_cores_per_replica,int rank,int shard_dim)196 void SetXlaShardingNodeAttr(Node* xla_sharding_node, int num_cores_per_replica,
197 int rank, int shard_dim) {
198 auto sharding = absl::make_optional<xla::OpSharding>();
199 sharding->set_type(xla::OpSharding::OTHER);
200
201 std::vector<int64_t> dims(rank, 1LL);
202 dims[shard_dim] = num_cores_per_replica;
203 for (auto dim : dims) {
204 sharding->add_tile_assignment_dimensions(dim);
205 }
206
207 // Sets up tile_assignment_devices.
208 for (int d = 0; d < num_cores_per_replica; ++d) {
209 sharding->add_tile_assignment_devices(d);
210 }
211
212 xla_sharding_node->ClearAttr(kXLAShardingAttrName);
213 xla_sharding_node->ClearAttr(kXLAShardingAttrAltName);
214 xla_sharding_node->AddAttr(kXLAShardingAttrName,
215 sharding->SerializeAsString());
216 xla_sharding_node->AddAttr(kXLAShardingAttrAltName,
217 sharding->SerializeAsString());
218 }
219
220 // If 'device_name' is a TPU device, set its device_ordinal to 'device_ordinal'
221 // and set '*rewritten' to true. Otherwise, do nothing.
UpdateTPUDeviceOrdinal(int device_ordinal,string * device_name,bool * rewritten)222 Status UpdateTPUDeviceOrdinal(int device_ordinal, string* device_name,
223 bool* rewritten) {
224 DeviceNameUtils::ParsedName device;
225 if (!DeviceNameUtils::ParseFullName(*device_name, &device)) {
226 return errors::InvalidArgument("Unable to parse device name ",
227 *device_name);
228 }
229 if (device.type == DEVICE_TPU_NODE) {
230 device.id = device_ordinal;
231 *rewritten = true;
232 }
233 *device_name = DeviceNameUtils::ParsedNameToString(device);
234 return OkStatus();
235 }
236
FindHostToDeviceEdge(Node * arg_node)237 const Edge* FindHostToDeviceEdge(Node* arg_node) {
238 const Edge* candidate_edge = nullptr;
239 for (const Edge* edge : arg_node->out_edges())
240 if (!edge->IsControlEdge()) {
241 // Find CPU -> TPU input edge.
242 const Edge* original_edge;
243 while (edge->src()->attrs().Find(kTpuReplicateAttr) != nullptr ||
244 edge->dst()->attrs().Find(kTpuReplicateAttr) == nullptr) {
245 const Node* new_src = edge->dst();
246 original_edge = edge;
247 for (const Edge* new_edge : new_src->out_edges())
248 if (!new_edge->IsControlEdge()) {
249 original_edge = edge;
250 edge = new_edge;
251 break;
252 }
253 if (original_edge == edge) break;
254 }
255 // TPU input edge: src is on CPU and dest is on TPU.
256 if (edge->src()->attrs().Find(kTpuReplicateAttr) != nullptr ||
257 edge->dst()->attrs().Find(kTpuReplicateAttr) == nullptr)
258 continue;
259 // Won't work with GuaranteeConst.
260 if (edge->src()->type_string() == "GuaranteeConst") break;
261 candidate_edge = edge;
262 }
263 return candidate_edge;
264 }
265
CreateInputProxy(Graph * graph,const Edge * candidate_edge,const Edge ** tpu_input_edge)266 Status CreateInputProxy(Graph* graph, const Edge* candidate_edge,
267 const Edge** tpu_input_edge) {
268 std::vector<const Edge*> edges_to_replace;
269 for (const Edge* input_edge : candidate_edge->src()->out_edges()) {
270 if (!input_edge->IsControlEdge() &&
271 input_edge->dst()->attrs().Find(kTpuReplicateAttr) != nullptr)
272 edges_to_replace.push_back(input_edge);
273 }
274 // Build an Identity node as the proxy of the original edge source.
275 Node* input_identity_node = nullptr;
276 TF_RETURN_IF_ERROR(
277 NodeBuilder(strings::StrCat(candidate_edge->src()->name(), "/proxy"),
278 "Identity")
279 .Input(candidate_edge->src())
280 .Attr("T", candidate_edge->src()->output_type(0))
281 .Attr(kTpuReplicateAttr,
282 candidate_edge->dst()->attrs().Find(kTpuReplicateAttr)->s())
283 .Finalize(graph, &input_identity_node));
284 // Find the tpu input edge from original source to proxy identity.
285 for (const Edge* input_edge : input_identity_node->in_edges())
286 if (input_edge->src() == candidate_edge->src()) {
287 *tpu_input_edge = input_edge;
288 break;
289 }
290 // Replace original input edges with proxy's output.
291 for (const Edge* input_edge : edges_to_replace) {
292 graph->RemoveEdge(input_edge);
293 graph->AddEdge(input_identity_node, 0, input_edge->dst(),
294 input_edge->dst_input());
295 }
296 return OkStatus();
297 }
298
GetClusterName(Graph * graph,string * cluster_name)299 Status GetClusterName(Graph* graph, string* cluster_name) {
300 *cluster_name = "";
301 for (const Node* node : graph->nodes()) {
302 if (node->attrs().Find(kTpuReplicateAttr) == nullptr) continue;
303 if (cluster_name->empty())
304 *cluster_name = node->attrs().Find(kTpuReplicateAttr)->s();
305 // When optimization is turned on, the graph should only have one TPU
306 // cluster.
307 if (*cluster_name != node->attrs().Find(kTpuReplicateAttr)->s())
308 return errors::FailedPrecondition(
309 "Only one cluster is allowed when optimization is turned on for "
310 "TPUPartitionedCall. Found ",
311 node->attrs().Find(kTpuReplicateAttr)->s(), " and ", *cluster_name);
312 }
313 return OkStatus();
314 }
315
316 // Removes nodes that has no effect that directly descends from _Arg node.
317 //
318 // This is currently used for removing TPUReplicatedInput and XlaSharding node
319 // are always descendants of _Arg node. During optimization, we try to insert
320 // nodes in between _Arg and _Arg's children, where some of the nodes inserted
321 // are TPU nodes. We will add the TPUReplicatedInput and XlaSharding op nodes
322 // back where necessary.
323 //
324 // Returns the number of nodes that were removed.
RemoveDescendantNodeOfArg(Graph * graph,const std::string & node_type_to_remove,const std::set<std::string> & must_be_child_of)325 int64_t RemoveDescendantNodeOfArg(
326 Graph* graph, const std::string& node_type_to_remove,
327 const std::set<std::string>& must_be_child_of) {
328 int64_t nodes_removed = 0;
329 std::vector<std::pair<const Edge*, std::vector<const Edge*>>> edges_to_remove;
330
331 for (Node* node : graph->nodes()) {
332 if (node_type_to_remove != node->type_string()) continue;
333 if (!must_be_child_of.empty()) {
334 bool has_arg_parent = false;
335 for (const Edge* edge : node->in_edges()) {
336 if (must_be_child_of.count(edge->src()->type_string()) > 0) {
337 has_arg_parent = true;
338 }
339 }
340 if (!has_arg_parent) continue;
341 }
342 nodes_removed++;
343
344 const Edge* input_edge = nullptr;
345 std::vector<const Edge*> output_edges;
346 for (const Edge* edge : node->in_edges())
347 if (!edge->IsControlEdge()) {
348 input_edge = edge;
349 break;
350 }
351 for (const Edge* edge : node->out_edges())
352 if (!edge->IsControlEdge()) {
353 output_edges.push_back(edge);
354 }
355 if (input_edge != nullptr && !output_edges.empty())
356 edges_to_remove.push_back(std::make_pair(input_edge, output_edges));
357 }
358 for (const auto& it : edges_to_remove) {
359 for (const Edge* output_edge : it.second) {
360 graph->RemoveEdge(output_edge);
361 graph->AddEdge(it.first->src(), it.first->src_output(),
362 output_edge->dst(), output_edge->dst_input());
363 }
364 graph->RemoveNode(it.first->dst());
365 }
366 return nodes_removed;
367 }
368
GetInputHash(OpKernelContext * ctx)369 uint64 GetInputHash(OpKernelContext* ctx) {
370 uint64 input_hash = 0; // initialization for determinism.
371 // Use the number of elements to compute hash.
372 // TODO(chiachenc): use fhe full shape to compute the hash.
373 for (int i = 0; i < ctx->num_inputs(); ++i) {
374 VLOG(4) << "InputHash, combine input " << i
375 << ", NumElements: " << ctx->input(i).NumElements();
376 input_hash = Hash64Combine(input_hash, ctx->input(i).NumElements());
377 }
378 return input_hash;
379 }
380
HashShapeAndType(const string prefix,const std::vector<int> & input_dims,const DataType & dtype,const bool input_shape_opt)381 string HashShapeAndType(const string prefix, const std::vector<int>& input_dims,
382 const DataType& dtype, const bool input_shape_opt) {
383 string hash = strings::StrCat(prefix, dtype, "_dims");
384 // We will concat at the last dimension.
385 for (int d = 0; d < input_dims.size() - 1; ++d) {
386 strings::StrAppend(&hash, "_", input_dims.at(d));
387 }
388
389 if (input_shape_opt) {
390 if (input_dims.back() % kLastDimOfTpuInputFastPath == 0) {
391 strings::StrAppend(&hash, "_last_", kLastDimOfTpuInputFastPath, "n");
392 } else {
393 strings::StrAppend(&hash, "_last_other");
394 }
395 }
396 return hash;
397 }
398
399 // Get the information for input and output tensors (shapes, dtypes, etc).
GetInputOutputInfo(Graph * graph,GraphShapeInfo & tpu_inferred_info,std::map<int,InferredShape> & arg_shapes,EdgeShapes & tpu_input_shapes,absl::flat_hash_map<const Edge *,DataType> & tpu_input_dtypes,OpKernelContext * ctx)400 Status GetInputOutputInfo(
401 Graph* graph, GraphShapeInfo& tpu_inferred_info,
402 std::map<int, InferredShape>& arg_shapes, EdgeShapes& tpu_input_shapes,
403 absl::flat_hash_map<const Edge*, DataType>& tpu_input_dtypes,
404 OpKernelContext* ctx) {
405 // Search for the device-to-host or tpu-to-cpu edges.
406 for (Node* node : graph->op_nodes()) {
407 if (!node->IsArg()) continue;
408 const DataType dtype = node->attrs().Find("T")->type();
409 const int arg_index = node->attrs().Find("index")->i();
410 if (dtype != DT_INT32 && dtype != DT_BFLOAT16 && dtype != DT_FLOAT &&
411 dtype != DT_BOOL && dtype != DT_QINT8 && dtype != DT_QUINT8)
412 continue;
413 VLOG(3) << "Argnode: " << node->DebugString();
414 const Tensor& tensor = ctx->input(arg_index);
415
416 // Search for the cross-device edge from arg node.
417 const Edge* candidate_edge = FindHostToDeviceEdge(node);
418 if (candidate_edge == nullptr) continue;
419
420 // Make proxy and get the sole tpu_input_edge for transfer the input tensor
421 // corresponding to the current _Arg node.
422 const Edge* tpu_input_edge = nullptr;
423 TF_RETURN_IF_ERROR(
424 CreateInputProxy(graph, candidate_edge, &tpu_input_edge));
425 if (tpu_input_edge == nullptr)
426 return errors::NotFound("Couldn't find TPU input edge for", node->name());
427
428 // Optimize edge: original source to proxy identity.
429 VLOG(3) << "Input: " << tpu_input_edge->src()->name();
430 std::vector<int>& input_shapes = tpu_input_shapes[tpu_input_edge];
431 input_shapes.clear();
432 for (int d = 0; d < tensor.dims(); ++d) {
433 input_shapes.push_back(tensor.dim_size(d));
434 VLOG(3) << "Input Tensor: Dim[" << d << "] = " << tensor.dim_size(d);
435 }
436 tpu_input_dtypes[tpu_input_edge] = tensor.dtype();
437
438 // Collect shapes for non-resource-variable args.
439 PartialTensorShape partial_tensor_shape;
440 auto partial_shape = PartialTensorShape::MakePartialShape(
441 input_shapes.data(), input_shapes.size(), &partial_tensor_shape);
442 InferredShape inferred_shape = {partial_tensor_shape};
443 arg_shapes[arg_index] = inferred_shape;
444 }
445 return OkStatus();
446 }
447
448 // Converts a integer vector that represents the shapes to a Tensorshape.
ConvertEdgeShapesToTensorShapes(const std::map<std::string,std::vector<int>> & named_input_shapes,std::vector<TensorShape> * shapes)449 Status ConvertEdgeShapesToTensorShapes(
450 const std::map<std::string, std::vector<int>>& named_input_shapes,
451 std::vector<TensorShape>* shapes) {
452 shapes->resize(named_input_shapes.size());
453 int32_t i = 0;
454 // keys in tpu_input_shapes may be stale.
455 for (const auto& iter : named_input_shapes) {
456 VLOG(2) << iter.first << ", rank: " << iter.second.size();
457 const int64_t rank = iter.second.size();
458 std::vector<int64_t> dims(rank);
459 for (int64_t d = 0; d < rank; ++d) {
460 VLOG(2) << " dim[" << d << "]: " << iter.second.at(d);
461 dims[d] = iter.second.at(d);
462 }
463 TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(dims, &(*shapes)[i]));
464 i++;
465 }
466 return OkStatus();
467 }
468
469 // Get the TF fingerprint with the information from the TPUCompileOp or
470 // _TPUCompileMlirOp.
MaybeRegisterFingerprint(Graph * graph,const std::map<std::string,std::vector<int>> & named_input_shapes,uint64 input_hash)471 Status MaybeRegisterFingerprint(
472 Graph* graph,
473 const std::map<std::string, std::vector<int>>& named_input_shapes,
474 uint64 input_hash) {
475 // Find the compiler metadata.
476 tpu::TPUCompileMetadataProto metadata_proto;
477 std::map<std::string, std::vector<int>> inputs_to_keep;
478 int num_dynamic_shapes = -1;
479 tensorflow::uint64 fingerprint = 0;
480
481 for (Node* node : graph->op_nodes()) {
482 if (node->type_string() == "TPUCompile" ||
483 node->type_string() == "_TPUCompileMlir") {
484 num_dynamic_shapes = node->attrs().Find("NumDynamicShapes")->i();
485 if (num_dynamic_shapes <= 0) {
486 break;
487 }
488 int visited = 0;
489 // TPUCompileOp/_TPUCompileMlirOp take Shape nodes as inputs.
490 // The number of Shape nodes matches the number of dynamic shaped inputs.
491 // The Shape nodes come from the input nodes:
492 // [TPU Input] --> [Input Shape] --> [TPUCompileOp]
493 for (auto in_node : node->in_nodes()) {
494 if (in_node->type_string() != "Shape") {
495 continue;
496 }
497 for (auto input_node : in_node->in_nodes()) {
498 auto iter = named_input_shapes.find(input_node->name());
499 if (iter != named_input_shapes.end()) {
500 inputs_to_keep[iter->first] = iter->second;
501 }
502 }
503 visited++;
504 if (visited == num_dynamic_shapes) {
505 break;
506 }
507 }
508 std::string metadata = node->attrs().Find("metadata")->s();
509 metadata_proto.ParseFromString(metadata);
510
511 if (node->type_string() == "_TPUCompileMlir") {
512 std::string mlir_module = node->attrs().Find("mlir_module")->s();
513 fingerprint = tensorflow::Fingerprint64(mlir_module);
514 } else {
515 fingerprint = metadata_proto.function_library_fingerprint();
516 }
517
518 break;
519 }
520 }
521 VLOG(2) << "inputs_to_keep size: " << inputs_to_keep.size();
522 if (inputs_to_keep.size() != num_dynamic_shapes) {
523 VLOG(2) << "Cannot match all inputs shapes. Skip fingerprint registration.";
524 return OkStatus();
525 }
526
527 std::vector<TensorShape> input_shapes;
528 TF_RETURN_IF_ERROR(
529 ConvertEdgeShapesToTensorShapes(inputs_to_keep, &input_shapes));
530
531 std::vector<TensorShape> arg_shapes;
532 auto status =
533 tpu::ComputeArgumentShapes(metadata_proto, input_shapes, &arg_shapes);
534 if (!status.ok()) {
535 VLOG(2) << status.error_message();
536 return OkStatus();
537 }
538 uint64 tf_fingerprint =
539 tpu::CreateFingerprintWithNameAndShapes(fingerprint, arg_shapes);
540 VLOG(2) << "fingerprint: " << fingerprint;
541 VLOG(2) << "TF fingerprint: " << tf_fingerprint;
542
543 ResourceMgr* rm = GetTPUConfigResourceMgr();
544 tpu::TpuFingerprintLookup* fingerprint_lookup;
545 TF_RETURN_IF_ERROR(rm->Lookup<tpu::TpuFingerprintLookup>(
546 rm->default_container(), tpu::kFingerprintLookupResourceName,
547 &fingerprint_lookup));
548 fingerprint_lookup->RegisterKeyAndIntermediatePair(input_hash,
549 tf_fingerprint);
550 return OkStatus();
551 }
552
FindTpuReplicatedInputAndXlaSharding(const Graph * graph,XlaShardingInfoMap & xla_sharding_ops,TpuReplicatedInputInfoMap & tpu_replicated_input_ops)553 bool FindTpuReplicatedInputAndXlaSharding(
554 const Graph* graph, XlaShardingInfoMap& xla_sharding_ops,
555 TpuReplicatedInputInfoMap& tpu_replicated_input_ops) {
556 bool xla_spmd_input_sharded = false;
557 // Detect whether there are XLA Sharding on the inputs, if there are, then
558 // we cannot remove the replicated inputs or the xla sharding ops.
559 for (Node* xla_sharding_node : graph->nodes()) {
560 if (xla_sharding_node->type_string() == "XlaSharding") {
561 for (const Edge* edge : xla_sharding_node->in_edges()) {
562 if (edge->src()->type_string() == "TPUReplicatedInput") {
563 Node* tpu_replicated_input_node = edge->src();
564 Node* tpu_replicated_metadata_node = nullptr;
565 for (const Edge* input_edge : tpu_replicated_input_node->in_edges()) {
566 if (input_edge->IsControlEdge()) {
567 tpu_replicated_metadata_node = input_edge->src();
568 }
569 }
570
571 for (const Edge* input_edge : tpu_replicated_input_node->in_edges()) {
572 if (input_edge->src()->type_string() == "_Arg") {
573 Node* arg_node = input_edge->src();
574
575 xla_sharding_ops[arg_node->name()] = std::make_tuple(
576 xla_sharding_node->attrs().Find("T")->type(),
577 xla_sharding_node->attrs().Find("sharding")->s(),
578 xla_sharding_node->attrs().Find("_tpu_replicate")->s());
579
580 tpu_replicated_input_ops[arg_node->name()] = std::make_tuple(
581 tpu_replicated_input_node->attrs().Find("T")->type(),
582 tpu_replicated_metadata_node);
583
584 VLOG(2) << "Detected input is sharded. XlaSharding node: "
585 << xla_sharding_node->DebugString()
586 << ", TPUReplicatedInput node: "
587 << edge->src()->DebugString()
588 << ", _Arg node: " << arg_node->DebugString();
589 xla_spmd_input_sharded = true;
590 break;
591 }
592 }
593 }
594 }
595 }
596 }
597 return xla_spmd_input_sharded;
598 }
599
600 // Returns the name of the framework that rewrote the graph to support
601 // inference on TPUs. This name is accessed later during metric collection.
GetProducerName(const string & function_name)602 string GetProducerName(const string& function_name) {
603 if (absl::StrContains(function_name, "tpu_fn_icv2_")) {
604 if (absl::StrContains(function_name, "_tf_quant")) {
605 return "TPU_INFERENCE_CONVERTER_V2_TF_QUANTIZER";
606 }
607 return "TPU_INFERENCE_CONVERTER_V2";
608 }
609 if (absl::StrContains(function_name, "tpu_func_0") ||
610 absl::StrContains(function_name, "_with_batch") ||
611 absl::StrContains(function_name, "_optim")) {
612 if (absl::StrContains(function_name, "_tf_quant")) {
613 return "TPU_INFERENCE_CONVERTER_TF_QUANTIZER";
614 }
615 return "TPU_INFERENCE_CONVERTER";
616 }
617 return "UNKNOWN";
618 }
619
620 // Gets the proper tensor dimension from XLA OpSharding.
621 // "replicate_on_last_tile_dim" and "last_tile_dims" should be deducted from the
622 // real Tensor dimensions when tiled.
623 // For example:
624 // f32[8,512](sharding={devices=[1,1,2]0,1 last_tile_dims={REPLICATED})
625 // also means a replicated tensor over all devices.
626 //
627 // See xla_data.proto for detailed explanations on the fields.
GetDimsFromXLAShardingTiled(const xla::OpSharding xla_sharding)628 int GetDimsFromXLAShardingTiled(const xla::OpSharding xla_sharding) {
629 return xla_sharding.tile_assignment_dimensions_size() -
630 (xla_sharding.replicate_on_last_tile_dim() ? 1 : 0) -
631 xla_sharding.last_tile_dims_size();
632 }
633
634 } // end namespace
635
636 namespace tpu_functional_internal {
637
638 // An optimization pass that separates tensors to leverage the fast path in
639 // TPU input preparation. The algorithm is as follows:
640 // (1) Group all tensors that have same dimensions except the last dimension. A
641 // group of tensors will be concatenated by the last dimension in a later pass.
642 // (2) Check all groups of tensors and find groups whose dimensions after concat
643 // cannot leverage the fast path.
644 // (3) For groups of tensors that don't leverage the fast path, split tensors
645 // into two sub-groups such that one sub-group of tensors can leverage the fast
646 // path.
647 // Exception in (2) is that concated tensors are small, which means separating
648 // tensors would introduce overheads of data transfer to device.
649 // This optimization takes effect when both --input_shape_opt and
650 // --group_tensors_for_packing are true.
GroupTensorsForInputPacking(const EdgeShapes & tpu_input_shapes,const absl::flat_hash_map<const Edge *,DataType> & tpu_input_dtypes,bool input_shape_opt,bool group_tensors_for_packing)651 GroupedEdges GroupTensorsForInputPacking(
652 const EdgeShapes& tpu_input_shapes,
653 const absl::flat_hash_map<const Edge*, DataType>& tpu_input_dtypes,
654 bool input_shape_opt, bool group_tensors_for_packing) {
655 GroupedEdges grouped_input_edges;
656 for (const auto& iter : tpu_input_shapes) {
657 if (iter.second.empty()) continue;
658 DataType dtype = tpu_input_dtypes.find(iter.first)->second;
659 string hash_key = HashShapeAndType("input_tensors_dtype_", iter.second,
660 dtype, /*input_shape_opt*/ false);
661 grouped_input_edges[hash_key].push_back(iter.first);
662 }
663 // Apply grouping when both are true.
664 if (!input_shape_opt || !group_tensors_for_packing)
665 return grouped_input_edges;
666
667 GroupedEdges grouped_input_edges_opt;
668 for (const auto& iter : grouped_input_edges) {
669 int sum_last_dim = 0;
670 int product_other_dims = 0;
671 VLOG(3) << "group name: " << iter.first;
672 for (const auto& edge : iter.second) {
673 const std::vector<int>& input_shapes =
674 tpu_input_shapes.find(edge)->second;
675 sum_last_dim += input_shapes.back();
676 if (product_other_dims == 0) {
677 product_other_dims = 1;
678 for (int d = 0; d < input_shapes.size() - 1; ++d)
679 product_other_dims *= input_shapes.at(d);
680 }
681 }
682 VLOG(3) << "sum_last_dim: " << sum_last_dim;
683 VLOG(3) << "product_other_dims: " << product_other_dims;
684 // Already uses fast path, skip further grouping.
685 if ((sum_last_dim % kLastDimOfTpuInputFastPath) == 0 &&
686 (product_other_dims % kOtherDimOfTpuInputFastPath) == 0) {
687 grouped_input_edges_opt[iter.first] = iter.second;
688 continue;
689 }
690 // Tensors are small, skip further grouping.
691 if ((sum_last_dim * product_other_dims) <
692 (kLastDimOfTpuInputFastPath * kOtherDimOfTpuInputFastPath)) {
693 grouped_input_edges_opt[iter.first] = iter.second;
694 continue;
695 }
696 VLOG(3) << "Splitting tensors.";
697 for (const auto& edge : iter.second) {
698 auto tpu_input_shape = tpu_input_shapes.find(edge)->second;
699 string hash_key =
700 HashShapeAndType("input_tensors_dtype_", tpu_input_shape,
701 tpu_input_dtypes.find(edge)->second,
702 /*input_shape_opt*/ true);
703 grouped_input_edges_opt[hash_key].push_back(edge);
704 }
705 }
706 return grouped_input_edges_opt;
707 }
708
GroupTensorsForOutputPacking(Graph * graph,EdgeShapes & tpu_output_shapes,GraphShapeInfo * shape_info)709 GroupedEdges GroupTensorsForOutputPacking(Graph* graph,
710 EdgeShapes& tpu_output_shapes,
711 GraphShapeInfo* shape_info) {
712 GroupedEdges shape_to_output;
713 for (const Edge* edge : graph->edges()) {
714 if (edge->IsControlEdge()) continue;
715
716 // TPU input edge: src is on TPU and dest is on CPU.
717 if (edge->dst()->type_string() != "TPUReplicatedOutput") continue;
718 if (!shape_info->count(edge->src()->name())) continue;
719
720 // output shapes for hashing
721 std::vector<int>& output_shapes = tpu_output_shapes[edge];
722 output_shapes.clear();
723
724 int output_id = edge->src_output();
725 auto inferred_shape_vec = shape_info->at(edge->src()->name());
726
727 for (int d : inferred_shape_vec.at(output_id).shape.dim_sizes()) {
728 output_shapes.push_back(d);
729 }
730
731 // Hash Shape and Types.
732 DataType dtype = edge->src()->input_type(output_id);
733 string hash_key =
734 HashShapeAndType("output_tensors_dtype_", output_shapes, dtype, false);
735
736 shape_to_output[hash_key].push_back(edge);
737 }
738 return shape_to_output;
739 }
740
741 // Concatenates input tensors on CPU along the last dimension if all other
742 // dimensions are the same, and split them on TPU to reduce input overhead.
743 // `tpu_input_shapes` maps an edge to the shape of its output tensor.
744 // `grouped_input_edges` maps tensor name to all edges output from this tensor.
CreateConcatAndSplitNodesForInputTensor(Graph * graph,const string & cluster_name,EdgeShapes * tpu_input_shapes,const absl::flat_hash_map<std::string,std::vector<const Edge * >> & grouped_input_edges,int32_t minimum_input_tensors_packing,bool xla_spmd_input_sharded,const XlaShardingInfoMap & xla_sharding_info,const TpuReplicatedInputInfoMap & tpu_replicated_input_info)745 Status CreateConcatAndSplitNodesForInputTensor(
746 Graph* graph, const string& cluster_name, EdgeShapes* tpu_input_shapes,
747 const absl::flat_hash_map<std::string, std::vector<const Edge*>>&
748 grouped_input_edges,
749 int32_t minimum_input_tensors_packing, bool xla_spmd_input_sharded,
750 const XlaShardingInfoMap& xla_sharding_info,
751 const TpuReplicatedInputInfoMap& tpu_replicated_input_info) {
752 for (const auto& iter : grouped_input_edges) {
753 std::vector<int> last_dim_vec;
754 std::vector<NodeBuilder::NodeOut> concat_nodeouts;
755 absl::flat_hash_map<std::string, int> tensor_to_split_output;
756 int rank;
757 DataType dtype = DT_INVALID;
758 std::string src_name;
759 for (const Edge* edge : iter.second) {
760 src_name = edge->src()->name();
761 string tensor_name =
762 absl::StrCat(edge->src()->name(), ":", edge->src_output());
763 // Create Concat / Split pair for a tensor if not exist yet.
764 if (tensor_to_split_output.contains(tensor_name)) continue;
765 tensor_to_split_output[tensor_name] = concat_nodeouts.size();
766 concat_nodeouts.push_back(
767 NodeBuilder::NodeOut(edge->src(), edge->src_output()));
768 dtype = edge->src()->output_type(edge->src_output());
769 rank = tpu_input_shapes->at(edge).size();
770 last_dim_vec.push_back(tpu_input_shapes->at(edge).back());
771 }
772
773 const int num_tensors = tensor_to_split_output.size();
774 VLOG(3) << iter.first << " num_tensors: " << num_tensors;
775 if (num_tensors < minimum_input_tensors_packing) {
776 VLOG(3) << "skip concat/split " << iter.first;
777 continue;
778 }
779
780 Node* concat_axis_node = nullptr;
781 TensorShape t_shape;
782 Tensor dim_tensor(DT_INT32, t_shape);
783 // Concat and Split at the last dim.
784 dim_tensor.flat<int>()(0) = rank - 1;
785 TF_RETURN_IF_ERROR(
786 NodeBuilder(strings::StrCat(iter.first, "/concat/axis"), "Const")
787 .Attr("dtype", DT_INT32)
788 .Attr("value", dim_tensor)
789 .Finalize(graph, &concat_axis_node));
790
791 Node* concat_node = nullptr;
792 TF_RETURN_IF_ERROR(
793 NodeBuilder(strings::StrCat(iter.first, "/concat"), "ConcatV2")
794 .Input(concat_nodeouts)
795 .Input(concat_axis_node)
796 .Attr("T", dtype)
797 .Attr("Tidx", DT_INT32)
798 .Attr("N", num_tensors)
799 .Finalize(graph, &concat_node));
800
801 Node* split_dim_node = nullptr;
802 TF_RETURN_IF_ERROR(
803 NodeBuilder(strings::StrCat(iter.first, "/split/split_dim"), "Const")
804 .Attr("dtype", DT_INT32)
805 .Attr("value", dim_tensor)
806 .Attr(kTpuReplicateAttr, cluster_name)
807 .Finalize(graph, &split_dim_node));
808
809 Node* split_vec_node = nullptr;
810 TensorShape split_vec_shape;
811 split_vec_shape.AddDim(1);
812 split_vec_shape.set_dim(0, last_dim_vec.size());
813
814 Tensor split_vec_tensor(DT_INT32, split_vec_shape);
815 for (int i = 0; i < last_dim_vec.size(); ++i) {
816 split_vec_tensor.flat<int>()(i) = last_dim_vec[i];
817 }
818 VLOG(3) << "split_vec_tensor: " << split_vec_tensor.DebugString();
819
820 TF_RETURN_IF_ERROR(
821 NodeBuilder(strings::StrCat(iter.first, "/split/vec"), "Const")
822 .Attr("dtype", DT_INT32)
823 .Attr("value", split_vec_tensor)
824 .Attr(kTpuReplicateAttr, cluster_name)
825 .Finalize(graph, &split_vec_node));
826
827 Node* split_node = nullptr;
828 Node* input_to_split_node = concat_node;
829 Node* output_from_concat_node = nullptr;
830 if (xla_spmd_input_sharded &&
831 tpu_replicated_input_info.count(src_name) > 0 &&
832 xla_sharding_info.count(src_name) > 0) {
833 // Create new TPUReplicatedInput and XLAShardingOp nodes
834 //
835 // Rewrite the graph from:
836 // Concat -> Split
837 // to
838 // Concat -> TPUReplicatedInput -> XlaSharding -> Split
839 Node* tpu_replicated_input = nullptr;
840 Node* xla_sharding_op = nullptr;
841
842 std::vector<NodeBuilder::NodeOut> replicated_input;
843 replicated_input.push_back(NodeBuilder::NodeOut(concat_node));
844
845 // TODO(b/183060455): Add TPUReplicatedInput to all graphs.
846 TF_RETURN_IF_ERROR(
847 NodeBuilder(strings::StrCat(iter.first, "/TPUReplicatedInput"),
848 "TPUReplicatedInput")
849 .Input(replicated_input)
850 .ControlInput(std::get<1>(tpu_replicated_input_info.at(src_name)))
851 .Attr("N", 1)
852 .Attr("T", std::get<0>(tpu_replicated_input_info.at(src_name)))
853 .Attr("index", -1)
854 .Attr("is_mirrored_variable", false)
855 .Attr("is_packed", false)
856 .Finalize(graph, &tpu_replicated_input));
857 VLOG(2) << "Created new TPUReplicatedInput node "
858 << tpu_replicated_input->DebugString();
859
860 TF_RETURN_IF_ERROR(
861 NodeBuilder(strings::StrCat(iter.first, "/XlaSharding"),
862 "XlaSharding")
863 .Input(tpu_replicated_input)
864 .Attr("T", std::get<0>(xla_sharding_info.at(src_name)))
865 .Attr("sharding", std::get<1>(xla_sharding_info.at(src_name)))
866 .Attr("_XlaSharding", std::get<1>(xla_sharding_info.at(src_name)))
867 .Attr("_tpu_replicate",
868 std::get<2>(xla_sharding_info.at(src_name)))
869 .Finalize(graph, &xla_sharding_op));
870 VLOG(2) << "Created new XLA sharding node "
871 << xla_sharding_op->DebugString();
872
873 input_to_split_node = xla_sharding_op;
874 output_from_concat_node = tpu_replicated_input;
875 }
876 // Update the `tpu_input_shapes` mapping: Add the new edge
877 // from concat to split.
878 TF_RETURN_IF_ERROR(
879 NodeBuilder(strings::StrCat(iter.first, "/split"), "SplitV")
880 .Input(input_to_split_node)
881 .Input(split_vec_node)
882 .Input(split_dim_node)
883 .Attr("T", dtype)
884 .Attr("num_split", num_tensors)
885 .Attr(kTpuReplicateAttr, cluster_name)
886 .Finalize(graph, &split_node));
887
888 if (output_from_concat_node == nullptr)
889 output_from_concat_node = split_node;
890
891 const Edge* concat_to_split;
892 for (const Edge* edge : concat_node->out_edges())
893 if (edge->dst() == output_from_concat_node) {
894 concat_to_split = edge;
895 break;
896 }
897 if (rank > 1) {
898 for (int d = 0; d < rank - 1; ++d)
899 (*tpu_input_shapes)[concat_to_split].push_back(
900 tpu_input_shapes->at(iter.second.back()).at(d));
901 }
902 (*tpu_input_shapes)[concat_to_split].push_back(
903 std::accumulate(last_dim_vec.begin(), last_dim_vec.end(), 0));
904
905 // Connect split node to original tensor output.
906 for (const Edge* edge : iter.second) {
907 string tensor_name =
908 absl::StrCat(edge->src()->name(), ":", edge->src_output());
909 int output_index = tensor_to_split_output.at(tensor_name);
910 graph->RemoveEdge(edge);
911 graph->AddEdge(split_node, output_index, edge->dst(), edge->dst_input());
912 // Update the `tpu_input_shapes` mapping: Remove old edges.
913 tpu_input_shapes->erase(edge);
914 }
915 VLOG(3) << "Concat node: " << concat_node->DebugString();
916 }
917 return OkStatus();
918 }
919
920 // Concatenates input tensors on TPU along the last dimension if all other
921 // dimensions are the same, and split them on CPU to reduce outfeed overhead.
922 // `tpu_inferred_info` maps an edge to the inferred shape of its output tensor.
923 // `shape_to_output` maps tensor name to all edges output from this tensor.
CreateConcatAndSplitNodesForOutputTensor(Graph * graph,const string & cluster_name,EdgeShapes * tpu_output_shapes,GraphShapeInfo * tpu_inferred_info,GroupedEdges shape_to_output,int32_t minimum_output_tensors_packing)924 Status CreateConcatAndSplitNodesForOutputTensor(
925 Graph* graph, const string& cluster_name, EdgeShapes* tpu_output_shapes,
926 GraphShapeInfo* tpu_inferred_info, GroupedEdges shape_to_output,
927 int32_t minimum_output_tensors_packing) {
928 for (const auto& iter : shape_to_output) {
929 std::vector<int> last_dim_vec;
930 std::vector<NodeBuilder::NodeOut> concat_nodeouts;
931 absl::flat_hash_map<std::string, int> tensor_to_split_output;
932 int rank;
933 DataType dtype = DT_INVALID;
934 for (const Edge* edge : iter.second) {
935 string tensor_name =
936 absl::StrCat(edge->src()->name(), ":", edge->src_output());
937
938 // Create Concat / Split pair for a tensor if not exist yet.
939 if (tensor_to_split_output.contains(tensor_name)) continue;
940 tensor_to_split_output[tensor_name] = concat_nodeouts.size();
941
942 concat_nodeouts.push_back(
943 NodeBuilder::NodeOut(edge->src(), edge->src_output()));
944 dtype = edge->src()->output_type(edge->src_output());
945 rank = tpu_output_shapes->at(edge).size();
946 last_dim_vec.push_back(tpu_output_shapes->at(edge).back());
947 }
948
949 const int num_tensors = tensor_to_split_output.size();
950 if (num_tensors < minimum_output_tensors_packing) {
951 VLOG(3) << "skip concat/split " << iter.first;
952 continue;
953 }
954
955 Node* concat_axis_node = nullptr;
956 TensorShape t_shape;
957 Tensor dim_tensor(DT_INT32, t_shape);
958 // Concat and Split at the last dim.
959 dim_tensor.flat<int>()(0) = rank - 1;
960 TF_RETURN_IF_ERROR(
961 NodeBuilder(strings::StrCat(iter.first, "/concat/axis"), "Const")
962 .Attr("dtype", DT_INT32)
963 .Attr("value", dim_tensor)
964 .Attr(kTpuReplicateAttr, cluster_name)
965 .Finalize(graph, &concat_axis_node));
966
967 Node* concat_node = nullptr;
968 TF_RETURN_IF_ERROR(
969 NodeBuilder(strings::StrCat(iter.first, "/concat"), "ConcatV2")
970 .Input(concat_nodeouts)
971 .Input(concat_axis_node)
972 .Attr("T", dtype)
973 .Attr("Tidx", DT_INT32)
974 .Attr("N", num_tensors)
975 .Attr(kTpuReplicateAttr, cluster_name)
976 .Finalize(graph, &concat_node));
977
978 Node* tpu_replicated_output_node = nullptr;
979 TF_RETURN_IF_ERROR(
980 NodeBuilder(strings::StrCat(iter.first, "/tpu_replicated_output"),
981 "TPUReplicatedOutput")
982 .Input(concat_node)
983 .Attr("T", dtype)
984 .Attr("num_replicas", 1)
985 .Finalize(graph, &tpu_replicated_output_node));
986
987 Node* split_dim_node = nullptr;
988 TF_RETURN_IF_ERROR(
989 NodeBuilder(strings::StrCat(iter.first, "/split/split_dim"), "Const")
990 .Attr("dtype", DT_INT32)
991 .Attr("value", dim_tensor)
992 .Finalize(graph, &split_dim_node));
993
994 Node* split_vec_node = nullptr;
995 TensorShape split_vec_shape;
996 split_vec_shape.AddDim(1);
997 split_vec_shape.set_dim(0, last_dim_vec.size());
998
999 Tensor split_vec_tensor(DT_INT32, split_vec_shape);
1000 for (int i = 0; i < last_dim_vec.size(); ++i) {
1001 split_vec_tensor.flat<int>()(i) = last_dim_vec[i];
1002 }
1003 VLOG(3) << "split_vec_tensor: " << split_vec_tensor.DebugString();
1004
1005 TF_RETURN_IF_ERROR(
1006 NodeBuilder(strings::StrCat(iter.first, "/split/vec"), "Const")
1007 .Attr("dtype", DT_INT32)
1008 .Attr("value", split_vec_tensor)
1009 .Finalize(graph, &split_vec_node));
1010
1011 Node* split_node = nullptr;
1012 TF_RETURN_IF_ERROR(
1013 NodeBuilder(strings::StrCat(iter.first, "/split"), "SplitV")
1014 .Input(tpu_replicated_output_node)
1015 .Input(split_vec_node)
1016 .Input(split_dim_node)
1017 .Attr("T", dtype)
1018 .Attr("num_split", num_tensors)
1019 .Finalize(graph, &split_node));
1020
1021 // Update the `tpu_out_shapes` mapping: Add the new edge
1022 // from concat to split.
1023 const Edge* concat_to_split;
1024 for (const Edge* edge : concat_node->out_edges())
1025 if (edge->dst() == split_node) {
1026 concat_to_split = edge;
1027 break;
1028 }
1029
1030 if (rank > 1) (*tpu_output_shapes)[concat_to_split].push_back(-1);
1031 for (int d = 1; d < rank - 1; ++d)
1032 (*tpu_output_shapes)[concat_to_split].push_back(
1033 tpu_output_shapes->at(iter.second.back()).at(d));
1034 (*tpu_output_shapes)[concat_to_split].push_back(
1035 std::accumulate(last_dim_vec.begin(), last_dim_vec.end(), 0));
1036
1037 for (const Edge* edge : iter.second) {
1038 // 1. Find old TPURelicatedOutput output edges
1039 std::vector<const Edge*> output_edge_vec;
1040 for (const Edge* output_edge : edge->dst()->out_edges())
1041 output_edge_vec.push_back(output_edge);
1042
1043 string tensor_name =
1044 absl::StrCat(edge->src()->name(), ":", edge->src_output());
1045 int output_index = tensor_to_split_output.at(tensor_name);
1046 VLOG(3) << "output_index: " << output_index;
1047
1048 // Connect split node to original tensor output.
1049 for (const Edge* output_edge : output_edge_vec) {
1050 VLOG(3) << "output_edge" << output_edge->DebugString();
1051 graph->RemoveEdge(output_edge);
1052 graph->AddEdge(split_node, output_index, output_edge->dst(),
1053 output_edge->dst_input());
1054 // Update the `tpu_output_shapes` mapping: Remove old edges.
1055 tpu_output_shapes->erase(output_edge);
1056 }
1057 graph->RemoveNode(edge->dst());
1058 }
1059 VLOG(3) << "Concat node: " << concat_node->DebugString();
1060 }
1061 return OkStatus();
1062 }
1063
InsertReshapeNodePairs(Graph * graph,const string & cluster_name,EdgeShapes * tpu_input_shapes,int num_cores_per_replica)1064 Status InsertReshapeNodePairs(Graph* graph, const string& cluster_name,
1065 EdgeShapes* tpu_input_shapes,
1066 int num_cores_per_replica) {
1067 std::vector<const Edge*> tpu_input_edges_original;
1068 for (const auto& it : *tpu_input_shapes)
1069 if (!it.second.empty()) tpu_input_edges_original.push_back(it.first);
1070 for (const Edge* edge : tpu_input_edges_original) {
1071 VLOG(3) << "Reshape input: " << edge->DebugString();
1072
1073 // Check if there is a TPUReplicatedInput and XlaSharding in the middle
1074 bool xla_sharded_input = false;
1075 Node* xla_sharding_node = nullptr;
1076 if (edge->dst()->type_string() == "TPUReplicatedInput" &&
1077 edge->dst()->out_nodes().begin()->type_string() == "XlaSharding") {
1078 VLOG(3) << "Detected TPUReplicatedInput " << edge->dst()->DebugString()
1079 << " and XlaSharding "
1080 << edge->dst()->out_nodes().begin()->DebugString()
1081 << ", setting xla_sharded_input = true";
1082 xla_sharded_input = true;
1083 xla_sharding_node = *(edge->dst()->out_nodes().begin());
1084 }
1085
1086 // 1. Build Reshape node for flatten.
1087
1088 // 1.1 Build Const node for shape
1089 Node* flatten_reshape_shape_node = nullptr;
1090 Tensor flattened_input_shape_tensor;
1091 flattened_input_shape_tensor =
1092 Tensor(DT_INT32, TensorShape({static_cast<int64_t>(1)}));
1093 flattened_input_shape_tensor.flat<int>()(0) = -1;
1094 TF_RETURN_IF_ERROR(
1095 NodeBuilder(absl::StrCat(edge->src()->name(), "/flatten/Reshape/shape"),
1096 "Const")
1097 .Attr("dtype", DT_INT32)
1098 .Attr("value", flattened_input_shape_tensor)
1099 .Finalize(graph, &flatten_reshape_shape_node));
1100
1101 // 1.2 Build Reshape node for flatten.
1102 Node* flatten_reshape_node = nullptr;
1103 TF_RETURN_IF_ERROR(
1104 NodeBuilder(absl::StrCat(edge->src()->name(), "/flatten/Reshape"),
1105 "Reshape")
1106 .Input(edge->src(), edge->src_output())
1107 .Input(flatten_reshape_shape_node)
1108 .Attr("T", edge->src()->output_type(edge->src_output()))
1109 .Attr("Tshape", DT_INT32)
1110 .Finalize(graph, &flatten_reshape_node));
1111
1112 // 2. Build Reshape node for recover.
1113
1114 // 2.1 Build Const node for shape.
1115 Node* recover_reshape_shape_node = nullptr;
1116 Tensor original_input_shape_tensor(
1117 DT_INT32,
1118 TensorShape({static_cast<int64_t>(tpu_input_shapes->at(edge).size())}));
1119 original_input_shape_tensor.flat<int>()(0) = -1;
1120 for (int d = 1; d < tpu_input_shapes->at(edge).size(); ++d)
1121 original_input_shape_tensor.flat<int>()(d) =
1122 tpu_input_shapes->at(edge).at(d);
1123 TF_RETURN_IF_ERROR(
1124 NodeBuilder(absl::StrCat(edge->src()->name(), "/recover/Reshape/shape"),
1125 "Const")
1126 .Attr("dtype", DT_INT32)
1127 .Attr("value", original_input_shape_tensor)
1128 .Attr(kTpuReplicateAttr, cluster_name) // This node is on TPU.
1129 .Finalize(graph, &recover_reshape_shape_node));
1130
1131 // 2.2 Build Reshape node for recover.
1132 Node* recover_reshape_input_node = flatten_reshape_node;
1133 const Edge* original_recover_reshape_input_edge = nullptr;
1134 if (xla_sharded_input) {
1135 // We want to find the node after the XlaSharding node
1136 original_recover_reshape_input_edge =
1137 *(edge->dst()->out_nodes().begin()->out_edges().begin());
1138 recover_reshape_input_node = *(edge->dst()->out_nodes().begin());
1139 VLOG(3) << "Recover reshape input node: "
1140 << recover_reshape_input_node->DebugString()
1141 << ", recover reshape input edge: "
1142 << original_recover_reshape_input_edge->DebugString();
1143 }
1144
1145 Node* recover_reshape_node = nullptr;
1146 TF_RETURN_IF_ERROR(
1147 NodeBuilder(absl::StrCat(edge->src()->name(), "/recover/Reshape"),
1148 "Reshape")
1149 .Input(recover_reshape_input_node)
1150 .Input(recover_reshape_shape_node)
1151 .Attr("T", edge->src()->output_type(edge->src_output()))
1152 .Attr("Tshape", DT_INT32)
1153 .Attr(kTpuReplicateAttr, cluster_name) // This node is on TPU.
1154 .Finalize(graph, &recover_reshape_node));
1155
1156 // 3. Rewrite XlaSharding attribute if necessary
1157 if (xla_sharding_node != nullptr) {
1158 // The flattened tensor always has rank = 1 and we want to shard the only
1159 // dimension (0).
1160 SetXlaShardingNodeAttr(xla_sharding_node, num_cores_per_replica, 1, 0);
1161 }
1162
1163 // 4. Connect / disconnect nodes.
1164 if (xla_sharded_input) {
1165 graph->AddEdge(flatten_reshape_node, 0, edge->dst(), edge->dst_input());
1166 }
1167
1168 if (original_recover_reshape_input_edge != nullptr) {
1169 graph->AddEdge(recover_reshape_node, 0,
1170 original_recover_reshape_input_edge->dst(),
1171 original_recover_reshape_input_edge->dst_input());
1172 } else {
1173 graph->AddEdge(recover_reshape_node, 0, edge->dst(), edge->dst_input());
1174 }
1175
1176 graph->RemoveEdge(edge);
1177 if (original_recover_reshape_input_edge != nullptr) {
1178 graph->RemoveEdge(original_recover_reshape_input_edge);
1179 }
1180
1181 // 4. Update EdgeShapes.
1182 int dimension = 1;
1183 for (auto& it : (*tpu_input_shapes)[edge]) {
1184 dimension *= it;
1185 }
1186 VLOG(3) << "Dimension after reshape: " << dimension;
1187 for (const Edge* out_edge : flatten_reshape_node->out_edges()) {
1188 if (out_edge->dst() == recover_reshape_node) {
1189 (*tpu_input_shapes)[out_edge].push_back(dimension);
1190 tpu_input_shapes->erase(edge);
1191 break;
1192 }
1193 }
1194 VLOG(3) << "Reshape optimization done for " << edge->src()->name();
1195 }
1196 return OkStatus();
1197 }
1198 } // namespace tpu_functional_internal
1199
ComputeAsync(OpKernelContext * ctx,DoneCallback done)1200 void TPUPartitionedCallOp::ComputeAsync(OpKernelContext* ctx,
1201 DoneCallback done) {
1202 Status init_status;
1203 absl::call_once(once_, [&]() {
1204 library_runtime_ = ctx->function_library();
1205 if (library_runtime_ == nullptr) {
1206 init_status = errors::Internal("No function library is provided.");
1207 return;
1208 }
1209 flib_def_ = std::make_unique<FunctionLibraryDefinition>(
1210 *library_runtime_->GetFunctionLibraryDefinition());
1211 device_mgr_ = library_runtime_->device_mgr();
1212 for (auto d : device_mgr_->ListDevices()) {
1213 device_set_.AddDevice(d);
1214 }
1215
1216 DeviceNameUtils::ParsedName tpu_device_name;
1217 tpu_device_name.has_type = true;
1218 tpu_device_name.type = "TPU";
1219 std::vector<Device*> tpu_devices;
1220 device_set_.FindMatchingDevices(tpu_device_name, &tpu_devices_);
1221 });
1222 OP_REQUIRES_OK_ASYNC(ctx, init_status, done);
1223
1224 // Initialize the ordinal selector with information from the graph if it is
1225 // the first time we are running this op.
1226 absl::call_once(ordinal_selector_once_, [&]() {
1227 std::unique_ptr<Graph> graph(new Graph(flib_def_.get()));
1228 bool enable_spmd_xla_partitioning = false;
1229 TPUMetadata tpu_metadata;
1230 {
1231 absl::MutexLock l(&mu_);
1232 OP_REQUIRES_OK_ASYNC(
1233 ctx,
1234 GetGraphFromFunction(graph.get(), /*device_ordinal=*/0,
1235 &enable_spmd_xla_partitioning, &tpu_metadata),
1236 done);
1237 }
1238 if (enable_spmd_xla_partitioning) {
1239 ordinal_selector_ = std::make_shared<tpu::TPUOrdinalSelector>(
1240 tpu_metadata.num_cores_per_replica);
1241 } else {
1242 ordinal_selector_ = std::make_shared<tpu::TPUOrdinalSelector>();
1243 }
1244
1245 metrics::RecordTPUXlaSpmdCoresPerReplica(
1246 tpu_metadata.num_cores_per_replica);
1247 });
1248 OP_REQUIRES_ASYNC(
1249 ctx, ordinal_selector_ != nullptr,
1250 errors::Internal("The TPUOrdinalSelector is not initialized."), done);
1251
1252 uint64 input_hash = GetInputHash(ctx);
1253 int64_t ordinal_selector_req_id = -1;
1254 // Select a TPU core.
1255 int32_t device_ordinal = 0;
1256 OP_REQUIRES_OK_ASYNC(
1257 ctx,
1258 GetTpuCoreOrdinal(ctx, input_hash, &ordinal_selector_req_id,
1259 &device_ordinal),
1260 done);
1261 uint64 cache_hash = Hash64Combine(input_hash, device_ordinal);
1262 absl::ReleasableMutexLock lock(&mu_);
1263
1264 const std::vector<DeviceAndFHandle>* functions;
1265
1266 bool cache_miss = !partition_cache_.count(cache_hash);
1267 if (cache_miss) {
1268 VLOG(3) << "Cache Miss: partitioning function " << func_.name()
1269 << " cache_hash: " << cache_hash
1270 << " device_ordinal: " << device_ordinal;
1271
1272 profiler::TraceMe trace_me(
1273 "TPUPartitionedCallOp-RewriteAndInstantiateFunctions");
1274 std::unique_ptr<Graph> graph(new Graph(flib_def_.get()));
1275 bool enable_spmd_xla_partitioning = false;
1276 TPUMetadata tpu_metadata;
1277 OP_REQUIRES_OK_ASYNC(
1278 ctx,
1279 GetGraphFromFunction(graph.get(), device_ordinal,
1280 &enable_spmd_xla_partitioning, &tpu_metadata),
1281 done);
1282
1283 VLOG(1) << DumpGraphToFile("before_input_output_optimizations", *graph,
1284 flib_def_.get());
1285
1286 std::map<std::string, std::vector<int>> named_input_shapes;
1287 OP_REQUIRES_OK_ASYNC(
1288 ctx,
1289 OptimizeTpuInputOutputTensors(graph.get(), enable_spmd_xla_partitioning,
1290 tpu_metadata.num_cores_per_replica,
1291 named_input_shapes, ctx),
1292 done);
1293
1294 VLOG(1) << DumpGraphToFile(
1295 "before_replace_resource_args_with_var_handle_ops", *graph,
1296 flib_def_.get());
1297 OP_REQUIRES_OK_ASYNC(ctx,
1298 ReplaceResourceArgsWithVarHandleOps(
1299 graph.get(), ctx, device_ordinal,
1300 enable_spmd_xla_partitioning, tpu_metadata),
1301 done);
1302
1303 VLOG(1) << DumpGraphToFile(
1304 "after_replace_resource_args_with_var_handle_ops", *graph,
1305 flib_def_.get());
1306
1307 // Graph rewrite passes.
1308 GraphOptimizationPassOptions optimization_options;
1309 // TODO(akshayka): Thread the SessionOptions into this kernel, or make
1310 // it possible to specify the relevant options via attributes.
1311 SessionOptions session_options;
1312 session_options.config.mutable_experimental()
1313 ->set_xla_fusion_autotuner_thresh(autotuner_thresh_);
1314
1315 session_options.env = ctx->env();
1316 optimization_options.session_handle = ctx->session_handle();
1317 optimization_options.session_options = &session_options;
1318 optimization_options.graph = &graph;
1319 optimization_options.flib_def = flib_def_.get();
1320 optimization_options.device_set = &device_set_;
1321 OP_REQUIRES_OK_ASYNC(
1322 ctx, PlacementHelper(device_set_, optimization_options, func_.name()),
1323 done);
1324
1325 if (!enable_spmd_xla_partitioning ||
1326 tpu_metadata.num_cores_per_replica == 1) {
1327 OP_REQUIRES_OK_ASYNC(
1328 ctx,
1329 MaybeRegisterFingerprint(graph.get(), named_input_shapes, input_hash),
1330 done);
1331 }
1332 // `subgraphs` maps from device names to functions.
1333 std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs;
1334 optimization_options.graph = nullptr;
1335 optimization_options.device_set = nullptr;
1336 optimization_options.partition_graphs = &subgraphs;
1337 VLOG(1) << DumpGraphToFile("before_partition_helper.pbtxt", *graph,
1338 flib_def_.get());
1339 OP_REQUIRES_OK_ASYNC(ctx,
1340 PartitionHelper(device_set_, optimization_options,
1341 graph.get(), &subgraphs),
1342 done);
1343 OP_REQUIRES_OK_ASYNC(
1344 ctx,
1345 InstantiateFunctionsFromSubgraphs(
1346 device_set_, device_ordinal, cache_hash,
1347 tpu_metadata.num_cores_per_replica, std::move(subgraphs)),
1348 done);
1349 }
1350 functions = &partition_cache_[cache_hash];
1351 lock.Release();
1352
1353 ExecuteFunctions(*functions, ctx, device_ordinal, ordinal_selector_req_id,
1354 std::move(done));
1355 }
1356
GetTpuCoreOrdinal(OpKernelContext * ctx,uint64 input_hash,int64_t * ordinal_selector_req_id,int32_t * core_ordinal)1357 Status TPUPartitionedCallOp::GetTpuCoreOrdinal(OpKernelContext* ctx,
1358 uint64 input_hash,
1359 int64_t* ordinal_selector_req_id,
1360 int32_t* core_ordinal) {
1361 profiler::TraceMe trace_me("TPUPartitionedCallOp-GetTpuCoreOrdinal");
1362 const Tensor* device_ordinal_t;
1363 TF_RETURN_IF_ERROR(ctx->input(kDeviceOrdinalAttr, &device_ordinal_t));
1364 int device_ordinal = device_ordinal_t->scalar<int>()();
1365 if (device_ordinal == tpu::kDeferredCoreSelectionReserved) {
1366 device_ordinal =
1367 ordinal_selector_->GetOrdinal(input_hash, ordinal_selector_req_id);
1368 }
1369 *core_ordinal = device_ordinal;
1370 return OkStatus();
1371 }
1372
InitializeVarOnTPU(OpKernelContext * ctx,const core::RefCountPtr<Var> & var,NodeDef * ndef,int device_ordinal,bool fast_mem)1373 Status TPUPartitionedCallOp::InitializeVarOnTPU(
1374 OpKernelContext* ctx, const core::RefCountPtr<Var>& var, NodeDef* ndef,
1375 int device_ordinal, bool fast_mem) {
1376 const string device = strings::StrCat(kTPUDeviceNamePrefix, device_ordinal);
1377 Status status;
1378 std::unique_ptr<Graph> init_graph(new Graph(OpRegistry::Global()));
1379 TF_ASSIGN_OR_RETURN(Node * init_handle, init_graph->AddNode(*ndef));
1380 init_handle->set_assigned_device_name(device);
1381
1382 NodeDef init_const_ndef;
1383 init_const_ndef.set_name("initial_value");
1384 #if defined(LIBTPU_ON_GCE) // TODO(b/217559071) - Remove once _TPUConst is OSS
1385 init_const_ndef.set_op("Const");
1386 #else
1387 init_const_ndef.set_op("_TPUConst");
1388 AddNodeAttr("memory_space", "HBM", &init_const_ndef);
1389 #endif
1390 init_const_ndef.set_device(device);
1391 AddNodeAttr("dtype", var->tensor()->dtype(), &init_const_ndef);
1392 AddNodeAttr("value", *var->tensor(), &init_const_ndef);
1393
1394 TF_ASSIGN_OR_RETURN(Node * init_const, init_graph->AddNode(init_const_ndef));
1395
1396 NodeDef assign_node_def;
1397 assign_node_def.set_name("Assign");
1398 assign_node_def.set_op("AssignVariableOp");
1399 assign_node_def.set_device(device);
1400 AddNodeAttr("dtype", var->tensor()->dtype(), &assign_node_def);
1401 TF_ASSIGN_OR_RETURN(Node * init_assign, init_graph->AddNode(assign_node_def));
1402
1403 init_graph->AddEdge(init_handle, 0, init_assign, 0);
1404 init_graph->AddEdge(init_const, 0, init_assign, 1);
1405 FHandle fhandle;
1406 const string fname =
1407 strings::StrCat(ndef->name(), "_init_ord_", device_ordinal);
1408
1409 TF_RETURN_IF_ERROR(
1410 InstantiatePartition(*init_graph, fname, device, &fhandle, nullptr));
1411
1412 FunctionLibraryRuntime::Options opts;
1413 opts.step_container = ctx->step_container();
1414 opts.cancellation_manager = ctx->cancellation_manager();
1415 opts.stats_collector = ctx->stats_collector();
1416
1417 // Blocking on threads in the same thread pool is disallowed because
1418 // concurrent warm-up requests can exhaust the default thread pool.
1419 // Create a new thread pool to initialize variables on TPU.
1420 std::function<void(std::function<void()>)> runner =
1421 [this](std::function<void()> fn) { pool_.Schedule(fn); };
1422 opts.runner = &runner;
1423
1424 opts.source_device = local_device_name_;
1425 PrivateIntraProcessRendezvous rendez(device_mgr_);
1426 opts.rendezvous = &rendez;
1427 opts.remote_execution = true;
1428
1429 std::vector<Tensor> dummy_args;
1430 std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
1431 Notification done;
1432 profiler::TraceMe trace_me("TPUPartitionedCallOp-InitializeVarOnTPU");
1433 library_runtime_->Run(opts, fhandle, dummy_args, dummy_rets,
1434 [dummy_rets, &done, ctx](const Status& status) {
1435 if (!status.ok()) {
1436 ctx->SetStatus(status);
1437 }
1438 delete dummy_rets;
1439 done.Notify();
1440 });
1441 done.WaitForNotification();
1442 // We don't actually want the variable initialization functions
1443 // in the function library definition and the function library
1444 // runtime, because flib_def_ is used for the graph rewrite passes.
1445 // The TPU distributed rewrite pass computes a fingerprint for
1446 // flib_def_, which will throw an length error if there are
1447 // many variables whose initialization functions are added
1448 // to the library definition.
1449 TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(fname));
1450 TF_RETURN_IF_ERROR(library_runtime_->ReleaseHandle(fhandle));
1451 return OkStatus();
1452 }
1453
InitializeShardedVarOnTPU(OpKernelContext * ctx,const core::RefCountPtr<Var> & var,std::vector<NodeDef> & ndefs,int split_dim,const std::vector<string> & tpu_devices)1454 Status TPUPartitionedCallOp::InitializeShardedVarOnTPU(
1455 OpKernelContext* ctx, const core::RefCountPtr<Var>& var,
1456 std::vector<NodeDef>& ndefs, int split_dim,
1457 const std::vector<string>& tpu_devices) {
1458 std::unique_ptr<Graph> init_graph(new Graph(OpRegistry::Global()));
1459 int num_cores = ndefs.size();
1460 string cpu_device = "/device:CPU:0";
1461
1462 Status status;
1463 std::vector<std::string> devices;
1464 std::vector<Node*> init_handles;
1465 for (int i = 0; i < num_cores; i++) {
1466 TF_ASSIGN_OR_RETURN(Node * init_handle, init_graph->AddNode(ndefs[i]));
1467 string device = tpu_devices[i];
1468 init_handle->set_assigned_device_name(device);
1469 devices.push_back(device);
1470 init_handles.push_back(init_handle);
1471 }
1472
1473 NodeDef init_const_ndef;
1474 init_const_ndef.set_name("initial_value");
1475 init_const_ndef.set_op("Const");
1476 init_const_ndef.set_device(cpu_device);
1477 AddNodeAttr("dtype", var->tensor()->dtype(), &init_const_ndef);
1478 AddNodeAttr("value", *var->tensor(), &init_const_ndef);
1479 TF_ASSIGN_OR_RETURN(Node * init_const, init_graph->AddNode(init_const_ndef));
1480 init_const->set_assigned_device_name(cpu_device);
1481
1482 Node* assign_value_node = init_const;
1483 // If the variable is sharded, we will insert "Split" node between the initial
1484 // value and AssignVariableOp, so the variables on each TPU device get
1485 // assigned to the splitted value.
1486 //
1487 // initial_value--Split--AssignVariableOp ("/device:TPU:0")
1488 // |
1489 // AssignVariableOp ("/device:TPU:1")
1490 if (split_dim >= 0) {
1491 // Add a split dimension node.
1492 NodeDef split_dim_def;
1493 split_dim_def.set_name("initial_value_split_dim");
1494 split_dim_def.set_op("Const");
1495 split_dim_def.set_device(cpu_device);
1496 AddNodeAttr("dtype", DT_INT32, &split_dim_def);
1497 TensorProto tensor_proto;
1498 tensor_proto.set_dtype(DT_INT32);
1499 tensor_proto.add_int_val(split_dim);
1500 TensorShape shape({});
1501 shape.AsProto(tensor_proto.mutable_tensor_shape());
1502 AddNodeAttr("value", tensor_proto, &split_dim_def);
1503 TF_ASSIGN_OR_RETURN(Node * split_dim_node,
1504 init_graph->AddNode(split_dim_def));
1505 split_dim_node->set_assigned_device_name(cpu_device);
1506
1507 // Add a split node.
1508 NodeDef split_def;
1509 int split_num = ndefs.size();
1510 split_def.set_name("initial_value_split");
1511 split_def.set_op("Split");
1512 split_def.set_device(cpu_device);
1513 AddNodeAttr("num_split", split_num, &split_def);
1514 AddNodeAttr("T", var->tensor()->dtype(), &split_def);
1515 split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
1516 split_def.add_input(absl::StrCat(init_const->name(), ":0"));
1517 TF_ASSIGN_OR_RETURN(Node * split_node, init_graph->AddNode(split_def));
1518 split_node->set_assigned_device_name(cpu_device);
1519
1520 init_graph->AddEdge(split_dim_node, 0, split_node, 0);
1521 init_graph->AddEdge(init_const, 0, split_node, 1);
1522
1523 assign_value_node = split_node;
1524 }
1525
1526 for (int i = 0; i < num_cores; i++) {
1527 NodeDef assign_node_def;
1528 assign_node_def.set_name(absl::StrCat("Assign_", i));
1529 assign_node_def.set_op("AssignVariableOp");
1530 assign_node_def.set_device(devices[i]);
1531 AddNodeAttr("dtype", var->tensor()->dtype(), &assign_node_def);
1532 TF_ASSIGN_OR_RETURN(Node * init_assign,
1533 init_graph->AddNode(assign_node_def));
1534 init_assign->set_assigned_device_name(devices[i]);
1535
1536 init_graph->AddEdge(init_handles[i], 0, init_assign, 0);
1537 if (split_dim >= 0) {
1538 init_graph->AddEdge(assign_value_node, i, init_assign, 1);
1539 } else {
1540 init_graph->AddEdge(assign_value_node, 0, init_assign, 1);
1541 }
1542 }
1543
1544 GraphOptimizationPassOptions optimization_options;
1545 SessionOptions session_options;
1546 session_options.env = ctx->env();
1547 optimization_options.session_handle = ctx->session_handle();
1548 optimization_options.session_options = &session_options;
1549 optimization_options.flib_def = flib_def_.get();
1550 optimization_options.graph = nullptr;
1551 optimization_options.device_set = nullptr;
1552 std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs;
1553 optimization_options.partition_graphs = &subgraphs;
1554 TF_RETURN_IF_ERROR(PartitionHelper(device_set_, optimization_options,
1555 init_graph.get(), &subgraphs));
1556
1557 std::vector<DeviceAndFHandle> functions;
1558 std::vector<std::string> function_names;
1559 for (auto& pair : subgraphs) {
1560 string target = pair.first;
1561 Device* device;
1562 TF_RETURN_IF_ERROR(
1563 library_runtime_->device_mgr()->LookupDevice(target, &device));
1564 Graph* subgraph = pair.second.get();
1565 string function_name = flib_def_->UniqueFunctionName(
1566 strings::StrCat(func_.name(), "_hash_", pair.first));
1567 function_names.push_back(function_name);
1568 FHandle handle;
1569 TF_RETURN_IF_ERROR(InstantiatePartition(*subgraph, function_name, target,
1570 &handle, nullptr));
1571 functions.push_back(DeviceAndFHandle{.device = target, .handle = handle});
1572 }
1573
1574 FunctionLibraryRuntime::Options opts;
1575
1576 // Blocking on threads in the same thread pool is disallowed because
1577 // concurrent warm-up requests can exhaust the default thread pool.
1578 // Create a new thread pool to initialize variables on TPU.
1579 std::function<void(std::function<void()>)> runner =
1580 [this](std::function<void()> fn) { pool_.Schedule(fn); };
1581 opts.runner = &runner;
1582
1583 opts.step_container = ctx->step_container();
1584 opts.cancellation_manager = ctx->cancellation_manager();
1585 opts.stats_collector = ctx->stats_collector();
1586 opts.source_device = local_device_name_;
1587 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
1588
1589 OpInputList arguments;
1590 TF_RETURN_IF_ERROR(ctx->input_list("args", &arguments));
1591
1592 PrivateIntraProcessRendezvous rendez(device_mgr_);
1593 opts.rendezvous = &rendez;
1594
1595 BlockingCounter bcount(functions.size());
1596 for (const DeviceAndFHandle& entry : functions) {
1597 const string& target_device = entry.device;
1598 FHandle handle = entry.handle;
1599
1600 TF_RETURN_IF_ERROR(
1601 ShouldUseRemoteExecutionForFn(target_device, &(opts.remote_execution)));
1602 std::vector<Tensor> dummy_args;
1603 std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
1604
1605 profiler::TraceMe trace_me(
1606 "TPUPartitionedCallOp-InitializeShardedVarOnTPU");
1607 library_runtime_->Run(opts, handle, dummy_args, dummy_rets,
1608 [dummy_rets, &bcount, ctx](const Status& status) {
1609 if (!status.ok()) {
1610 ctx->SetStatus(status);
1611 }
1612 delete dummy_rets;
1613 bcount.DecrementCount();
1614 });
1615 }
1616 bcount.Wait();
1617
1618 for (int i = 0; i < functions.size(); i++) {
1619 TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(function_names[i]));
1620 TF_RETURN_IF_ERROR(library_runtime_->ReleaseHandle(functions[i].handle));
1621 }
1622 return OkStatus();
1623 }
1624
IsInputToTPUReplicate(Node * node)1625 bool TPUPartitionedCallOp::IsInputToTPUReplicate(Node* node) {
1626 for (Node* successor : node->out_nodes()) {
1627 if (successor->attrs().Find(kTpuReplicateAttr) != nullptr) {
1628 return true;
1629 }
1630 }
1631 return false;
1632 }
1633
ReplaceResourceArgsWithVarHandleOps(Graph * graph,OpKernelContext * ctx,int device_ordinal,bool enable_spmd_xla_partitioning,const TPUMetadata & tpu_metadata)1634 Status TPUPartitionedCallOp::ReplaceResourceArgsWithVarHandleOps(
1635 Graph* graph, OpKernelContext* ctx, int device_ordinal,
1636 bool enable_spmd_xla_partitioning, const TPUMetadata& tpu_metadata) {
1637 // Currently variable deduplication is not supported for XLA SPMD
1638 // partitioning. It is possible that it could be supported in the future.
1639 bool enable_variable_deduplication =
1640 runtime_params_.enable_variable_deduplication;
1641 if (enable_spmd_xla_partitioning && tpu_metadata.num_cores_per_replica > 1) {
1642 // If enable_spmd_xla_partitioning is true, the user set the
1643 // enable_auto_xla_input_sharding flag. Warn them that only one of the flags
1644 // can be set safely when num_cores_per_replica > 1. If
1645 // num_cores_per_replica==1, enable_spmd_xla_partitioning is effectively a
1646 // no-op so we can skip this check.
1647 LOG(WARNING) << "Disabling variable deduplication because it is not "
1648 "compatible with enable_auto_xla_input_sharding.";
1649 enable_variable_deduplication = false;
1650 }
1651 std::vector<Node*> tpu_resource_args;
1652 std::vector<int> arg_indices;
1653 absl::flat_hash_map<const Node*, xla::OpSharding> variable_to_xla_sharding;
1654 for (Node* node : graph->op_nodes()) {
1655 if (node->IsArg()) {
1656 const AttrValue* attr_value;
1657 TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value));
1658 DataType dtype = attr_value->type();
1659 if (dtype == DT_RESOURCE && IsInputToTPUReplicate(node)) {
1660 // If this VarHandleOp is used by a TPU computation,
1661 // we need to create a TPU version of the variable,
1662 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
1663 int index = attr_value->i();
1664 tpu_resource_args.push_back(node);
1665 arg_indices.push_back(index);
1666 replaced_input_indices_[index] = true;
1667 }
1668 }
1669 }
1670
1671 VLOG(3) << "tpu_resource_args.size(): " << tpu_resource_args.size();
1672 // Create a mapping from ResourceHandle to variable node. When a
1673 // ResourceHandle backs several variable nodes, the variable nodes refer to
1674 // the same underlying resource. In that case, only one variable node needs
1675 // to be mirrored to the TPU for that resource.
1676 absl::flat_hash_map<uint64, Node*> tpu_variables;
1677 for (int i = 0; i < tpu_resource_args.size(); i++) {
1678 Node* node = tpu_resource_args[i];
1679 ResourceHandle handle = HandleFromInput(ctx, arg_indices[i]);
1680
1681 if (tpu_metadata.num_cores_per_replica > 1 &&
1682 enable_spmd_xla_partitioning) {
1683 TF_RETURN_IF_ERROR(ReplaceAndPartitionXLAShardingVariable(
1684 graph, ctx, device_ordinal, handle, node, tpu_metadata));
1685 continue;
1686 }
1687 TPUVariableInfo var_info(/*device_ordinal_id=*/0, /*use_fast_mem=*/false);
1688 TF_RETURN_IF_ERROR(ParseTPUVariableInfor(
1689 node, tpu_metadata.num_cores_per_replica, &var_info));
1690 // Only respect graph's placement when model parallelism enabled.
1691 if (tpu_metadata.num_cores_per_replica > 1)
1692 device_ordinal = var_info.device_ordinal;
1693
1694 const uint64 handle_fp =
1695 Fingerprint64(strings::StrCat(handle.container(), handle.name()));
1696 if (enable_variable_deduplication && tpu_variables.contains(handle_fp) &&
1697 tpu_metadata.num_cores_per_replica == 1) {
1698 Node* tpu_variable = tpu_variables.at(handle_fp);
1699 std::vector<Node*> dst_nodes;
1700 std::vector<int> src_indices;
1701 std::vector<int> dst_indices;
1702 for (const Edge* edge : node->out_edges()) {
1703 dst_nodes.push_back(edge->dst());
1704 src_indices.push_back(edge->src_output());
1705 dst_indices.push_back(edge->dst_input());
1706 }
1707 graph->RemoveNode(node);
1708 for (int i = 0; i < dst_nodes.size(); i++) {
1709 graph->AddEdge(tpu_variable, src_indices[i], dst_nodes[i],
1710 dst_indices[i]);
1711 }
1712 } else {
1713 uint64 fp =
1714 Fingerprint64(strings::StrCat(handle.container(), handle.name(), i));
1715 NodeDef ndef;
1716 ndef.set_name(strings::StrCat(handle.name(), fp));
1717 ndef.set_op(kVarHandleOp);
1718 if (tpu_metadata.num_cores_per_replica > 1) {
1719 ndef.set_device(strings::StrCat(kTPUDeviceNamePrefix, device_ordinal));
1720 } else {
1721 // Assign this new VarHandleOp to TPU:0 so the partitioner only
1722 // partiitons the graph into two subgraphs, one on CPU and one on TPU.
1723 // The actual device ordinal on which this VarHandleOp runs is assigned
1724 // after partitioning (in SetDeviceOrdinal).
1725 ndef.set_device(
1726 strings::StrCat(kTPUDeviceNamePrefix, kTPUDefaultDeviceOrdinal));
1727 }
1728
1729 // Replace each _Arg node of type DT_RESOURCE that goes into a TPU node
1730 // by a VarHandleOp on TPU with shared_name "v_tpu_x" where "v" is the
1731 // shared_name of the variable on CPU and "x" is the rewritten device
1732 // ordinal.
1733 const string sname =
1734 strings::StrCat(handle.name(), "_tpu_", device_ordinal);
1735 AddNodeAttr("shared_name", sname, &ndef);
1736 const string cname = ctx->resource_manager()->default_container();
1737 AddNodeAttr("container", cname, &ndef);
1738 core::RefCountPtr<Var> var;
1739 TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &var));
1740 AddNodeAttr("dtype", var->tensor()->dtype(), &ndef);
1741 TensorShapeProto proto;
1742 var->tensor()->shape().AsProto(&proto);
1743 AddNodeAttr("shape", proto, &ndef);
1744 TF_ASSIGN_OR_RETURN(Node * new_node, graph->AddNode(ndef));
1745 std::vector<const Edge*> in_edges(node->in_edges().begin(),
1746 node->in_edges().end());
1747 for (const Edge* edge : in_edges) {
1748 graph->AddEdge(edge->src(), edge->src_output(), new_node,
1749 edge->dst_input());
1750 }
1751 std::vector<Node*> dst_nodes;
1752 std::vector<int> src_indices;
1753 std::vector<int> dst_indices;
1754 for (const Edge* edge : node->out_edges()) {
1755 dst_nodes.push_back(edge->dst());
1756 src_indices.push_back(edge->src_output());
1757 dst_indices.push_back(edge->dst_input());
1758 }
1759 graph->RemoveNode(node);
1760 for (int i = 0; i < dst_nodes.size(); i++) {
1761 graph->AddEdge(new_node, src_indices[i], dst_nodes[i], dst_indices[i]);
1762 }
1763 // Don't initialize variables on TPU if it is done for the ordinal
1764 // already.
1765 if (seen_ordinals_.contains(device_ordinal)) continue;
1766
1767 Device* d;
1768 TF_RETURN_IF_ERROR(library_runtime_->device_mgr()->LookupDevice(
1769 strings::StrCat(kTPUDeviceNamePrefix, device_ordinal), &d));
1770 Var* tpu_var;
1771 Status status = d->resource_manager()->Lookup(cname, sname, &tpu_var);
1772 if (!status.ok()) {
1773 TF_RETURN_IF_ERROR(InitializeVarOnTPU(ctx, var, &ndef, device_ordinal,
1774 var_info.fast_mem));
1775 VLOG(3) << "Initialized variable on TPU: " << sname
1776 << " device_ordinal: " << device_ordinal;
1777 }
1778 tpu_variables[handle_fp] = new_node;
1779 }
1780 }
1781
1782 // adjust the index attr of other non-resource arg nodes
1783 int new_index = 0;
1784 for (Node* node : graph->op_nodes()) {
1785 if (node->IsArg()) {
1786 node->ClearAttr("index");
1787 node->AddAttr("index", new_index);
1788 new_index++;
1789 }
1790 }
1791
1792 seen_ordinals_.insert(device_ordinal);
1793
1794 return OkStatus();
1795 }
1796
ReplaceAndPartitionXLAShardingVariable(Graph * graph,OpKernelContext * ctx,int device_ordinal,ResourceHandle & handle,Node * variable,const TPUMetadata & tpu_metadata)1797 Status TPUPartitionedCallOp::ReplaceAndPartitionXLAShardingVariable(
1798 Graph* graph, OpKernelContext* ctx, int device_ordinal,
1799 ResourceHandle& handle, Node* variable, const TPUMetadata& tpu_metadata) {
1800 TF_ASSIGN_OR_RETURN(
1801 auto sharding,
1802 GetShardingFromNodeDef(variable->def(), /*add_metadata=*/false));
1803 xla::OpSharding xla_sharding;
1804 bool is_var_sharded = false;
1805 if (sharding.has_value() &&
1806 sharding.value().type() == xla::OpSharding::OTHER) {
1807 xla_sharding = sharding.value();
1808 for (int dim = 0; dim < GetDimsFromXLAShardingTiled(xla_sharding); dim++) {
1809 is_var_sharded |= xla_sharding.tile_assignment_dimensions(dim) > 1;
1810 }
1811 } else {
1812 xla_sharding.set_type(xla::OpSharding::REPLICATED);
1813 is_var_sharded = false;
1814 }
1815 VLOG(3) << "Replace and partition variable " << variable->name()
1816 << " with xla_sharding: " << xla_sharding.DebugString();
1817
1818 core::RefCountPtr<Var> var;
1819 TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &var));
1820
1821 int split_dim = -1;
1822 int split_size = 0;
1823
1824 if (is_var_sharded) {
1825 for (int dim = 0; dim < GetDimsFromXLAShardingTiled(xla_sharding); dim++) {
1826 if (xla_sharding.tile_assignment_dimensions(dim) > 1) {
1827 if (split_dim != -1) {
1828 return errors::InvalidArgument(
1829 "Currently we only support inference with one split dimension, "
1830 "however got sharding: ",
1831 xla_sharding.DebugString());
1832 }
1833 split_dim = dim;
1834 split_size = xla_sharding.tile_assignment_dimensions(dim);
1835 }
1836 }
1837 if (split_dim == -1 || split_dim >= var->tensor()->dims()) {
1838 return errors::InvalidArgument(
1839 "sharding split_dim ", split_dim, " for variable: ", variable->name(),
1840 " is -1 or large than the number of dimensions ",
1841 var->tensor()->dims());
1842 }
1843 }
1844
1845 const auto& topology = tpu_metadata.topology;
1846 int num_cores_per_replica = tpu_metadata.num_cores_per_replica;
1847 xla::Array4D<int> mapping(topology.mesh_shape(0), topology.mesh_shape(1),
1848 topology.mesh_shape(2), topology.mesh_shape(3), -1);
1849 int pos = 0;
1850 // The topology should only have one task.
1851 for (int device = 0; device < topology.num_tpu_devices_per_task(); device++) {
1852 int32_t x = topology.device_coordinates(pos++);
1853 int32_t y = topology.device_coordinates(pos++);
1854 int32_t z = topology.device_coordinates(pos++);
1855 int32_t core = topology.device_coordinates(pos++);
1856 mapping(x, y, z, core) = device;
1857 }
1858
1859 const string cname = ctx->resource_manager()->default_container();
1860 std::vector<Node*> per_core_vars;
1861 std::vector<string> tpu_devices;
1862 for (int i = 0; i < num_cores_per_replica; i++) {
1863 int offset = i * 4;
1864 int device_index = mapping(tpu_metadata.device_assignment[offset],
1865 tpu_metadata.device_assignment[offset + 1],
1866 tpu_metadata.device_assignment[offset + 2],
1867 tpu_metadata.device_assignment[offset + 3]);
1868
1869 NodeDef ndef;
1870 uint64 fp = Fingerprint64(
1871 strings::StrCat(handle.container(), handle.name(), "_", device_index));
1872 ndef.set_name(strings::StrCat(handle.name(), fp));
1873 ndef.set_op(kVarHandleOp);
1874 string tpu_device = strings::StrCat(kTPUDeviceNamePrefix, device_index);
1875 ndef.set_device(tpu_device);
1876 tpu_devices.push_back(tpu_device);
1877
1878 // Replace each _Arg node of type DT_RESOURCE that goes into a TPU node
1879 // by a VarHandleOp on TPU with shared_name "v_tpu_x" where "v" is the
1880 // shared_name of the variable on CPU and "x" is the rewritten device
1881 // ordinal.
1882 const string sname = strings::StrCat(handle.name(), "_tpu_", device_index);
1883 AddNodeAttr("shared_name", sname, &ndef);
1884 AddNodeAttr("container", cname, &ndef);
1885 AddNodeAttr("dtype", var->tensor()->dtype(), &ndef);
1886
1887 TensorShapeProto proto;
1888 var->tensor()->shape().AsProto(&proto);
1889
1890 if (is_var_sharded) {
1891 int dim_size = proto.dim(split_dim).size();
1892 if (dim_size % split_size != 0) {
1893 return errors::InvalidArgument("dimension size ", dim_size,
1894 " cannot be divisible by split size ",
1895 split_size);
1896 }
1897 proto.mutable_dim(split_dim)->set_size(dim_size / split_size);
1898 }
1899 AddNodeAttr("shape", proto, &ndef);
1900
1901 TF_ASSIGN_OR_RETURN(Node * new_node, graph->AddNode(ndef));
1902 per_core_vars.push_back(new_node);
1903 }
1904
1905 // Insert TPUPartitionedInput op.
1906 NodeDefBuilder builder(absl::StrCat(handle.name(), "/tpu_partitioned_input"),
1907 "TPUPartitionedInput");
1908 builder.Attr("N", num_cores_per_replica);
1909 builder.Attr("T", DT_RESOURCE);
1910 builder.Attr("partition_dim", split_dim);
1911 builder.Attr("_XlaSharding", xla_sharding.SerializeAsString());
1912 std::vector<NodeDefBuilder::NodeOut> inputs;
1913 inputs.reserve(num_cores_per_replica);
1914 for (int i = 0; i < num_cores_per_replica; i++) {
1915 inputs.push_back({per_core_vars[i]->name(), 0, DT_RESOURCE});
1916 }
1917 builder.Input(inputs);
1918 NodeDef node_def;
1919 TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
1920 TF_ASSIGN_OR_RETURN(Node * tpu_partitioned_input_node,
1921 graph->AddNode(node_def));
1922
1923 for (int i = 0; i < num_cores_per_replica; i++) {
1924 graph->AddEdge(per_core_vars[i], 0, tpu_partitioned_input_node, i);
1925 }
1926
1927 // Insert TPUReplicatedInput op.
1928 NodeDefBuilder replicated_builder(
1929 absl::StrCat(handle.name(), "/tpu_replicated_input"),
1930 "TPUReplicatedInput");
1931 replicated_builder.Attr("N", 1);
1932 replicated_builder.Attr("T", DT_RESOURCE);
1933 replicated_builder.Attr("is_mirrored_variable", true);
1934 std::vector<NodeDefBuilder::NodeOut> replicated_inputs;
1935 replicated_inputs.push_back(
1936 {tpu_partitioned_input_node->name(), 0, DT_RESOURCE});
1937 replicated_builder.Input(replicated_inputs);
1938 NodeDef replicated_node_def;
1939 TF_RETURN_IF_ERROR(replicated_builder.Finalize(&replicated_node_def));
1940 Status replicated_s;
1941 Node* tpu_replicated_input_node =
1942 graph->AddNode(replicated_node_def, &replicated_s);
1943 if (!replicated_s.ok()) {
1944 return replicated_s;
1945 }
1946 graph->AddEdge(tpu_partitioned_input_node, 0, tpu_replicated_input_node, 0);
1947
1948 // Connect the TPUReplicatedInput node to the previous output nodes of the
1949 // variable, and remove the variable node.
1950 std::vector<Node*> dst_nodes;
1951 std::vector<int> src_indices;
1952 std::vector<int> dst_indices;
1953 for (const Edge* edge : variable->out_edges()) {
1954 dst_nodes.push_back(edge->dst());
1955 src_indices.push_back(edge->src_output());
1956 dst_indices.push_back(edge->dst_input());
1957 }
1958 for (int i = 0; i < dst_nodes.size(); i++) {
1959 graph->AddEdge(tpu_replicated_input_node, src_indices[i], dst_nodes[i],
1960 dst_indices[i]);
1961 }
1962
1963 graph->RemoveNode(variable);
1964
1965 std::vector<NodeDef> ndefs;
1966 Status status;
1967 for (int i = 0; i < num_cores_per_replica; i++) {
1968 Device* d;
1969 TF_RETURN_IF_ERROR(
1970 library_runtime_->device_mgr()->LookupDevice(tpu_devices[i], &d));
1971 string sname;
1972 const NodeDef& ndef = per_core_vars[i]->def();
1973 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "shared_name", &sname));
1974 ndefs.push_back(ndef);
1975 Var* tpu_var;
1976 status = d->resource_manager()->Lookup(cname, sname, &tpu_var);
1977 }
1978
1979 if (!status.ok()) {
1980 TF_RETURN_IF_ERROR(
1981 InitializeShardedVarOnTPU(ctx, var, ndefs, split_dim, tpu_devices));
1982 if (VLOG_IS_ON(4)) {
1983 for (int i = 0; i < num_cores_per_replica; i++) {
1984 string sname;
1985 TF_RETURN_IF_ERROR(GetNodeAttr(ndefs[i], "shared_name", &sname));
1986 LOG(INFO) << "Initialized sharded variable on TPU: " << sname
1987 << " device: " << tpu_devices[i];
1988 }
1989 }
1990 }
1991
1992 return OkStatus();
1993 }
1994
InferShapesWithResourceVar(Graph * graph,OpKernelContext * ctx,std::map<int,InferredShape> & arg_shapes,GraphShapeInfo * tpu_inferred_info)1995 Status TPUPartitionedCallOp::InferShapesWithResourceVar(
1996 Graph* graph, OpKernelContext* ctx,
1997 std::map<int, InferredShape>& arg_shapes,
1998 GraphShapeInfo* tpu_inferred_info) {
1999 auto shape_inference_graph_interim =
2000 absl::make_unique<Graph>(graph->flib_def());
2001 CopyGraph(*graph, shape_inference_graph_interim.get());
2002
2003 for (Node* node : shape_inference_graph_interim->nodes()) {
2004 if (node->type_string() != "_Arg" ||
2005 node->attrs().Find("T")->type() != DT_RESOURCE)
2006 continue;
2007
2008 std::vector<std::function<void()>> to_remove;
2009
2010 for (const Edge* out_edge : node->out_edges()) {
2011 Node* read_node = out_edge->dst();
2012 if (read_node->type_string() != "ReadVariableOp") continue;
2013
2014 for (const Edge* variable_edge : read_node->out_edges()) {
2015 // We are delaying these modifications as we cannot do in-place
2016 // modification of EdgeSets.
2017 to_remove.push_back(
2018 [variable_edge, graph = shape_inference_graph_interim.get(), node] {
2019 Node* dst = variable_edge->dst();
2020 graph->RemoveEdge(variable_edge);
2021 graph->AddEdge(node, variable_edge->src_output(), dst,
2022 variable_edge->dst_input());
2023 });
2024 }
2025 to_remove.push_back(
2026 [graph = shape_inference_graph_interim.get(), out_edge, read_node] {
2027 graph->RemoveEdge(out_edge);
2028 graph->RemoveNode(read_node);
2029 });
2030 }
2031
2032 for (auto& func : to_remove) {
2033 func();
2034 }
2035
2036 int resource_arg_index = node->attrs().Find("index")->i();
2037
2038 // Get resource variable tensor
2039 core::RefCountPtr<Var> variable;
2040 const ResourceHandle& handle = HandleFromInput(ctx, resource_arg_index);
2041 TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &variable));
2042
2043 const Tensor* variable_tensor = variable->tensor();
2044 std::vector<int> variable_tensor_vec;
2045
2046 variable_tensor_vec.reserve(variable_tensor->dims());
2047 for (int d = 0; d < variable_tensor->dims(); ++d) {
2048 variable_tensor_vec.push_back(variable_tensor->dim_size(d));
2049 }
2050
2051 PartialTensorShape partial_tensor_shape;
2052 auto partial_shape = PartialTensorShape::MakePartialShape(
2053 variable_tensor_vec.data(), variable_tensor_vec.size(),
2054 &partial_tensor_shape);
2055 InferredShape inferred_shape = {partial_tensor_shape};
2056 arg_shapes.emplace(resource_arg_index, inferred_shape);
2057 }
2058
2059 TF_RETURN_IF_ERROR(tensorflow::InferShapes(
2060 shape_inference_graph_interim.get(), arg_shapes,
2061 &shape_inference_graph_interim->flib_def(), tpu_inferred_info));
2062 return OkStatus();
2063 }
2064
ShardInputsWithXlaSharding(Graph * graph,const std::string & cluster_name,int num_cores_per_replica,OpKernelContext * ctx)2065 Status TPUPartitionedCallOp::ShardInputsWithXlaSharding(
2066 Graph* graph, const std::string& cluster_name, int num_cores_per_replica,
2067 OpKernelContext* ctx) {
2068 for (Node* replicated_input_node : graph->nodes()) {
2069 if (replicated_input_node->type_string() != "TPUReplicatedInput") continue;
2070
2071 Node* arg_node;
2072 auto input_node_status = replicated_input_node->input_node(0, &arg_node);
2073 if (!input_node_status.ok()) {
2074 VLOG(2) << "Skip because cannot retrieve input node 0 of "
2075 << replicated_input_node->name() << " because "
2076 << input_node_status.ToString();
2077 continue;
2078 }
2079
2080 // Check if this TPUReplicatedInput can qualify because it has _Arg
2081 // as input and doesn't have XlaSharding already as an output, then
2082 // try to shard inputs automatically.
2083 //
2084 // In short, we want to see the following graph:
2085 // _Arg -> TPUReplicatedInput -> (not XlaSharding op)
2086 // and transform it to:
2087 // _Arg -> TPUReplicatedInput -> XlaSharding -> (not XlaSharding op)
2088 if (arg_node->IsArg() &&
2089 replicated_input_node->out_nodes().begin()->type_string() !=
2090 "XlaSharding") {
2091 int arg_id;
2092 if (!absl::SimpleAtoi(absl::StripPrefix(arg_node->name(), "arg_"),
2093 &arg_id)) {
2094 VLOG(3) << "Skip auto-sharding because we are unable to extract "
2095 "argument number from "
2096 << arg_node->name();
2097 continue;
2098 }
2099
2100 auto shape = ctx->input(arg_id).shape();
2101
2102 VLOG(3) << "Identified arg node " << arg_node->DebugString()
2103 << " for TPUReplicatedInput "
2104 << replicated_input_node->DebugString();
2105 VLOG(3) << "Shape within TPUReplicatedInput is: " << shape.DebugString();
2106
2107 int rank = shape.dims();
2108 int shard_dim =
2109 (runtime_params_.auto_xla_input_sharding_dim + rank) % rank;
2110
2111 if (shape.dim_size(shard_dim) % num_cores_per_replica != 0) {
2112 VLOG(3) << "Skip auto-sharding " << replicated_input_node->name()
2113 << " because the specified sharding dimension " << shard_dim
2114 << " cannot be evenly split by " << num_cores_per_replica;
2115 continue;
2116 }
2117
2118 auto sharding = absl::make_optional<xla::OpSharding>();
2119 sharding->set_type(xla::OpSharding::OTHER);
2120
2121 // Sets up tile_assignment_dimensions.
2122 std::vector<int64_t> dims(rank, 1LL);
2123 dims[shard_dim] = num_cores_per_replica;
2124 for (auto dim : dims) {
2125 sharding->add_tile_assignment_dimensions(dim);
2126 }
2127
2128 // Sets up tile_assignment_devices.
2129 for (int d = 0; d < num_cores_per_replica; ++d) {
2130 sharding->add_tile_assignment_devices(d);
2131 }
2132
2133 std::vector<const Edge*> edges_to_remove;
2134 for (const Edge* edge : replicated_input_node->out_edges()) {
2135 if (edge->IsControlEdge()) continue;
2136 edges_to_remove.push_back(edge);
2137 }
2138
2139 // Create XlaSharding Op.
2140 Node* sharding_op = nullptr;
2141 TF_RETURN_IF_ERROR(
2142 NodeBuilder(absl::StrCat(replicated_input_node->name(), "/sharding"),
2143 "XlaSharding")
2144 .Input(replicated_input_node, 0)
2145 .Attr("T", replicated_input_node->output_type(0))
2146 .Attr(kXLAShardingAttrName, sharding->SerializeAsString())
2147 .Attr(kXLAShardingAttrAltName, sharding->SerializeAsString())
2148 .Attr("_tpu_replicate", cluster_name)
2149 .Finalize(graph, &sharding_op));
2150 for (const Edge* edge : edges_to_remove) {
2151 VLOG(3) << "XlaSharding op creation output edge "
2152 << edge->DebugString();
2153 graph->RemoveEdge(edge);
2154 graph->AddEdge(sharding_op, 0, edge->dst(), edge->dst_input());
2155 }
2156
2157 VLOG(3) << "Auto shard " << replicated_input_node->name() << " by dim "
2158 << shard_dim << " into " << num_cores_per_replica << " slices";
2159
2160 VLOG(3) << "Created XlaSharding Op " << sharding_op->DebugString();
2161 }
2162 }
2163
2164 return OkStatus();
2165 }
2166
2167 // OptimizeTpuInputOutputTensors does the following things;
2168 // (1) Detect input arguments, and add XlaSharding op to the arguments if the
2169 // enable_auto_xla_input_sharding is turned on
2170 // (2) Pack multiple input tensors into one tensor by a concat to avoid PCIe
2171 // transfer overheads for small tensors.
2172 // (3) Reshape input tensors to R1 to leverage the fast path in TPU input
2173 // preparation done by runtime.
2174 // (4) Pack multiple output tensors into one tensor by a concat.
2175 //
2176 // (1) is controlled by --enable_auto_xla_input_sharding and
2177 // --auto_xla_input_sharding_dim
2178 // (2) and (3) are controlled by flags --minimum_input_tensors_packing
2179 // and --input_shape_opt, respectively, while (4) is controlled by
2180 // --minimum_output_tensors_packing.
OptimizeTpuInputOutputTensors(Graph * graph,bool enable_spmd_xla_partitioning,int num_cores_per_replica,std::map<std::string,std::vector<int>> & named_input_shapes,OpKernelContext * ctx)2181 Status TPUPartitionedCallOp::OptimizeTpuInputOutputTensors(
2182 Graph* graph, bool enable_spmd_xla_partitioning, int num_cores_per_replica,
2183 std::map<std::string, std::vector<int>>& named_input_shapes,
2184 OpKernelContext* ctx) {
2185 std::string cluster_name;
2186 TF_RETURN_IF_ERROR(GetClusterName(graph, &cluster_name));
2187
2188 if (runtime_params_.enable_auto_xla_input_sharding) {
2189 VLOG(2) << DumpGraphToFile("before_enable_auto_xla_input_sharding", *graph,
2190 flib_def_.get());
2191
2192 TF_RETURN_IF_ERROR(ShardInputsWithXlaSharding(graph, cluster_name,
2193 num_cores_per_replica, ctx));
2194 }
2195
2196 GraphShapeInfo tpu_inferred_info;
2197 std::map<int, InferredShape> arg_shapes;
2198 EdgeShapes tpu_input_shapes;
2199 absl::flat_hash_map<const Edge*, DataType> tpu_input_dtypes;
2200
2201 // Contains attrs "T", "sharding", "_tpu_replicate" for each XlaSharding op.
2202 XlaShardingInfoMap xla_sharding_ops;
2203
2204 // Contains attrs "T", and a pointer to tpu_replicated_metadata for ctrl dep
2205 TpuReplicatedInputInfoMap tpu_replicated_input_ops;
2206
2207 bool xla_spmd_input_sharded = false;
2208
2209 if (enable_spmd_xla_partitioning) {
2210 xla_spmd_input_sharded = FindTpuReplicatedInputAndXlaSharding(
2211 graph, xla_sharding_ops, tpu_replicated_input_ops);
2212 }
2213
2214 VLOG(1) << "xla_spmd_input_sharded: " << xla_spmd_input_sharded;
2215 VLOG(2) << DumpGraphToFile("before_remove_descendant_nodes", *graph,
2216 flib_def_.get());
2217
2218 if (!xla_spmd_input_sharded ||
2219 runtime_params_.minimum_input_tensors_packing > 1 ||
2220 runtime_params_.enable_auto_xla_input_sharding) {
2221 // Currently we remove `TPUReplicatedInput` nodes when the input tensors are
2222 // not sharded, input tensors packing optimization is enabled or when
2223 // auto xla input sharding is there, or else downstream rewrites will be
2224 // confused.
2225 RemoveDescendantNodeOfArg(graph, "TPUReplicatedInput",
2226 /*must_be_child_of=*/{});
2227 }
2228
2229 if (xla_spmd_input_sharded) {
2230 // We are setting must_be_child_of to {"Arg"} because we do not want
2231 // to remove other XlaSharding ops that might be in the graph. We only
2232 // want the XlaSharding ops that are directly attached to the input
2233 // arguments to be removed.
2234 RemoveDescendantNodeOfArg(graph, "XlaSharding",
2235 /*must_be_child_of=*/{"_Arg"});
2236 }
2237
2238 VLOG(2) << DumpGraphToFile("before_get_input_output_info", *graph,
2239 flib_def_.get());
2240
2241 TF_RETURN_IF_ERROR(GetInputOutputInfo(graph, tpu_inferred_info, arg_shapes,
2242 tpu_input_shapes, tpu_input_dtypes,
2243 ctx));
2244
2245 VLOG(2) << DumpGraphToFile("before_optimize_tpu_input_output_tensors", *graph,
2246 flib_def_.get());
2247
2248 if (runtime_params_.minimum_output_tensors_packing > 1) {
2249 // Copy graph to shape_inference_graph
2250 EdgeShapes tpu_output_shapes;
2251 TF_RETURN_IF_ERROR(
2252 InferShapesWithResourceVar(graph, ctx, arg_shapes, &tpu_inferred_info));
2253
2254 // Find TPU -> CPU output edges.
2255 GroupedEdges shape_to_output =
2256 tpu_functional_internal::GroupTensorsForOutputPacking(
2257 graph, tpu_output_shapes, &tpu_inferred_info);
2258
2259 TF_RETURN_IF_ERROR(
2260 tpu_functional_internal::CreateConcatAndSplitNodesForOutputTensor(
2261 graph, cluster_name, &tpu_output_shapes, &tpu_inferred_info,
2262 shape_to_output, runtime_params_.minimum_output_tensors_packing));
2263 }
2264
2265 if (runtime_params_.minimum_input_tensors_packing > 1) {
2266 GroupedEdges grouped_input_edges =
2267 tpu_functional_internal::GroupTensorsForInputPacking(
2268 tpu_input_shapes, tpu_input_dtypes, runtime_params_.input_shape_opt,
2269 runtime_params_.group_tensors_for_packing);
2270 TF_RETURN_IF_ERROR(
2271 tpu_functional_internal::CreateConcatAndSplitNodesForInputTensor(
2272 graph, cluster_name, &tpu_input_shapes, grouped_input_edges,
2273 runtime_params_.minimum_input_tensors_packing,
2274 xla_spmd_input_sharded, xla_sharding_ops,
2275 tpu_replicated_input_ops));
2276 }
2277 if (runtime_params_.input_shape_opt) {
2278 TF_RETURN_IF_ERROR(tpu_functional_internal::InsertReshapeNodePairs(
2279 graph, cluster_name, &tpu_input_shapes, num_cores_per_replica));
2280 }
2281 VLOG(1) << DumpGraphToFile("optim_result", *graph);
2282
2283 // With or without optimizations, collect the input names and shapes.
2284 for (const auto& iter : tpu_input_shapes) {
2285 std::string name = iter.first->src()->name();
2286 named_input_shapes[name] = iter.second;
2287 }
2288 return OkStatus();
2289 }
2290
GetGraphFromFunction(Graph * graph,int device_ordinal,bool * use_spmd_for_xla_partitioning,TPUMetadata * tpu_metadata)2291 Status TPUPartitionedCallOp::GetGraphFromFunction(
2292 Graph* graph, int device_ordinal, bool* use_spmd_for_xla_partitioning,
2293 TPUMetadata* tpu_metadata) {
2294 FunctionLibraryRuntime::InstantiateOptions opts;
2295 FHandle handle;
2296 TF_RETURN_IF_ERROR(library_runtime_->Instantiate(
2297 func_.name(), AttrSlice(&func_.attr()), opts, &handle));
2298 const FunctionBody* fbody = library_runtime_->GetFunctionBody(handle);
2299 if (fbody == nullptr) {
2300 return errors::Internal("Could not find handle ", handle);
2301 }
2302 CopyGraph(*fbody->graph, graph);
2303
2304 // Pin the inputs and outputs to the local device to simplify the
2305 // function-dispatching logic.
2306 local_device_name_ = library_runtime_->device()->name();
2307 replaced_input_indices_.resize(fbody->arg_nodes.size(), false);
2308 for (Node* node : graph->op_nodes()) {
2309 if (node->IsArg() || node->IsRetval()) {
2310 node->set_assigned_device_name(local_device_name_);
2311 } else if (node->type_string() == "TPUReplicateMetadata") {
2312 // Record the producer name so it can be accessed later during metric
2313 // collection.
2314 string producer_name = GetProducerName(func_.name());
2315 node->AddAttr("_producer_name", producer_name);
2316
2317 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "num_cores_per_replica",
2318 &tpu_metadata->num_cores_per_replica));
2319 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(),
2320 "use_spmd_for_xla_partitioning",
2321 use_spmd_for_xla_partitioning));
2322 VLOG(1) << "num_core_per_replica = "
2323 << tpu_metadata->num_cores_per_replica
2324 << ", use_spmd_for_xla_partitioning = "
2325 << *use_spmd_for_xla_partitioning;
2326
2327 if (tpu_metadata->num_cores_per_replica > 1) {
2328 int num_replicas;
2329 TF_RETURN_IF_ERROR(
2330 GetNodeAttr(node->attrs(), "num_replicas", &num_replicas));
2331 if (num_replicas > 1) {
2332 return errors::InvalidArgument(
2333 "num_replicas shouldn't be large than 1, however it is: ",
2334 num_replicas);
2335 }
2336
2337 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "device_assignment",
2338 &tpu_metadata->device_assignment));
2339
2340 if (!tpu_metadata->device_assignment.empty() && device_ordinal > 0) {
2341 return errors::InvalidArgument(
2342 "`device_assignment` shouldn't be set manually in the graph when "
2343 "round-robin core selection is enabled.");
2344 }
2345
2346 tpu_metadata->topology = GetTPUTopology();
2347 VLOG(1) << "TPU topology: " << tpu_metadata->topology.DebugString();
2348 std::string topology_str;
2349 TF_RETURN_IF_ERROR(
2350 GetNodeAttr(node->attrs(), "topology", &topology_str));
2351 if (!topology_str.empty()) {
2352 LOG(WARNING)
2353 << "Ignore the `topology` value set in TPUReplicateMetadata "
2354 "node, the TPU topology is queried in the runtime.";
2355 }
2356 node->ClearAttr("topology");
2357 node->AddAttr("topology", tpu_metadata->topology.SerializeAsString());
2358
2359 if (tpu_metadata->topology.num_tasks() > 1) {
2360 return errors::InvalidArgument(
2361 "TPUPartitionedCallOp is only supported in single-host setup, "
2362 "however num_task is: ",
2363 tpu_metadata->topology.num_tasks());
2364 }
2365
2366 if (tpu_metadata->device_assignment.empty()) {
2367 VLOG(1) << "Auto assigning device assignment";
2368
2369 // The auto generated device assignment should be the same as or a
2370 // slice of TPU topology device_coordinates. This guarantees the
2371 // logical device IDs order the same as the physical device IDs order.
2372 // It is important for round-robin core selection, as we assume
2373 // the TPU device group for one inference request is
2374 // [TPU:device_ordinal, TPU:device_ordinal + num_cores_per_replica].
2375
2376 auto coordinates_start =
2377 tpu_metadata->topology.device_coordinates().begin() +
2378 device_ordinal * 4;
2379 auto coordinates_end =
2380 tpu_metadata->topology.device_coordinates().begin() +
2381 (device_ordinal + tpu_metadata->num_cores_per_replica) * 4;
2382
2383 node->ClearAttr("device_assignment");
2384 tpu_metadata->device_assignment.insert(
2385 tpu_metadata->device_assignment.begin(), coordinates_start,
2386 coordinates_end);
2387 node->AddAttr("device_assignment", tpu_metadata->device_assignment);
2388 }
2389 }
2390 }
2391 }
2392 return OkStatus();
2393 }
2394
PlacementHelper(const DeviceSet & device_set,const GraphOptimizationPassOptions & optimization_options,const string & function_name)2395 Status TPUPartitionedCallOp::PlacementHelper(
2396 const DeviceSet& device_set,
2397 const GraphOptimizationPassOptions& optimization_options,
2398 const string& function_name) {
2399 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2400 OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
2401 Placer placer(optimization_options.graph->get(), function_name,
2402 optimization_options.flib_def, &device_set);
2403 TF_RETURN_IF_ERROR(placer.Run());
2404 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2405 OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
2406 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2407 OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
2408 return OkStatus();
2409 }
2410
PartitionHelper(const DeviceSet & device_set,const GraphOptimizationPassOptions & optimization_options,Graph * graph,std::unordered_map<std::string,std::unique_ptr<Graph>> * subgraphs)2411 Status TPUPartitionedCallOp::PartitionHelper(
2412 const DeviceSet& device_set,
2413 const GraphOptimizationPassOptions& optimization_options, Graph* graph,
2414 std::unordered_map<std::string, std::unique_ptr<Graph>>* subgraphs) {
2415 PartitionOptions partition_options;
2416 partition_options.node_to_loc = [](const Node* node) {
2417 // TODO(akshayka): To better support the distributed case, first split
2418 // the graph by worker (e.g,. using the master session's
2419 // `SplitByWorker` policy), and then recursively partition the
2420 // per-worker shards at the remote worker(s).
2421 return node->assigned_device_name();
2422 };
2423 int64_t edge_name_counter = 0;
2424 partition_options.new_name = [&edge_name_counter](const string& prefix) {
2425 return strings::StrCat(prefix, "/_", ++edge_name_counter);
2426 };
2427 partition_options.get_incarnation = [&device_set](const string& name) {
2428 const Device* d = device_set.FindDeviceByName(name);
2429 if (d == nullptr) {
2430 return PartitionOptions::kIllegalIncarnation;
2431 } else {
2432 return d->attributes().incarnation();
2433 }
2434 };
2435 partition_options.control_flow_added = false;
2436 std::unordered_map<std::string, GraphDef> partitions;
2437 TF_RETURN_IF_ERROR(Partition(partition_options, graph, &partitions));
2438
2439 VLOG(3) << "Partitioned function '" << func_.name() << "', yielding "
2440 << partitions.size() << " shards.";
2441
2442 const FunctionLibraryDefinition* flib_def = &graph->flib_def();
2443 for (auto& partition : partitions) {
2444 std::unique_ptr<Graph> subgraph(new Graph(flib_def));
2445 GraphConstructorOptions opts;
2446 opts.allow_internal_ops = true;
2447 opts.expect_device_spec = true;
2448 const string& device = partition.first;
2449 GraphDef& graph_def = partition.second;
2450 TF_RETURN_IF_ERROR(
2451 ConvertGraphDefToGraph(opts, std::move(graph_def), subgraph.get()));
2452 subgraphs->emplace(device, std::move(subgraph));
2453 }
2454
2455 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2456 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
2457
2458 return OkStatus();
2459 }
2460
InstantiatePartition(const Graph & graph,const string & function_name,const string & target_device,FHandle * handle,std::unique_ptr<FunctionLibraryDefinition> * out_flib_def)2461 Status TPUPartitionedCallOp::InstantiatePartition(
2462 const Graph& graph, const string& function_name,
2463 const string& target_device, FHandle* handle,
2464 std::unique_ptr<FunctionLibraryDefinition>* out_flib_def) {
2465 FunctionDef shard;
2466 TF_RETURN_IF_ERROR(GraphToFunctionDef(graph, function_name, &shard));
2467 TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(shard));
2468 FunctionLibraryRuntime::InstantiateOptions opts;
2469 opts.target = target_device;
2470 if (out_flib_def) {
2471 *out_flib_def = std::make_unique<FunctionLibraryDefinition>(*flib_def_);
2472 opts.lib_def = out_flib_def->get();
2473 } else {
2474 opts.lib_def = flib_def_.get();
2475 }
2476 return library_runtime_->Instantiate(function_name, AttrSlice(&shard.attr()),
2477 opts, handle);
2478 }
2479
SetDeviceOrdinal(const DeviceSet & device_set,int device_ordinal,Graph * graph,bool * modified)2480 Status TPUPartitionedCallOp::SetDeviceOrdinal(const DeviceSet& device_set,
2481 int device_ordinal, Graph* graph,
2482 bool* modified) {
2483 int ordinal = -1;
2484 for (Node* node : graph->op_nodes()) {
2485 if (node->type_string() == kVarHandleOp) {
2486 if (IsInputToTPUReplicate(node)) {
2487 // If this VarHandleOp is going to a TPU computation,
2488 // it refers to the TPU variable that we created when replacing the
2489 // resource arguments with VarHandleOps.
2490 node->set_assigned_device_name(
2491 strings::StrCat(kTPUDeviceNamePrefix, device_ordinal));
2492 }
2493 continue;
2494 }
2495 if (HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
2496 // Outside compilation related node.
2497 TF_RETURN_IF_ERROR(
2498 SetDeviceOrdinalAttributeForNode(node, device_ordinal));
2499 *modified = true;
2500 continue;
2501 }
2502 const AttrValue* attr = node->attrs().Find(kDeviceOrdinalAttr);
2503 if (attr != nullptr) {
2504 if (!IsSupportedTPUOp(node->type_string())) {
2505 return errors::InvalidArgument("Node ", node->type_string(),
2506 " is not yet supported.");
2507 }
2508 if (ordinal == -1) {
2509 ordinal = attr->i();
2510 } else {
2511 if (ordinal != attr->i()) {
2512 return errors::InvalidArgument(
2513 "Can only partition graphs that use a single device ordinal.");
2514 }
2515 }
2516 node->ClearAttr(kDeviceOrdinalAttr);
2517 node->AddAttr(kDeviceOrdinalAttr, device_ordinal);
2518 VLOG(3) << "Set device ordinal of " << node->type_string() << " to "
2519 << device_ordinal;
2520 *modified = true;
2521 }
2522 if (node->IsSend() || node->IsRecv()) {
2523 static const char* kSendDevice = "send_device";
2524 static const char* kSendDeviceIncarnation = "send_device_incarnation";
2525 static const char* kRecvDevice = "recv_device";
2526 const AttrValue* attr = node->attrs().Find(kSendDevice);
2527 if (attr != nullptr) {
2528 string device = attr->s();
2529 TF_RETURN_IF_ERROR(
2530 UpdateTPUDeviceOrdinal(device_ordinal, &device, modified));
2531 node->ClearAttr(kSendDevice);
2532 node->AddAttr(kSendDevice, device);
2533 node->ClearAttr(kSendDeviceIncarnation);
2534 const Device* d = device_set.FindDeviceByName(device);
2535 int64_t send_incarnation = (d == nullptr)
2536 ? PartitionOptions::kIllegalIncarnation
2537 : d->attributes().incarnation();
2538 node->AddAttr(kSendDeviceIncarnation, send_incarnation);
2539 }
2540 attr = node->attrs().Find(kRecvDevice);
2541 if (attr != nullptr) {
2542 string device = attr->s();
2543 TF_RETURN_IF_ERROR(
2544 UpdateTPUDeviceOrdinal(device_ordinal, &device, modified));
2545 node->ClearAttr(kRecvDevice);
2546 node->AddAttr(kRecvDevice, device);
2547 }
2548 }
2549 }
2550 return OkStatus();
2551 }
2552
InstantiateFunctionsFromSubgraphs(const DeviceSet & device_set,int replica_id,uint64 cache_hash,int num_cores_per_replica,std::unordered_map<std::string,std::unique_ptr<Graph>> subgraphs)2553 Status TPUPartitionedCallOp::InstantiateFunctionsFromSubgraphs(
2554 const DeviceSet& device_set, int replica_id, uint64 cache_hash,
2555 int num_cores_per_replica,
2556 std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs) {
2557 const Device* reference_device = nullptr;
2558 auto entry =
2559 partition_cache_.emplace(cache_hash, std::vector<DeviceAndFHandle>());
2560
2561 bool rewritten = false;
2562 for (auto& pair : subgraphs) {
2563 string target = pair.first;
2564 int device_ordinal = replica_id;
2565 if (num_cores_per_replica > 1) {
2566 DeviceNameUtils::ParsedName parsed_device;
2567 if (!DeviceNameUtils::ParseFullName(target, &parsed_device)) {
2568 return errors::InvalidArgument("Malformed assigned device '", target,
2569 "'");
2570 }
2571 device_ordinal = parsed_device.id;
2572 }
2573 Device* device;
2574 TF_RETURN_IF_ERROR(
2575 library_runtime_->device_mgr()->LookupDevice(target, &device));
2576 if (reference_device == nullptr) {
2577 reference_device = device;
2578 } else {
2579 if (!DeviceNameUtils::IsSameAddressSpace(
2580 device->parsed_name(), reference_device->parsed_name())) {
2581 return errors::InvalidArgument(
2582 "TPUPartitionedCallOp does not yet support inter-process"
2583 "execution.");
2584 }
2585 }
2586 TF_RETURN_IF_ERROR(device->MaybeRewriteGraph(&pair.second));
2587 Graph* subgraph = pair.second.get();
2588 // For model paralleism inference, we only support num_replica == 1, thus
2589 // there is no need to update the device_ordinal anymore.
2590 if (num_cores_per_replica == 1) {
2591 TF_RETURN_IF_ERROR(
2592 SetDeviceOrdinal(device_set, device_ordinal, subgraph, &rewritten));
2593 } else {
2594 VLOG(1) << "Skip SetDeviceOrdinal()";
2595 }
2596 string function_name = flib_def_->UniqueFunctionName(
2597 strings::StrCat(func_.name(), "_hash_", cache_hash));
2598 TF_RETURN_IF_ERROR(
2599 UpdateTPUDeviceOrdinal(device_ordinal, &target, &rewritten));
2600 FHandle handle;
2601 // Use a copy of the current `flib_def_` to instantiate the function to
2602 // avoid races.
2603 std::unique_ptr<FunctionLibraryDefinition> sub_flib_def;
2604 TF_RETURN_IF_ERROR(InstantiatePartition(*subgraph, function_name, target,
2605 &handle, &sub_flib_def));
2606 // Add handle to the cache entry.
2607 entry.first->second.push_back(
2608 DeviceAndFHandle{.device = target,
2609 .handle = handle,
2610 .flib_def = std::move(sub_flib_def)});
2611 }
2612
2613 if (!rewritten) {
2614 // For regular use cases, TPUPartitionedCallOp only works when the
2615 // function being called in rewritten for TPU. If we don't see any signs
2616 // of this rewriting, warn the user about it.
2617 // We don't raise an error because we want to support the use case of
2618 // running tpu.initialize_system eagerly. In this case, we can't use
2619 // tpu.rewrite because it will add compilation ops that require TPU
2620 // to be initialized, i.e. there is a chicken and egg problem.
2621 // We run tpu.initialize_system through TPUPartitionedCallOp because it
2622 // invokes graph rewrite passes that are necessary for initialization to
2623 // work.
2624 LOG(INFO) << "Function body was not rewritten for TPU. "
2625 << "This is probably a bug unless you are initializing "
2626 << "TPUs eagerly.";
2627 }
2628 return OkStatus();
2629 }
2630
ExecuteRemoteFunction(const FunctionLibraryRuntime::Options & opts,FHandle handle,OpKernelContext * ctx,ReffedStatusCallback * done)2631 void TPUPartitionedCallOp::ExecuteRemoteFunction(
2632 const FunctionLibraryRuntime::Options& opts, FHandle handle,
2633 OpKernelContext* ctx, ReffedStatusCallback* done) {
2634 std::vector<Tensor> dummy_args;
2635 std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
2636
2637 profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteRemote");
2638 library_runtime_->Run(opts, handle, dummy_args, dummy_rets,
2639 [dummy_rets, done, ctx](const Status& status) {
2640 if (!status.ok()) {
2641 done->UpdateStatus(status);
2642 }
2643 delete dummy_rets;
2644 done->Unref();
2645 });
2646 }
2647
ExecuteLocalFunction(const FunctionLibraryRuntime::Options & opts,const OpInputList & arguments,FHandle handle,OpKernelContext * ctx,ReffedStatusCallback * done)2648 void TPUPartitionedCallOp::ExecuteLocalFunction(
2649 const FunctionLibraryRuntime::Options& opts, const OpInputList& arguments,
2650 FHandle handle, OpKernelContext* ctx, ReffedStatusCallback* done) {
2651 std::vector<Tensor> args;
2652
2653 for (int i = 0; i < arguments.size(); ++i) {
2654 if (!replaced_input_indices_[i]) {
2655 // _Arg nodes of type DT_RESOURCE that go into a TPU node have been
2656 // replaced by TPU VarHandleOp nodes. No longer need to pass them as
2657 // inputs.
2658 args.push_back(arguments[i]);
2659 }
2660 }
2661 auto* rets = new std::vector<Tensor>;
2662
2663 profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteLocal");
2664 library_runtime_->Run(opts, handle, args, rets,
2665 [rets, done, ctx](const Status& status) {
2666 if (!status.ok()) {
2667 done->UpdateStatus(status);
2668 } else {
2669 for (int i = 0; i < rets->size(); ++i) {
2670 ctx->set_output(i, (*rets)[i]);
2671 }
2672 }
2673 delete rets;
2674 done->Unref();
2675 });
2676 }
2677
ExecuteFunctions(const std::vector<DeviceAndFHandle> & functions,OpKernelContext * ctx,int device_ordinal,int64_t ordinal_selector_req_id,DoneCallback done)2678 void TPUPartitionedCallOp::ExecuteFunctions(
2679 const std::vector<DeviceAndFHandle>& functions, OpKernelContext* ctx,
2680 int device_ordinal, int64_t ordinal_selector_req_id, DoneCallback done) {
2681 profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteFunctions");
2682 FunctionLibraryRuntime::Options opts;
2683 opts.step_container = ctx->step_container();
2684 opts.stats_collector = ctx->stats_collector();
2685 // TODO(akshayka): Consider selecting a runner on a per-device basis,
2686 // i.e., using device-specific threadpools when available.
2687 opts.runner = ctx->runner();
2688 opts.source_device = local_device_name_;
2689 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
2690
2691 OpInputList arguments;
2692 OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
2693
2694 auto* local_cm = new CancellationManager(ctx->cancellation_manager());
2695 auto* rendez = new RefCountedIntraProcessRendezvous(device_mgr_);
2696 opts.cancellation_manager = local_cm;
2697 opts.rendezvous = rendez;
2698
2699 StatusCallback callback(
2700 [rendez = rendez, local_cm, done = std::move(done),
2701 device_ordinal = device_ordinal, req_id = ordinal_selector_req_id, ctx,
2702 ordinal_selector = ordinal_selector_](const Status& status) {
2703 delete local_cm;
2704 rendez->Unref();
2705 if (!status.ok()) {
2706 ctx->SetStatus(status);
2707 }
2708 done();
2709 if (req_id >= 0) {
2710 ordinal_selector->DequeueFromCoreSelector(device_ordinal, req_id);
2711 }
2712 });
2713
2714 auto* refcounted_done = new ReffedStatusCallback(std::move(callback));
2715 for (int i = 1; i < functions.size(); ++i) {
2716 refcounted_done->Ref();
2717 }
2718 for (const DeviceAndFHandle& entry : functions) {
2719 const string& target_device = entry.device;
2720 FHandle handle = entry.handle;
2721 VLOG(3) << "Running function shard on device " << target_device
2722 << " with local device name " << local_device_name_;
2723 if (target_device == local_device_name_) {
2724 opts.remote_execution = false;
2725 ExecuteLocalFunction(opts, arguments, handle, ctx, refcounted_done);
2726 } else {
2727 opts.remote_execution = true;
2728 ExecuteRemoteFunction(opts, handle, ctx, refcounted_done);
2729 }
2730 }
2731 }
2732
2733 REGISTER_KERNEL_BUILDER(Name("TPUPartitionedCall").Device(DEVICE_CPU),
2734 TPUPartitionedCallOp);
2735
2736 } // end namespace tensorflow
2737