xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/load_and_remap_matrix_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
31 
32 namespace tensorflow {
33 
34 namespace {
35 // Returning a Status instead of using OP_REQUIRES directly since that doesn't
36 // seem to work outside the main OpKernel functions.
RemapVectorToMap(const TTypes<const int64_t>::Vec & remapping,std::vector<bool> * id_present,std::unordered_map<int64_t,int64_t> * old_id_to_new_id)37 Status RemapVectorToMap(
38     const TTypes<const int64_t>::Vec& remapping, std::vector<bool>* id_present,
39     std::unordered_map<int64_t, int64_t>* old_id_to_new_id) {
40   id_present->clear();
41   id_present->resize(remapping.size(), false);
42   for (int i = 0; i < remapping.size(); ++i) {
43     const int64_t old_id = remapping(i);
44     if (old_id < 0) continue;
45     (*id_present)[i] = true;
46     if (!gtl::InsertIfNotPresent(old_id_to_new_id, old_id, i)) {
47       return errors::Unimplemented(
48           strings::StrCat("Old ID ", old_id, " is mapped to both new ID ",
49                           old_id_to_new_id->at(old_id), " and ", i,
50                           ", which is not supported."));
51     }
52   }
53   return OkStatus();
54 }
55 }  // anonymous namespace
56 
57 // This op loads a rank-2 Tensor (matrix) from a TensorFlow checkpoint (V2) and
58 // swaps around the rows/columns according to row_remapping/col_remapping.
59 // "Missing" cells are initialized with values from initializing_values.
60 class LoadAndRemapMatrixOp : public OpKernel {
61  public:
LoadAndRemapMatrixOp(OpKernelConstruction * context)62   explicit LoadAndRemapMatrixOp(OpKernelConstruction* context)
63       : OpKernel(context) {
64     OP_REQUIRES_OK(context, context->GetAttr("num_rows", &num_rows_));
65     OP_REQUIRES_OK(context, context->GetAttr("num_cols", &num_cols_));
66     OP_REQUIRES_OK(
67         context, context->GetAttr("max_rows_in_memory", &max_rows_in_memory_));
68   }
69 
Compute(OpKernelContext * context)70   void Compute(OpKernelContext* context) override {
71     // Checks what we're remapping and inverts the relevant remapping Tensors to
72     // be maps with key = old ID, value = new ID.
73     std::unordered_map<int64_t, int64_t> old_row_to_new_row_map;
74     std::vector<bool> row_id_present;
75     const Tensor* row_remapping_t;
76     OP_REQUIRES_OK(context, context->input("row_remapping", &row_remapping_t));
77     OP_REQUIRES(
78         context, row_remapping_t->dims() == 1,
79         errors::InvalidArgument("The `row_remapping` tensor must be 1-D, got "
80                                 "a tensor of shape ",
81                                 row_remapping_t->shape().DebugString()));
82     const auto row_remapping = row_remapping_t->vec<int64_t>();
83     OP_REQUIRES(context, row_remapping.size() == num_rows_,
84                 errors::InvalidArgument(strings::StrCat(
85                     "Size of row_remapping is ", row_remapping.size(),
86                     " instead of being equal to num_rows=", num_rows_)));
87     OP_REQUIRES_OK(context, RemapVectorToMap(row_remapping, &row_id_present,
88                                              &old_row_to_new_row_map));
89 
90     // Calculates the min/max old row ID that we need to read, to save us from
91     // reading some unnecessary slices of the old tensor.
92     int64_t min_old_row = -1;
93     int64_t max_old_row = -1;
94     for (int i = 0; i < row_remapping.size(); ++i) {
95       if (min_old_row < 0 ||
96           (row_remapping(i) >= 0 && row_remapping(i) < min_old_row)) {
97         min_old_row = row_remapping(i);
98       }
99       if (max_old_row < 0 ||
100           (row_remapping(i) >= 0 && row_remapping(i) > max_old_row)) {
101         max_old_row = row_remapping(i);
102       }
103     }
104 
105     // Processes the remapping for columns.
106     std::unordered_map<int64_t, int64_t> old_col_to_new_col_map;
107     std::vector<bool> col_id_present;
108     const Tensor* col_remapping_t;
109     OP_REQUIRES_OK(context, context->input("col_remapping", &col_remapping_t));
110     const auto col_remapping = col_remapping_t->vec<int64_t>();
111     // Note that we always "remap rows", even when the row vocabulary does
112     // not change, because partitioning requires a mapping from partitioned
113     // Variables to the full checkpoints we load.
114     const bool remap_cols = col_remapping.size() > 0;
115     if (remap_cols) {
116       OP_REQUIRES(
117           context, col_remapping.size() == num_cols_,
118           errors::InvalidArgument(strings::StrCat(
119               "Provided col_remapping, but its size is ", col_remapping.size(),
120               " instead of being equal to num_cols=", num_cols_)));
121       OP_REQUIRES_OK(context, RemapVectorToMap(col_remapping, &col_id_present,
122                                                &old_col_to_new_col_map));
123     } else {
124       col_id_present.clear();
125       col_id_present.resize(num_cols_, true);
126     }
127 
128     // Processes the checkpoint source and the provided Tensor name.
129     const Tensor* ckpt_path_t;
130     OP_REQUIRES_OK(context, context->input("ckpt_path", &ckpt_path_t));
131     OP_REQUIRES(
132         context, ckpt_path_t->NumElements() == 1,
133         errors::InvalidArgument("The `ckpt_path` tensor must have exactly one "
134                                 "element, got tensor of shape ",
135                                 ckpt_path_t->shape().DebugString()));
136     const string& ckpt_path = ckpt_path_t->scalar<tstring>()();
137     const Tensor* old_tensor_name_t;
138     OP_REQUIRES_OK(context,
139                    context->input("old_tensor_name", &old_tensor_name_t));
140     const string& old_tensor_name = old_tensor_name_t->scalar<tstring>()();
141 
142     LOG(INFO) << "Processing checkpoint : " << ckpt_path;
143     BundleReader reader(context->env(), ckpt_path);
144     OP_REQUIRES_OK(context, reader.status());
145 
146     DataType tensor_type;
147     TensorShape tensor_shape;
148     OP_REQUIRES_OK(context, reader.LookupDtypeAndShape(
149                                 old_tensor_name, &tensor_type, &tensor_shape));
150     OP_REQUIRES(context, tensor_type == DT_FLOAT,
151                 errors::InvalidArgument(strings::StrCat(
152                     "Tensor ", old_tensor_name, " has invalid type ",
153                     DataTypeString(tensor_type), " instead of expected type ",
154                     DataTypeString(DT_FLOAT))));
155     // This op is limited to loading Tensors of rank 2 (matrices).
156     OP_REQUIRES(
157         context, tensor_shape.dims() == 2,
158         errors::InvalidArgument(strings::StrCat(
159             "Tensor ", old_tensor_name, " has shape ",
160             tensor_shape.DebugString(), " of invalid rank ",
161             tensor_shape.dims(), " instead of expected shape of rank 2.")));
162 
163     if (!remap_cols) {
164       // TODO(weiho): Consider relaxing this restriction to allow partial column
165       // loading (even when no column remapping is specified) if there turns out
166       // to be a use case for it.
167       OP_REQUIRES(context, num_cols_ == tensor_shape.dim_size(1),
168                   errors::InvalidArgument(strings::StrCat(
169                       "Tensor ", old_tensor_name, " has shape ",
170                       tensor_shape.DebugString(),
171                       ", where the size of its 2nd dimension is ",
172                       tensor_shape.dim_size(1),
173                       " instead of being equal to num_cols=", num_cols_)));
174     }
175 
176     // Uses TensorSlice to potentially load the old tensor in chunks in case
177     // memory usage is a concern.
178     std::vector<TensorSlice> tensor_slices;
179     TensorSlice slice(tensor_shape.dims());
180     if (min_old_row >= 0 && max_old_row >= 0) {
181       int64_t row_start = min_old_row;
182       // TODO(weiho): Given the list of old row IDs of interest (the keys of
183       // old_row_to_new_row_map), we could also try something smarter to
184       // find some minimal set of covering ranges for the list of old row IDs
185       // such that the size of each range is less than max_rows_in_memory_.
186       while (row_start <= max_old_row) {
187         const int64_t slice_length =
188             max_rows_in_memory_ <= 0
189                 // If max_rows_in_memory_ <= 0, we just load the entire chunk.
190                 ? max_old_row - row_start + 1
191                 : std::min(max_rows_in_memory_, max_old_row - row_start + 1);
192         slice.set_start(0, row_start);
193         slice.set_length(0, slice_length);
194         tensor_slices.push_back(slice);
195         row_start += slice_length;
196       }
197     }
198 
199     // Allocates the output matrix.
200     Tensor* output_matrix_t = nullptr;
201     OP_REQUIRES_OK(context,
202                    context->allocate_output("output_matrix",
203                                             TensorShape({num_rows_, num_cols_}),
204                                             &output_matrix_t));
205     auto output_matrix = output_matrix_t->matrix<float>();
206 
207     // Iterates through tensor slices and copies over values from the old tensor
208     // to the output matrix.
209     int64_t row_index = min_old_row;
210     int64_t rows_copied = 0;
211     Tensor loaded_tensor_t;
212     for (const TensorSlice& tensor_slice : tensor_slices) {
213       LOG(INFO) << "Loading slice " << tensor_slice.DebugString();
214       TensorShape slice_shape;
215       OP_REQUIRES_OK(context,
216                      tensor_slice.SliceTensorShape(tensor_shape, &slice_shape));
217       // Potentially re-allocates the tensor buffer since the last slice may
218       // have fewer rows than the other slices.
219       if (loaded_tensor_t.shape() != slice_shape) {
220         loaded_tensor_t = Tensor(DT_FLOAT, slice_shape);
221       }
222       OP_REQUIRES_OK(context, reader.LookupSlice(old_tensor_name, tensor_slice,
223                                                  &loaded_tensor_t));
224 
225       // Iterates through the old loaded tensor slice row-by-row.
226       for (int row = 0; row < loaded_tensor_t.dim_size(0); ++row, ++row_index) {
227         if (row_index % 500000 == min_old_row) {
228           LOG(INFO) << "Processing old row " << row_index;
229         }
230 
231         // If the old row ID is not found in old_row_to_new_row_map, continue
232         // to the next row; otherwise, copy it to the output matrix.
233         const int64_t* new_row_ptr =
234             gtl::FindOrNull(old_row_to_new_row_map, row_index);
235         if (new_row_ptr == nullptr) {
236           continue;
237         }
238         ++rows_copied;
239         const int64_t new_row = *new_row_ptr;
240 
241         // Copies over the row element-by-element, in case remapping is needed
242         // along the column axis.
243         const auto& loaded_tensor = loaded_tensor_t.matrix<float>();
244         for (int old_col = 0; old_col < loaded_tensor_t.dim_size(1);
245              ++old_col) {
246           int64_t new_col = old_col;
247           if (remap_cols) {
248             const int64_t* new_col_ptr =
249                 gtl::FindOrNull(old_col_to_new_col_map, old_col);
250             if (new_col_ptr == nullptr) {
251               // Column remapping is specified, but this column is not found in
252               // old_col_to_new_col_map, so we leave it uninitialized, to be
253               // filled in with initializing_values later.
254               continue;
255             }
256             new_col = *new_col_ptr;
257           }
258 
259           OP_REQUIRES(context,
260                       new_row < num_rows_ && new_col < num_cols_ &&
261                           new_row >= 0 && new_col >= 0,
262                       errors::Internal(strings::StrCat(
263                           "new_row=", new_row, " and new_col=", new_col,
264                           " should have been less than num_rows_=", num_rows_,
265                           " and num_cols_=", num_cols_,
266                           " and non-negative. This should never have happened "
267                           "if the code were correct. Please file a bug.")));
268           output_matrix(new_row, new_col) = loaded_tensor(row, old_col);
269         }
270       }
271     }
272     LOG(INFO) << "Copied " << rows_copied << " rows from old matrix (with "
273               << tensor_shape.dim_size(0) << " rows) to new matrix (with "
274               << num_rows_ << " rows).";
275 
276     // At this point, there are potentially whole rows/columns uninitialized
277     // (corresponding to the indices where row_id_present/col_id_present are
278     // false). We fill this in cell-by-cell using row_id_present and
279     // col_id_present while dequeuing from the initializing_values vector.
280     const Tensor* initializing_values_t;
281     OP_REQUIRES_OK(
282         context, context->input("initializing_values", &initializing_values_t));
283     const auto initializing_values = initializing_values_t->flat<float>();
284     int64_t initializing_values_index = 0;
285     for (int i = 0; i < num_rows_; ++i) {
286       for (int j = 0; j < num_cols_; ++j) {
287         if (row_id_present[i] && col_id_present[j]) continue;
288         OP_REQUIRES(
289             context, initializing_values_index < initializing_values.size(),
290             errors::InvalidArgument(
291                 "initializing_values contained ", initializing_values.size(),
292                 " elements, but more missing values remain."));
293         output_matrix(i, j) = initializing_values(initializing_values_index);
294         ++initializing_values_index;
295       }
296     }
297 
298     // Checks that we used all the given initializing values.
299     OP_REQUIRES(
300         context, initializing_values_index == initializing_values.size(),
301         errors::InvalidArgument(
302             "initializing_values contained ", initializing_values.size(),
303             " elements, but only ", initializing_values_index,
304             " elements were used to fill in missing values."));
305   }
306 
307  private:
308   int64_t num_rows_;
309   int64_t num_cols_;
310   int64_t max_rows_in_memory_;
311 };
312 
313 REGISTER_KERNEL_BUILDER(Name("LoadAndRemapMatrix").Device(DEVICE_CPU),
314                         LoadAndRemapMatrixOp);
315 
316 }  // namespace tensorflow
317