xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/str_split.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/FormatVariadic.h"
32 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
33 #include "mlir/IR/Builders.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
35 #include "mlir/IR/Matchers.h"  // from @llvm-project
36 #include "mlir/IR/Operation.h"  // from @llvm-project
37 #include "mlir/Support/LLVM.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
41 #include "tensorflow/core/platform/errors.h"
42 #include "tensorflow/core/platform/path.h"
43 #include "tensorflow/dtensor/cc/dstatus.h"
44 #include "tensorflow/dtensor/cc/dtensor_utils.h"
45 #include "tensorflow/dtensor/cc/save_restore_util.h"
46 #include "tensorflow/dtensor/cc/tensor_layout.h"
47 #include "tensorflow/dtensor/mlir/device_utils.h"
48 #include "tensorflow/dtensor/mlir/dtensor_send_recv.h"
49 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
50 #include "tensorflow/dtensor/mlir/layout_parsing.h"
51 #include "tensorflow/dtensor/mlir/op_utils.h"
52 #include "tensorflow/dtensor/mlir/shape_utils.h"
53 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
54 #include "tensorflow/dtensor/mlir/value_utils.h"
55 
56 namespace tensorflow {
57 namespace dtensor {
58 
59 namespace {
60 
61 // Given a string tensor `prefix` of shape [k], produces a new string tensor
62 // of shape [k*n] where n = number of devices in `mesh` by appending
63 // device_id from [0, n) to `prefix`.
64 //
65 // For example:
66 //   before:
67 //     prefix = tf.Constant(["alice", "bob"])
68 //     mesh.num_devices() = 2
69 //   after =
70 //     result = tf.Constant(["alice_device_0", "bob_device_0", "alice_device_1",
71 //     "bob_device_1"])
72 //
73 // This is needed for DTensorCheckpointV2 tf.MergeV2Checkpoint SPMD expansion
74 // to generate all candidate checkpoint prefix string that we generated
75 // during tf.SaveV2 SPMD Expansion.
GetAllCandidateCheckpointPrefixes(mlir::OpBuilder & builder,mlir::Value prefix,const Mesh & mesh)76 mlir::Value GetAllCandidateCheckpointPrefixes(mlir::OpBuilder& builder,
77                                               mlir::Value prefix,
78                                               const Mesh& mesh) {
79   if (mesh.num_devices() == 0) return prefix;
80 
81   mlir::Value new_prefix =
82       builder
83           .create<mlir::TF::AddOp>(
84               prefix.getLoc(),
85               prefix.getType().dyn_cast<mlir::RankedTensorType>(), prefix,
86               StringConst(builder, prefix.getLoc(),
87                           llvm::SmallVector<llvm::StringRef>(
88                               {DeviceSuffix(0, mesh.num_devices())})))
89           .z();
90 
91   for (int64_t device_id = 1; device_id < mesh.num_devices(); ++device_id) {
92     mlir::Value prefix_plus_dtensor_suffix =
93         builder
94             .create<mlir::TF::AddOp>(
95                 prefix.getLoc(),
96                 prefix.getType().dyn_cast<mlir::RankedTensorType>(), prefix,
97                 StringConst(builder, prefix.getLoc(),
98                             llvm::SmallVector<llvm::StringRef>(
99                                 {DeviceSuffix(device_id, mesh.num_devices())})))
100             .z();
101 
102     new_prefix = builder
103                      .create<mlir::TF::ConcatOp>(
104                          prefix.getLoc(),
105                          /*output=*/prefix.getType(),
106                          /*concat_dim=*/
107                          IntConst(builder, prefix.getLoc(), /*values=*/{0}),
108                          llvm::SmallVector<mlir::Value, 4>{
109                              new_prefix, prefix_plus_dtensor_suffix})
110                      .getResult();
111   }
112   return new_prefix;
113 }
114 
115 // Maps a device_id to a 0 based switch-case branch index.
116 //
117 // For Save/Restore ops, constructing a switch-case on all global devices is not
118 // going to scale to larger slices as the function grows with the number of
119 // devices. Instead, we only need to look at devices that are local to the
120 // current host and generate SPMD for those. This allows the SPMD become
121 // O(variables) since the local devices are constant for all device types.
122 //
123 // The challenge is that the switch-case op branch index is 0 based, meaning
124 // that we can not use the device_id the same way in the global devices switch.
125 // To deal with that, we will use this function to map the local_device_id on
126 // the hosts into a 0 base, by constructing a 1D tensor with all local device
127 // ids and using the index of the tensor as the branch index.
128 //
129 // A concrete example would be:
130 //
131 // local_device_ids = [1, 2, 4, 5, 6] -- We shouldn't assume continuity in
132 // device_ids.
133 //
134 // switching device_id = [4]
135 //
136 // branch_index = idx_of(local_device_ids) = 2
137 //
138 // The tf op equivalent would be:
139 // tf.reshape(tf.where(tf.equal(local_device_ids, device_id)), ())
DeviceIdToLocalBranchIndex(const mlir::Location & location,const llvm::ArrayRef<int64_t> & local_device_ids,mlir::Value device_id,mlir::OpBuilder & builder)140 mlir::Value DeviceIdToLocalBranchIndex(
141     const mlir::Location& location,
142     const llvm::ArrayRef<int64_t>& local_device_ids, mlir::Value device_id,
143     mlir::OpBuilder& builder) {
144   mlir::Value local_device_id_tensors =
145       IntConst(builder, location,
146                llvm::SmallVector<int32_t>(local_device_ids.begin(),
147                                           local_device_ids.end()));
148   mlir::Value condition = builder.create<mlir::TF::EqualOp>(
149       location, local_device_id_tensors, device_id,
150       /*incompatible_shape_error=*/builder.getBoolAttr(true));
151   auto where_op = builder.create<mlir::TF::WhereOp>(
152       location, mlir::RankedTensorType::get({1, 1}, builder.getI64Type()),
153       condition);
154   // cast to int32 as where_op returns a int64 array.
155   auto cast_op = builder.create<mlir::TF::CastOp>(
156       location, mlir::RankedTensorType::get({1, 1}, builder.getI32Type()),
157       where_op.getResult());
158 
159   // Reshape the output to i32 Scalar.
160   auto size_type = mlir::RankedTensorType::get({}, builder.getI32Type());
161   mlir::Value scalar_shape = mlir::TF::collection_ops_util::GetR1Const(
162       size_type.getShape(), builder, location);
163   auto branch_index_scalar = builder.create<mlir::TF::ReshapeOp>(
164       location, mlir::ArrayRef<mlir::Type>{size_type},
165       mlir::ArrayRef<mlir::Value>{cast_op.getResult(), scalar_shape},
166       mlir::ArrayRef<mlir::NamedAttribute>{});
167 
168   return branch_index_scalar.getResult();
169 }
170 
171 // Builds a switch case function that only conditionally runs save with its
172 // slice_specs on sharded tensors.
173 //
174 // Note that this would generate multiple prefixes for saving rather than the
175 // single one passed in from the original op.
176 // DTensor uses DTensorShardedPrefix to query the generated ones and use those
177 // in MergeV2.
ConditionalSave(mlir::TF::SaveV2Op original_save,const Mesh & mesh,const absl::flat_hash_map<int64_t,absl::flat_hash_map<int64_t,std::vector<std::string>>> & saving_specs)178 StatusOr<mlir::TF::CaseOp> ConditionalSave(
179     mlir::TF::SaveV2Op original_save, const Mesh& mesh,
180     const absl::flat_hash_map<
181         int64_t, absl::flat_hash_map<int64_t, std::vector<std::string>>>&
182         saving_specs) {
183   mlir::ModuleOp module = original_save->getParentOfType<mlir::ModuleOp>();
184   if (!module)
185     return errors::Internal("SaveV2 op isn't enclosed inside a mlir::ModuleOp");
186 
187   mlir::SymbolTable symbol_table(module);
188 
189   mlir::OpBuilder builder(original_save);
190   const auto& location = original_save.getLoc();
191 
192   llvm::SmallVector<mlir::func::FuncOp, 8> branch_funs;
193 
194   // Try to extract prefix out as constants and build new shard prefix base on
195   // it.
196   TF_ASSIGN_OR_RETURN(std::string prefix, ExtractConstScalarStringFromValue(
197                                               original_save.prefix()));
198 
199   // Best effort extraction on shape_and_slices and verify they are empty. If
200   // the extraction failed to just ignore those values and work as if those are
201   // empty.
202   llvm::SmallVector<std::string, 4> original_shape_and_slices;
203   const Status extraction_status = ExtractConstStringVectorFromValue(
204       original_save.shape_and_slices(), original_shape_and_slices);
205   if (extraction_status.ok()) {
206     for (const std::string& shape_and_slice : original_shape_and_slices) {
207       if (!shape_and_slice.empty())
208         return errors::InvalidArgument(
209             absl::StrCat("DTensor SaveV2 requires shape_and_slices() field to "
210                          "be empty for tensors, but get : ",
211                          shape_and_slice));
212     }
213   } else {
214     VLOG(2) << "Failed to extract and verify shape_and_slices() from "
215                "original SaveV2 op. SaveV2 SPMD would proceed as if "
216                "shape_and_slices are empty for all the tensors.";
217   }
218 
219   // Branch functions have shared function type, where input is simply all the
220   // inputs from origial saveV2 and no outputs.
221   auto func_type = mlir::FunctionType::get(builder.getContext(),
222                                            original_save.getOperandTypes(),
223                                            /*results=*/{});
224   // Only generates save functions for devices that is local to the client.
225   // This would mean that we will run different functions on different client,
226   // but it would be fine as we're running on CPU for this.
227   for (int device_id : mesh.local_device_ids()) {
228     //  If saving_spec doesn't contain the device_id, then that device_id is a
229     //  no-op on the save.
230     const auto& it = saving_specs.find(device_id);
231     if (it == saving_specs.end()) {
232       // Builds place holder for the no_op function, which takes the exact same
233       // args as the original save op and returns nothing.
234       mlir::func::FuncOp no_op = mlir::func::FuncOp::create(
235           location,
236           llvm::formatv("{0}_no_op_on_device_{1}_{2}", OpName(original_save),
237                         device_id, OpHash(original_save))
238               .str(),
239           func_type, llvm::ArrayRef<mlir::NamedAttribute>{});
240       // Set function visibility to private to indicate that it is only used in
241       // this module.
242       no_op.setVisibility(mlir::SymbolTable::Visibility::Private);
243       symbol_table.insert(no_op);
244 
245       mlir::Block* fn_block = no_op.addEntryBlock();
246       mlir::OpBuilder fn_builder = mlir::OpBuilder::atBlockBegin(fn_block);
247       fn_builder.create<mlir::TF::NoOp>(location);
248       fn_builder.create<mlir::func::ReturnOp>(location);
249 
250       branch_funs.push_back(no_op);
251     } else {
252       const absl::flat_hash_map<int64_t, std::vector<std::string>>&
253           per_device_specs = it->second;
254 
255       // Build the new SaveV2 that contains proper SliceSpec on this device.
256       // tensor_names and slice_spec would be concatted into a 1d string tensor.
257       mlir::func::FuncOp new_save = mlir::func::FuncOp::create(
258           location,
259           llvm::formatv("{0}_save_op_on_device_{1}_{2}", OpName(original_save),
260                         device_id, OpHash(original_save))
261               .str(),
262           func_type, llvm::ArrayRef<mlir::NamedAttribute>{});
263       // Set function visibility to private to indicate that it is only used in
264       // this module.
265       new_save.setVisibility(mlir::SymbolTable::Visibility::Private);
266       symbol_table.insert(new_save);
267 
268       mlir::Block* fn_block = new_save.addEntryBlock();
269       mlir::OpBuilder fn_builder = mlir::OpBuilder::atBlockBegin(fn_block);
270 
271       mlir::Value tensor_names = new_save.getArgument(1);
272       // It is currently unsupported if user passes in shape_and_slices.
273       // TODO(hthu): Implement this.
274       // mlir::Value shape_and_slices = new_save.getArgument(2);
275 
276       // First run a split op on the tensor_names so that we can use the proper
277       // splitted output(one of the tensor_name) to reconstruct tensor_names
278       // field in the new SaveV2 op.
279       TF_ASSIGN_OR_RETURN(
280           llvm::ArrayRef<int64_t> tensor_names_shape,
281           GetGlobalShapeOfValueFromDTensorLayout(original_save.tensor_names()));
282       if (tensor_names_shape.size() != 1)
283         return errors::Internal(
284             llvm::formatv("SaveV2 op got `tensor_names` with rank {0}) but "
285                           "expects rank to be 1.",
286                           tensor_names_shape.size())
287                 .str());
288       mlir::TF::SplitOp name_splits;
289       TF_RETURN_IF_ERROR(CreateSplitOp(/*num_split=*/tensor_names_shape[0],
290                                        /*split_dimension=*/0, location,
291                                        /*src_input=*/tensor_names, &fn_builder,
292                                        &name_splits));
293 
294       // Builds the per device saving spec, that takes care of tensor_name
295       // uniqueness requirement. Each save op should use new_tensor_indices and
296       // new_specs to map the corresponding saving tensor and its slice spec.
297       SaveOpSpecs specs = BuildPerDeviceSave(per_device_specs, device_id,
298                                              prefix, mesh.num_devices());
299       const std::vector<std::vector<int>>& new_tensor_indices =
300           specs.tensor_indices;
301       const std::vector<std::vector<std::string>>& new_specs =
302           specs.shape_and_slice_spec;
303 
304       // Prepare corresponding SaveOp arguments.
305       for (int save_op_index = 0; save_op_index < new_tensor_indices.size();
306            ++save_op_index) {
307         llvm::SmallVector<mlir::Value, 4> new_tensor_names;
308         llvm::SmallVector<std::string, 4> new_shape_and_slices;
309         llvm::SmallVector<mlir::Value, 4> new_tensors;
310 
311         // Per_device_specs records the index of the tensor_names from the
312         // original save, and all slice_specs needed to save that tensor.
313         // The corresponding saving tensor can be found in the original save op
314         // by adding 3 to the index (as 0, 1, 2) are fixed inputs for prefix,
315         // tensor_names and shapes_and_slices.
316         for (int i = 0; i < new_tensor_indices[save_op_index].size(); ++i) {
317           int tensor_name_index = new_tensor_indices[save_op_index][i];
318           int tensor_index = 3 + tensor_name_index;
319           new_tensor_names.push_back(name_splits.getResult(tensor_name_index));
320           new_shape_and_slices.push_back(new_specs[save_op_index][i]);
321           new_tensors.push_back(new_save.getArgument(tensor_index));
322         }
323         // Build the new SaveV2 op.
324         mlir::Value tensor_names = new_tensor_names[0];
325         if (new_tensor_names.size() > 1) {
326           // For tensor_names that has more than 1 entry, we concat the list of
327           // names into a 1d vector.
328           tensor_names =
329               fn_builder
330                   .create<mlir::TF::ConcatOp>(
331                       location,
332                       /*output=*/original_save.tensor_names().getType(),
333                       /*concat_dim=*/
334                       IntConst(fn_builder, location, /*values=*/{0}),
335                       new_tensor_names)
336                   .getResult();
337         }
338 
339         // Builds a unique prefix for this device and this save_op.
340         std::string new_prefix =
341             prefix +
342             llvm::formatv("_device_{0}_save_op_{1}", device_id, save_op_index)
343                 .str();
344 
345         fn_builder.create<mlir::TF::SaveV2Op>(
346             location,
347             StringConst(fn_builder, location,
348                         {specs.new_prefixes[save_op_index]}),
349             /*tensor_name=*/tensor_names,
350             /*shape_and_slices=*/
351             StringConst(
352                 fn_builder, location,
353                 llvm::SmallVector<llvm::StringRef>(new_shape_and_slices.begin(),
354                                                    new_shape_and_slices.end())),
355             new_tensors);
356       }
357       branch_funs.push_back(new_save);
358       fn_builder.create<mlir::func::ReturnOp>(location);
359     }
360   }
361 
362   llvm::SmallVector<mlir::Attribute, 4> symbols;
363   for (auto& func : branch_funs)
364     symbols.push_back(mlir::SymbolRefAttr::get(func));
365 
366   TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(original_save));
367   llvm::SmallVector<int64_t> local_device_ids(mesh.local_device_ids().begin(),
368                                               mesh.local_device_ids().end());
369   mlir::Value branch_index = DeviceIdToLocalBranchIndex(
370       location, local_device_ids, device_id, builder);
371 
372   auto case_op = builder.create<mlir::TF::CaseOp>(
373       location,
374       // SaveV2 doesn't return a value.
375       /*output=*/llvm::ArrayRef<mlir::Type>{},
376       /*branch_index=*/branch_index,
377       /*input=*/original_save.getOperands(),
378       /*branches=*/builder.getArrayAttr(symbols),
379       /*is_stateless=*/builder.getBoolAttr(false));
380 
381   return case_op;
382 }
383 
ExpandSaveV2Op(mlir::Operation * op)384 StatusOr<mlir::Operation*> ExpandSaveV2Op(mlir::Operation* op) {
385   if (!llvm::isa<mlir::TF::SaveV2Op>(op)) {
386     return errors::InvalidArgument(
387         llvm::formatv("Expecting SaveV2Op but got {0}", OpName(op)).str());
388   }
389 
390   TF_ASSIGN_OR_RETURN(Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
391   auto save_v2 = mlir::cast<mlir::TF::SaveV2Op>(op);
392 
393   mlir::OpBuilder builder(save_v2);
394 
395   absl::flat_hash_map<int64_t, std::pair<std::vector<int64_t>, Layout>>
396       tensor_shape_layout_map;
397   std::vector<SavingTensorMetadata> metadata;
398   for (const auto& it : llvm::enumerate(save_v2.tensors())) {
399     mlir::Value tensor = it.value();
400     // We use index to select the tensor names and shape_and_slices from the
401     // inputs. This is generic regardless whether the inputs are constants or
402     // just arguments.
403     int index = it.index();
404     TF_ASSIGN_OR_RETURN(absl::optional<Layout> layout,
405                         ExtractLayoutFromOperand(tensor));
406     if (!layout)
407       return errors::InvalidArgument(
408           "layout is required when saving a DTensor but find no layout "
409           "attached");
410 
411     TF_ASSIGN_OR_RETURN(llvm::ArrayRef<int64_t> tensor_shape,
412                         GetGlobalShapeOfValueFromDTensorLayout(it.value()));
413 
414     metadata.push_back(SavingTensorMetadata(
415         index, std::vector<int64_t>(tensor_shape.begin(), tensor_shape.end()),
416         *layout));
417   }
418   TF_ASSIGN_OR_RETURN(auto saving_specs, BuildSavingSpec(metadata));
419 
420   // Now we have a complete map on device_id and its saving tensors and specs.
421   // Build a switch case conditioned on device_id and do saves properly.
422   TF_ASSIGN_OR_RETURN(mlir::TF::CaseOp case_op,
423                       ConditionalSave(save_v2, mesh, saving_specs));
424 
425   save_v2->replaceAllUsesWith(case_op);
426   save_v2->erase();
427 
428   return case_op.getOperation();
429 }
430 
431 // SPMD Expander for MergeV2.
432 //
433 // The op is expected to have one and only one of the prefix input, which is
434 // used as a key to query all the saved shard prefixed generated in SaveV2 op
435 // SPMD.
436 //
437 // The expanded MergeV2 contains all the shard_prefix generated, and only runs
438 // on Device 0.
ExpandMergeV2Op(mlir::Operation * op)439 StatusOr<mlir::Operation*> ExpandMergeV2Op(mlir::Operation* op) {
440   mlir::TF::MergeV2CheckpointsOp merge_v2 =
441       mlir::dyn_cast<mlir::TF::MergeV2CheckpointsOp>(op);
442   if (!merge_v2) {
443     return errors::InvalidArgument(
444         llvm::formatv("Expecting MergeV2CheckpointsOp but got {0}", OpName(op))
445             .str());
446   }
447 
448   // Build an if op that only runs MergeV2 on device 0. Note that if condition
449   // is tested false when device_id == 0, so that the `then` branch will be
450   // no_op while the else branch will be the real MergeV2 op that is on device
451   // 0.
452   auto module = merge_v2->getParentOfType<mlir::ModuleOp>();
453   mlir::SymbolTable symbol_table(module);
454   auto location = merge_v2.getLoc();
455   mlir::OpBuilder builder(merge_v2);
456 
457   auto func_type =
458       mlir::FunctionType::get(builder.getContext(), merge_v2.getOperandTypes(),
459                               llvm::ArrayRef<mlir::Type>{});
460   // Build then_func that is the branch of device_id != 0, which only contains a
461   // single NoOp.
462   mlir::func::FuncOp then_func = mlir::func::FuncOp::create(
463       location,
464       llvm::formatv("{0}_then_func_{1}", OpName(merge_v2), OpHash(merge_v2))
465           .str(),
466       func_type, llvm::ArrayRef<mlir::NamedAttribute>{});
467   // Set function visibility to private to indicate that it is only used in
468   // this module.
469   then_func.setVisibility(mlir::SymbolTable::Visibility::Private);
470   mlir::Block* then_fn_block = then_func.addEntryBlock();
471   mlir::OpBuilder then_fn_builder =
472       mlir::OpBuilder::atBlockBegin(then_fn_block);
473   then_fn_builder.create<mlir::TF::NoOp>(location);
474   then_fn_builder.create<mlir::func::ReturnOp>(location);
475 
476   // Build else_func that is the branch of device_id == 0.
477   // The else func is just the original MergeV2 itself.
478   mlir::func::FuncOp else_func = mlir::func::FuncOp::create(
479       location,
480       llvm::formatv("{0}_else_func_{1}", OpName(merge_v2), OpHash(merge_v2))
481           .str(),
482       func_type, llvm::ArrayRef<mlir::NamedAttribute>{});
483   // Set function visibility to private to indicate that it is only used in
484   // this module.
485   else_func.setVisibility(mlir::SymbolTable::Visibility::Private);
486 
487   mlir::Block* else_fn_block = else_func.addEntryBlock();
488   mlir::OpBuilder else_fn_builder =
489       mlir::OpBuilder::atBlockBegin(else_fn_block);
490   mlir::Value checkpoint_prefixes = else_fn_block->getArgument(0);
491 
492   bool allow_missing_files = false;
493 
494   // If DTensorCheckpointV2 is enabled, then each string in
495   // `checkpoint_prefixes` tensor is missing a "device_id_" suffix that we
496   // generated from SaveV2 SPMD Expansion. So, generate all the possible
497   // suffixes and use that as the `checkpoint_prefixes` argument.
498   if (DTensorCheckpointV2Enabled()) {
499     allow_missing_files = true;
500     TF_ASSIGN_OR_RETURN(Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
501     checkpoint_prefixes = GetAllCandidateCheckpointPrefixes(
502         else_fn_builder, checkpoint_prefixes, mesh);
503   }
504 
505   mlir::Value destination_prefixes = else_fn_block->getArgument(1);
506 
507   else_fn_builder.create<mlir::TF::MergeV2CheckpointsOp>(
508       location, checkpoint_prefixes, destination_prefixes,
509       /*delete_old_dirs=*/
510       else_fn_builder.getBoolAttr(merge_v2.delete_old_dirs()),
511       /*allow_missing_files=*/else_fn_builder.getBoolAttr(allow_missing_files));
512 
513   else_fn_builder.create<mlir::func::ReturnOp>(location);
514 
515   symbol_table.insert(then_func);
516   symbol_table.insert(else_func);
517 
518   TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(merge_v2));
519   auto if_op = builder.create<mlir::TF::IfOp>(
520       location, then_func.getFunctionType().getResults(), /*cond=*/device_id,
521       /*input=*/merge_v2.getOperands(),
522       /*then_branch=*/then_func.getSymName(),
523       /*else_branch=*/else_func.getSymName(), /*is_stateless=*/false);
524 
525   merge_v2->replaceAllUsesWith(if_op);
526   merge_v2.erase();
527   return if_op.getOperation();
528 }
529 
530 // SPMD Expander for RestoreV2 op.
531 //
532 // Both tf.RestoreV2 and DTensorRestoreV2 op will be expanded the same way.
533 // That is, they will be updated to only restore the slice for the
534 // given device_id. For replicated tensors, that would be the full tensor slice.
535 // For sharded tensors, we compute its slice using device coordinates and tensor
536 // layout.
537 //
538 // `global_shapes` refers to the global shapes of the outputs of the op.
539 // `layouts` refers to the output layouts of the op.
ExpandRestoreV2OpHelper(mlir::Operation * op,std::vector<std::vector<int64_t>> global_shapes,std::vector<Layout> layouts,std::vector<mlir::Type> output_types,mlir::MutableOperandRange shapes_and_slices_mutable)540 StatusOr<mlir::Operation*> ExpandRestoreV2OpHelper(
541     mlir::Operation* op, std::vector<std::vector<int64_t>> global_shapes,
542     std::vector<Layout> layouts, std::vector<mlir::Type> output_types,
543     mlir::MutableOperandRange shapes_and_slices_mutable) {
544   TF_ASSIGN_OR_RETURN(Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
545 
546   // Prepare for building CaseOp.
547   mlir::ModuleOp module = op->template getParentOfType<mlir::ModuleOp>();
548   if (!module)
549     return errors::Internal(
550         "DTensorRestoreV2 op isn't enclosed inside a mlir::ModuleOp");
551 
552   mlir::SymbolTable symbol_table(module);
553 
554   mlir::OpBuilder builder(op);
555   const auto& location = op->getLoc();
556 
557   // Tracks case branch functions for each local_device_id.
558   llvm::SmallVector<mlir::func::FuncOp> branch_funcs(
559       mesh.local_device_ids().size());
560   // Stores restore ops for each device_id in a function, that is suitable for
561   // feeding into a CaseOp.
562   //
563   // Branch functions have shared function type as original restore_v2.
564   const auto func_type =
565       mlir::FunctionType::get(builder.getContext(), op->getOperandTypes(),
566                               mlir::TypeRange(output_types));
567 
568   for (int local_device_idx = 0;
569        local_device_idx < mesh.local_device_ids().size(); ++local_device_idx) {
570     int device_id = mesh.local_device_ids()[local_device_idx];
571     TF_ASSIGN_OR_RETURN(const DeviceLocation& coords,
572                         mesh.device_location(device_id));
573 
574     llvm::SmallVector<std::string> new_shapes_and_slices(op->getNumResults());
575 
576     // For each tensor, build its restore shape_and_slice.
577     for (const auto& it : llvm::enumerate(llvm::zip(global_shapes, layouts))) {
578       std::vector<int64_t> global_shape = std::get<0>(it.value());
579       Layout layout = std::get<1>(it.value());
580       // Fully replicated tensor does not need a slice and spec field and we
581       // simply leave it as empty string. Note that Non-DTensor restore will
582       // use replicated layout from SaveSpec.
583       if (layout.IsFullyReplicated()) {
584         new_shapes_and_slices[it.index()] = "";
585         continue;
586       }
587 
588       TF_ASSIGN_OR_RETURN(
589           std::vector<std::string> slice_specs,
590           SliceSpecOnDevice(layout, mesh, coords, global_shape));
591 
592       // Concat shape and slice specs
593       new_shapes_and_slices[it.index()] =
594           llvm::formatv("{0} {1}", absl::StrJoin(global_shape, " "),
595                         absl::StrJoin(slice_specs, ":"))
596               .str();
597     }
598 
599     // Builds the restore op on device_id.
600     mlir::OpBuilder builder(op);
601     shapes_and_slices_mutable.assign(StringConst(
602         builder, op->getLoc(),
603         llvm::SmallVector<llvm::StringRef>(new_shapes_and_slices.begin(),
604                                            new_shapes_and_slices.end())));
605     mlir::func::FuncOp device_restore_fn = mlir::func::FuncOp::create(
606         location,
607         llvm::formatv("{0}_on_device_{1}_{2}", OpName(op), device_id,
608                       OpHash(op))
609             .str(),
610         func_type, llvm::ArrayRef<mlir::NamedAttribute>{});
611     // Set function visibility to private to indicate that it is only used in
612     // this module.
613     device_restore_fn.setVisibility(mlir::SymbolTable::Visibility::Private);
614     symbol_table.insert(device_restore_fn);
615 
616     mlir::Block* fn_block = device_restore_fn.addEntryBlock();
617     mlir::OpBuilder fn_builder = mlir::OpBuilder::atBlockBegin(fn_block);
618     mlir::Value prefix = device_restore_fn.getArgument(0);
619     mlir::Value tensor_names = device_restore_fn.getArgument(1);
620     // Constructs shapes and slices ourselves while reusing all other
621     // arguments.
622     auto new_restore_v2 = fn_builder.create<mlir::TF::RestoreV2Op>(
623         location, mlir::TypeRange(output_types), prefix, tensor_names,
624         StringConst(
625             fn_builder, location,
626             llvm::SmallVector<llvm::StringRef>(new_shapes_and_slices.begin(),
627                                                new_shapes_and_slices.end())));
628     fn_builder.create<mlir::func::ReturnOp>(location,
629                                             new_restore_v2.getResults());
630 
631     branch_funcs[local_device_idx] = device_restore_fn;
632   }
633 
634   // Builds the final case op.
635   llvm::SmallVector<mlir::Attribute, 4> symbols;
636   for (auto& func : branch_funcs)
637     symbols.push_back(mlir::SymbolRefAttr::get(func));
638 
639   TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(op));
640   llvm::SmallVector<int64_t> local_device_ids(mesh.local_device_ids().begin(),
641                                               mesh.local_device_ids().end());
642   mlir::Value branch_index = DeviceIdToLocalBranchIndex(
643       location, local_device_ids, device_id, builder);
644 
645   auto case_op = builder.create<mlir::TF::CaseOp>(
646       location,
647       /*output=*/mlir::TypeRange(output_types),
648       /*branch_index=*/branch_index,
649       /*input=*/op->getOperands(),
650       /*branches=*/builder.getArrayAttr(symbols),
651       /*is_stateless=*/builder.getBoolAttr(false));
652 
653   op->replaceAllUsesWith(case_op);
654   op->erase();
655 
656   return case_op.getOperation();
657 }
658 
659 // DTensorRestoreV2 op has layouts and shapes as the attribute of the op
660 // itself. We extract those attributes and call the helper expander.
ExpandDTensorRestoreV2Op(mlir::Operation * op)661 StatusOr<mlir::Operation*> ExpandDTensorRestoreV2Op(mlir::Operation* op) {
662   mlir::TF::DTensorRestoreV2Op restore_v2 =
663       mlir::dyn_cast<mlir::TF::DTensorRestoreV2Op>(op);
664   if (!restore_v2) {
665     return errors::InvalidArgument(
666         llvm::formatv("Expecting DTensorRestoreV2Op but got {0}", OpName(op))
667             .str());
668   }
669 
670   mlir::ArrayAttr input_shapes_attr =
671       restore_v2->getAttrOfType<mlir::ArrayAttr>("input_shapes");
672   if (!input_shapes_attr) {
673     return errors::InvalidArgument(
674         "DTensorRestoreV2Op requires input_shapes attributes.");
675   }
676 
677   std::vector<std::vector<int64_t>> input_shapes;
678   input_shapes.reserve(input_shapes_attr.size());
679   for (const auto& shape : input_shapes_attr) {
680     mlir::TF::ShapeAttr shape_attr = shape.cast<mlir::TF::ShapeAttr>();
681     if (!shape_attr.hasStaticShape()) {
682       return errors::InvalidArgument(
683           llvm::formatv("DTensorRestoreV2Op requires statically known input "
684                         "shape, but got non-static shape: {0}.",
685                         shape_attr)
686               .str());
687     }
688     input_shapes.push_back(std::vector<int64_t>(shape_attr.getShape().begin(),
689                                                 shape_attr.getShape().end()));
690   }
691 
692   mlir::ArrayAttr input_layouts_attr = restore_v2.input_layouts();
693   if (!input_layouts_attr) {
694     return errors::InvalidArgument(
695         "DTensorRestoreV2Op requires input_layouts attributes.");
696   }
697   std::vector<Layout> input_layouts;
698   input_layouts.reserve(input_layouts_attr.size());
699   for (const auto& layout : input_layouts_attr.getValue().vec()) {
700     input_layouts.push_back(
701         Layout::FromString(layout.cast<mlir::StringAttr>().getValue().str())
702             .ValueOrDie());
703   }
704 
705   return ExpandRestoreV2OpHelper(
706       op, input_shapes, input_layouts,
707       std::vector<mlir::Type>(op->getResultTypes().begin(),
708                               op->getResultTypes().end()),
709       restore_v2.shape_and_slicesMutable());
710 }
711 
712 // Extract the layout and shapes the normal way. By this time, we should
713 // have all necessary DTensorLayout op as the outputs of each op
714 // and the correct Type shapes and dtypes as the outputs of the tf.RestoreV2
715 // op.
716 //
717 // Call the helper expander function with those shapes and layouts.
ExpandRestoreV2Op(mlir::Operation * op)718 StatusOr<mlir::Operation*> ExpandRestoreV2Op(mlir::Operation* op) {
719   // Fetch the shape of each output.
720   std::vector<std::vector<int64_t>> global_shapes;
721   global_shapes.reserve(op->getNumResults());
722 
723   // This is subtle. For tf.train.Checkpoint.save_counter scalar variable,
724   // this variable may not yet be created by the time we call
725   // Checkpoint.restore.
726   //
727   // In this case, the tf.RestoreV2 is called eagerly, and thus there is no
728   // tf.AssignVariable op. This means that we cannot infer the shapes and layout
729   // from previous pass CreateDTensorInferShapesForRestoreV2Op.
730   //
731   // But for save_counter, we know this is always replicated, and we can just
732   // return the op itself. For now, we will do this hacky way, but eventually
733   // we need to generalize restoring variables that are not yet created.
734   //
735   // TODO(b/235373719) Generalize support for checkpoint restoration for
736   // variables that are not yet created.
737   if (op->getNumResults() == 1 && !GetShapeOfValue(op->getResult(0)).ok()) {
738     return op;
739   }
740 
741   for (auto result : op->getResults()) {
742     global_shapes.push_back(GetShapeOfValue(result).ValueOrDie());
743   }
744 
745   // Fetch the layout of each output.
746   TF_ASSIGN_OR_RETURN(std::vector<Layout> layouts,
747                       ExtractRequiredLayoutFromOp(op));
748 
749   // Calculate the new local type range needed for the new RestoreV2Op we will
750   // emit.
751   std::vector<mlir::Type> new_types;
752   new_types.reserve(op->getNumResults());
753 
754   for (const auto& it :
755        llvm::zip(op->getResultTypes(), global_shapes, layouts)) {
756     mlir::Type type = std::get<0>(it);
757     std::vector<int64_t>& shape = std::get<1>(it);
758     Layout& layout = std::get<2>(it);
759     new_types.push_back(mlir::RankedTensorType::get(
760         layout.LocalShapeFromGlobalShape(shape),
761         type.dyn_cast<mlir::RankedTensorType>().getElementType()));
762   }
763 
764   return ExpandRestoreV2OpHelper(
765       op, global_shapes, layouts, new_types,
766       mlir::dyn_cast<mlir::TF::RestoreV2Op>(op).shape_and_slicesMutable());
767 }
768 
769 }  // namespace
770 
ExpandOp(mlir::Operation * op)771 StatusOr<mlir::Operation*> SaveRestoreSPMDExpander::ExpandOp(
772     mlir::Operation* op) {
773   if (llvm::isa<mlir::TF::SaveV2Op>(op)) {
774     return ExpandSaveV2Op(op);
775   }
776   if (llvm::isa<mlir::TF::MergeV2CheckpointsOp>(op)) {
777     return ExpandMergeV2Op(op);
778   }
779   if (llvm::isa<mlir::TF::DTensorRestoreV2Op>(op)) {
780     return ExpandDTensorRestoreV2Op(op);
781   }
782   if (llvm::isa<mlir::TF::RestoreV2Op>(op)) {
783     return ExpandRestoreV2Op(op);
784   }
785 
786   return errors::Unimplemented(
787       llvm::formatv("SPMD for op : {0} is not implemented ", OpName(op)).str());
788 }
789 
790 // Find all the resource tensor layouts attached to the AssignVariableOp
791 // that `restore_op` is restoring to.
GetLayoutsFromAssignVariableOps(mlir::ModuleOp module,mlir::TF::RestoreV2Op * restore_op)792 StatusOr<llvm::SmallVector<Layout>> GetLayoutsFromAssignVariableOps(
793     mlir::ModuleOp module, mlir::TF::RestoreV2Op* restore_op) {
794   llvm::SmallVector<Layout> layouts(restore_op->getNumResults());
795 
796   for (auto result : restore_op->getResults()) {
797     // Find the AssignVariableOp connected to this output. There should only
798     // be at most one IdentityOp and one DTensorSend between this result
799     // and the AssignVariableOp.
800     for (auto consuming_op : result.getUsers()) {
801       // To get to the AssignVariableOp that consumes `result`, we expect
802       // an IdentityOp or a DTensorSend op on the path. So, skip past
803       // these ops first.
804       while (llvm::isa<mlir::TF::IdentityOp, mlir::TF::DTensorSend>(
805           consuming_op)) {
806         if (auto send_op =
807                 mlir::dyn_cast_or_null<mlir::TF::DTensorSend>(consuming_op)) {
808           TF_ASSIGN_OR_RETURN(
809               consuming_op, GetCorrespondingDTensorSendRecvOp(module, send_op));
810         }
811         auto next_op = consuming_op->getResult(0).getUsers();
812         if (next_op.empty()) {
813           return errors::Internal(
814               "Expected a result of an identity op to be consumed by another "
815               "op, but was empty during RestoreV2 Expansion.");
816         }
817         consuming_op = *next_op.begin();
818       }
819       // We skipped past ops like Identity and Send's. There might be an
820       // AssignVariableOp now.
821       if (auto assign_op = llvm::dyn_cast_or_null<mlir::TF::AssignVariableOp>(
822               consuming_op)) {
823         TF_ASSIGN_OR_RETURN(auto layout, ExtractRequiredLayoutFromOperand(
824                                              assign_op.resource()));
825         layouts[result.getResultNumber()] = layout;
826         break;
827       }
828     }
829   }
830   return layouts;
831 }
832 
833 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts,const llvm::DenseMap<int,Layout> & output_layouts)834 SaveRestoreSPMDExpander::ComputeLayoutForward(
835     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts,
836     const llvm::DenseMap<int, Layout>& output_layouts) {
837   // Save op doesn't have return values.
838   if (llvm::isa<mlir::TF::SaveV2Op, mlir::TF::MergeV2CheckpointsOp>(op)) {
839     return llvm::DenseMap<int, Layout>();
840   }
841   if (llvm::isa<mlir::TF::RestoreV2Op>(op)) {
842     // If there are already output layouts specified, this means that
843     // we are in the Late Variable Creation restoration. For this path,
844     // the output layout is already specified, through the default layout
845     // scope. So just return that layout.
846     if (!output_layouts.empty()) return output_layouts;
847 
848     mlir::ModuleOp module_op = op->getParentOfType<mlir::ModuleOp>();
849     mlir::TF::RestoreV2Op restore_v2 = mlir::cast<mlir::TF::RestoreV2Op>(op);
850     TF_ASSIGN_OR_RETURN(Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
851     if (!mesh.is_cpu_mesh()) {
852       return errors::InvalidArgument(
853           llvm::formatv(
854               "RestoreV2Op must run on a CPU mesh, but was running on: {0}",
855               mesh.ToString())
856               .str());
857     }
858     // Extract the layout of each resource tensor from the AssignVariableOp
859     // consuming each result. This layout sharding will be used as the
860     // output layout for each result tensor.
861     TF_ASSIGN_OR_RETURN(
862         auto layouts, GetLayoutsFromAssignVariableOps(module_op, &restore_v2));
863     if (layouts.size() != restore_v2.getNumResults()) {
864       return errors::Internal(llvm::formatv("Failed to get {0} output layouts "
865                                             "for RestoreV2Op. Got {1} layouts.",
866                                             restore_v2.getNumResults(),
867                                             layouts.size())
868                                   .str());
869     }
870     llvm::DenseMap<int, Layout> output_layouts(restore_v2.getNumResults());
871 
872     // Change the mesh of each layout to `mesh` since RestoreOp always runs on
873     // the CPU.
874     for (int i = 0; i < layouts.size(); ++i) {
875       Layout host_mesh_layout = layouts[i];
876       host_mesh_layout.set_mesh(mesh);
877       output_layouts[i] = host_mesh_layout;
878     }
879     return output_layouts;
880   }
881   if (llvm::isa<mlir::TF::DTensorRestoreV2Op>(op)) {
882     mlir::TF::DTensorRestoreV2Op restore_v2 =
883         mlir::cast<mlir::TF::DTensorRestoreV2Op>(op);
884     llvm::DenseMap<int, Layout> output_layouts(restore_v2.getNumResults());
885     // Output layout is simply the layout from the arguments.
886     for (const auto& it : llvm::enumerate(restore_v2.input_layouts())) {
887       TF_ASSIGN_OR_RETURN(
888           Layout layout,
889           Layout::FromString(
890               it.value().cast<mlir::StringAttr>().getValue().str()));
891       output_layouts[it.index()] = layout;
892     }
893     return output_layouts;
894   }
895   return errors::Unimplemented(
896       llvm::formatv("Layout propagation for op : {0} is not implemented",
897                     OpName(op))
898           .str());
899 }
900 
901 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)902 SaveRestoreSPMDExpander::ComputeLayoutBackward(
903     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
904   return llvm::DenseMap<int, Layout>();
905 }
906 
907 }  // namespace dtensor
908 }  // namespace tensorflow
909