xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/dtensor_sharded_prefix_kernel.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/platform/errors.h"
21 #include "tensorflow/dtensor/cc/save_restore_util.h"
22 #include "tensorflow/dtensor/cc/tensor_layout.h"
23 
24 namespace tensorflow {
25 namespace dtensor {
26 
27 // An OpKernel that queries prefixes of all generated Save v2 ops. This is
28 // needed in distributed context to track all save ops inserted by DTensor SPMD,
29 // and ensures a proper MergeV2 ops afterwards.
30 class DTensorShardedPrefixOpKernel : public OpKernel {
31  public:
DTensorShardedPrefixOpKernel(OpKernelConstruction * ctx)32   explicit DTensorShardedPrefixOpKernel(OpKernelConstruction* ctx)
33       : OpKernel(ctx) {}
34 
35  private:
Compute(OpKernelContext * ctx)36   void Compute(OpKernelContext* ctx) override {
37     const Tensor* prefix_tensor;
38     const Tensor* tensor_names;
39     const Tensor* shape_and_slices;
40     const Tensor* mesh_tensor;
41     const Tensor* layouts;
42 
43     OP_REQUIRES_OK(ctx, ctx->input("prefix", &prefix_tensor));
44     OP_REQUIRES_OK(ctx, ctx->input("tensor_names", &tensor_names));
45     OP_REQUIRES_OK(ctx, ctx->input("shape_and_slices", &shape_and_slices));
46     OP_REQUIRES_OK(ctx, ctx->input("mesh", &mesh_tensor));
47     OP_REQUIRES_OK(ctx, ctx->input("layouts", &layouts));
48 
49     const std::string& prefix = prefix_tensor->scalar<tstring>()();
50 
51     const auto& shape_and_slices_vec = shape_and_slices->flat<tstring>();
52     for (int i = 0; i < shape_and_slices->NumElements(); ++i) {
53       OP_REQUIRES(
54           ctx, shape_and_slices_vec(i).empty(),
55           errors::Unimplemented("DTensor save currently does not support "
56                                 "distributed save with shape_and_slices"));
57     }
58 
59     const std::string& mesh_str = mesh_tensor->scalar<tstring>()();
60     const auto& mesh_or = Mesh::FromString(mesh_str);
61     OP_REQUIRES(ctx, mesh_or.ok(),
62                 errors::InvalidArgument(
63                     absl::StrCat("Got invalid mesh string : ", mesh_str)));
64     const Mesh& mesh = *mesh_or;
65 
66     const auto& layouts_flat = layouts->flat<tensorflow::tstring>();
67     OP_REQUIRES(ctx, tensor_names->NumElements() == layouts->NumElements(),
68                 errors::InvalidArgument(absl::StrCat(
69                     "tensor_names must match the size of layouts, "
70                     "but got tensor_names size : ",
71                     tensor_names->NumElements(),
72                     " and layouts size : ", layouts->NumElements())));
73 
74     // (prefix, tensor names, shape_and_slices, mesh, layout) are fixed inputs
75     // while tensors are variadic inputs.
76     const int kFixedInputs = 5;
77     const int num_tensors = static_cast<int>(tensor_names->NumElements());
78 
79     // Construct a map of the <tensor_idex -> (tensor_shape, Layout)) to build
80     // the saving spec.
81     std::vector<SavingTensorMetadata> metadata;
82     metadata.reserve(num_tensors);
83 
84     for (int i = 0; i < num_tensors; ++i) {
85       const string& layout_string = layouts_flat(i);
86       const Tensor& tensor = ctx->input(i + kFixedInputs);
87 
88       // Note that in runtime we always have local shape, so here we recovers to
89       // global shape to compute saving specs correctly.
90       const TensorShape& shape = tensor.shape();
91       std::vector<int64_t> local_shape;
92       local_shape.reserve(shape.dims());
93       for (int dim = 0; dim < shape.dims(); ++dim) {
94         local_shape.push_back(shape.dim_size(dim));
95       }
96 
97       const auto& layout_or = Layout::FromString(layout_string);
98       OP_REQUIRES(ctx, layout_or.ok(),
99                   errors::InvalidArgument(absl::StrCat(
100                       "Tensor at index : ", i,
101                       " has invalid layout string : ", layout_string)));
102       std::vector<int64_t> global_shape =
103           layout_or->GlobalShapeFromLocalShape(local_shape);
104       metadata.push_back(SavingTensorMetadata(i, std::move(global_shape),
105                                               std::move(*layout_or)));
106     }
107 
108     const auto& saving_specs_or = BuildSavingSpec(metadata);
109     OP_REQUIRES(ctx, saving_specs_or.ok(),
110                 errors::Internal(absl::StrCat(
111                     "failed to build saving specs for given shapes and "
112                     "layouts. This should not happen. Message from stack : ",
113                     saving_specs_or.status().error_message())));
114 
115     const absl::flat_hash_map<
116         int64_t, absl::flat_hash_map<int64_t, std::vector<std::string>>>&
117         saving_spec = *saving_specs_or;
118 
119     // Construct the mesh and builds per device save ops.
120     // We don't need to build the real save ops here. Rather, we just query the
121     // shards that would be generated in DTensor SPMD.
122     std::vector<std::string> all_shard_prefixes;
123     for (int device_id = 0; device_id < mesh.size(); ++device_id) {
124       const auto& it = saving_spec.find(device_id);
125       if (it == saving_spec.end()) continue;
126       SaveOpSpecs specs =
127           BuildPerDeviceSave(it->second, device_id, prefix, mesh.size());
128       // Add all generated shards into a vector
129       for (const std::string& new_prefix : specs.new_prefixes) {
130         all_shard_prefixes.push_back(new_prefix);
131       }
132     }
133 
134     Tensor* out;
135     auto out_vector_size = all_shard_prefixes.size();
136     OP_REQUIRES_OK(
137         ctx,
138         ctx->allocate_output(
139             0, TensorShape({static_cast<int64_t>(out_vector_size)}), &out));
140     for (size_t i = 0; i < out_vector_size; ++i) {
141       out->flat<tstring>()(i) = all_shard_prefixes[i];
142     }
143   }
144 };
145 
146 REGISTER_KERNEL_BUILDER(Name("DTensorShardedPrefix").Device(DEVICE_CPU),
147                         DTensorShardedPrefixOpKernel);
148 
149 }  // namespace dtensor
150 }  // namespace tensorflow
151