1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/str_split.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/FormatVariadic.h"
32 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
33 #include "mlir/IR/Builders.h" // from @llvm-project
34 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
35 #include "mlir/IR/Matchers.h" // from @llvm-project
36 #include "mlir/IR/Operation.h" // from @llvm-project
37 #include "mlir/Support/LLVM.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
41 #include "tensorflow/core/platform/errors.h"
42 #include "tensorflow/core/platform/path.h"
43 #include "tensorflow/dtensor/cc/dstatus.h"
44 #include "tensorflow/dtensor/cc/dtensor_utils.h"
45 #include "tensorflow/dtensor/cc/save_restore_util.h"
46 #include "tensorflow/dtensor/cc/tensor_layout.h"
47 #include "tensorflow/dtensor/mlir/device_utils.h"
48 #include "tensorflow/dtensor/mlir/dtensor_send_recv.h"
49 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
50 #include "tensorflow/dtensor/mlir/layout_parsing.h"
51 #include "tensorflow/dtensor/mlir/op_utils.h"
52 #include "tensorflow/dtensor/mlir/shape_utils.h"
53 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
54 #include "tensorflow/dtensor/mlir/value_utils.h"
55
56 namespace tensorflow {
57 namespace dtensor {
58
59 namespace {
60
61 // Given a string tensor `prefix` of shape [k], produces a new string tensor
62 // of shape [k*n] where n = number of devices in `mesh` by appending
63 // device_id from [0, n) to `prefix`.
64 //
65 // For example:
66 // before:
67 // prefix = tf.Constant(["alice", "bob"])
68 // mesh.num_devices() = 2
69 // after =
70 // result = tf.Constant(["alice_device_0", "bob_device_0", "alice_device_1",
71 // "bob_device_1"])
72 //
73 // This is needed for DTensorCheckpointV2 tf.MergeV2Checkpoint SPMD expansion
74 // to generate all candidate checkpoint prefix string that we generated
75 // during tf.SaveV2 SPMD Expansion.
GetAllCandidateCheckpointPrefixes(mlir::OpBuilder & builder,mlir::Value prefix,const Mesh & mesh)76 mlir::Value GetAllCandidateCheckpointPrefixes(mlir::OpBuilder& builder,
77 mlir::Value prefix,
78 const Mesh& mesh) {
79 if (mesh.num_devices() == 0) return prefix;
80
81 mlir::Value new_prefix =
82 builder
83 .create<mlir::TF::AddOp>(
84 prefix.getLoc(),
85 prefix.getType().dyn_cast<mlir::RankedTensorType>(), prefix,
86 StringConst(builder, prefix.getLoc(),
87 llvm::SmallVector<llvm::StringRef>(
88 {DeviceSuffix(0, mesh.num_devices())})))
89 .z();
90
91 for (int64_t device_id = 1; device_id < mesh.num_devices(); ++device_id) {
92 mlir::Value prefix_plus_dtensor_suffix =
93 builder
94 .create<mlir::TF::AddOp>(
95 prefix.getLoc(),
96 prefix.getType().dyn_cast<mlir::RankedTensorType>(), prefix,
97 StringConst(builder, prefix.getLoc(),
98 llvm::SmallVector<llvm::StringRef>(
99 {DeviceSuffix(device_id, mesh.num_devices())})))
100 .z();
101
102 new_prefix = builder
103 .create<mlir::TF::ConcatOp>(
104 prefix.getLoc(),
105 /*output=*/prefix.getType(),
106 /*concat_dim=*/
107 IntConst(builder, prefix.getLoc(), /*values=*/{0}),
108 llvm::SmallVector<mlir::Value, 4>{
109 new_prefix, prefix_plus_dtensor_suffix})
110 .getResult();
111 }
112 return new_prefix;
113 }
114
115 // Maps a device_id to a 0 based switch-case branch index.
116 //
117 // For Save/Restore ops, constructing a switch-case on all global devices is not
118 // going to scale to larger slices as the function grows with the number of
119 // devices. Instead, we only need to look at devices that are local to the
120 // current host and generate SPMD for those. This allows the SPMD become
121 // O(variables) since the local devices are constant for all device types.
122 //
123 // The challenge is that the switch-case op branch index is 0 based, meaning
124 // that we can not use the device_id the same way in the global devices switch.
125 // To deal with that, we will use this function to map the local_device_id on
126 // the hosts into a 0 base, by constructing a 1D tensor with all local device
127 // ids and using the index of the tensor as the branch index.
128 //
129 // A concrete example would be:
130 //
131 // local_device_ids = [1, 2, 4, 5, 6] -- We shouldn't assume continuity in
132 // device_ids.
133 //
134 // switching device_id = [4]
135 //
136 // branch_index = idx_of(local_device_ids) = 2
137 //
138 // The tf op equivalent would be:
139 // tf.reshape(tf.where(tf.equal(local_device_ids, device_id)), ())
DeviceIdToLocalBranchIndex(const mlir::Location & location,const llvm::ArrayRef<int64_t> & local_device_ids,mlir::Value device_id,mlir::OpBuilder & builder)140 mlir::Value DeviceIdToLocalBranchIndex(
141 const mlir::Location& location,
142 const llvm::ArrayRef<int64_t>& local_device_ids, mlir::Value device_id,
143 mlir::OpBuilder& builder) {
144 mlir::Value local_device_id_tensors =
145 IntConst(builder, location,
146 llvm::SmallVector<int32_t>(local_device_ids.begin(),
147 local_device_ids.end()));
148 mlir::Value condition = builder.create<mlir::TF::EqualOp>(
149 location, local_device_id_tensors, device_id,
150 /*incompatible_shape_error=*/builder.getBoolAttr(true));
151 auto where_op = builder.create<mlir::TF::WhereOp>(
152 location, mlir::RankedTensorType::get({1, 1}, builder.getI64Type()),
153 condition);
154 // cast to int32 as where_op returns a int64 array.
155 auto cast_op = builder.create<mlir::TF::CastOp>(
156 location, mlir::RankedTensorType::get({1, 1}, builder.getI32Type()),
157 where_op.getResult());
158
159 // Reshape the output to i32 Scalar.
160 auto size_type = mlir::RankedTensorType::get({}, builder.getI32Type());
161 mlir::Value scalar_shape = mlir::TF::collection_ops_util::GetR1Const(
162 size_type.getShape(), builder, location);
163 auto branch_index_scalar = builder.create<mlir::TF::ReshapeOp>(
164 location, mlir::ArrayRef<mlir::Type>{size_type},
165 mlir::ArrayRef<mlir::Value>{cast_op.getResult(), scalar_shape},
166 mlir::ArrayRef<mlir::NamedAttribute>{});
167
168 return branch_index_scalar.getResult();
169 }
170
171 // Builds a switch case function that only conditionally runs save with its
172 // slice_specs on sharded tensors.
173 //
174 // Note that this would generate multiple prefixes for saving rather than the
175 // single one passed in from the original op.
176 // DTensor uses DTensorShardedPrefix to query the generated ones and use those
177 // in MergeV2.
ConditionalSave(mlir::TF::SaveV2Op original_save,const Mesh & mesh,const absl::flat_hash_map<int64_t,absl::flat_hash_map<int64_t,std::vector<std::string>>> & saving_specs)178 StatusOr<mlir::TF::CaseOp> ConditionalSave(
179 mlir::TF::SaveV2Op original_save, const Mesh& mesh,
180 const absl::flat_hash_map<
181 int64_t, absl::flat_hash_map<int64_t, std::vector<std::string>>>&
182 saving_specs) {
183 mlir::ModuleOp module = original_save->getParentOfType<mlir::ModuleOp>();
184 if (!module)
185 return errors::Internal("SaveV2 op isn't enclosed inside a mlir::ModuleOp");
186
187 mlir::SymbolTable symbol_table(module);
188
189 mlir::OpBuilder builder(original_save);
190 const auto& location = original_save.getLoc();
191
192 llvm::SmallVector<mlir::func::FuncOp, 8> branch_funs;
193
194 // Try to extract prefix out as constants and build new shard prefix base on
195 // it.
196 TF_ASSIGN_OR_RETURN(std::string prefix, ExtractConstScalarStringFromValue(
197 original_save.prefix()));
198
199 // Best effort extraction on shape_and_slices and verify they are empty. If
200 // the extraction failed to just ignore those values and work as if those are
201 // empty.
202 llvm::SmallVector<std::string, 4> original_shape_and_slices;
203 const Status extraction_status = ExtractConstStringVectorFromValue(
204 original_save.shape_and_slices(), original_shape_and_slices);
205 if (extraction_status.ok()) {
206 for (const std::string& shape_and_slice : original_shape_and_slices) {
207 if (!shape_and_slice.empty())
208 return errors::InvalidArgument(
209 absl::StrCat("DTensor SaveV2 requires shape_and_slices() field to "
210 "be empty for tensors, but get : ",
211 shape_and_slice));
212 }
213 } else {
214 VLOG(2) << "Failed to extract and verify shape_and_slices() from "
215 "original SaveV2 op. SaveV2 SPMD would proceed as if "
216 "shape_and_slices are empty for all the tensors.";
217 }
218
219 // Branch functions have shared function type, where input is simply all the
220 // inputs from origial saveV2 and no outputs.
221 auto func_type = mlir::FunctionType::get(builder.getContext(),
222 original_save.getOperandTypes(),
223 /*results=*/{});
224 // Only generates save functions for devices that is local to the client.
225 // This would mean that we will run different functions on different client,
226 // but it would be fine as we're running on CPU for this.
227 for (int device_id : mesh.local_device_ids()) {
228 // If saving_spec doesn't contain the device_id, then that device_id is a
229 // no-op on the save.
230 const auto& it = saving_specs.find(device_id);
231 if (it == saving_specs.end()) {
232 // Builds place holder for the no_op function, which takes the exact same
233 // args as the original save op and returns nothing.
234 mlir::func::FuncOp no_op = mlir::func::FuncOp::create(
235 location,
236 llvm::formatv("{0}_no_op_on_device_{1}_{2}", OpName(original_save),
237 device_id, OpHash(original_save))
238 .str(),
239 func_type, llvm::ArrayRef<mlir::NamedAttribute>{});
240 // Set function visibility to private to indicate that it is only used in
241 // this module.
242 no_op.setVisibility(mlir::SymbolTable::Visibility::Private);
243 symbol_table.insert(no_op);
244
245 mlir::Block* fn_block = no_op.addEntryBlock();
246 mlir::OpBuilder fn_builder = mlir::OpBuilder::atBlockBegin(fn_block);
247 fn_builder.create<mlir::TF::NoOp>(location);
248 fn_builder.create<mlir::func::ReturnOp>(location);
249
250 branch_funs.push_back(no_op);
251 } else {
252 const absl::flat_hash_map<int64_t, std::vector<std::string>>&
253 per_device_specs = it->second;
254
255 // Build the new SaveV2 that contains proper SliceSpec on this device.
256 // tensor_names and slice_spec would be concatted into a 1d string tensor.
257 mlir::func::FuncOp new_save = mlir::func::FuncOp::create(
258 location,
259 llvm::formatv("{0}_save_op_on_device_{1}_{2}", OpName(original_save),
260 device_id, OpHash(original_save))
261 .str(),
262 func_type, llvm::ArrayRef<mlir::NamedAttribute>{});
263 // Set function visibility to private to indicate that it is only used in
264 // this module.
265 new_save.setVisibility(mlir::SymbolTable::Visibility::Private);
266 symbol_table.insert(new_save);
267
268 mlir::Block* fn_block = new_save.addEntryBlock();
269 mlir::OpBuilder fn_builder = mlir::OpBuilder::atBlockBegin(fn_block);
270
271 mlir::Value tensor_names = new_save.getArgument(1);
272 // It is currently unsupported if user passes in shape_and_slices.
273 // TODO(hthu): Implement this.
274 // mlir::Value shape_and_slices = new_save.getArgument(2);
275
276 // First run a split op on the tensor_names so that we can use the proper
277 // splitted output(one of the tensor_name) to reconstruct tensor_names
278 // field in the new SaveV2 op.
279 TF_ASSIGN_OR_RETURN(
280 llvm::ArrayRef<int64_t> tensor_names_shape,
281 GetGlobalShapeOfValueFromDTensorLayout(original_save.tensor_names()));
282 if (tensor_names_shape.size() != 1)
283 return errors::Internal(
284 llvm::formatv("SaveV2 op got `tensor_names` with rank {0}) but "
285 "expects rank to be 1.",
286 tensor_names_shape.size())
287 .str());
288 mlir::TF::SplitOp name_splits;
289 TF_RETURN_IF_ERROR(CreateSplitOp(/*num_split=*/tensor_names_shape[0],
290 /*split_dimension=*/0, location,
291 /*src_input=*/tensor_names, &fn_builder,
292 &name_splits));
293
294 // Builds the per device saving spec, that takes care of tensor_name
295 // uniqueness requirement. Each save op should use new_tensor_indices and
296 // new_specs to map the corresponding saving tensor and its slice spec.
297 SaveOpSpecs specs = BuildPerDeviceSave(per_device_specs, device_id,
298 prefix, mesh.num_devices());
299 const std::vector<std::vector<int>>& new_tensor_indices =
300 specs.tensor_indices;
301 const std::vector<std::vector<std::string>>& new_specs =
302 specs.shape_and_slice_spec;
303
304 // Prepare corresponding SaveOp arguments.
305 for (int save_op_index = 0; save_op_index < new_tensor_indices.size();
306 ++save_op_index) {
307 llvm::SmallVector<mlir::Value, 4> new_tensor_names;
308 llvm::SmallVector<std::string, 4> new_shape_and_slices;
309 llvm::SmallVector<mlir::Value, 4> new_tensors;
310
311 // Per_device_specs records the index of the tensor_names from the
312 // original save, and all slice_specs needed to save that tensor.
313 // The corresponding saving tensor can be found in the original save op
314 // by adding 3 to the index (as 0, 1, 2) are fixed inputs for prefix,
315 // tensor_names and shapes_and_slices.
316 for (int i = 0; i < new_tensor_indices[save_op_index].size(); ++i) {
317 int tensor_name_index = new_tensor_indices[save_op_index][i];
318 int tensor_index = 3 + tensor_name_index;
319 new_tensor_names.push_back(name_splits.getResult(tensor_name_index));
320 new_shape_and_slices.push_back(new_specs[save_op_index][i]);
321 new_tensors.push_back(new_save.getArgument(tensor_index));
322 }
323 // Build the new SaveV2 op.
324 mlir::Value tensor_names = new_tensor_names[0];
325 if (new_tensor_names.size() > 1) {
326 // For tensor_names that has more than 1 entry, we concat the list of
327 // names into a 1d vector.
328 tensor_names =
329 fn_builder
330 .create<mlir::TF::ConcatOp>(
331 location,
332 /*output=*/original_save.tensor_names().getType(),
333 /*concat_dim=*/
334 IntConst(fn_builder, location, /*values=*/{0}),
335 new_tensor_names)
336 .getResult();
337 }
338
339 // Builds a unique prefix for this device and this save_op.
340 std::string new_prefix =
341 prefix +
342 llvm::formatv("_device_{0}_save_op_{1}", device_id, save_op_index)
343 .str();
344
345 fn_builder.create<mlir::TF::SaveV2Op>(
346 location,
347 StringConst(fn_builder, location,
348 {specs.new_prefixes[save_op_index]}),
349 /*tensor_name=*/tensor_names,
350 /*shape_and_slices=*/
351 StringConst(
352 fn_builder, location,
353 llvm::SmallVector<llvm::StringRef>(new_shape_and_slices.begin(),
354 new_shape_and_slices.end())),
355 new_tensors);
356 }
357 branch_funs.push_back(new_save);
358 fn_builder.create<mlir::func::ReturnOp>(location);
359 }
360 }
361
362 llvm::SmallVector<mlir::Attribute, 4> symbols;
363 for (auto& func : branch_funs)
364 symbols.push_back(mlir::SymbolRefAttr::get(func));
365
366 TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(original_save));
367 llvm::SmallVector<int64_t> local_device_ids(mesh.local_device_ids().begin(),
368 mesh.local_device_ids().end());
369 mlir::Value branch_index = DeviceIdToLocalBranchIndex(
370 location, local_device_ids, device_id, builder);
371
372 auto case_op = builder.create<mlir::TF::CaseOp>(
373 location,
374 // SaveV2 doesn't return a value.
375 /*output=*/llvm::ArrayRef<mlir::Type>{},
376 /*branch_index=*/branch_index,
377 /*input=*/original_save.getOperands(),
378 /*branches=*/builder.getArrayAttr(symbols),
379 /*is_stateless=*/builder.getBoolAttr(false));
380
381 return case_op;
382 }
383
ExpandSaveV2Op(mlir::Operation * op)384 StatusOr<mlir::Operation*> ExpandSaveV2Op(mlir::Operation* op) {
385 if (!llvm::isa<mlir::TF::SaveV2Op>(op)) {
386 return errors::InvalidArgument(
387 llvm::formatv("Expecting SaveV2Op but got {0}", OpName(op)).str());
388 }
389
390 TF_ASSIGN_OR_RETURN(Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
391 auto save_v2 = mlir::cast<mlir::TF::SaveV2Op>(op);
392
393 mlir::OpBuilder builder(save_v2);
394
395 absl::flat_hash_map<int64_t, std::pair<std::vector<int64_t>, Layout>>
396 tensor_shape_layout_map;
397 std::vector<SavingTensorMetadata> metadata;
398 for (const auto& it : llvm::enumerate(save_v2.tensors())) {
399 mlir::Value tensor = it.value();
400 // We use index to select the tensor names and shape_and_slices from the
401 // inputs. This is generic regardless whether the inputs are constants or
402 // just arguments.
403 int index = it.index();
404 TF_ASSIGN_OR_RETURN(absl::optional<Layout> layout,
405 ExtractLayoutFromOperand(tensor));
406 if (!layout)
407 return errors::InvalidArgument(
408 "layout is required when saving a DTensor but find no layout "
409 "attached");
410
411 TF_ASSIGN_OR_RETURN(llvm::ArrayRef<int64_t> tensor_shape,
412 GetGlobalShapeOfValueFromDTensorLayout(it.value()));
413
414 metadata.push_back(SavingTensorMetadata(
415 index, std::vector<int64_t>(tensor_shape.begin(), tensor_shape.end()),
416 *layout));
417 }
418 TF_ASSIGN_OR_RETURN(auto saving_specs, BuildSavingSpec(metadata));
419
420 // Now we have a complete map on device_id and its saving tensors and specs.
421 // Build a switch case conditioned on device_id and do saves properly.
422 TF_ASSIGN_OR_RETURN(mlir::TF::CaseOp case_op,
423 ConditionalSave(save_v2, mesh, saving_specs));
424
425 save_v2->replaceAllUsesWith(case_op);
426 save_v2->erase();
427
428 return case_op.getOperation();
429 }
430
431 // SPMD Expander for MergeV2.
432 //
433 // The op is expected to have one and only one of the prefix input, which is
434 // used as a key to query all the saved shard prefixed generated in SaveV2 op
435 // SPMD.
436 //
437 // The expanded MergeV2 contains all the shard_prefix generated, and only runs
438 // on Device 0.
ExpandMergeV2Op(mlir::Operation * op)439 StatusOr<mlir::Operation*> ExpandMergeV2Op(mlir::Operation* op) {
440 mlir::TF::MergeV2CheckpointsOp merge_v2 =
441 mlir::dyn_cast<mlir::TF::MergeV2CheckpointsOp>(op);
442 if (!merge_v2) {
443 return errors::InvalidArgument(
444 llvm::formatv("Expecting MergeV2CheckpointsOp but got {0}", OpName(op))
445 .str());
446 }
447
448 // Build an if op that only runs MergeV2 on device 0. Note that if condition
449 // is tested false when device_id == 0, so that the `then` branch will be
450 // no_op while the else branch will be the real MergeV2 op that is on device
451 // 0.
452 auto module = merge_v2->getParentOfType<mlir::ModuleOp>();
453 mlir::SymbolTable symbol_table(module);
454 auto location = merge_v2.getLoc();
455 mlir::OpBuilder builder(merge_v2);
456
457 auto func_type =
458 mlir::FunctionType::get(builder.getContext(), merge_v2.getOperandTypes(),
459 llvm::ArrayRef<mlir::Type>{});
460 // Build then_func that is the branch of device_id != 0, which only contains a
461 // single NoOp.
462 mlir::func::FuncOp then_func = mlir::func::FuncOp::create(
463 location,
464 llvm::formatv("{0}_then_func_{1}", OpName(merge_v2), OpHash(merge_v2))
465 .str(),
466 func_type, llvm::ArrayRef<mlir::NamedAttribute>{});
467 // Set function visibility to private to indicate that it is only used in
468 // this module.
469 then_func.setVisibility(mlir::SymbolTable::Visibility::Private);
470 mlir::Block* then_fn_block = then_func.addEntryBlock();
471 mlir::OpBuilder then_fn_builder =
472 mlir::OpBuilder::atBlockBegin(then_fn_block);
473 then_fn_builder.create<mlir::TF::NoOp>(location);
474 then_fn_builder.create<mlir::func::ReturnOp>(location);
475
476 // Build else_func that is the branch of device_id == 0.
477 // The else func is just the original MergeV2 itself.
478 mlir::func::FuncOp else_func = mlir::func::FuncOp::create(
479 location,
480 llvm::formatv("{0}_else_func_{1}", OpName(merge_v2), OpHash(merge_v2))
481 .str(),
482 func_type, llvm::ArrayRef<mlir::NamedAttribute>{});
483 // Set function visibility to private to indicate that it is only used in
484 // this module.
485 else_func.setVisibility(mlir::SymbolTable::Visibility::Private);
486
487 mlir::Block* else_fn_block = else_func.addEntryBlock();
488 mlir::OpBuilder else_fn_builder =
489 mlir::OpBuilder::atBlockBegin(else_fn_block);
490 mlir::Value checkpoint_prefixes = else_fn_block->getArgument(0);
491
492 bool allow_missing_files = false;
493
494 // If DTensorCheckpointV2 is enabled, then each string in
495 // `checkpoint_prefixes` tensor is missing a "device_id_" suffix that we
496 // generated from SaveV2 SPMD Expansion. So, generate all the possible
497 // suffixes and use that as the `checkpoint_prefixes` argument.
498 if (DTensorCheckpointV2Enabled()) {
499 allow_missing_files = true;
500 TF_ASSIGN_OR_RETURN(Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
501 checkpoint_prefixes = GetAllCandidateCheckpointPrefixes(
502 else_fn_builder, checkpoint_prefixes, mesh);
503 }
504
505 mlir::Value destination_prefixes = else_fn_block->getArgument(1);
506
507 else_fn_builder.create<mlir::TF::MergeV2CheckpointsOp>(
508 location, checkpoint_prefixes, destination_prefixes,
509 /*delete_old_dirs=*/
510 else_fn_builder.getBoolAttr(merge_v2.delete_old_dirs()),
511 /*allow_missing_files=*/else_fn_builder.getBoolAttr(allow_missing_files));
512
513 else_fn_builder.create<mlir::func::ReturnOp>(location);
514
515 symbol_table.insert(then_func);
516 symbol_table.insert(else_func);
517
518 TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(merge_v2));
519 auto if_op = builder.create<mlir::TF::IfOp>(
520 location, then_func.getFunctionType().getResults(), /*cond=*/device_id,
521 /*input=*/merge_v2.getOperands(),
522 /*then_branch=*/then_func.getSymName(),
523 /*else_branch=*/else_func.getSymName(), /*is_stateless=*/false);
524
525 merge_v2->replaceAllUsesWith(if_op);
526 merge_v2.erase();
527 return if_op.getOperation();
528 }
529
530 // SPMD Expander for RestoreV2 op.
531 //
532 // Both tf.RestoreV2 and DTensorRestoreV2 op will be expanded the same way.
533 // That is, they will be updated to only restore the slice for the
534 // given device_id. For replicated tensors, that would be the full tensor slice.
535 // For sharded tensors, we compute its slice using device coordinates and tensor
536 // layout.
537 //
538 // `global_shapes` refers to the global shapes of the outputs of the op.
539 // `layouts` refers to the output layouts of the op.
ExpandRestoreV2OpHelper(mlir::Operation * op,std::vector<std::vector<int64_t>> global_shapes,std::vector<Layout> layouts,std::vector<mlir::Type> output_types,mlir::MutableOperandRange shapes_and_slices_mutable)540 StatusOr<mlir::Operation*> ExpandRestoreV2OpHelper(
541 mlir::Operation* op, std::vector<std::vector<int64_t>> global_shapes,
542 std::vector<Layout> layouts, std::vector<mlir::Type> output_types,
543 mlir::MutableOperandRange shapes_and_slices_mutable) {
544 TF_ASSIGN_OR_RETURN(Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
545
546 // Prepare for building CaseOp.
547 mlir::ModuleOp module = op->template getParentOfType<mlir::ModuleOp>();
548 if (!module)
549 return errors::Internal(
550 "DTensorRestoreV2 op isn't enclosed inside a mlir::ModuleOp");
551
552 mlir::SymbolTable symbol_table(module);
553
554 mlir::OpBuilder builder(op);
555 const auto& location = op->getLoc();
556
557 // Tracks case branch functions for each local_device_id.
558 llvm::SmallVector<mlir::func::FuncOp> branch_funcs(
559 mesh.local_device_ids().size());
560 // Stores restore ops for each device_id in a function, that is suitable for
561 // feeding into a CaseOp.
562 //
563 // Branch functions have shared function type as original restore_v2.
564 const auto func_type =
565 mlir::FunctionType::get(builder.getContext(), op->getOperandTypes(),
566 mlir::TypeRange(output_types));
567
568 for (int local_device_idx = 0;
569 local_device_idx < mesh.local_device_ids().size(); ++local_device_idx) {
570 int device_id = mesh.local_device_ids()[local_device_idx];
571 TF_ASSIGN_OR_RETURN(const DeviceLocation& coords,
572 mesh.device_location(device_id));
573
574 llvm::SmallVector<std::string> new_shapes_and_slices(op->getNumResults());
575
576 // For each tensor, build its restore shape_and_slice.
577 for (const auto& it : llvm::enumerate(llvm::zip(global_shapes, layouts))) {
578 std::vector<int64_t> global_shape = std::get<0>(it.value());
579 Layout layout = std::get<1>(it.value());
580 // Fully replicated tensor does not need a slice and spec field and we
581 // simply leave it as empty string. Note that Non-DTensor restore will
582 // use replicated layout from SaveSpec.
583 if (layout.IsFullyReplicated()) {
584 new_shapes_and_slices[it.index()] = "";
585 continue;
586 }
587
588 TF_ASSIGN_OR_RETURN(
589 std::vector<std::string> slice_specs,
590 SliceSpecOnDevice(layout, mesh, coords, global_shape));
591
592 // Concat shape and slice specs
593 new_shapes_and_slices[it.index()] =
594 llvm::formatv("{0} {1}", absl::StrJoin(global_shape, " "),
595 absl::StrJoin(slice_specs, ":"))
596 .str();
597 }
598
599 // Builds the restore op on device_id.
600 mlir::OpBuilder builder(op);
601 shapes_and_slices_mutable.assign(StringConst(
602 builder, op->getLoc(),
603 llvm::SmallVector<llvm::StringRef>(new_shapes_and_slices.begin(),
604 new_shapes_and_slices.end())));
605 mlir::func::FuncOp device_restore_fn = mlir::func::FuncOp::create(
606 location,
607 llvm::formatv("{0}_on_device_{1}_{2}", OpName(op), device_id,
608 OpHash(op))
609 .str(),
610 func_type, llvm::ArrayRef<mlir::NamedAttribute>{});
611 // Set function visibility to private to indicate that it is only used in
612 // this module.
613 device_restore_fn.setVisibility(mlir::SymbolTable::Visibility::Private);
614 symbol_table.insert(device_restore_fn);
615
616 mlir::Block* fn_block = device_restore_fn.addEntryBlock();
617 mlir::OpBuilder fn_builder = mlir::OpBuilder::atBlockBegin(fn_block);
618 mlir::Value prefix = device_restore_fn.getArgument(0);
619 mlir::Value tensor_names = device_restore_fn.getArgument(1);
620 // Constructs shapes and slices ourselves while reusing all other
621 // arguments.
622 auto new_restore_v2 = fn_builder.create<mlir::TF::RestoreV2Op>(
623 location, mlir::TypeRange(output_types), prefix, tensor_names,
624 StringConst(
625 fn_builder, location,
626 llvm::SmallVector<llvm::StringRef>(new_shapes_and_slices.begin(),
627 new_shapes_and_slices.end())));
628 fn_builder.create<mlir::func::ReturnOp>(location,
629 new_restore_v2.getResults());
630
631 branch_funcs[local_device_idx] = device_restore_fn;
632 }
633
634 // Builds the final case op.
635 llvm::SmallVector<mlir::Attribute, 4> symbols;
636 for (auto& func : branch_funcs)
637 symbols.push_back(mlir::SymbolRefAttr::get(func));
638
639 TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(op));
640 llvm::SmallVector<int64_t> local_device_ids(mesh.local_device_ids().begin(),
641 mesh.local_device_ids().end());
642 mlir::Value branch_index = DeviceIdToLocalBranchIndex(
643 location, local_device_ids, device_id, builder);
644
645 auto case_op = builder.create<mlir::TF::CaseOp>(
646 location,
647 /*output=*/mlir::TypeRange(output_types),
648 /*branch_index=*/branch_index,
649 /*input=*/op->getOperands(),
650 /*branches=*/builder.getArrayAttr(symbols),
651 /*is_stateless=*/builder.getBoolAttr(false));
652
653 op->replaceAllUsesWith(case_op);
654 op->erase();
655
656 return case_op.getOperation();
657 }
658
659 // DTensorRestoreV2 op has layouts and shapes as the attribute of the op
660 // itself. We extract those attributes and call the helper expander.
ExpandDTensorRestoreV2Op(mlir::Operation * op)661 StatusOr<mlir::Operation*> ExpandDTensorRestoreV2Op(mlir::Operation* op) {
662 mlir::TF::DTensorRestoreV2Op restore_v2 =
663 mlir::dyn_cast<mlir::TF::DTensorRestoreV2Op>(op);
664 if (!restore_v2) {
665 return errors::InvalidArgument(
666 llvm::formatv("Expecting DTensorRestoreV2Op but got {0}", OpName(op))
667 .str());
668 }
669
670 mlir::ArrayAttr input_shapes_attr =
671 restore_v2->getAttrOfType<mlir::ArrayAttr>("input_shapes");
672 if (!input_shapes_attr) {
673 return errors::InvalidArgument(
674 "DTensorRestoreV2Op requires input_shapes attributes.");
675 }
676
677 std::vector<std::vector<int64_t>> input_shapes;
678 input_shapes.reserve(input_shapes_attr.size());
679 for (const auto& shape : input_shapes_attr) {
680 mlir::TF::ShapeAttr shape_attr = shape.cast<mlir::TF::ShapeAttr>();
681 if (!shape_attr.hasStaticShape()) {
682 return errors::InvalidArgument(
683 llvm::formatv("DTensorRestoreV2Op requires statically known input "
684 "shape, but got non-static shape: {0}.",
685 shape_attr)
686 .str());
687 }
688 input_shapes.push_back(std::vector<int64_t>(shape_attr.getShape().begin(),
689 shape_attr.getShape().end()));
690 }
691
692 mlir::ArrayAttr input_layouts_attr = restore_v2.input_layouts();
693 if (!input_layouts_attr) {
694 return errors::InvalidArgument(
695 "DTensorRestoreV2Op requires input_layouts attributes.");
696 }
697 std::vector<Layout> input_layouts;
698 input_layouts.reserve(input_layouts_attr.size());
699 for (const auto& layout : input_layouts_attr.getValue().vec()) {
700 input_layouts.push_back(
701 Layout::FromString(layout.cast<mlir::StringAttr>().getValue().str())
702 .ValueOrDie());
703 }
704
705 return ExpandRestoreV2OpHelper(
706 op, input_shapes, input_layouts,
707 std::vector<mlir::Type>(op->getResultTypes().begin(),
708 op->getResultTypes().end()),
709 restore_v2.shape_and_slicesMutable());
710 }
711
712 // Extract the layout and shapes the normal way. By this time, we should
713 // have all necessary DTensorLayout op as the outputs of each op
714 // and the correct Type shapes and dtypes as the outputs of the tf.RestoreV2
715 // op.
716 //
717 // Call the helper expander function with those shapes and layouts.
ExpandRestoreV2Op(mlir::Operation * op)718 StatusOr<mlir::Operation*> ExpandRestoreV2Op(mlir::Operation* op) {
719 // Fetch the shape of each output.
720 std::vector<std::vector<int64_t>> global_shapes;
721 global_shapes.reserve(op->getNumResults());
722
723 // This is subtle. For tf.train.Checkpoint.save_counter scalar variable,
724 // this variable may not yet be created by the time we call
725 // Checkpoint.restore.
726 //
727 // In this case, the tf.RestoreV2 is called eagerly, and thus there is no
728 // tf.AssignVariable op. This means that we cannot infer the shapes and layout
729 // from previous pass CreateDTensorInferShapesForRestoreV2Op.
730 //
731 // But for save_counter, we know this is always replicated, and we can just
732 // return the op itself. For now, we will do this hacky way, but eventually
733 // we need to generalize restoring variables that are not yet created.
734 //
735 // TODO(b/235373719) Generalize support for checkpoint restoration for
736 // variables that are not yet created.
737 if (op->getNumResults() == 1 && !GetShapeOfValue(op->getResult(0)).ok()) {
738 return op;
739 }
740
741 for (auto result : op->getResults()) {
742 global_shapes.push_back(GetShapeOfValue(result).ValueOrDie());
743 }
744
745 // Fetch the layout of each output.
746 TF_ASSIGN_OR_RETURN(std::vector<Layout> layouts,
747 ExtractRequiredLayoutFromOp(op));
748
749 // Calculate the new local type range needed for the new RestoreV2Op we will
750 // emit.
751 std::vector<mlir::Type> new_types;
752 new_types.reserve(op->getNumResults());
753
754 for (const auto& it :
755 llvm::zip(op->getResultTypes(), global_shapes, layouts)) {
756 mlir::Type type = std::get<0>(it);
757 std::vector<int64_t>& shape = std::get<1>(it);
758 Layout& layout = std::get<2>(it);
759 new_types.push_back(mlir::RankedTensorType::get(
760 layout.LocalShapeFromGlobalShape(shape),
761 type.dyn_cast<mlir::RankedTensorType>().getElementType()));
762 }
763
764 return ExpandRestoreV2OpHelper(
765 op, global_shapes, layouts, new_types,
766 mlir::dyn_cast<mlir::TF::RestoreV2Op>(op).shape_and_slicesMutable());
767 }
768
769 } // namespace
770
ExpandOp(mlir::Operation * op)771 StatusOr<mlir::Operation*> SaveRestoreSPMDExpander::ExpandOp(
772 mlir::Operation* op) {
773 if (llvm::isa<mlir::TF::SaveV2Op>(op)) {
774 return ExpandSaveV2Op(op);
775 }
776 if (llvm::isa<mlir::TF::MergeV2CheckpointsOp>(op)) {
777 return ExpandMergeV2Op(op);
778 }
779 if (llvm::isa<mlir::TF::DTensorRestoreV2Op>(op)) {
780 return ExpandDTensorRestoreV2Op(op);
781 }
782 if (llvm::isa<mlir::TF::RestoreV2Op>(op)) {
783 return ExpandRestoreV2Op(op);
784 }
785
786 return errors::Unimplemented(
787 llvm::formatv("SPMD for op : {0} is not implemented ", OpName(op)).str());
788 }
789
790 // Find all the resource tensor layouts attached to the AssignVariableOp
791 // that `restore_op` is restoring to.
GetLayoutsFromAssignVariableOps(mlir::ModuleOp module,mlir::TF::RestoreV2Op * restore_op)792 StatusOr<llvm::SmallVector<Layout>> GetLayoutsFromAssignVariableOps(
793 mlir::ModuleOp module, mlir::TF::RestoreV2Op* restore_op) {
794 llvm::SmallVector<Layout> layouts(restore_op->getNumResults());
795
796 for (auto result : restore_op->getResults()) {
797 // Find the AssignVariableOp connected to this output. There should only
798 // be at most one IdentityOp and one DTensorSend between this result
799 // and the AssignVariableOp.
800 for (auto consuming_op : result.getUsers()) {
801 // To get to the AssignVariableOp that consumes `result`, we expect
802 // an IdentityOp or a DTensorSend op on the path. So, skip past
803 // these ops first.
804 while (llvm::isa<mlir::TF::IdentityOp, mlir::TF::DTensorSend>(
805 consuming_op)) {
806 if (auto send_op =
807 mlir::dyn_cast_or_null<mlir::TF::DTensorSend>(consuming_op)) {
808 TF_ASSIGN_OR_RETURN(
809 consuming_op, GetCorrespondingDTensorSendRecvOp(module, send_op));
810 }
811 auto next_op = consuming_op->getResult(0).getUsers();
812 if (next_op.empty()) {
813 return errors::Internal(
814 "Expected a result of an identity op to be consumed by another "
815 "op, but was empty during RestoreV2 Expansion.");
816 }
817 consuming_op = *next_op.begin();
818 }
819 // We skipped past ops like Identity and Send's. There might be an
820 // AssignVariableOp now.
821 if (auto assign_op = llvm::dyn_cast_or_null<mlir::TF::AssignVariableOp>(
822 consuming_op)) {
823 TF_ASSIGN_OR_RETURN(auto layout, ExtractRequiredLayoutFromOperand(
824 assign_op.resource()));
825 layouts[result.getResultNumber()] = layout;
826 break;
827 }
828 }
829 }
830 return layouts;
831 }
832
833 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts,const llvm::DenseMap<int,Layout> & output_layouts)834 SaveRestoreSPMDExpander::ComputeLayoutForward(
835 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts,
836 const llvm::DenseMap<int, Layout>& output_layouts) {
837 // Save op doesn't have return values.
838 if (llvm::isa<mlir::TF::SaveV2Op, mlir::TF::MergeV2CheckpointsOp>(op)) {
839 return llvm::DenseMap<int, Layout>();
840 }
841 if (llvm::isa<mlir::TF::RestoreV2Op>(op)) {
842 // If there are already output layouts specified, this means that
843 // we are in the Late Variable Creation restoration. For this path,
844 // the output layout is already specified, through the default layout
845 // scope. So just return that layout.
846 if (!output_layouts.empty()) return output_layouts;
847
848 mlir::ModuleOp module_op = op->getParentOfType<mlir::ModuleOp>();
849 mlir::TF::RestoreV2Op restore_v2 = mlir::cast<mlir::TF::RestoreV2Op>(op);
850 TF_ASSIGN_OR_RETURN(Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
851 if (!mesh.is_cpu_mesh()) {
852 return errors::InvalidArgument(
853 llvm::formatv(
854 "RestoreV2Op must run on a CPU mesh, but was running on: {0}",
855 mesh.ToString())
856 .str());
857 }
858 // Extract the layout of each resource tensor from the AssignVariableOp
859 // consuming each result. This layout sharding will be used as the
860 // output layout for each result tensor.
861 TF_ASSIGN_OR_RETURN(
862 auto layouts, GetLayoutsFromAssignVariableOps(module_op, &restore_v2));
863 if (layouts.size() != restore_v2.getNumResults()) {
864 return errors::Internal(llvm::formatv("Failed to get {0} output layouts "
865 "for RestoreV2Op. Got {1} layouts.",
866 restore_v2.getNumResults(),
867 layouts.size())
868 .str());
869 }
870 llvm::DenseMap<int, Layout> output_layouts(restore_v2.getNumResults());
871
872 // Change the mesh of each layout to `mesh` since RestoreOp always runs on
873 // the CPU.
874 for (int i = 0; i < layouts.size(); ++i) {
875 Layout host_mesh_layout = layouts[i];
876 host_mesh_layout.set_mesh(mesh);
877 output_layouts[i] = host_mesh_layout;
878 }
879 return output_layouts;
880 }
881 if (llvm::isa<mlir::TF::DTensorRestoreV2Op>(op)) {
882 mlir::TF::DTensorRestoreV2Op restore_v2 =
883 mlir::cast<mlir::TF::DTensorRestoreV2Op>(op);
884 llvm::DenseMap<int, Layout> output_layouts(restore_v2.getNumResults());
885 // Output layout is simply the layout from the arguments.
886 for (const auto& it : llvm::enumerate(restore_v2.input_layouts())) {
887 TF_ASSIGN_OR_RETURN(
888 Layout layout,
889 Layout::FromString(
890 it.value().cast<mlir::StringAttr>().getValue().str()));
891 output_layouts[it.index()] = layout;
892 }
893 return output_layouts;
894 }
895 return errors::Unimplemented(
896 llvm::formatv("Layout propagation for op : {0} is not implemented",
897 OpName(op))
898 .str());
899 }
900
901 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)902 SaveRestoreSPMDExpander::ComputeLayoutBackward(
903 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
904 return llvm::DenseMap<int, Layout>();
905 }
906
907 } // namespace dtensor
908 } // namespace tensorflow
909