xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/dtensor_device_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/cc/dtensor_device_util.h"
17 
18 #include <cstddef>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/c/eager/c_api_internal.h"
25 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
26 #include "tensorflow/c/tf_status.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/common_runtime/shape_refiner.h"
30 #include "tensorflow/core/framework/attr_value.pb.h"
31 #include "tensorflow/core/framework/function.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/lib/strings/proto_serialization.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/fingerprint.h"
40 #include "tensorflow/core/public/version.h"
41 #include "tensorflow/dtensor/cc/constants.h"
42 #include "tensorflow/dtensor/cc/dstatus.h"
43 #include "tensorflow/dtensor/cc/small_constant_optimization.h"
44 
45 namespace tensorflow {
46 namespace dtensor {
47 namespace {
48 // Represents an input node during graph construction.
49 // When executing a Function, `output` is used to align graph inputs
50 // with the inputs to the function call.
51 struct FunctionArgument {
52   Node* node;
53   NodeDefBuilder::NodeOut output;
54 };
55 
56 std::unique_ptr<parallel_device::ParallelTensor>
BroadcastTensorHandleToParallelTensor(TFE_Context * context,TFE_TensorHandle * tensor,const MeshWithParallelDevice & mesh,TF_Status * status)57 BroadcastTensorHandleToParallelTensor(TFE_Context* context,
58                                       TFE_TensorHandle* tensor,
59                                       const MeshWithParallelDevice& mesh,
60                                       TF_Status* status) {
61   // Broadcast tensor value to local devices.
62   const Mesh& target_mesh = mesh.mesh_config();
63   absl::Span<const std::string> local_devices = target_mesh.local_devices();
64   const int num_local_devices = local_devices.size();
65 
66   std::vector<parallel_device::TensorHandlePtr> components;
67   components.reserve(num_local_devices);
68   for (int i = 0; i < num_local_devices; ++i) {
69     // Create tensor copies to each local devices specifie by `target_mesh`.
70     components.emplace_back(TFE_TensorHandleCopyToDevice(
71         tensor, context, local_devices[i].c_str(), status));
72     if (TF_GetCode(status) != TF_OK) {
73       TF_SetStatus(
74           status, TF_INTERNAL,
75           absl::StrCat(
76               "Unable to copy tensor value for broadcast. Original message: ",
77               TF_Message(status))
78               .c_str());
79       return nullptr;
80     }
81   }
82 
83   std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor =
84       parallel_device::ParallelTensor::FromTensorHandles(
85           mesh.parallel_device(), std::move(components), status);
86   if (TF_GetCode(status) != TF_OK) return nullptr;
87   return parallel_tensor;
88 }
89 
90 // Broadcast a single non-parallel resource tensor onto `mesh` with a fully
91 // replicated sharding spec. Does not take ownership of `tensor`.
BroadcastResourceTensor(TFE_Context * context,TFE_TensorHandle * tensor,const MeshWithParallelDevice & mesh,const std::string & dtensor_device_name,TF_Status * status)92 std::unique_ptr<TensorWithLayout> BroadcastResourceTensor(
93     TFE_Context* context, TFE_TensorHandle* tensor,
94     const MeshWithParallelDevice& mesh, const std::string& dtensor_device_name,
95     TF_Status* status) {
96   // Only broadcast resource tensors that point to scalars since they are
97   // always replicated. We also still want to catch honest user errors so
98   // error out on non-scalars.
99   // Resolve the Tensor as resource handle and get the shape and dtype
100   // of the tensor it points to.
101   std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tf_tensor(
102       TFE_TensorHandleResolve(tensor, status), TF_DeleteTensor);
103   Tensor t;
104   Status convert_status = TF_TensorToTensor(tf_tensor.get(), &t);
105   if (!convert_status.ok() || t.dtype() != DataType::DT_RESOURCE) {
106     TF_SetStatus(status, TF_INTERNAL, convert_status.error_message().c_str());
107     return nullptr;
108   }
109   // Replicate this resource handle to all devices without changing the
110   // associated device of the resource itself.
111   ResourceHandle r = t.flat<ResourceHandle>()(0);
112   if (r.dtypes_and_shapes().empty()) {
113     TF_SetStatus(status, TF_INTERNAL,
114                  "Expected resource handle to have at least one underlying "
115                  "dtype and shape during broadcasting.");
116     return nullptr;
117   }
118   PartialTensorShape partial_shape = r.dtypes_and_shapes().begin()->shape;
119   int64_t num_elements = partial_shape.num_elements();
120 
121   // Only broadcast scalar resource tensors onto a CPU mesh. Copying
122   // resource tensors to non CPU device is not supported.
123   if (num_elements != 1 || !mesh.mesh_config().is_cpu_mesh()) {
124     std::string error_message =
125         "Using a non-DTensor variable with DTensor is only supported for "
126         "scalar variables copying to a CPU mesh. If you are using a scope "
127         "based API, create "
128         "variables inside the DTensor scope.\n";
129 
130     // Get the stack_trace and Summaries from the resource tensor.
131     absl::StrAppend(
132         &error_message, "Offending variable summary: ", r.SummarizeValue(),
133         "\nStack trace: ", DefinitionLocationMsg(r.definition_stack_trace()));
134     TF_SetStatus(status, TF_INVALID_ARGUMENT, error_message.c_str());
135     return nullptr;
136   }
137 
138   LOG(INFO) << "Broadcasting resource tensor to a dtensor resource tensor.";
139   if (mesh.mesh_config().is_remote()) {
140     TF_DataType dtype = TFE_TensorHandleDataType(tensor);
141     std::vector<int64_t> shape(TensorShapeAsVector(tensor, status));
142     if (TF_GetCode(status) != TF_OK) return nullptr;
143     auto layout = Layout::ReplicatedOnMesh(mesh.mesh_config(), shape.size());
144 
145     auto ret = TensorWithLayout::Dummy(shape, dtype, mesh, layout);
146     return ret;
147   }
148 
149   std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor =
150       BroadcastTensorHandleToParallelTensor(context, tensor, mesh, status);
151   if (TF_GetCode(status) != TF_OK) return nullptr;
152 
153   StatusOr<std::unique_ptr<TensorWithLayout>> result = TensorWithLayout::Wrap(
154       std::move(parallel_tensor), mesh,
155       Layout::ReplicatedOnMesh(mesh.mesh_config(), partial_shape.dims()));
156   if (!result.ok()) {
157     TF_SetStatus(
158         status, TF_INTERNAL,
159         absl::StrCat("Error creating a TensorWithLayout from a resource tensor "
160                      "during broadcasting with original error message:",
161                      result.status().error_message())
162             .c_str());
163     return nullptr;
164   }
165   // Set the shape/type of the tensor that the resource points to
166   // so that the graph has correct shape/type information that we can use.
167   (*result)->UpdateShapeAndDType(partial_shape.AsProto(),
168                                  r.dtypes_and_shapes().begin()->dtype, status);
169   if (TF_GetCode(status) != TF_OK) {
170     TF_SetStatus(status, TF_INTERNAL,
171                  "Error updating shape and dtype for resource tensor during "
172                  "broadcasting.");
173     return nullptr;
174   }
175   return std::move(*result);
176 }
177 
LayoutsAreCompatible(absl::optional<Layout> first_layout,absl::optional<Layout> second_layout)178 bool LayoutsAreCompatible(absl::optional<Layout> first_layout,
179                           absl::optional<Layout> second_layout) {
180   if (!first_layout.has_value() && !second_layout.has_value()) {
181     return true;
182   }
183   if (!first_layout.has_value() || !second_layout.has_value()) {
184     return false;
185   }
186   return first_layout.value() == second_layout.value();
187 }
188 
189 // Parse a pair of attribute of (indices, layouts) into a map.
ParseAttrMap(const Node & node,absl::string_view indices_attr,absl::string_view layout_attr,std::map<int,Layout> * indices_layout_map)190 Status ParseAttrMap(const Node& node, absl::string_view indices_attr,
191                     absl::string_view layout_attr,
192                     std::map<int, Layout>* indices_layout_map) {
193   std::vector<std::string> layouts;
194   if (!TryGetNodeAttr(node.attrs(), layout_attr, &layouts)) {
195     return OkStatus();
196   }
197   const TensorProto* indices;
198   if (!TryGetNodeAttr(node.attrs(), indices_attr, &indices)) {
199     return errors::Internal(
200         "Arg indices must be set when setting inferred resource layouts.");
201   }
202   if (indices->int_val_size() != layouts.size()) {
203     return errors::Internal(
204         "Arg indices for inferred resource argument must match the "
205         "size of inferred resource layout.");
206   }
207   for (int i = 0; i < indices->int_val_size(); ++i) {
208     const auto arg_index = indices->int_val(i);
209     const auto& arg_layout = layouts[i];
210     indices_layout_map->emplace(
211         arg_index,
212         tensorflow::dtensor::Layout::FromString(arg_layout).ValueOrDie());
213   }
214   return OkStatus();
215 }
216 
ParseResourceArgumentLayouts(const Node & node,std::map<int,Layout> * inferred_resource_input_layouts)217 Status ParseResourceArgumentLayouts(
218     const Node& node, std::map<int, Layout>* inferred_resource_input_layouts) {
219   return ParseAttrMap(node, kNewResourceLayoutIndices, kNewResourceArgLayouts,
220                       inferred_resource_input_layouts);
221 }
222 
ParseShapeInputLayouts(const Node & node,std::map<int,Layout> * shape_output_metadata)223 Status ParseShapeInputLayouts(const Node& node,
224                               std::map<int, Layout>* shape_output_metadata) {
225   return ParseAttrMap(node, kShapeOpInputLayoutIndices, kShapeOpInputLayout,
226                       shape_output_metadata);
227 }
228 
229 // Gets the layout attached to a specific node at a given index, ignoring any
230 // Identity ops.
GetLayoutThroughIdentityOps(Node * op,int output_index)231 StatusOr<Layout> GetLayoutThroughIdentityOps(Node* op, int output_index) {
232   while (op->op_def().name() == "Identity" ||
233          op->op_def().name() == "IdentityN") {
234     const Edge* edge;
235     TF_RETURN_IF_ERROR(op->input_edge(output_index, &edge));
236     op = edge->src();
237     output_index = edge->src_output();
238   }
239   const auto serialized_layouts = op->attrs().Find(kLayoutAttr);
240 
241   if (!serialized_layouts) {
242     return errors::InvalidArgument(
243         op->op_def().name(), " doesn't contain attribute : ", kLayoutAttr);
244   }
245 
246   // We assume that there is one layout for each output.
247   if (serialized_layouts->list().s_size() != op->num_outputs()) {
248     return errors::InvalidArgument(
249         "Number of outputs to ", op->op_def().name(),
250         " does not match number of layouts attached");
251   }
252 
253   return Layout::FromString(serialized_layouts->list().s(output_index));
254 }
255 
256 }  // namespace
257 
CacheKey() const258 tensorflow::Fprint128 TensorWithLayout::CacheKey() const {
259   tensorflow::Fprint128 f = tensorflow::Fingerprint128(layout_.ToString());
260   // Use exact shape to compute the key.
261   for (const int64_t dim : local_shape()) {
262     f = FingerprintCat128(f, dim);
263   }
264   if (const_value_.has_value()) {
265     std::string serialized;
266     SerializeToStringDeterministic(const_value_.value(), &serialized);
267     f = FingerprintCat128(f, tensorflow::Fingerprint128(serialized));
268   }
269   return f;
270 }
271 
Broadcast(TFE_Context * context,TFE_TensorHandle * tensor,const MeshWithParallelDevice & mesh,const std::string & dtensor_device_name,TF_Status * status)272 std::unique_ptr<TensorWithLayout> TensorWithLayout::Broadcast(
273     TFE_Context* context, TFE_TensorHandle* tensor,
274     const MeshWithParallelDevice& mesh, const std::string& dtensor_device_name,
275     TF_Status* status) {
276   const char* input_device = TFE_TensorHandleDeviceName(tensor, status);
277   if (TF_GetCode(status) != TF_OK) return nullptr;
278 
279   if (dtensor_device_name == input_device) {
280     TF_SetStatus(status, TF_INVALID_ARGUMENT,
281                  "Input to Broadcast must be eager tensor.");
282     return nullptr;
283   }
284 
285   // Handle resource tensor broadcasting to the mesh.
286   if (TFE_TensorHandleDataType(tensor) == TF_RESOURCE) {
287     return BroadcastResourceTensor(context, tensor, mesh, dtensor_device_name,
288                                    status);
289   }
290 
291   if (mesh.mesh_config().is_remote()) {
292     TF_DataType dtype = TFE_TensorHandleDataType(tensor);
293     std::vector<int64_t> shape(TensorShapeAsVector(tensor, status));
294     if (TF_GetCode(status) != TF_OK) return nullptr;
295     auto layout = Layout::ReplicatedOnMesh(mesh.mesh_config(), shape.size());
296 
297     auto ret = TensorWithLayout::Dummy(shape, dtype, mesh, layout);
298     absl::optional<NodeDef> const_value =
299         ExtractSmallTensorValue(context, tensor, layout, status);
300     if (TF_GetCode(status) != TF_OK) return nullptr;
301     if (const_value) {
302       ret->set_const_value(const_value.value());
303     }
304     return ret;
305   }
306 
307   std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor =
308       BroadcastTensorHandleToParallelTensor(context, tensor, mesh, status);
309   if (TF_GetCode(status) != TF_OK) return nullptr;
310 
311   const std::vector<int64_t>* shape;
312   Status s = parallel_tensor->Shape(&shape);
313   if (!s.ok()) {
314     TF_SetStatus(status, static_cast<TF_Code>(s.code()),
315                  s.error_message().c_str());
316     return nullptr;
317   }
318   size_t num_dims = shape->size();
319   const Layout layout = Layout::ReplicatedOnMesh(mesh.mesh_config(), num_dims);
320 
321   absl::optional<NodeDef> const_value =
322       ExtractSmallTensorValue(context, tensor, layout, status);
323   if (TF_GetCode(status) != TF_OK) return nullptr;
324 
325   std::unique_ptr<TensorWithLayout> result(new TensorWithLayout(
326       std::move(parallel_tensor), mesh, std::move(layout), *shape,
327       /*dtype=*/absl::nullopt, std::move(const_value)));
328   return result;
329 }
330 
Wrap(std::unique_ptr<parallel_device::ParallelTensor> tensor,const MeshWithParallelDevice & mesh,const Layout & layout)331 StatusOr<std::unique_ptr<TensorWithLayout>> TensorWithLayout::Wrap(
332     std::unique_ptr<parallel_device::ParallelTensor> tensor,
333     const MeshWithParallelDevice& mesh, const Layout& layout) {
334   const std::vector<int64_t>* shape;
335   TF_RETURN_IF_ERROR(tensor->Shape(&shape));
336 
337   if (tensor->dtype() != TF_RESOURCE) {
338     return std::unique_ptr<TensorWithLayout>(
339         new TensorWithLayout(std::move(tensor), mesh, layout, *shape));
340   } else {
341     return std::unique_ptr<TensorWithLayout>(
342         new ResourceHandleWithLayout(std::move(tensor), mesh, layout, *shape));
343   }
344 }
345 
Dummy(const std::vector<int64_t> & local_shape,const TF_DataType dtype,const MeshWithParallelDevice & mesh,const Layout & layout)346 std::unique_ptr<TensorWithLayout> TensorWithLayout::Dummy(
347     const std::vector<int64_t>& local_shape, const TF_DataType dtype,
348     const MeshWithParallelDevice& mesh, const Layout& layout) {
349   if (dtype != TF_RESOURCE) {
350     return std::unique_ptr<TensorWithLayout>(new TensorWithLayout(
351         /*tensor=*/nullptr, mesh, layout, local_shape, dtype));
352   } else {
353     return std::unique_ptr<TensorWithLayout>(new ResourceHandleWithLayout(
354         /*tensor=*/nullptr, mesh, layout, local_shape));
355   }
356 }
357 
SummarizeValue() const358 std::string TensorWithLayout::SummarizeValue() const {
359   std::string value_summary;
360   Status status;
361   if (layout().IsFullyReplicated()) {
362     status =
363         tensorflow::unwrap(tensor()->tensor(0))->SummarizeValue(value_summary);
364   } else {
365     // Note that this just prints the local values for sharded tensors. We could
366     // instead run a collective here to relayout to replicated.
367     status = tensor()->SummarizeValue(value_summary);
368   }
369   if (!status.ok()) {
370     value_summary = "<error computing value>";
371   }
372   return absl::StrCat(value_summary, ", layout=\"", layout().ToString(), "\"");
373 }
374 
DebugString() const375 std::string TensorWithLayout::DebugString() const {
376   auto dtype = static_cast<DataType>(tensor()->dtype());
377 
378   const auto& shape_vector = global_shape();
379   return absl::StrCat("DTensor(", SummarizeValue(),
380                       ", shape=", ShapeToDebugString(shape_vector),
381                       ", type=", DataTypeString(dtype), ")");
382 }
383 
EncodeAttributes(tensorflow::NodeDefBuilder & builder) const384 void ResourceHandleWithLayout::EncodeAttributes(
385     tensorflow::NodeDefBuilder& builder) const {
386   // If set, attach shape and dtype to the given node def.
387   if (dereferenced_shape().has_value()) {
388     builder.Attr("_handle_shapes", {*dereferenced_shape()});
389   }
390   if (dereferenced_dtype().has_value()) {
391     builder.Attr("_handle_dtypes", {*dereferenced_dtype()});
392   }
393 }
394 
CacheKey() const395 tensorflow::Fprint128 ResourceHandleWithLayout::CacheKey() const {
396   tensorflow::Fprint128 f = tensorflow::Fingerprint128(layout().ToString());
397   if (dereferenced_shape().has_value()) {
398     std::string serialized;
399     SerializeToStringDeterministic(dereferenced_shape().value(), &serialized);
400     f = FingerprintCat128(f, tensorflow::Fingerprint128(serialized));
401   }
402   if (dereferenced_dtype().has_value()) {
403     f = FingerprintCat128(f, dereferenced_dtype().value());
404   }
405   return f;
406 }
407 
UpdateLayout(const Layout & new_layout,TF_Status * status)408 void ResourceHandleWithLayout::UpdateLayout(const Layout& new_layout,
409                                             TF_Status* status) {
410   // Only set the value for deferenced layout if the incoming layout is not
411   // empty. This is still hacky as we use empty layout as placeholder for
412   // eagerly placed VarHandleOp.
413   if (!dereferenced_layout_.has_value() && new_layout.IsEmpty()) return;
414   if (dereferenced_layout_.has_value() &&
415       !LayoutsAreCompatible(dereferenced_layout_, new_layout)) {
416     // TODO(xiejw, allenl): Consider allowing variables to switch layouts.
417     RETURN_STATUS(status, TF_INVALID_ARGUMENT,
418                   "Attempted to overwrite an existing Layout.");
419   }
420   dereferenced_layout_.emplace(new_layout);
421 }
422 
UpdateAttrs(const EmbeddingResourceAttrs & attrs,TF_Status * status)423 void ResourceHandleWithLayout::UpdateAttrs(const EmbeddingResourceAttrs& attrs,
424                                            TF_Status* status) {
425   if (attrs_.has_value()) {
426     RETURN_STATUS(status, TF_INVALID_ARGUMENT,
427                   "Attepted to overwrite an existing embedding resource "
428                   "attribute.");
429   }
430   attrs_.emplace(attrs);
431 }
432 
Wrap(std::unique_ptr<parallel_device::ParallelTensor> indices_tensor,std::unique_ptr<parallel_device::ParallelTensor> values_tensor,std::unique_ptr<parallel_device::ParallelTensor> shapes_tensor,const MeshWithParallelDevice & mesh,const Layout & layout,std::vector<int64_t> local_shape)433 StatusOr<std::unique_ptr<TensorWithLayout>> SparseTensorWithLayout::Wrap(
434     std::unique_ptr<parallel_device::ParallelTensor> indices_tensor,
435     std::unique_ptr<parallel_device::ParallelTensor> values_tensor,
436     std::unique_ptr<parallel_device::ParallelTensor> shapes_tensor,
437     const MeshWithParallelDevice& mesh, const Layout& layout,
438     std::vector<int64_t> local_shape) {
439   return std::unique_ptr<TensorWithLayout>(new SparseTensorWithLayout(
440       std::move(indices_tensor), std::move(values_tensor),
441       std::move(shapes_tensor), mesh, layout, local_shape));
442 }
443 
SummarizeValue() const444 std::string SparseTensorWithLayout::SummarizeValue() const {
445   std::string indices_summary;
446   std::string values_summary;
447   std::string dense_shapes_summary;
448 
449   Status indices_status;
450   Status values_status;
451   Status dense_shapes_status;
452 
453   if (layout().IsFullyReplicated()) {
454     indices_status = tensorflow::unwrap(indices_->tensor(0))
455                          ->SummarizeValue(indices_summary);
456     values_status =
457         tensorflow::unwrap(values_->tensor(0))->SummarizeValue(values_summary);
458     dense_shapes_status = tensorflow::unwrap(dense_shapes_->tensor(0))
459                               ->SummarizeValue(dense_shapes_summary);
460   } else {
461     indices_status = indices_->SummarizeValue(indices_summary);
462     values_status = values_->SummarizeValue(values_summary);
463     dense_shapes_status = dense_shapes_->SummarizeValue(dense_shapes_summary);
464   }
465 
466   if (!indices_status.ok())
467     values_summary = "<error computing summary for indices>";
468   if (!values_status.ok())
469     indices_summary = "<error computing summary for values>";
470   if (!dense_shapes_status.ok())
471     indices_summary = "<error computing summary for dense_shapes>";
472 
473   return absl::StrCat("indices: ", indices_summary, ", ",
474                       "values: ", values_summary, ", ",
475                       "dense_shapes: ", dense_shapes_summary, ", layout=\"",
476                       layout().ToString(), "\"");
477 }
478 
DebugString() const479 std::string SparseTensorWithLayout::DebugString() const {
480   auto dtype = static_cast<DataType>(values_->dtype());
481 
482   const auto& shape_vector = global_shape();
483   return absl::StrCat("DTensor(", SummarizeValue(),
484                       ", shape=", ShapeToDebugString(shape_vector),
485                       ", type=", DataTypeString(dtype), ")");
486 }
487 
dtype() const488 TF_DataType SparseTensorWithLayout::dtype() const {
489   if (dtype_.has_value()) {
490     return dtype_.value();
491   } else {
492     return values_->dtype();
493   }
494 }
495 
get_tensor(size_t index) const496 TFE_TensorHandle* SparseTensorWithLayout::get_tensor(size_t index) const {
497   int num_sparse_tensors = num_tensors() / 3;
498   if (index < num_sparse_tensors) {
499     return indices()->tensor(index);
500   } else if (index < 2 * num_sparse_tensors) {
501     return values()->tensor(index % num_sparse_tensors);
502   } else {
503     return dense_shapes()->tensor(index % num_sparse_tensors);
504   }
505 }
506 
GetConstantFoldableTensors(const std::vector<TensorWithLayout * > & inputs)507 absl::flat_hash_map<int, NodeDef> GetConstantFoldableTensors(
508     const std::vector<TensorWithLayout*>& inputs) {
509   absl::flat_hash_map<int, NodeDef> small_tensors;
510   for (auto index = 0; index < inputs.size(); ++index) {
511     if (inputs[index]->const_value().has_value()) {
512       small_tensors.insert({index, inputs[index]->const_value().value()});
513     }
514   }
515   return small_tensors;
516 }
517 
518 // Thread unsafe method. go/thread-unsafe
519 // Cache key computation should consider all features of an op that affects
520 // the SPMD lowering. The cache keys of two ops must be different if the
521 // translated functions are different.
522 // - op name and attr
523 // - input shapes and layouts
524 // - default layout of outputs.
525 // - values of constant foldable inputs.
CacheKeyForGraph(const DTensorOperation & doperation,const NameAttrList & attributes,const std::vector<TensorWithLayout * > & inputs,const std::vector<const Layout * > & output_layouts)526 tensorflow::Fprint128 FunctionManager::CacheKeyForGraph(
527     const DTensorOperation& doperation, const NameAttrList& attributes,
528     const std::vector<TensorWithLayout*>& inputs,
529     const std::vector<const Layout*>& output_layouts) {
530   tensorflow::Fprint128 cache_key = tensorflow::Fingerprint128(doperation.name);
531   std::string serialized;
532   SerializeToStringDeterministic(attributes, &serialized);
533   cache_key =
534       FingerprintCat128(cache_key, tensorflow::Fingerprint128(serialized));
535   // Higher level cache based on operation name and input shapes.
536   for (auto i = 0; i < inputs.size(); ++i) {
537     if (!IsConstantFoldable(doperation, i)) {
538       inputs[i]->reset_const_value();
539     }
540     cache_key = FingerprintCat128(cache_key, inputs[i]->CacheKey());
541   }
542   for (int output_index = 0; output_index < output_layouts.size();
543        ++output_index) {
544     if (output_layouts[output_index]) {
545       cache_key = FingerprintCat128(cache_key, output_index);
546       cache_key = FingerprintCat128(
547           cache_key,
548           tensorflow::Fingerprint128(output_layouts[output_index]->ToString()));
549     }
550   }
551   return cache_key;
552 }
553 
554 // Thread-unsafe method go/thread-unsafe.
555 std::pair<tensorflow::Fprint128, const ExecutionFunctions*>
GetCachedFunction(const DTensorOperation & doperation,const NameAttrList & attributes,const std::vector<TensorWithLayout * > & inputs,const std::vector<const Layout * > & output_layouts)556 FunctionManager::GetCachedFunction(
557     const DTensorOperation& doperation, const NameAttrList& attributes,
558     const std::vector<TensorWithLayout*>& inputs,
559     const std::vector<const Layout*>& output_layouts) {
560   tensorflow::Fprint128 cache_key =
561       CacheKeyForGraph(doperation, attributes, inputs, output_layouts);
562   auto iter = function_cache_.find(cache_key);
563 
564   // Early return if we have a cache hit.
565   if (iter != function_cache_.end()) {
566     return std::pair<Fprint128, ExecutionFunctions*>(cache_key, &iter->second);
567   }
568 
569   // For eager ops we early return the cache miss and do not make further
570   // optimizations.
571   if (!doperation.is_func()) {
572     return std::pair<Fprint128, std::nullptr_t>(cache_key, nullptr);
573   }
574 
575   const tensorflow::Fprint128 doperation_hash =
576       CacheKeyForDTensorOperation(doperation);
577 
578   // Save the constant folded inputs to this doperation if we have not seen this
579   // before. This is needed so that in the next call to this operation, we
580   // can compare these inputs to confirm which one is indeed a constant.
581   auto doperation_iter = dtensor_op_and_small_inputs_.find(doperation_hash);
582   if (doperation_iter == dtensor_op_and_small_inputs_.end()) {
583     dtensor_op_and_small_inputs_.insert(
584         {doperation_hash, GetConstantFoldableTensors(inputs)});
585     return std::pair<Fprint128, std::nullptr_t>(cache_key, nullptr);
586   }
587 
588   // If we are here, then we have ran this function before but constant folded
589   // some input(s) when it was not a constant input i.e. one of the small value
590   // to this function input changed. So mark those changed values as
591   // non-constant.
592   absl::flat_hash_map<int, NodeDef>& previous_small_inputs =
593       doperation_iter->second;
594   std::vector<int> non_constant_indices;
595 
596   for (auto const& [index, previous_small_input] : previous_small_inputs) {
597     if (inputs[index]->const_value().has_value()) {
598       if (NodeDefsHaveDifferentTensorProto(
599               previous_small_input, inputs[index]->const_value().value())) {
600         inputs[index]->reset_const_value();
601         non_constant_indices.push_back(index);
602       }
603     }
604   }
605   for (int non_constant_index : non_constant_indices) {
606     previous_small_inputs.erase(non_constant_index);
607   }
608   // Generate a new cache key since we updated small const inputs which change
609   // the cache key.
610   cache_key = CacheKeyForGraph(doperation, attributes, inputs, output_layouts);
611   return std::pair<Fprint128, std::nullptr_t>(cache_key, nullptr);
612 }
613 
AddCachedFunction(const DTensorOperation & op,tensorflow::Fprint128 cache_key,ExecutionFunctions function)614 const ExecutionFunctions* FunctionManager::AddCachedFunction(
615     const DTensorOperation& op, tensorflow::Fprint128 cache_key,
616     ExecutionFunctions function) {
617   return &function_cache_.insert({cache_key, std::move(function)})
618               .first->second;
619 }
620 
IsConstantFoldable(const DTensorOperation & doperation,const int input_index) const621 bool FunctionManager::IsConstantFoldable(const DTensorOperation& doperation,
622                                          const int input_index) const {
623   // For eager ops, assume the inputs are constant foldable.
624   if (!doperation.is_func()) return true;
625   const tensorflow::Fprint128 doperation_hash =
626       CacheKeyForDTensorOperation(doperation);
627   // If we didn't see this doperation before then optimisticly assume this is
628   // foldable. The input at `input_index` is foldable only if it is one of the
629   // indices we have saved as the small inputs.
630   auto doperation_iter = dtensor_op_and_small_inputs_.find(doperation_hash);
631   return doperation_iter == dtensor_op_and_small_inputs_.end() ||
632          doperation_iter->second.contains(input_index);
633 }
634 
CacheKeyForDTensorOperation(const DTensorOperation & doperation) const635 const tensorflow::Fprint128 FunctionManager::CacheKeyForDTensorOperation(
636     const DTensorOperation& doperation) const {
637   return tensorflow::Fingerprint128(doperation.name);
638 }
639 
TensorShapeAsVector(TFE_TensorHandle * tensor,TF_Status * status)640 std::vector<int64_t> TensorShapeAsVector(TFE_TensorHandle* tensor,
641                                          TF_Status* status) {
642   std::vector<int64_t> shape(TFE_TensorHandleNumDims(tensor, status));
643   if (TF_GetCode(status) != TF_OK) return {};
644   for (int i = 0; i < shape.size(); ++i) {
645     shape[i] = TFE_TensorHandleDim(tensor, i, status);
646     if (TF_GetCode(status) != TF_OK) return {};
647   }
648   return shape;
649 }
650 
PrepareGraphForMlir(const FunctionManager & function_manager,const std::vector<TensorWithLayout * > & inputs,const DTensorOperation & doperation,const tensorflow::FunctionLibraryDefinition & flib_def,const NameAttrList & attributes,const absl::optional<Layout> & default_layout,tensorflow::Graph * graph,std::vector<PartialTensorShape> * global_output_shapes,std::vector<const Layout * > * output_layouts)651 Status PrepareGraphForMlir(
652     const FunctionManager& function_manager,
653     const std::vector<TensorWithLayout*>& inputs,
654     const DTensorOperation& doperation,
655     const tensorflow::FunctionLibraryDefinition& flib_def,
656     const NameAttrList& attributes,
657     const absl::optional<Layout>& default_layout, tensorflow::Graph* graph,
658     std::vector<PartialTensorShape>* global_output_shapes,
659     std::vector<const Layout*>* output_layouts) {
660   // We run shape inference on the graph to find output shapes, which may
661   // determine default layouts.
662   ShapeRefiner shape_refiner(TF_GRAPH_DEF_VERSION, &flib_def);
663   shape_refiner.set_function_library_for_shape_inference(&flib_def);
664   tensorflow::Status status;
665   {
666     // We include an _Arg node for the device ID, but this isn't used by the
667     // initial function. It will be provided a value, though, so it's available
668     // for use in rewrites.
669     tensorflow::NodeDefBuilder builder("device_id", "_Arg");
670     tensorflow::PartialTensorShape partial_shape;
671     TF_RETURN_IF_ERROR(tensorflow::PartialTensorShape::MakePartialShape(
672         static_cast<int*>(nullptr), 0, &partial_shape));
673     tensorflow::NodeDef arg_node_def;
674     TF_RETURN_IF_ERROR(builder.Attr("shape", partial_shape)
675                            .Attr("T", tensorflow::DT_INT32)
676                            .Attr("index", 0)
677                            .Finalize(&arg_node_def, /*consume=*/true));
678     tensorflow::Node* arg_node = graph->AddNode(arg_node_def, &status);
679     TF_RETURN_IF_ERROR(status);
680     graph->AddControlEdge(graph->source_node(), arg_node);
681     TF_RETURN_IF_ERROR(shape_refiner.AddNode(arg_node));
682   }
683   std::vector<FunctionArgument> graph_op_inputs;
684   graph_op_inputs.reserve(inputs.size());
685   for (int i = 0; i < inputs.size(); ++i) {
686     const TensorWithLayout* input = inputs[i];
687     // TODO(allenl): This will block until async execution is complete, which
688     // will be slow. We should find a non-blocking way of fetching the shape,
689     // at least pre-cache.
690     // The shape passed into MLIR transformation represents the global shape of
691     // the tensor. Ideally, the local shape on each parallel device should not
692     // be consulted at all and we should use the shape on our input tensor
693     // directly.
694     const auto& shape = input->global_shape();
695     std::vector<tensorflow::int64> cast_shape(shape.begin(), shape.end());
696     tensorflow::PartialTensorShape partial_shape;
697     // For resource tensors, `shape` attribute should not be specified as shape
698     // of resource tensors is specified by resource shape subtype -- not the
699     // shape attribute.
700     auto* resource = dynamic_cast<const ResourceHandleWithLayout*>(input);
701     if (!resource) {
702       TF_RETURN_IF_ERROR(tensorflow::PartialTensorShape::MakePartialShape(
703           cast_shape.data(), cast_shape.size(), &partial_shape));
704     }
705 
706     tensorflow::NodeDef arg_node_def;
707     auto dtype = static_cast<tensorflow::DataType>(input->dtype());
708     tensorflow::NodeDefBuilder builder(absl::StrCat("op_input_", i), "_Arg");
709 
710     // Delegate TensorWithLayout to encode attributes if applicable.
711     input->EncodeAttributes(builder);
712 
713     // Here we set each arg node's `index` attribute to the position of
714     // the dtensor inputs. This is important for later use when we create
715     // a mapping from the graph argument node to the corresponding argument
716     // index of the list of dtensor inputs. Thus, even if the argument node
717     // orderings change within the graph, we can always correctly
718     // find the dtensor input corresponding to that arg node.
719     //
720     // This assumes that the dtensor inputs stay unchanged in ordering,
721     // and if there is an ordering change of dtensor inputs, then special
722     // care must be taken.
723     TF_RETURN_IF_ERROR(
724         builder.Attr("shape", partial_shape)
725             .Attr("T", dtype)
726             .Attr("index", i + 1)  // Indices are offset by 1 for device_id
727             .Attr(kLayoutAttr, input->layout().ToString())
728             .Attr(kMeshAttr, input->mesh().mesh_config().ToString())
729             .Finalize(&arg_node_def, /*consume=*/true));
730     Node* arg_node = graph->AddNode(arg_node_def, &status);
731     TF_RETURN_IF_ERROR(status);
732     TF_RETURN_IF_ERROR(shape_refiner.AddNode(arg_node));
733 
734     shape_inference::InferenceContext* inference_context =
735         shape_refiner.GetContext(arg_node);
736     shape_inference::ShapeHandle shape_handle;
737     TF_RETURN_IF_ERROR(inference_context->MakeShapeFromPartialTensorShape(
738         partial_shape, &shape_handle));
739     TF_RETURN_IF_ERROR(shape_refiner.SetShape(arg_node, 0, shape_handle));
740 
741     // Small constants are converted into constant graph nodes, instead of being
742     // passed in as input arguments. This provides more information to the SPMD
743     // and layout propagation passes.
744     if (!input->const_value().has_value() ||
745         !function_manager.IsConstantFoldable(doperation, i)) {
746       graph_op_inputs.push_back(FunctionArgument{
747           arg_node, NodeDefBuilder::NodeOut{arg_node->name(), i, dtype}});
748       graph->AddControlEdge(graph->source_node(), arg_node);
749     } else {
750       // TODO(xiejw): Refactor the TensorWithLayout representation to avoid
751       // special code here.
752       NodeDef const_node = input->const_value().value();
753       const_node.set_name(absl::StrCat("input_", i, "_const_value"));
754       Node* const_value_n = graph->AddNode(const_node, &status);
755       TF_RETURN_IF_ERROR(status);
756       TF_RETURN_IF_ERROR(shape_refiner.AddNode(const_value_n));
757       graph_op_inputs.push_back(FunctionArgument{
758           const_value_n, tensorflow::NodeDefBuilder::NodeOut{
759                              const_value_n->name(), i, dtype}});
760     }
761   }
762 
763   tensorflow::NodeDef op_node_def;
764   const FunctionDef* function_def = doperation.function_def;
765   if (function_def) {
766     AttrValue func_attr;
767     func_attr.mutable_func()->set_name(doperation.name);
768     std::vector<tensorflow::NodeDefBuilder::NodeOut> func_inputs;
769     std::vector<tensorflow::DataType> inputs_types;
770     for (const auto& in : graph_op_inputs) {
771       func_inputs.emplace_back(in.output);
772       inputs_types.emplace_back(in.output.data_type);
773     }
774 
775     std::vector<tensorflow::DataType> output_types;
776     for (const auto& out : function_def->signature().output_arg())
777       output_types.emplace_back(out.type());
778 
779     TF_RETURN_IF_ERROR(
780         NodeDefBuilder("eager_operation", "StatefulPartitionedCall")
781             .Attr("Tin", inputs_types)
782             .Attr("Tout", output_types)
783             .Attr("f", func_attr)
784             .Input(func_inputs)
785             .Finalize(&op_node_def, true));
786   } else {
787     op_node_def.set_op(doperation.name);
788     op_node_def.set_name("eager_operation");
789   }
790 
791   op_node_def.mutable_attr()->insert(attributes.attr().begin(),
792                                      attributes.attr().end());
793 
794   tensorflow::Node* op_node = graph->AddNode(op_node_def, &status);
795   TF_RETURN_IF_ERROR(status);
796 
797   for (int i = 0; i < graph_op_inputs.size(); ++i) {
798     graph->AddEdge(graph_op_inputs[i].node, 0, op_node, i);
799   }
800   TF_RETURN_IF_ERROR(shape_refiner.AddNode(op_node));
801 
802   output_layouts->clear();
803   output_layouts->reserve(op_node->num_outputs());
804   global_output_shapes->reserve(op_node->num_outputs());
805   for (int output_index = 0; output_index < op_node->num_outputs();
806        ++output_index) {
807     tensorflow::NodeDefBuilder builder(absl::StrCat("op_output_", output_index),
808                                        "_Retval");
809     tensorflow::NodeDef ret_node_def;
810     tensorflow::DataType output_type = op_node->output_type(output_index);
811 
812     TF_RETURN_IF_ERROR(builder.Attr("T", output_type)
813                            .Attr("index", output_index)
814                            .Input("eager_operation", output_index, output_type)
815                            .Finalize(&ret_node_def, /*consume=*/true));
816     tensorflow::Node* ret_node = graph->AddNode(ret_node_def, &status);
817     TF_RETURN_IF_ERROR(status);
818     graph->AddEdge(op_node, output_index, ret_node, 0);
819     graph->AddControlEdge(ret_node, graph->sink_node());
820 
821     shape_inference::InferenceContext* inference_context =
822         shape_refiner.GetContext(op_node);
823     shape_inference::ShapeHandle output_shape_handle =
824         inference_context->output(output_index);
825     TensorShapeProto output_shape_proto;
826     inference_context->ShapeHandleToProto(output_shape_handle,
827                                           &output_shape_proto);
828     PartialTensorShape global_output_shape(output_shape_proto);
829     VLOG(3) << "Inferred shape for operation '" << doperation.name
830             << "':" << global_output_shape.DebugString();
831     global_output_shapes->push_back(global_output_shape);
832 
833     const Layout* layout = nullptr;
834     if (default_layout.has_value() && output_index == 0) {
835       // Record the user's requested output layout. The scope currently only
836       // covers the first output of an op.
837       layout = &default_layout.value();
838       ret_node->AddAttr(kDefaultLayoutAttr, layout->ToString());
839     }
840     output_layouts->push_back(layout);
841   }
842   return OkStatus();
843 }
844 
845 // Returns set of functions to run to execute DTensor computation.
IdentifyAllFunctionsToExecute(const tensorflow::Graph & graph,const std::vector<PartialTensorShape> & global_output_shapes)846 StatusOr<ExecutionFunctions> IdentifyAllFunctionsToExecute(
847     const tensorflow::Graph& graph,
848     const std::vector<PartialTensorShape>& global_output_shapes) {
849   ExecutionFunctions execution_functions;
850   execution_functions.function_list = std::vector<TranslatedFunction>();
851   for (Node* node : graph.nodes()) {
852     if (node->op_def().name() != "StatefulPartitionedCall") continue;
853     // Extract mesh to execute the function.
854     std::string serialized_mesh;
855     TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kMeshAttr, &serialized_mesh));
856     Mesh mesh;
857     TF_ASSIGN_OR_RETURN(mesh, Mesh::FromString(serialized_mesh));
858 
859     TranslatedFunction function;
860     function.function_mesh = std::move(mesh);
861     function.node_to_execute = node;
862 
863     // Identify input arg information.
864     TF_RETURN_IF_ERROR(
865         ParseResourceArgumentLayouts(*node, &function.resource_input_layouts));
866 
867     TF_RETURN_IF_ERROR(
868         ParseShapeInputLayouts(*node, &function.shape_output_metadata));
869 
870     function.input_index_map.resize(node->num_inputs());
871     // Identity mapping between local mesh function input index and global
872     // input index.
873     for (int in_index = 0; in_index < node->num_inputs(); ++in_index) {
874       Node* input_node;
875 
876       TF_RETURN_IF_ERROR(node->input_node(in_index, &input_node));
877       if (!input_node->IsArg())
878         return errors::InvalidArgument(
879             "Input node to mesh computation must be arg node.");
880 
881       int global_index;
882       TF_RETURN_IF_ERROR(
883           GetNodeAttr(input_node->attrs(), "index", &global_index));
884       function.input_index_map[in_index] = global_index;
885     }
886 
887     // Identify output mappings and layouts for each outputs.
888     std::map<int, const Edge*> output_edges;
889     for (const Edge* out_edge : node->out_edges()) {
890       if (out_edge->IsControlEdge()) continue;
891 
892       const Node* retval_or_identity_node = out_edge->dst();
893       while (retval_or_identity_node->IsIdentity()) {
894         retval_or_identity_node =
895             *(retval_or_identity_node->out_nodes().begin());
896       }
897 
898       TF_RET_CHECK(retval_or_identity_node->IsRetval());
899       int global_index;
900       TF_RETURN_IF_ERROR(GetNodeAttr(retval_or_identity_node->attrs(), "index",
901                                      &global_index));
902       output_edges[global_index] = out_edge;
903     }
904 
905     for (auto it = output_edges.begin(); it != output_edges.end(); it++) {
906       const int global_index = it->first;
907       function.output_index_map.emplace_back(global_index);
908 
909       const Edge* retval_edge = it->second;
910       const int output_index = retval_edge->src_output();
911 
912       // Add output layout and shape information.
913       TF_ASSIGN_OR_RETURN(
914           const Layout output_layout,
915           GetLayoutThroughIdentityOps(retval_edge->src(), output_index));
916 
917       function.output_layouts.emplace_back(output_layout);
918       function.local_output_shapes.emplace_back(
919           output_layout.LocalShapeFromGlobalShape(
920               global_output_shapes[global_index]));
921     }
922 
923     execution_functions.function_list.emplace_back(std::move(function));
924   }
925 
926   if (execution_functions.function_list.empty()) {
927     return errors::InvalidArgument(
928         "MLIR transformed graph does not have any functions to execute for "
929         "mesh.");
930   }
931 
932   return execution_functions;
933 }
934 
935 // For functions with control outputs, add identity nodes between
936 // StatefulPartitionedCall and _Retvals, in order to preserve control output
937 // dependencies after StatefulPartitionedCall is inlined at runtime.
938 // Consider calling this in PrepareGraphForMlir, once the identity nodes won't
939 // be dropped during MLIR lowering.
940 // TODO(b/171265131): fix the underlying issue to avoid inserting identity
941 // nodes.
MaybeInsertIdentityNodes(const FunctionDef * function_def,Graph * graph)942 Status MaybeInsertIdentityNodes(const FunctionDef* function_def, Graph* graph) {
943   if (function_def == nullptr || function_def->control_ret().empty()) {
944     return OkStatus();
945   }
946   tensorflow::Status status;
947   for (Node* n : graph->nodes()) {
948     if (!n->IsRetval()) {
949       continue;
950     }
951     const Edge* edge;
952     TF_RETURN_IF_ERROR(n->input_edge(0, &edge));
953     int ret_index;
954     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &ret_index));
955     tensorflow::NodeDefBuilder identity_builder(
956         absl::StrCat("op_output_identity_", ret_index), "Identity");
957     tensorflow::NodeDef ret_identity_node_def;
958     tensorflow::DataType output_type = n->input_type(0);
959     TF_RETURN_IF_ERROR(
960         identity_builder.Attr("T", output_type)
961             .Input(edge->src()->name(), edge->src_output(), output_type)
962             .Finalize(&ret_identity_node_def, /*consume=*/true));
963     Node* ret_identity_node = graph->AddNode(ret_identity_node_def, &status);
964     TF_RETURN_IF_ERROR(status);
965     // Delete the edge between StatefulPartitionedCall and _Retval.
966     graph->RemoveEdge(edge);
967     // Add an edge between StatefulPartitionedCall and Identity.
968     graph->AddEdge(edge->src(), edge->src_output(), ret_identity_node, 0);
969     graph->AddControlEdge(edge->src(), ret_identity_node);
970     // Add an edge between Identity and _Retval.
971     graph->AddEdge(ret_identity_node, 0, n, 0);
972   }
973   return OkStatus();
974 }
975 
AddDTensorFunctionAttr(FunctionDef & function_def)976 void AddDTensorFunctionAttr(FunctionDef& function_def) {
977   // Do not xla compile function returned by DTensor MLIR graph transformation
978   // as it already returns compiled graph.
979   AttrValue xla_must_compile_val;
980   xla_must_compile_val.set_b(false);
981   function_def.mutable_attr()->insert(
982       {"_XlaMustCompile", xla_must_compile_val});
983 
984   // Explicitly place function outputs on the default function device to avoid
985   // redundant host <-> device copies (Placer may place outputs on the host
986   // CPU).
987   AttrValue outputs_on_op_device;
988   outputs_on_op_device.set_b(true);
989   function_def.mutable_attr()->insert(
990       {"_OutputsOnOpDevice", outputs_on_op_device});
991 }
992 
PrepareEmbeddingInputs(const std::vector<TensorWithLayout * > & inputs)993 StatusOr<std::vector<parallel_device::ParallelTensor*>> PrepareEmbeddingInputs(
994     const std::vector<TensorWithLayout*>& inputs) {
995   absl::flat_hash_map<int64_t, std::vector<int64_t>> table_vars_input_index;
996   for (int64_t i = 0; i < inputs.size(); ++i) {
997     if (inputs[i]->tensor_type() != kResource) continue;
998 
999     const absl::optional<EmbeddingResourceAttrs>& resource_attrs =
1000         inputs[i]->attrs();
1001     if (resource_attrs.has_value()) {
1002       table_vars_input_index[resource_attrs->table_id].push_back(i);
1003     }
1004   }
1005 
1006   // Check if there is no embedding resource input found.
1007   if (table_vars_input_index.empty()) {
1008     return errors::Internal("There are no TPU embedding resource input found.");
1009   }
1010   std::vector<parallel_device::ParallelTensor*> parallel_inputs;
1011   // Assure parallel inputs has numeric order as table ids.
1012   for (const auto& [table_id, table_vars_indices] : table_vars_input_index) {
1013     for (const int64_t input_index : table_vars_indices) {
1014       parallel_inputs.push_back(inputs[input_index]->tensor());
1015     }
1016   }
1017   return parallel_inputs;
1018 }
1019 
GetTPUEmbeddingInputNodes(TF_Status * s,const Graph & graph,const std::vector<TensorWithLayout * > & inputs)1020 StatusOr<std::map<int64_t, std::vector<Node*>>> GetTPUEmbeddingInputNodes(
1021     TF_Status* s, const Graph& graph,
1022     const std::vector<TensorWithLayout*>& inputs) {
1023   // After the graph is lowered, the sparse tensors live at the end of the
1024   // argument list, so process the dtensor dense inputs only so that
1025   // we index correctly.
1026   std::vector<TensorWithLayout*> non_sparse_inputs;
1027   non_sparse_inputs.reserve(inputs.size());
1028   for (TensorWithLayout* input : inputs) {
1029     if (input->tensor_type() != TensorType::kSparse) {
1030       non_sparse_inputs.push_back(input);
1031     }
1032   }
1033   std::map<int64_t, std::vector<Node*>> table_id_node_map;
1034   for (Node* node : graph.nodes()) {
1035     if (!node->IsArg()) continue;
1036 
1037     const int64_t& arg_id = node->attrs().Find("index")->i();
1038     const AttrValue* embedding_attr =
1039         node->attrs().Find("_tpu_embedding_table_id");
1040 
1041     if (embedding_attr == nullptr) continue;
1042     EmbeddingResourceAttrs embedding_input_attrs;
1043 
1044     // Add embedding table id.
1045     const int64_t table_id = embedding_attr->i();
1046     embedding_input_attrs.table_id = table_id;
1047 
1048     // Add embedding slot id if there is one.
1049     const AttrValue* embedding_slot_attr =
1050         node->attrs().Find("_tpu_embedding_slot_id");
1051     if (embedding_slot_attr != nullptr) {
1052       const int64_t slot_id = embedding_slot_attr->i();
1053       embedding_input_attrs.slot_id = slot_id;
1054     }
1055 
1056     table_id_node_map[table_id].push_back(node);
1057 
1058     // Arg input offset due to device id.
1059     if (non_sparse_inputs[arg_id - 1]->attrs().has_value()) continue;
1060     non_sparse_inputs[arg_id - 1]->UpdateAttrs(embedding_input_attrs, s);
1061     if (!s->status.ok()) {
1062       return errors::Internal(
1063           "Failed to set embedding resource attrs. \n Got error: ",
1064           s->status.error_message());
1065     }
1066   }
1067   return table_id_node_map;
1068 }
1069 
ValidateResourceMeshConsistency(const std::vector<TensorWithLayout * > & inputs)1070 StatusOr<std::string> ValidateResourceMeshConsistency(
1071     const std::vector<TensorWithLayout*>& inputs) {
1072   std::string mesh_str;
1073   for (TensorWithLayout* inp : inputs) {
1074     if ((inp->tensor_type() != kResource) || !inp->attrs().has_value())
1075       continue;
1076     const std::string& input_mesh_str = inp->layout().mesh().ToString();
1077     if (mesh_str.empty()) {
1078       mesh_str = input_mesh_str;
1079     } else if (mesh_str != input_mesh_str) {
1080       return errors::Internal(absl::StrCat(
1081           "All inputs of embedding resource must be on same mesh. but get : ",
1082           mesh_str, " != ", input_mesh_str));
1083     }
1084   }
1085   VLOG(1) << "Resource input mesh is : " << mesh_str;
1086   return mesh_str;
1087 }
1088 
InsertFunctionForTPUEmbeddingCheckpoint(TF_Status * status,Graph * graph,const std::vector<TensorWithLayout * > & inputs,const std::string & checkpoint_fn_name)1089 Status InsertFunctionForTPUEmbeddingCheckpoint(
1090     TF_Status* status, Graph* graph,
1091     const std::vector<TensorWithLayout*>& inputs,
1092     const std::string& checkpoint_fn_name) {
1093   if (checkpoint_fn_name != kLoadEmbeddingFn &&
1094       checkpoint_fn_name != kRetrieveEmbeddingFn) {
1095     return errors::InvalidArgument(absl::StrCat(
1096         "Found wrong function name: ", checkpoint_fn_name,
1097         " \n expects : ", kLoadEmbeddingFn, " or ", kRetrieveEmbeddingFn));
1098   }
1099 
1100   StatusOr<std::map<int64_t, std::vector<Node*>>> table_id_node_map =
1101       GetTPUEmbeddingInputNodes(status, *graph, inputs);
1102   if (!table_id_node_map.ok()) {
1103     return errors::Internal(table_id_node_map.status().error_message());
1104   }
1105 
1106   StatusOr<std::string> mesh_str = ValidateResourceMeshConsistency(inputs);
1107 
1108   const int64_t& num_tables = table_id_node_map->size();
1109   NodeDef func_node_def;
1110   std::vector<NodeDefBuilder::NodeOut> func_inputs;
1111   std::vector<DataType> input_types, output_types;
1112 
1113   func_inputs.reserve(num_tables);
1114   input_types.reserve(num_tables);
1115 
1116   for (int i = 0; i < num_tables; ++i) {
1117     auto node_vec_ptr = table_id_node_map->find(i);
1118     if (node_vec_ptr == table_id_node_map->end()) {
1119       return errors::Internal(
1120           absl::StrCat("Embedding table id ", i, " is not found."));
1121     }
1122     for (const Node* n : node_vec_ptr->second) {
1123       const std::string& node_name = n->name();
1124       func_inputs.push_back({node_name, i, DT_RESOURCE});
1125       input_types.push_back(DT_RESOURCE);
1126     }
1127   }
1128 
1129   AttrValue mesh_attr;
1130   *mesh_attr.mutable_s() = *mesh_str;
1131   NameAttrList func_attr;
1132   func_attr.set_name(checkpoint_fn_name);
1133   TF_RETURN_IF_ERROR(
1134       NodeDefBuilder(checkpoint_fn_name, "StatefulPartitionedCall")
1135           .Attr("Tin", input_types)
1136           .Attr("Tout", output_types)
1137           .Attr("f", func_attr)
1138           .Attr(kMeshAttr, mesh_attr)
1139           .Attr("config", mesh_attr)
1140           .Input(func_inputs)
1141           .Finalize(&func_node_def, true));
1142 
1143   TF_ASSIGN_OR_RETURN(Node * func_node, graph->AddNode(func_node_def));
1144   for (int i = 0; i < num_tables; ++i) {
1145     const std::vector<Node*>& node_vec = table_id_node_map->find(i)->second;
1146     for (int j = 0; j < node_vec.size(); ++j) {
1147       graph->AddEdge(node_vec[j], 0, func_node, j + i);
1148     }
1149   }
1150 
1151   return OkStatus();
1152 }
1153 
1154 }  // namespace dtensor
1155 }  // namespace tensorflow
1156