1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/cc/dtensor_device_util.h"
17
18 #include <cstddef>
19 #include <string>
20 #include <utility>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/c/eager/c_api_internal.h"
25 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
26 #include "tensorflow/c/tf_status.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/common_runtime/shape_refiner.h"
30 #include "tensorflow/core/framework/attr_value.pb.h"
31 #include "tensorflow/core/framework/function.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/lib/strings/proto_serialization.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/fingerprint.h"
40 #include "tensorflow/core/public/version.h"
41 #include "tensorflow/dtensor/cc/constants.h"
42 #include "tensorflow/dtensor/cc/dstatus.h"
43 #include "tensorflow/dtensor/cc/small_constant_optimization.h"
44
45 namespace tensorflow {
46 namespace dtensor {
47 namespace {
48 // Represents an input node during graph construction.
49 // When executing a Function, `output` is used to align graph inputs
50 // with the inputs to the function call.
51 struct FunctionArgument {
52 Node* node;
53 NodeDefBuilder::NodeOut output;
54 };
55
56 std::unique_ptr<parallel_device::ParallelTensor>
BroadcastTensorHandleToParallelTensor(TFE_Context * context,TFE_TensorHandle * tensor,const MeshWithParallelDevice & mesh,TF_Status * status)57 BroadcastTensorHandleToParallelTensor(TFE_Context* context,
58 TFE_TensorHandle* tensor,
59 const MeshWithParallelDevice& mesh,
60 TF_Status* status) {
61 // Broadcast tensor value to local devices.
62 const Mesh& target_mesh = mesh.mesh_config();
63 absl::Span<const std::string> local_devices = target_mesh.local_devices();
64 const int num_local_devices = local_devices.size();
65
66 std::vector<parallel_device::TensorHandlePtr> components;
67 components.reserve(num_local_devices);
68 for (int i = 0; i < num_local_devices; ++i) {
69 // Create tensor copies to each local devices specifie by `target_mesh`.
70 components.emplace_back(TFE_TensorHandleCopyToDevice(
71 tensor, context, local_devices[i].c_str(), status));
72 if (TF_GetCode(status) != TF_OK) {
73 TF_SetStatus(
74 status, TF_INTERNAL,
75 absl::StrCat(
76 "Unable to copy tensor value for broadcast. Original message: ",
77 TF_Message(status))
78 .c_str());
79 return nullptr;
80 }
81 }
82
83 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor =
84 parallel_device::ParallelTensor::FromTensorHandles(
85 mesh.parallel_device(), std::move(components), status);
86 if (TF_GetCode(status) != TF_OK) return nullptr;
87 return parallel_tensor;
88 }
89
90 // Broadcast a single non-parallel resource tensor onto `mesh` with a fully
91 // replicated sharding spec. Does not take ownership of `tensor`.
BroadcastResourceTensor(TFE_Context * context,TFE_TensorHandle * tensor,const MeshWithParallelDevice & mesh,const std::string & dtensor_device_name,TF_Status * status)92 std::unique_ptr<TensorWithLayout> BroadcastResourceTensor(
93 TFE_Context* context, TFE_TensorHandle* tensor,
94 const MeshWithParallelDevice& mesh, const std::string& dtensor_device_name,
95 TF_Status* status) {
96 // Only broadcast resource tensors that point to scalars since they are
97 // always replicated. We also still want to catch honest user errors so
98 // error out on non-scalars.
99 // Resolve the Tensor as resource handle and get the shape and dtype
100 // of the tensor it points to.
101 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tf_tensor(
102 TFE_TensorHandleResolve(tensor, status), TF_DeleteTensor);
103 Tensor t;
104 Status convert_status = TF_TensorToTensor(tf_tensor.get(), &t);
105 if (!convert_status.ok() || t.dtype() != DataType::DT_RESOURCE) {
106 TF_SetStatus(status, TF_INTERNAL, convert_status.error_message().c_str());
107 return nullptr;
108 }
109 // Replicate this resource handle to all devices without changing the
110 // associated device of the resource itself.
111 ResourceHandle r = t.flat<ResourceHandle>()(0);
112 if (r.dtypes_and_shapes().empty()) {
113 TF_SetStatus(status, TF_INTERNAL,
114 "Expected resource handle to have at least one underlying "
115 "dtype and shape during broadcasting.");
116 return nullptr;
117 }
118 PartialTensorShape partial_shape = r.dtypes_and_shapes().begin()->shape;
119 int64_t num_elements = partial_shape.num_elements();
120
121 // Only broadcast scalar resource tensors onto a CPU mesh. Copying
122 // resource tensors to non CPU device is not supported.
123 if (num_elements != 1 || !mesh.mesh_config().is_cpu_mesh()) {
124 std::string error_message =
125 "Using a non-DTensor variable with DTensor is only supported for "
126 "scalar variables copying to a CPU mesh. If you are using a scope "
127 "based API, create "
128 "variables inside the DTensor scope.\n";
129
130 // Get the stack_trace and Summaries from the resource tensor.
131 absl::StrAppend(
132 &error_message, "Offending variable summary: ", r.SummarizeValue(),
133 "\nStack trace: ", DefinitionLocationMsg(r.definition_stack_trace()));
134 TF_SetStatus(status, TF_INVALID_ARGUMENT, error_message.c_str());
135 return nullptr;
136 }
137
138 LOG(INFO) << "Broadcasting resource tensor to a dtensor resource tensor.";
139 if (mesh.mesh_config().is_remote()) {
140 TF_DataType dtype = TFE_TensorHandleDataType(tensor);
141 std::vector<int64_t> shape(TensorShapeAsVector(tensor, status));
142 if (TF_GetCode(status) != TF_OK) return nullptr;
143 auto layout = Layout::ReplicatedOnMesh(mesh.mesh_config(), shape.size());
144
145 auto ret = TensorWithLayout::Dummy(shape, dtype, mesh, layout);
146 return ret;
147 }
148
149 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor =
150 BroadcastTensorHandleToParallelTensor(context, tensor, mesh, status);
151 if (TF_GetCode(status) != TF_OK) return nullptr;
152
153 StatusOr<std::unique_ptr<TensorWithLayout>> result = TensorWithLayout::Wrap(
154 std::move(parallel_tensor), mesh,
155 Layout::ReplicatedOnMesh(mesh.mesh_config(), partial_shape.dims()));
156 if (!result.ok()) {
157 TF_SetStatus(
158 status, TF_INTERNAL,
159 absl::StrCat("Error creating a TensorWithLayout from a resource tensor "
160 "during broadcasting with original error message:",
161 result.status().error_message())
162 .c_str());
163 return nullptr;
164 }
165 // Set the shape/type of the tensor that the resource points to
166 // so that the graph has correct shape/type information that we can use.
167 (*result)->UpdateShapeAndDType(partial_shape.AsProto(),
168 r.dtypes_and_shapes().begin()->dtype, status);
169 if (TF_GetCode(status) != TF_OK) {
170 TF_SetStatus(status, TF_INTERNAL,
171 "Error updating shape and dtype for resource tensor during "
172 "broadcasting.");
173 return nullptr;
174 }
175 return std::move(*result);
176 }
177
LayoutsAreCompatible(absl::optional<Layout> first_layout,absl::optional<Layout> second_layout)178 bool LayoutsAreCompatible(absl::optional<Layout> first_layout,
179 absl::optional<Layout> second_layout) {
180 if (!first_layout.has_value() && !second_layout.has_value()) {
181 return true;
182 }
183 if (!first_layout.has_value() || !second_layout.has_value()) {
184 return false;
185 }
186 return first_layout.value() == second_layout.value();
187 }
188
189 // Parse a pair of attribute of (indices, layouts) into a map.
ParseAttrMap(const Node & node,absl::string_view indices_attr,absl::string_view layout_attr,std::map<int,Layout> * indices_layout_map)190 Status ParseAttrMap(const Node& node, absl::string_view indices_attr,
191 absl::string_view layout_attr,
192 std::map<int, Layout>* indices_layout_map) {
193 std::vector<std::string> layouts;
194 if (!TryGetNodeAttr(node.attrs(), layout_attr, &layouts)) {
195 return OkStatus();
196 }
197 const TensorProto* indices;
198 if (!TryGetNodeAttr(node.attrs(), indices_attr, &indices)) {
199 return errors::Internal(
200 "Arg indices must be set when setting inferred resource layouts.");
201 }
202 if (indices->int_val_size() != layouts.size()) {
203 return errors::Internal(
204 "Arg indices for inferred resource argument must match the "
205 "size of inferred resource layout.");
206 }
207 for (int i = 0; i < indices->int_val_size(); ++i) {
208 const auto arg_index = indices->int_val(i);
209 const auto& arg_layout = layouts[i];
210 indices_layout_map->emplace(
211 arg_index,
212 tensorflow::dtensor::Layout::FromString(arg_layout).ValueOrDie());
213 }
214 return OkStatus();
215 }
216
ParseResourceArgumentLayouts(const Node & node,std::map<int,Layout> * inferred_resource_input_layouts)217 Status ParseResourceArgumentLayouts(
218 const Node& node, std::map<int, Layout>* inferred_resource_input_layouts) {
219 return ParseAttrMap(node, kNewResourceLayoutIndices, kNewResourceArgLayouts,
220 inferred_resource_input_layouts);
221 }
222
ParseShapeInputLayouts(const Node & node,std::map<int,Layout> * shape_output_metadata)223 Status ParseShapeInputLayouts(const Node& node,
224 std::map<int, Layout>* shape_output_metadata) {
225 return ParseAttrMap(node, kShapeOpInputLayoutIndices, kShapeOpInputLayout,
226 shape_output_metadata);
227 }
228
229 // Gets the layout attached to a specific node at a given index, ignoring any
230 // Identity ops.
GetLayoutThroughIdentityOps(Node * op,int output_index)231 StatusOr<Layout> GetLayoutThroughIdentityOps(Node* op, int output_index) {
232 while (op->op_def().name() == "Identity" ||
233 op->op_def().name() == "IdentityN") {
234 const Edge* edge;
235 TF_RETURN_IF_ERROR(op->input_edge(output_index, &edge));
236 op = edge->src();
237 output_index = edge->src_output();
238 }
239 const auto serialized_layouts = op->attrs().Find(kLayoutAttr);
240
241 if (!serialized_layouts) {
242 return errors::InvalidArgument(
243 op->op_def().name(), " doesn't contain attribute : ", kLayoutAttr);
244 }
245
246 // We assume that there is one layout for each output.
247 if (serialized_layouts->list().s_size() != op->num_outputs()) {
248 return errors::InvalidArgument(
249 "Number of outputs to ", op->op_def().name(),
250 " does not match number of layouts attached");
251 }
252
253 return Layout::FromString(serialized_layouts->list().s(output_index));
254 }
255
256 } // namespace
257
CacheKey() const258 tensorflow::Fprint128 TensorWithLayout::CacheKey() const {
259 tensorflow::Fprint128 f = tensorflow::Fingerprint128(layout_.ToString());
260 // Use exact shape to compute the key.
261 for (const int64_t dim : local_shape()) {
262 f = FingerprintCat128(f, dim);
263 }
264 if (const_value_.has_value()) {
265 std::string serialized;
266 SerializeToStringDeterministic(const_value_.value(), &serialized);
267 f = FingerprintCat128(f, tensorflow::Fingerprint128(serialized));
268 }
269 return f;
270 }
271
Broadcast(TFE_Context * context,TFE_TensorHandle * tensor,const MeshWithParallelDevice & mesh,const std::string & dtensor_device_name,TF_Status * status)272 std::unique_ptr<TensorWithLayout> TensorWithLayout::Broadcast(
273 TFE_Context* context, TFE_TensorHandle* tensor,
274 const MeshWithParallelDevice& mesh, const std::string& dtensor_device_name,
275 TF_Status* status) {
276 const char* input_device = TFE_TensorHandleDeviceName(tensor, status);
277 if (TF_GetCode(status) != TF_OK) return nullptr;
278
279 if (dtensor_device_name == input_device) {
280 TF_SetStatus(status, TF_INVALID_ARGUMENT,
281 "Input to Broadcast must be eager tensor.");
282 return nullptr;
283 }
284
285 // Handle resource tensor broadcasting to the mesh.
286 if (TFE_TensorHandleDataType(tensor) == TF_RESOURCE) {
287 return BroadcastResourceTensor(context, tensor, mesh, dtensor_device_name,
288 status);
289 }
290
291 if (mesh.mesh_config().is_remote()) {
292 TF_DataType dtype = TFE_TensorHandleDataType(tensor);
293 std::vector<int64_t> shape(TensorShapeAsVector(tensor, status));
294 if (TF_GetCode(status) != TF_OK) return nullptr;
295 auto layout = Layout::ReplicatedOnMesh(mesh.mesh_config(), shape.size());
296
297 auto ret = TensorWithLayout::Dummy(shape, dtype, mesh, layout);
298 absl::optional<NodeDef> const_value =
299 ExtractSmallTensorValue(context, tensor, layout, status);
300 if (TF_GetCode(status) != TF_OK) return nullptr;
301 if (const_value) {
302 ret->set_const_value(const_value.value());
303 }
304 return ret;
305 }
306
307 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor =
308 BroadcastTensorHandleToParallelTensor(context, tensor, mesh, status);
309 if (TF_GetCode(status) != TF_OK) return nullptr;
310
311 const std::vector<int64_t>* shape;
312 Status s = parallel_tensor->Shape(&shape);
313 if (!s.ok()) {
314 TF_SetStatus(status, static_cast<TF_Code>(s.code()),
315 s.error_message().c_str());
316 return nullptr;
317 }
318 size_t num_dims = shape->size();
319 const Layout layout = Layout::ReplicatedOnMesh(mesh.mesh_config(), num_dims);
320
321 absl::optional<NodeDef> const_value =
322 ExtractSmallTensorValue(context, tensor, layout, status);
323 if (TF_GetCode(status) != TF_OK) return nullptr;
324
325 std::unique_ptr<TensorWithLayout> result(new TensorWithLayout(
326 std::move(parallel_tensor), mesh, std::move(layout), *shape,
327 /*dtype=*/absl::nullopt, std::move(const_value)));
328 return result;
329 }
330
Wrap(std::unique_ptr<parallel_device::ParallelTensor> tensor,const MeshWithParallelDevice & mesh,const Layout & layout)331 StatusOr<std::unique_ptr<TensorWithLayout>> TensorWithLayout::Wrap(
332 std::unique_ptr<parallel_device::ParallelTensor> tensor,
333 const MeshWithParallelDevice& mesh, const Layout& layout) {
334 const std::vector<int64_t>* shape;
335 TF_RETURN_IF_ERROR(tensor->Shape(&shape));
336
337 if (tensor->dtype() != TF_RESOURCE) {
338 return std::unique_ptr<TensorWithLayout>(
339 new TensorWithLayout(std::move(tensor), mesh, layout, *shape));
340 } else {
341 return std::unique_ptr<TensorWithLayout>(
342 new ResourceHandleWithLayout(std::move(tensor), mesh, layout, *shape));
343 }
344 }
345
Dummy(const std::vector<int64_t> & local_shape,const TF_DataType dtype,const MeshWithParallelDevice & mesh,const Layout & layout)346 std::unique_ptr<TensorWithLayout> TensorWithLayout::Dummy(
347 const std::vector<int64_t>& local_shape, const TF_DataType dtype,
348 const MeshWithParallelDevice& mesh, const Layout& layout) {
349 if (dtype != TF_RESOURCE) {
350 return std::unique_ptr<TensorWithLayout>(new TensorWithLayout(
351 /*tensor=*/nullptr, mesh, layout, local_shape, dtype));
352 } else {
353 return std::unique_ptr<TensorWithLayout>(new ResourceHandleWithLayout(
354 /*tensor=*/nullptr, mesh, layout, local_shape));
355 }
356 }
357
SummarizeValue() const358 std::string TensorWithLayout::SummarizeValue() const {
359 std::string value_summary;
360 Status status;
361 if (layout().IsFullyReplicated()) {
362 status =
363 tensorflow::unwrap(tensor()->tensor(0))->SummarizeValue(value_summary);
364 } else {
365 // Note that this just prints the local values for sharded tensors. We could
366 // instead run a collective here to relayout to replicated.
367 status = tensor()->SummarizeValue(value_summary);
368 }
369 if (!status.ok()) {
370 value_summary = "<error computing value>";
371 }
372 return absl::StrCat(value_summary, ", layout=\"", layout().ToString(), "\"");
373 }
374
DebugString() const375 std::string TensorWithLayout::DebugString() const {
376 auto dtype = static_cast<DataType>(tensor()->dtype());
377
378 const auto& shape_vector = global_shape();
379 return absl::StrCat("DTensor(", SummarizeValue(),
380 ", shape=", ShapeToDebugString(shape_vector),
381 ", type=", DataTypeString(dtype), ")");
382 }
383
EncodeAttributes(tensorflow::NodeDefBuilder & builder) const384 void ResourceHandleWithLayout::EncodeAttributes(
385 tensorflow::NodeDefBuilder& builder) const {
386 // If set, attach shape and dtype to the given node def.
387 if (dereferenced_shape().has_value()) {
388 builder.Attr("_handle_shapes", {*dereferenced_shape()});
389 }
390 if (dereferenced_dtype().has_value()) {
391 builder.Attr("_handle_dtypes", {*dereferenced_dtype()});
392 }
393 }
394
CacheKey() const395 tensorflow::Fprint128 ResourceHandleWithLayout::CacheKey() const {
396 tensorflow::Fprint128 f = tensorflow::Fingerprint128(layout().ToString());
397 if (dereferenced_shape().has_value()) {
398 std::string serialized;
399 SerializeToStringDeterministic(dereferenced_shape().value(), &serialized);
400 f = FingerprintCat128(f, tensorflow::Fingerprint128(serialized));
401 }
402 if (dereferenced_dtype().has_value()) {
403 f = FingerprintCat128(f, dereferenced_dtype().value());
404 }
405 return f;
406 }
407
UpdateLayout(const Layout & new_layout,TF_Status * status)408 void ResourceHandleWithLayout::UpdateLayout(const Layout& new_layout,
409 TF_Status* status) {
410 // Only set the value for deferenced layout if the incoming layout is not
411 // empty. This is still hacky as we use empty layout as placeholder for
412 // eagerly placed VarHandleOp.
413 if (!dereferenced_layout_.has_value() && new_layout.IsEmpty()) return;
414 if (dereferenced_layout_.has_value() &&
415 !LayoutsAreCompatible(dereferenced_layout_, new_layout)) {
416 // TODO(xiejw, allenl): Consider allowing variables to switch layouts.
417 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
418 "Attempted to overwrite an existing Layout.");
419 }
420 dereferenced_layout_.emplace(new_layout);
421 }
422
UpdateAttrs(const EmbeddingResourceAttrs & attrs,TF_Status * status)423 void ResourceHandleWithLayout::UpdateAttrs(const EmbeddingResourceAttrs& attrs,
424 TF_Status* status) {
425 if (attrs_.has_value()) {
426 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
427 "Attepted to overwrite an existing embedding resource "
428 "attribute.");
429 }
430 attrs_.emplace(attrs);
431 }
432
Wrap(std::unique_ptr<parallel_device::ParallelTensor> indices_tensor,std::unique_ptr<parallel_device::ParallelTensor> values_tensor,std::unique_ptr<parallel_device::ParallelTensor> shapes_tensor,const MeshWithParallelDevice & mesh,const Layout & layout,std::vector<int64_t> local_shape)433 StatusOr<std::unique_ptr<TensorWithLayout>> SparseTensorWithLayout::Wrap(
434 std::unique_ptr<parallel_device::ParallelTensor> indices_tensor,
435 std::unique_ptr<parallel_device::ParallelTensor> values_tensor,
436 std::unique_ptr<parallel_device::ParallelTensor> shapes_tensor,
437 const MeshWithParallelDevice& mesh, const Layout& layout,
438 std::vector<int64_t> local_shape) {
439 return std::unique_ptr<TensorWithLayout>(new SparseTensorWithLayout(
440 std::move(indices_tensor), std::move(values_tensor),
441 std::move(shapes_tensor), mesh, layout, local_shape));
442 }
443
SummarizeValue() const444 std::string SparseTensorWithLayout::SummarizeValue() const {
445 std::string indices_summary;
446 std::string values_summary;
447 std::string dense_shapes_summary;
448
449 Status indices_status;
450 Status values_status;
451 Status dense_shapes_status;
452
453 if (layout().IsFullyReplicated()) {
454 indices_status = tensorflow::unwrap(indices_->tensor(0))
455 ->SummarizeValue(indices_summary);
456 values_status =
457 tensorflow::unwrap(values_->tensor(0))->SummarizeValue(values_summary);
458 dense_shapes_status = tensorflow::unwrap(dense_shapes_->tensor(0))
459 ->SummarizeValue(dense_shapes_summary);
460 } else {
461 indices_status = indices_->SummarizeValue(indices_summary);
462 values_status = values_->SummarizeValue(values_summary);
463 dense_shapes_status = dense_shapes_->SummarizeValue(dense_shapes_summary);
464 }
465
466 if (!indices_status.ok())
467 values_summary = "<error computing summary for indices>";
468 if (!values_status.ok())
469 indices_summary = "<error computing summary for values>";
470 if (!dense_shapes_status.ok())
471 indices_summary = "<error computing summary for dense_shapes>";
472
473 return absl::StrCat("indices: ", indices_summary, ", ",
474 "values: ", values_summary, ", ",
475 "dense_shapes: ", dense_shapes_summary, ", layout=\"",
476 layout().ToString(), "\"");
477 }
478
DebugString() const479 std::string SparseTensorWithLayout::DebugString() const {
480 auto dtype = static_cast<DataType>(values_->dtype());
481
482 const auto& shape_vector = global_shape();
483 return absl::StrCat("DTensor(", SummarizeValue(),
484 ", shape=", ShapeToDebugString(shape_vector),
485 ", type=", DataTypeString(dtype), ")");
486 }
487
dtype() const488 TF_DataType SparseTensorWithLayout::dtype() const {
489 if (dtype_.has_value()) {
490 return dtype_.value();
491 } else {
492 return values_->dtype();
493 }
494 }
495
get_tensor(size_t index) const496 TFE_TensorHandle* SparseTensorWithLayout::get_tensor(size_t index) const {
497 int num_sparse_tensors = num_tensors() / 3;
498 if (index < num_sparse_tensors) {
499 return indices()->tensor(index);
500 } else if (index < 2 * num_sparse_tensors) {
501 return values()->tensor(index % num_sparse_tensors);
502 } else {
503 return dense_shapes()->tensor(index % num_sparse_tensors);
504 }
505 }
506
GetConstantFoldableTensors(const std::vector<TensorWithLayout * > & inputs)507 absl::flat_hash_map<int, NodeDef> GetConstantFoldableTensors(
508 const std::vector<TensorWithLayout*>& inputs) {
509 absl::flat_hash_map<int, NodeDef> small_tensors;
510 for (auto index = 0; index < inputs.size(); ++index) {
511 if (inputs[index]->const_value().has_value()) {
512 small_tensors.insert({index, inputs[index]->const_value().value()});
513 }
514 }
515 return small_tensors;
516 }
517
518 // Thread unsafe method. go/thread-unsafe
519 // Cache key computation should consider all features of an op that affects
520 // the SPMD lowering. The cache keys of two ops must be different if the
521 // translated functions are different.
522 // - op name and attr
523 // - input shapes and layouts
524 // - default layout of outputs.
525 // - values of constant foldable inputs.
CacheKeyForGraph(const DTensorOperation & doperation,const NameAttrList & attributes,const std::vector<TensorWithLayout * > & inputs,const std::vector<const Layout * > & output_layouts)526 tensorflow::Fprint128 FunctionManager::CacheKeyForGraph(
527 const DTensorOperation& doperation, const NameAttrList& attributes,
528 const std::vector<TensorWithLayout*>& inputs,
529 const std::vector<const Layout*>& output_layouts) {
530 tensorflow::Fprint128 cache_key = tensorflow::Fingerprint128(doperation.name);
531 std::string serialized;
532 SerializeToStringDeterministic(attributes, &serialized);
533 cache_key =
534 FingerprintCat128(cache_key, tensorflow::Fingerprint128(serialized));
535 // Higher level cache based on operation name and input shapes.
536 for (auto i = 0; i < inputs.size(); ++i) {
537 if (!IsConstantFoldable(doperation, i)) {
538 inputs[i]->reset_const_value();
539 }
540 cache_key = FingerprintCat128(cache_key, inputs[i]->CacheKey());
541 }
542 for (int output_index = 0; output_index < output_layouts.size();
543 ++output_index) {
544 if (output_layouts[output_index]) {
545 cache_key = FingerprintCat128(cache_key, output_index);
546 cache_key = FingerprintCat128(
547 cache_key,
548 tensorflow::Fingerprint128(output_layouts[output_index]->ToString()));
549 }
550 }
551 return cache_key;
552 }
553
554 // Thread-unsafe method go/thread-unsafe.
555 std::pair<tensorflow::Fprint128, const ExecutionFunctions*>
GetCachedFunction(const DTensorOperation & doperation,const NameAttrList & attributes,const std::vector<TensorWithLayout * > & inputs,const std::vector<const Layout * > & output_layouts)556 FunctionManager::GetCachedFunction(
557 const DTensorOperation& doperation, const NameAttrList& attributes,
558 const std::vector<TensorWithLayout*>& inputs,
559 const std::vector<const Layout*>& output_layouts) {
560 tensorflow::Fprint128 cache_key =
561 CacheKeyForGraph(doperation, attributes, inputs, output_layouts);
562 auto iter = function_cache_.find(cache_key);
563
564 // Early return if we have a cache hit.
565 if (iter != function_cache_.end()) {
566 return std::pair<Fprint128, ExecutionFunctions*>(cache_key, &iter->second);
567 }
568
569 // For eager ops we early return the cache miss and do not make further
570 // optimizations.
571 if (!doperation.is_func()) {
572 return std::pair<Fprint128, std::nullptr_t>(cache_key, nullptr);
573 }
574
575 const tensorflow::Fprint128 doperation_hash =
576 CacheKeyForDTensorOperation(doperation);
577
578 // Save the constant folded inputs to this doperation if we have not seen this
579 // before. This is needed so that in the next call to this operation, we
580 // can compare these inputs to confirm which one is indeed a constant.
581 auto doperation_iter = dtensor_op_and_small_inputs_.find(doperation_hash);
582 if (doperation_iter == dtensor_op_and_small_inputs_.end()) {
583 dtensor_op_and_small_inputs_.insert(
584 {doperation_hash, GetConstantFoldableTensors(inputs)});
585 return std::pair<Fprint128, std::nullptr_t>(cache_key, nullptr);
586 }
587
588 // If we are here, then we have ran this function before but constant folded
589 // some input(s) when it was not a constant input i.e. one of the small value
590 // to this function input changed. So mark those changed values as
591 // non-constant.
592 absl::flat_hash_map<int, NodeDef>& previous_small_inputs =
593 doperation_iter->second;
594 std::vector<int> non_constant_indices;
595
596 for (auto const& [index, previous_small_input] : previous_small_inputs) {
597 if (inputs[index]->const_value().has_value()) {
598 if (NodeDefsHaveDifferentTensorProto(
599 previous_small_input, inputs[index]->const_value().value())) {
600 inputs[index]->reset_const_value();
601 non_constant_indices.push_back(index);
602 }
603 }
604 }
605 for (int non_constant_index : non_constant_indices) {
606 previous_small_inputs.erase(non_constant_index);
607 }
608 // Generate a new cache key since we updated small const inputs which change
609 // the cache key.
610 cache_key = CacheKeyForGraph(doperation, attributes, inputs, output_layouts);
611 return std::pair<Fprint128, std::nullptr_t>(cache_key, nullptr);
612 }
613
AddCachedFunction(const DTensorOperation & op,tensorflow::Fprint128 cache_key,ExecutionFunctions function)614 const ExecutionFunctions* FunctionManager::AddCachedFunction(
615 const DTensorOperation& op, tensorflow::Fprint128 cache_key,
616 ExecutionFunctions function) {
617 return &function_cache_.insert({cache_key, std::move(function)})
618 .first->second;
619 }
620
IsConstantFoldable(const DTensorOperation & doperation,const int input_index) const621 bool FunctionManager::IsConstantFoldable(const DTensorOperation& doperation,
622 const int input_index) const {
623 // For eager ops, assume the inputs are constant foldable.
624 if (!doperation.is_func()) return true;
625 const tensorflow::Fprint128 doperation_hash =
626 CacheKeyForDTensorOperation(doperation);
627 // If we didn't see this doperation before then optimisticly assume this is
628 // foldable. The input at `input_index` is foldable only if it is one of the
629 // indices we have saved as the small inputs.
630 auto doperation_iter = dtensor_op_and_small_inputs_.find(doperation_hash);
631 return doperation_iter == dtensor_op_and_small_inputs_.end() ||
632 doperation_iter->second.contains(input_index);
633 }
634
CacheKeyForDTensorOperation(const DTensorOperation & doperation) const635 const tensorflow::Fprint128 FunctionManager::CacheKeyForDTensorOperation(
636 const DTensorOperation& doperation) const {
637 return tensorflow::Fingerprint128(doperation.name);
638 }
639
TensorShapeAsVector(TFE_TensorHandle * tensor,TF_Status * status)640 std::vector<int64_t> TensorShapeAsVector(TFE_TensorHandle* tensor,
641 TF_Status* status) {
642 std::vector<int64_t> shape(TFE_TensorHandleNumDims(tensor, status));
643 if (TF_GetCode(status) != TF_OK) return {};
644 for (int i = 0; i < shape.size(); ++i) {
645 shape[i] = TFE_TensorHandleDim(tensor, i, status);
646 if (TF_GetCode(status) != TF_OK) return {};
647 }
648 return shape;
649 }
650
PrepareGraphForMlir(const FunctionManager & function_manager,const std::vector<TensorWithLayout * > & inputs,const DTensorOperation & doperation,const tensorflow::FunctionLibraryDefinition & flib_def,const NameAttrList & attributes,const absl::optional<Layout> & default_layout,tensorflow::Graph * graph,std::vector<PartialTensorShape> * global_output_shapes,std::vector<const Layout * > * output_layouts)651 Status PrepareGraphForMlir(
652 const FunctionManager& function_manager,
653 const std::vector<TensorWithLayout*>& inputs,
654 const DTensorOperation& doperation,
655 const tensorflow::FunctionLibraryDefinition& flib_def,
656 const NameAttrList& attributes,
657 const absl::optional<Layout>& default_layout, tensorflow::Graph* graph,
658 std::vector<PartialTensorShape>* global_output_shapes,
659 std::vector<const Layout*>* output_layouts) {
660 // We run shape inference on the graph to find output shapes, which may
661 // determine default layouts.
662 ShapeRefiner shape_refiner(TF_GRAPH_DEF_VERSION, &flib_def);
663 shape_refiner.set_function_library_for_shape_inference(&flib_def);
664 tensorflow::Status status;
665 {
666 // We include an _Arg node for the device ID, but this isn't used by the
667 // initial function. It will be provided a value, though, so it's available
668 // for use in rewrites.
669 tensorflow::NodeDefBuilder builder("device_id", "_Arg");
670 tensorflow::PartialTensorShape partial_shape;
671 TF_RETURN_IF_ERROR(tensorflow::PartialTensorShape::MakePartialShape(
672 static_cast<int*>(nullptr), 0, &partial_shape));
673 tensorflow::NodeDef arg_node_def;
674 TF_RETURN_IF_ERROR(builder.Attr("shape", partial_shape)
675 .Attr("T", tensorflow::DT_INT32)
676 .Attr("index", 0)
677 .Finalize(&arg_node_def, /*consume=*/true));
678 tensorflow::Node* arg_node = graph->AddNode(arg_node_def, &status);
679 TF_RETURN_IF_ERROR(status);
680 graph->AddControlEdge(graph->source_node(), arg_node);
681 TF_RETURN_IF_ERROR(shape_refiner.AddNode(arg_node));
682 }
683 std::vector<FunctionArgument> graph_op_inputs;
684 graph_op_inputs.reserve(inputs.size());
685 for (int i = 0; i < inputs.size(); ++i) {
686 const TensorWithLayout* input = inputs[i];
687 // TODO(allenl): This will block until async execution is complete, which
688 // will be slow. We should find a non-blocking way of fetching the shape,
689 // at least pre-cache.
690 // The shape passed into MLIR transformation represents the global shape of
691 // the tensor. Ideally, the local shape on each parallel device should not
692 // be consulted at all and we should use the shape on our input tensor
693 // directly.
694 const auto& shape = input->global_shape();
695 std::vector<tensorflow::int64> cast_shape(shape.begin(), shape.end());
696 tensorflow::PartialTensorShape partial_shape;
697 // For resource tensors, `shape` attribute should not be specified as shape
698 // of resource tensors is specified by resource shape subtype -- not the
699 // shape attribute.
700 auto* resource = dynamic_cast<const ResourceHandleWithLayout*>(input);
701 if (!resource) {
702 TF_RETURN_IF_ERROR(tensorflow::PartialTensorShape::MakePartialShape(
703 cast_shape.data(), cast_shape.size(), &partial_shape));
704 }
705
706 tensorflow::NodeDef arg_node_def;
707 auto dtype = static_cast<tensorflow::DataType>(input->dtype());
708 tensorflow::NodeDefBuilder builder(absl::StrCat("op_input_", i), "_Arg");
709
710 // Delegate TensorWithLayout to encode attributes if applicable.
711 input->EncodeAttributes(builder);
712
713 // Here we set each arg node's `index` attribute to the position of
714 // the dtensor inputs. This is important for later use when we create
715 // a mapping from the graph argument node to the corresponding argument
716 // index of the list of dtensor inputs. Thus, even if the argument node
717 // orderings change within the graph, we can always correctly
718 // find the dtensor input corresponding to that arg node.
719 //
720 // This assumes that the dtensor inputs stay unchanged in ordering,
721 // and if there is an ordering change of dtensor inputs, then special
722 // care must be taken.
723 TF_RETURN_IF_ERROR(
724 builder.Attr("shape", partial_shape)
725 .Attr("T", dtype)
726 .Attr("index", i + 1) // Indices are offset by 1 for device_id
727 .Attr(kLayoutAttr, input->layout().ToString())
728 .Attr(kMeshAttr, input->mesh().mesh_config().ToString())
729 .Finalize(&arg_node_def, /*consume=*/true));
730 Node* arg_node = graph->AddNode(arg_node_def, &status);
731 TF_RETURN_IF_ERROR(status);
732 TF_RETURN_IF_ERROR(shape_refiner.AddNode(arg_node));
733
734 shape_inference::InferenceContext* inference_context =
735 shape_refiner.GetContext(arg_node);
736 shape_inference::ShapeHandle shape_handle;
737 TF_RETURN_IF_ERROR(inference_context->MakeShapeFromPartialTensorShape(
738 partial_shape, &shape_handle));
739 TF_RETURN_IF_ERROR(shape_refiner.SetShape(arg_node, 0, shape_handle));
740
741 // Small constants are converted into constant graph nodes, instead of being
742 // passed in as input arguments. This provides more information to the SPMD
743 // and layout propagation passes.
744 if (!input->const_value().has_value() ||
745 !function_manager.IsConstantFoldable(doperation, i)) {
746 graph_op_inputs.push_back(FunctionArgument{
747 arg_node, NodeDefBuilder::NodeOut{arg_node->name(), i, dtype}});
748 graph->AddControlEdge(graph->source_node(), arg_node);
749 } else {
750 // TODO(xiejw): Refactor the TensorWithLayout representation to avoid
751 // special code here.
752 NodeDef const_node = input->const_value().value();
753 const_node.set_name(absl::StrCat("input_", i, "_const_value"));
754 Node* const_value_n = graph->AddNode(const_node, &status);
755 TF_RETURN_IF_ERROR(status);
756 TF_RETURN_IF_ERROR(shape_refiner.AddNode(const_value_n));
757 graph_op_inputs.push_back(FunctionArgument{
758 const_value_n, tensorflow::NodeDefBuilder::NodeOut{
759 const_value_n->name(), i, dtype}});
760 }
761 }
762
763 tensorflow::NodeDef op_node_def;
764 const FunctionDef* function_def = doperation.function_def;
765 if (function_def) {
766 AttrValue func_attr;
767 func_attr.mutable_func()->set_name(doperation.name);
768 std::vector<tensorflow::NodeDefBuilder::NodeOut> func_inputs;
769 std::vector<tensorflow::DataType> inputs_types;
770 for (const auto& in : graph_op_inputs) {
771 func_inputs.emplace_back(in.output);
772 inputs_types.emplace_back(in.output.data_type);
773 }
774
775 std::vector<tensorflow::DataType> output_types;
776 for (const auto& out : function_def->signature().output_arg())
777 output_types.emplace_back(out.type());
778
779 TF_RETURN_IF_ERROR(
780 NodeDefBuilder("eager_operation", "StatefulPartitionedCall")
781 .Attr("Tin", inputs_types)
782 .Attr("Tout", output_types)
783 .Attr("f", func_attr)
784 .Input(func_inputs)
785 .Finalize(&op_node_def, true));
786 } else {
787 op_node_def.set_op(doperation.name);
788 op_node_def.set_name("eager_operation");
789 }
790
791 op_node_def.mutable_attr()->insert(attributes.attr().begin(),
792 attributes.attr().end());
793
794 tensorflow::Node* op_node = graph->AddNode(op_node_def, &status);
795 TF_RETURN_IF_ERROR(status);
796
797 for (int i = 0; i < graph_op_inputs.size(); ++i) {
798 graph->AddEdge(graph_op_inputs[i].node, 0, op_node, i);
799 }
800 TF_RETURN_IF_ERROR(shape_refiner.AddNode(op_node));
801
802 output_layouts->clear();
803 output_layouts->reserve(op_node->num_outputs());
804 global_output_shapes->reserve(op_node->num_outputs());
805 for (int output_index = 0; output_index < op_node->num_outputs();
806 ++output_index) {
807 tensorflow::NodeDefBuilder builder(absl::StrCat("op_output_", output_index),
808 "_Retval");
809 tensorflow::NodeDef ret_node_def;
810 tensorflow::DataType output_type = op_node->output_type(output_index);
811
812 TF_RETURN_IF_ERROR(builder.Attr("T", output_type)
813 .Attr("index", output_index)
814 .Input("eager_operation", output_index, output_type)
815 .Finalize(&ret_node_def, /*consume=*/true));
816 tensorflow::Node* ret_node = graph->AddNode(ret_node_def, &status);
817 TF_RETURN_IF_ERROR(status);
818 graph->AddEdge(op_node, output_index, ret_node, 0);
819 graph->AddControlEdge(ret_node, graph->sink_node());
820
821 shape_inference::InferenceContext* inference_context =
822 shape_refiner.GetContext(op_node);
823 shape_inference::ShapeHandle output_shape_handle =
824 inference_context->output(output_index);
825 TensorShapeProto output_shape_proto;
826 inference_context->ShapeHandleToProto(output_shape_handle,
827 &output_shape_proto);
828 PartialTensorShape global_output_shape(output_shape_proto);
829 VLOG(3) << "Inferred shape for operation '" << doperation.name
830 << "':" << global_output_shape.DebugString();
831 global_output_shapes->push_back(global_output_shape);
832
833 const Layout* layout = nullptr;
834 if (default_layout.has_value() && output_index == 0) {
835 // Record the user's requested output layout. The scope currently only
836 // covers the first output of an op.
837 layout = &default_layout.value();
838 ret_node->AddAttr(kDefaultLayoutAttr, layout->ToString());
839 }
840 output_layouts->push_back(layout);
841 }
842 return OkStatus();
843 }
844
845 // Returns set of functions to run to execute DTensor computation.
IdentifyAllFunctionsToExecute(const tensorflow::Graph & graph,const std::vector<PartialTensorShape> & global_output_shapes)846 StatusOr<ExecutionFunctions> IdentifyAllFunctionsToExecute(
847 const tensorflow::Graph& graph,
848 const std::vector<PartialTensorShape>& global_output_shapes) {
849 ExecutionFunctions execution_functions;
850 execution_functions.function_list = std::vector<TranslatedFunction>();
851 for (Node* node : graph.nodes()) {
852 if (node->op_def().name() != "StatefulPartitionedCall") continue;
853 // Extract mesh to execute the function.
854 std::string serialized_mesh;
855 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kMeshAttr, &serialized_mesh));
856 Mesh mesh;
857 TF_ASSIGN_OR_RETURN(mesh, Mesh::FromString(serialized_mesh));
858
859 TranslatedFunction function;
860 function.function_mesh = std::move(mesh);
861 function.node_to_execute = node;
862
863 // Identify input arg information.
864 TF_RETURN_IF_ERROR(
865 ParseResourceArgumentLayouts(*node, &function.resource_input_layouts));
866
867 TF_RETURN_IF_ERROR(
868 ParseShapeInputLayouts(*node, &function.shape_output_metadata));
869
870 function.input_index_map.resize(node->num_inputs());
871 // Identity mapping between local mesh function input index and global
872 // input index.
873 for (int in_index = 0; in_index < node->num_inputs(); ++in_index) {
874 Node* input_node;
875
876 TF_RETURN_IF_ERROR(node->input_node(in_index, &input_node));
877 if (!input_node->IsArg())
878 return errors::InvalidArgument(
879 "Input node to mesh computation must be arg node.");
880
881 int global_index;
882 TF_RETURN_IF_ERROR(
883 GetNodeAttr(input_node->attrs(), "index", &global_index));
884 function.input_index_map[in_index] = global_index;
885 }
886
887 // Identify output mappings and layouts for each outputs.
888 std::map<int, const Edge*> output_edges;
889 for (const Edge* out_edge : node->out_edges()) {
890 if (out_edge->IsControlEdge()) continue;
891
892 const Node* retval_or_identity_node = out_edge->dst();
893 while (retval_or_identity_node->IsIdentity()) {
894 retval_or_identity_node =
895 *(retval_or_identity_node->out_nodes().begin());
896 }
897
898 TF_RET_CHECK(retval_or_identity_node->IsRetval());
899 int global_index;
900 TF_RETURN_IF_ERROR(GetNodeAttr(retval_or_identity_node->attrs(), "index",
901 &global_index));
902 output_edges[global_index] = out_edge;
903 }
904
905 for (auto it = output_edges.begin(); it != output_edges.end(); it++) {
906 const int global_index = it->first;
907 function.output_index_map.emplace_back(global_index);
908
909 const Edge* retval_edge = it->second;
910 const int output_index = retval_edge->src_output();
911
912 // Add output layout and shape information.
913 TF_ASSIGN_OR_RETURN(
914 const Layout output_layout,
915 GetLayoutThroughIdentityOps(retval_edge->src(), output_index));
916
917 function.output_layouts.emplace_back(output_layout);
918 function.local_output_shapes.emplace_back(
919 output_layout.LocalShapeFromGlobalShape(
920 global_output_shapes[global_index]));
921 }
922
923 execution_functions.function_list.emplace_back(std::move(function));
924 }
925
926 if (execution_functions.function_list.empty()) {
927 return errors::InvalidArgument(
928 "MLIR transformed graph does not have any functions to execute for "
929 "mesh.");
930 }
931
932 return execution_functions;
933 }
934
935 // For functions with control outputs, add identity nodes between
936 // StatefulPartitionedCall and _Retvals, in order to preserve control output
937 // dependencies after StatefulPartitionedCall is inlined at runtime.
938 // Consider calling this in PrepareGraphForMlir, once the identity nodes won't
939 // be dropped during MLIR lowering.
940 // TODO(b/171265131): fix the underlying issue to avoid inserting identity
941 // nodes.
MaybeInsertIdentityNodes(const FunctionDef * function_def,Graph * graph)942 Status MaybeInsertIdentityNodes(const FunctionDef* function_def, Graph* graph) {
943 if (function_def == nullptr || function_def->control_ret().empty()) {
944 return OkStatus();
945 }
946 tensorflow::Status status;
947 for (Node* n : graph->nodes()) {
948 if (!n->IsRetval()) {
949 continue;
950 }
951 const Edge* edge;
952 TF_RETURN_IF_ERROR(n->input_edge(0, &edge));
953 int ret_index;
954 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &ret_index));
955 tensorflow::NodeDefBuilder identity_builder(
956 absl::StrCat("op_output_identity_", ret_index), "Identity");
957 tensorflow::NodeDef ret_identity_node_def;
958 tensorflow::DataType output_type = n->input_type(0);
959 TF_RETURN_IF_ERROR(
960 identity_builder.Attr("T", output_type)
961 .Input(edge->src()->name(), edge->src_output(), output_type)
962 .Finalize(&ret_identity_node_def, /*consume=*/true));
963 Node* ret_identity_node = graph->AddNode(ret_identity_node_def, &status);
964 TF_RETURN_IF_ERROR(status);
965 // Delete the edge between StatefulPartitionedCall and _Retval.
966 graph->RemoveEdge(edge);
967 // Add an edge between StatefulPartitionedCall and Identity.
968 graph->AddEdge(edge->src(), edge->src_output(), ret_identity_node, 0);
969 graph->AddControlEdge(edge->src(), ret_identity_node);
970 // Add an edge between Identity and _Retval.
971 graph->AddEdge(ret_identity_node, 0, n, 0);
972 }
973 return OkStatus();
974 }
975
AddDTensorFunctionAttr(FunctionDef & function_def)976 void AddDTensorFunctionAttr(FunctionDef& function_def) {
977 // Do not xla compile function returned by DTensor MLIR graph transformation
978 // as it already returns compiled graph.
979 AttrValue xla_must_compile_val;
980 xla_must_compile_val.set_b(false);
981 function_def.mutable_attr()->insert(
982 {"_XlaMustCompile", xla_must_compile_val});
983
984 // Explicitly place function outputs on the default function device to avoid
985 // redundant host <-> device copies (Placer may place outputs on the host
986 // CPU).
987 AttrValue outputs_on_op_device;
988 outputs_on_op_device.set_b(true);
989 function_def.mutable_attr()->insert(
990 {"_OutputsOnOpDevice", outputs_on_op_device});
991 }
992
PrepareEmbeddingInputs(const std::vector<TensorWithLayout * > & inputs)993 StatusOr<std::vector<parallel_device::ParallelTensor*>> PrepareEmbeddingInputs(
994 const std::vector<TensorWithLayout*>& inputs) {
995 absl::flat_hash_map<int64_t, std::vector<int64_t>> table_vars_input_index;
996 for (int64_t i = 0; i < inputs.size(); ++i) {
997 if (inputs[i]->tensor_type() != kResource) continue;
998
999 const absl::optional<EmbeddingResourceAttrs>& resource_attrs =
1000 inputs[i]->attrs();
1001 if (resource_attrs.has_value()) {
1002 table_vars_input_index[resource_attrs->table_id].push_back(i);
1003 }
1004 }
1005
1006 // Check if there is no embedding resource input found.
1007 if (table_vars_input_index.empty()) {
1008 return errors::Internal("There are no TPU embedding resource input found.");
1009 }
1010 std::vector<parallel_device::ParallelTensor*> parallel_inputs;
1011 // Assure parallel inputs has numeric order as table ids.
1012 for (const auto& [table_id, table_vars_indices] : table_vars_input_index) {
1013 for (const int64_t input_index : table_vars_indices) {
1014 parallel_inputs.push_back(inputs[input_index]->tensor());
1015 }
1016 }
1017 return parallel_inputs;
1018 }
1019
GetTPUEmbeddingInputNodes(TF_Status * s,const Graph & graph,const std::vector<TensorWithLayout * > & inputs)1020 StatusOr<std::map<int64_t, std::vector<Node*>>> GetTPUEmbeddingInputNodes(
1021 TF_Status* s, const Graph& graph,
1022 const std::vector<TensorWithLayout*>& inputs) {
1023 // After the graph is lowered, the sparse tensors live at the end of the
1024 // argument list, so process the dtensor dense inputs only so that
1025 // we index correctly.
1026 std::vector<TensorWithLayout*> non_sparse_inputs;
1027 non_sparse_inputs.reserve(inputs.size());
1028 for (TensorWithLayout* input : inputs) {
1029 if (input->tensor_type() != TensorType::kSparse) {
1030 non_sparse_inputs.push_back(input);
1031 }
1032 }
1033 std::map<int64_t, std::vector<Node*>> table_id_node_map;
1034 for (Node* node : graph.nodes()) {
1035 if (!node->IsArg()) continue;
1036
1037 const int64_t& arg_id = node->attrs().Find("index")->i();
1038 const AttrValue* embedding_attr =
1039 node->attrs().Find("_tpu_embedding_table_id");
1040
1041 if (embedding_attr == nullptr) continue;
1042 EmbeddingResourceAttrs embedding_input_attrs;
1043
1044 // Add embedding table id.
1045 const int64_t table_id = embedding_attr->i();
1046 embedding_input_attrs.table_id = table_id;
1047
1048 // Add embedding slot id if there is one.
1049 const AttrValue* embedding_slot_attr =
1050 node->attrs().Find("_tpu_embedding_slot_id");
1051 if (embedding_slot_attr != nullptr) {
1052 const int64_t slot_id = embedding_slot_attr->i();
1053 embedding_input_attrs.slot_id = slot_id;
1054 }
1055
1056 table_id_node_map[table_id].push_back(node);
1057
1058 // Arg input offset due to device id.
1059 if (non_sparse_inputs[arg_id - 1]->attrs().has_value()) continue;
1060 non_sparse_inputs[arg_id - 1]->UpdateAttrs(embedding_input_attrs, s);
1061 if (!s->status.ok()) {
1062 return errors::Internal(
1063 "Failed to set embedding resource attrs. \n Got error: ",
1064 s->status.error_message());
1065 }
1066 }
1067 return table_id_node_map;
1068 }
1069
ValidateResourceMeshConsistency(const std::vector<TensorWithLayout * > & inputs)1070 StatusOr<std::string> ValidateResourceMeshConsistency(
1071 const std::vector<TensorWithLayout*>& inputs) {
1072 std::string mesh_str;
1073 for (TensorWithLayout* inp : inputs) {
1074 if ((inp->tensor_type() != kResource) || !inp->attrs().has_value())
1075 continue;
1076 const std::string& input_mesh_str = inp->layout().mesh().ToString();
1077 if (mesh_str.empty()) {
1078 mesh_str = input_mesh_str;
1079 } else if (mesh_str != input_mesh_str) {
1080 return errors::Internal(absl::StrCat(
1081 "All inputs of embedding resource must be on same mesh. but get : ",
1082 mesh_str, " != ", input_mesh_str));
1083 }
1084 }
1085 VLOG(1) << "Resource input mesh is : " << mesh_str;
1086 return mesh_str;
1087 }
1088
InsertFunctionForTPUEmbeddingCheckpoint(TF_Status * status,Graph * graph,const std::vector<TensorWithLayout * > & inputs,const std::string & checkpoint_fn_name)1089 Status InsertFunctionForTPUEmbeddingCheckpoint(
1090 TF_Status* status, Graph* graph,
1091 const std::vector<TensorWithLayout*>& inputs,
1092 const std::string& checkpoint_fn_name) {
1093 if (checkpoint_fn_name != kLoadEmbeddingFn &&
1094 checkpoint_fn_name != kRetrieveEmbeddingFn) {
1095 return errors::InvalidArgument(absl::StrCat(
1096 "Found wrong function name: ", checkpoint_fn_name,
1097 " \n expects : ", kLoadEmbeddingFn, " or ", kRetrieveEmbeddingFn));
1098 }
1099
1100 StatusOr<std::map<int64_t, std::vector<Node*>>> table_id_node_map =
1101 GetTPUEmbeddingInputNodes(status, *graph, inputs);
1102 if (!table_id_node_map.ok()) {
1103 return errors::Internal(table_id_node_map.status().error_message());
1104 }
1105
1106 StatusOr<std::string> mesh_str = ValidateResourceMeshConsistency(inputs);
1107
1108 const int64_t& num_tables = table_id_node_map->size();
1109 NodeDef func_node_def;
1110 std::vector<NodeDefBuilder::NodeOut> func_inputs;
1111 std::vector<DataType> input_types, output_types;
1112
1113 func_inputs.reserve(num_tables);
1114 input_types.reserve(num_tables);
1115
1116 for (int i = 0; i < num_tables; ++i) {
1117 auto node_vec_ptr = table_id_node_map->find(i);
1118 if (node_vec_ptr == table_id_node_map->end()) {
1119 return errors::Internal(
1120 absl::StrCat("Embedding table id ", i, " is not found."));
1121 }
1122 for (const Node* n : node_vec_ptr->second) {
1123 const std::string& node_name = n->name();
1124 func_inputs.push_back({node_name, i, DT_RESOURCE});
1125 input_types.push_back(DT_RESOURCE);
1126 }
1127 }
1128
1129 AttrValue mesh_attr;
1130 *mesh_attr.mutable_s() = *mesh_str;
1131 NameAttrList func_attr;
1132 func_attr.set_name(checkpoint_fn_name);
1133 TF_RETURN_IF_ERROR(
1134 NodeDefBuilder(checkpoint_fn_name, "StatefulPartitionedCall")
1135 .Attr("Tin", input_types)
1136 .Attr("Tout", output_types)
1137 .Attr("f", func_attr)
1138 .Attr(kMeshAttr, mesh_attr)
1139 .Attr("config", mesh_attr)
1140 .Input(func_inputs)
1141 .Finalize(&func_node_def, true));
1142
1143 TF_ASSIGN_OR_RETURN(Node * func_node, graph->AddNode(func_node_def));
1144 for (int i = 0; i < num_tables; ++i) {
1145 const std::vector<Node*>& node_vec = table_id_node_map->find(i)->second;
1146 for (int j = 0; j < node_vec.size(); ++j) {
1147 graph->AddEdge(node_vec[j], 0, func_node, j + i);
1148 }
1149 }
1150
1151 return OkStatus();
1152 }
1153
1154 } // namespace dtensor
1155 } // namespace tensorflow
1156