1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <string> 17 18 #include "absl/container/flat_hash_map.h" 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/platform/errors.h" 21 #include "tensorflow/dtensor/cc/save_restore_util.h" 22 #include "tensorflow/dtensor/cc/tensor_layout.h" 23 24 namespace tensorflow { 25 namespace dtensor { 26 27 // An OpKernel that queries prefixes of all generated Save v2 ops. This is 28 // needed in distributed context to track all save ops inserted by DTensor SPMD, 29 // and ensures a proper MergeV2 ops afterwards. 30 class DTensorShardedPrefixOpKernel : public OpKernel { 31 public: DTensorShardedPrefixOpKernel(OpKernelConstruction * ctx)32 explicit DTensorShardedPrefixOpKernel(OpKernelConstruction* ctx) 33 : OpKernel(ctx) {} 34 35 private: Compute(OpKernelContext * ctx)36 void Compute(OpKernelContext* ctx) override { 37 const Tensor* prefix_tensor; 38 const Tensor* tensor_names; 39 const Tensor* shape_and_slices; 40 const Tensor* mesh_tensor; 41 const Tensor* layouts; 42 43 OP_REQUIRES_OK(ctx, ctx->input("prefix", &prefix_tensor)); 44 OP_REQUIRES_OK(ctx, ctx->input("tensor_names", &tensor_names)); 45 OP_REQUIRES_OK(ctx, ctx->input("shape_and_slices", &shape_and_slices)); 46 OP_REQUIRES_OK(ctx, ctx->input("mesh", &mesh_tensor)); 47 OP_REQUIRES_OK(ctx, ctx->input("layouts", &layouts)); 48 49 const std::string& prefix = prefix_tensor->scalar<tstring>()(); 50 51 const auto& shape_and_slices_vec = shape_and_slices->flat<tstring>(); 52 for (int i = 0; i < shape_and_slices->NumElements(); ++i) { 53 OP_REQUIRES( 54 ctx, shape_and_slices_vec(i).empty(), 55 errors::Unimplemented("DTensor save currently does not support " 56 "distributed save with shape_and_slices")); 57 } 58 59 const std::string& mesh_str = mesh_tensor->scalar<tstring>()(); 60 const auto& mesh_or = Mesh::FromString(mesh_str); 61 OP_REQUIRES(ctx, mesh_or.ok(), 62 errors::InvalidArgument( 63 absl::StrCat("Got invalid mesh string : ", mesh_str))); 64 const Mesh& mesh = *mesh_or; 65 66 const auto& layouts_flat = layouts->flat<tensorflow::tstring>(); 67 OP_REQUIRES(ctx, tensor_names->NumElements() == layouts->NumElements(), 68 errors::InvalidArgument(absl::StrCat( 69 "tensor_names must match the size of layouts, " 70 "but got tensor_names size : ", 71 tensor_names->NumElements(), 72 " and layouts size : ", layouts->NumElements()))); 73 74 // (prefix, tensor names, shape_and_slices, mesh, layout) are fixed inputs 75 // while tensors are variadic inputs. 76 const int kFixedInputs = 5; 77 const int num_tensors = static_cast<int>(tensor_names->NumElements()); 78 79 // Construct a map of the <tensor_idex -> (tensor_shape, Layout)) to build 80 // the saving spec. 81 std::vector<SavingTensorMetadata> metadata; 82 metadata.reserve(num_tensors); 83 84 for (int i = 0; i < num_tensors; ++i) { 85 const string& layout_string = layouts_flat(i); 86 const Tensor& tensor = ctx->input(i + kFixedInputs); 87 88 // Note that in runtime we always have local shape, so here we recovers to 89 // global shape to compute saving specs correctly. 90 const TensorShape& shape = tensor.shape(); 91 std::vector<int64_t> local_shape; 92 local_shape.reserve(shape.dims()); 93 for (int dim = 0; dim < shape.dims(); ++dim) { 94 local_shape.push_back(shape.dim_size(dim)); 95 } 96 97 const auto& layout_or = Layout::FromString(layout_string); 98 OP_REQUIRES(ctx, layout_or.ok(), 99 errors::InvalidArgument(absl::StrCat( 100 "Tensor at index : ", i, 101 " has invalid layout string : ", layout_string))); 102 std::vector<int64_t> global_shape = 103 layout_or->GlobalShapeFromLocalShape(local_shape); 104 metadata.push_back(SavingTensorMetadata(i, std::move(global_shape), 105 std::move(*layout_or))); 106 } 107 108 const auto& saving_specs_or = BuildSavingSpec(metadata); 109 OP_REQUIRES(ctx, saving_specs_or.ok(), 110 errors::Internal(absl::StrCat( 111 "failed to build saving specs for given shapes and " 112 "layouts. This should not happen. Message from stack : ", 113 saving_specs_or.status().error_message()))); 114 115 const absl::flat_hash_map< 116 int64_t, absl::flat_hash_map<int64_t, std::vector<std::string>>>& 117 saving_spec = *saving_specs_or; 118 119 // Construct the mesh and builds per device save ops. 120 // We don't need to build the real save ops here. Rather, we just query the 121 // shards that would be generated in DTensor SPMD. 122 std::vector<std::string> all_shard_prefixes; 123 for (int device_id = 0; device_id < mesh.size(); ++device_id) { 124 const auto& it = saving_spec.find(device_id); 125 if (it == saving_spec.end()) continue; 126 SaveOpSpecs specs = 127 BuildPerDeviceSave(it->second, device_id, prefix, mesh.size()); 128 // Add all generated shards into a vector 129 for (const std::string& new_prefix : specs.new_prefixes) { 130 all_shard_prefixes.push_back(new_prefix); 131 } 132 } 133 134 Tensor* out; 135 auto out_vector_size = all_shard_prefixes.size(); 136 OP_REQUIRES_OK( 137 ctx, 138 ctx->allocate_output( 139 0, TensorShape({static_cast<int64_t>(out_vector_size)}), &out)); 140 for (size_t i = 0; i < out_vector_size; ++i) { 141 out->flat<tstring>()(i) = all_shard_prefixes[i]; 142 } 143 } 144 }; 145 146 REGISTER_KERNEL_BUILDER(Name("DTensorShardedPrefix").Device(DEVICE_CPU), 147 DTensorShardedPrefixOpKernel); 148 149 } // namespace dtensor 150 } // namespace tensorflow 151