xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdint>
16 #include <utility>
17 #include <vector>
18 
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/variant.h"
24 #include "tensorflow/core/framework/variant_encode_decode.h"
25 #include "tensorflow/core/framework/variant_op_registry.h"
26 #include "tensorflow/core/kernels/concat_lib.h"
27 #include "tensorflow/core/kernels/ragged_tensor_variant.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/core/util/tensor_ops_util.h"
32 
33 namespace tensorflow {
34 namespace {
35 
36 template <typename VALUE_TYPE>
UnbatchDenseZerothDim(const RaggedTensorVariant & batched_ragged,std::vector<RaggedTensorVariant> * ragged_components)37 Status UnbatchDenseZerothDim(
38     const RaggedTensorVariant& batched_ragged,
39     std::vector<RaggedTensorVariant>* ragged_components) {
40   Tensor batched_values = batched_ragged.values();
41   TensorShape values_shape = batched_values.shape();
42   if (values_shape.dims() < 1) {
43     return errors::InvalidArgument("Can't unbatch rank-0 tensor.");
44   }
45   auto num_components = values_shape.dim_size(0);
46   values_shape.RemoveDim(0);
47   auto num_values = values_shape.num_elements();
48 
49   ragged_components->resize(num_components);
50   const auto& batched_flat = batched_values.flat<VALUE_TYPE>();
51 
52   for (auto i = decltype(num_components){}; i < num_components; i++) {
53     (*ragged_components)[i].set_values(
54         Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
55     auto ragged_component_values_flat =
56         (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
57     for (auto j = decltype(num_values){}; j < num_values; j++) {
58       ragged_component_values_flat(j) = batched_flat(j + i * num_values);
59     }
60   }
61 
62   return OkStatus();
63 }
64 
65 template <typename VALUE_TYPE, typename SPLIT_TYPE>
UnbatchRaggedZerothDim(const RaggedTensorVariant & batched_ragged,std::vector<RaggedTensorVariant> * ragged_components)66 Status UnbatchRaggedZerothDim(
67     const RaggedTensorVariant& batched_ragged,
68     std::vector<RaggedTensorVariant>* ragged_components) {
69   // Set up the component Ragged Tensors.
70   int ragged_rank = batched_ragged.ragged_rank();
71   if (ragged_rank == 0) {
72     return UnbatchDenseZerothDim<VALUE_TYPE>(batched_ragged, ragged_components);
73   }
74 
75   auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>();
76   auto num_components = batched_splits_top_vec.size() - 1;
77 
78   if (num_components < 0) {
79     return errors::Internal("Invalid split argument.");
80   }
81 
82   int num_splits = ragged_rank - 1;
83   ragged_components->resize(num_components);
84   for (RaggedTensorVariant& ragged_component : *ragged_components) {
85     ragged_component.mutable_nested_splits()->reserve(num_splits);
86   }
87   const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>();
88   auto num_inner_elems = batched_ragged.values().NumElements();
89   if (batched_ragged.values().dim_size(0) > 1) {
90     num_inner_elems /= batched_ragged.values().dim_size(0);
91   }
92   TensorShape values_shape = batched_ragged.values().shape();
93 
94   // Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]]
95   if (num_splits == 0) {
96     for (auto i = decltype(num_components){}; i < num_components; i++) {
97       auto start = batched_splits_top_vec(i);
98       auto limit = batched_splits_top_vec(i + 1);
99       auto num_values = limit - start;
100       values_shape.set_dim(0, num_values);
101       (*ragged_components)[i].set_values(
102           Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
103       auto ragged_component_values_flat =
104           (*ragged_components)[i].mutable_values()->template flat<VALUE_TYPE>();
105       for (auto j = decltype(num_values * num_inner_elems){};
106            j < num_values * num_inner_elems; j++) {
107         ragged_component_values_flat(j) =
108             batched_flat(j + start * num_inner_elems);
109       }
110     }
111     return OkStatus();
112   }
113 
114   // Unbatch nested splits.
115   std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec;
116   batched_splits_vec.reserve(ragged_rank);
117   for (int i = 0; i < ragged_rank; i++) {
118     batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>());
119   }
120   std::vector<SPLIT_TYPE> index(num_splits, 1);
121   std::vector<SPLIT_TYPE> ragged_component_values_size(num_components, 0);
122   for (auto i = decltype(num_components){}; i < num_components; i++) {
123     std::vector<typename TTypes<SPLIT_TYPE>::Vec> ragged_component_splits_vec;
124     ragged_component_splits_vec.reserve(num_splits);
125     SPLIT_TYPE split_size = -1;
126     for (int j = 0; j < num_splits; j++) {
127       if (j == 0) {
128         split_size =
129             batched_splits_top_vec(i + 1) - batched_splits_top_vec(i) + 1;
130       } else {
131         // Update split size based on previous split.
132         SPLIT_TYPE last_index = ragged_component_splits_vec[j - 1].size() - 1;
133         split_size = ragged_component_splits_vec[j - 1](last_index) + 1;
134       }
135       (*ragged_components)[i].append_splits(
136           Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
137       ragged_component_splits_vec.push_back((*ragged_components)[i]
138                                                 .mutable_splits(j)
139                                                 ->template vec<SPLIT_TYPE>());
140       SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1);
141       ragged_component_splits_vec[j](0) = 0;
142       for (SPLIT_TYPE k = 1; k < split_size; k++, index[j]++) {
143         ragged_component_splits_vec[j](k) =
144             batched_splits_vec[j + 1](index[j]) - last_split_value;
145       }
146     }
147     SPLIT_TYPE last_split_size =
148         ragged_component_splits_vec[num_splits - 1].size();
149     ragged_component_values_size[i] =
150         ragged_component_splits_vec[num_splits - 1](last_split_size - 1);
151   }
152 
153   // Unbatch values.
154   int64_t value_index = 0;
155   for (auto i = decltype(num_components){}; i < num_components; i++) {
156     SPLIT_TYPE num_values = ragged_component_values_size[i];
157     values_shape.set_dim(0, num_values);
158     (*ragged_components)[i].set_values(
159         Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
160     auto ragged_component_values_flat =
161         (*ragged_components)[i].mutable_values()->template flat<VALUE_TYPE>();
162     for (int64_t j = 0; j < num_values * num_inner_elems; j++, value_index++) {
163       ragged_component_values_flat(j) = batched_flat(value_index);
164     }
165   }
166 
167   return OkStatus();
168 }
169 }  // namespace
170 
171 template <typename VALUE_TYPE, typename SPLIT_TYPE>
172 class RaggedTensorToVariantOp : public OpKernel {
173  public:
RaggedTensorToVariantOp(OpKernelConstruction * context)174   explicit RaggedTensorToVariantOp(OpKernelConstruction* context)
175       : OpKernel(context) {
176     OP_REQUIRES_OK(context, context->GetAttr("batched_input", &batched_input_));
177   }
178 
Compute(OpKernelContext * context)179   void Compute(OpKernelContext* context) override {
180     // Read ragged_splits inputs.
181     OpInputList ragged_nested_splits_in;
182     OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
183                                                 &ragged_nested_splits_in));
184     const int ragged_nested_splits_len = ragged_nested_splits_in.size();
185     RaggedTensorVariant batched_ragged_input;
186     // Read ragged_values input.
187     batched_ragged_input.set_values(context->input(ragged_nested_splits_len));
188     batched_ragged_input.mutable_nested_splits()->reserve(
189         ragged_nested_splits_len);
190     for (int i = 0; i < ragged_nested_splits_len; i++) {
191       OP_REQUIRES(context, ragged_nested_splits_in[i].dims() == 1,
192                   errors::InvalidArgument("Requires nested_row_splits[", i, "]",
193                                           " to be rank 1 but is rank ",
194                                           ragged_nested_splits_in[i].dims()));
195       batched_ragged_input.append_splits(ragged_nested_splits_in[i]);
196     }
197 
198     if (!batched_input_) {
199       // Encode as a Scalar Variant Tensor.
200       Tensor* encoded_scalar;
201       OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
202                                                        &encoded_scalar));
203       encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input);
204       return;
205     }
206 
207     // Checked here instead of at input in case batched_input_ is false
208     OP_REQUIRES(context, ragged_nested_splits_len > 0,
209                 errors::InvalidArgument(
210                     "rt_nested_splits must be a list of one or more, but "
211                     "received rt_nested_splits of length 0."));
212 
213     // Unbatch the Ragged Tensor and encode the components.
214     std::vector<RaggedTensorVariant> unbatched_ragged_input;
215     auto batched_splits_top_vec =
216         batched_ragged_input.splits(0).vec<SPLIT_TYPE>();
217     int num_components = batched_splits_top_vec.size() - 1;
218     OP_REQUIRES(context, num_components >= 0,
219                 errors::Internal("Invalid split argument."));
220     OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>(
221                                 batched_ragged_input, &unbatched_ragged_input));
222 
223     // Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor.
224     Tensor* encoded_vector;
225 
226     // output_size will be used for calling TensorShape(int64_t ...). We
227     // cannot use `auto` type here, or there will be a narrowing error.
228     int64_t output_size = unbatched_ragged_input.size();
229     OP_REQUIRES_OK(context,
230                    context->allocate_output(0, TensorShape({output_size}),
231                                             &encoded_vector));
232     auto encoded_vector_t = encoded_vector->vec<Variant>();
233     for (auto i = decltype(output_size){}; i < output_size; i++) {
234       encoded_vector_t(i) = unbatched_ragged_input[i];
235     }
236   }
237 
238  private:
239   bool batched_input_;
240 };
241 
242 template <typename VALUE_TYPE, typename SPLIT_TYPE>
243 class RaggedTensorToVariantGradientOp : public OpKernel {
244  public:
245   using OpKernel::OpKernel;
246 
Compute(OpKernelContext * context)247   void Compute(OpKernelContext* context) override {
248     // Read inputs.
249     Tensor encoded_variant = context->input(0);
250     Tensor row_splits = context->input(1);
251     auto flat_row_splits = row_splits.flat<SPLIT_TYPE>();
252     TensorShape dense_values_shape;
253     OP_REQUIRES_OK(context,
254                    TensorShapeUtils::MakeShape(context->input(2).vec<int32>(),
255                                                &dense_values_shape));
256 
257     const auto& flat_variants = encoded_variant.flat<Variant>();
258 
259     // Get a Tensor containing the flat_values for each variant.
260     std::vector<Tensor> values;
261     for (int i = 0; i < flat_variants.size(); ++i) {
262       if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) {
263         values.push_back(encoded->values());
264       } else {
265         // Missing value: this happens if only some of the variant values
266         // generated by ragged_tensor_to_variant impacted the value that we're
267         // calculating the gradient for.  In this case, we will see a
268         // default-constructed variant; so treat it as a zero tensor with the
269         // appropriate shape.
270         const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
271         auto piece_size = flat_row_splits(i + 1) - flat_row_splits(i);
272         TensorShape zeros_shape = dense_values_shape;
273         zeros_shape.set_dim(0, piece_size);
274         Tensor zero(value_dtype, zeros_shape);
275         zero.flat<VALUE_TYPE>().setZero();
276         values.push_back(zero);
277       }
278     }
279 
280     if (values.size() == 1) {
281       // Just one flat_value tensor: return as-is.
282       context->set_output(0, values[0]);
283     } else {
284       Tensor* out = nullptr;
285       OP_REQUIRES_OK(context,
286                      context->allocate_output(0, dense_values_shape, &out));
287       // ConcatCPU assumes non-empty output.
288       if (dense_values_shape.num_elements() == 0) return;
289       // Multiple flat_values tensors: concatenate them together.
290       using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix;
291       using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix;
292       std::vector<std::unique_ptr<ConstPiece>> pieces;
293       pieces.reserve(values.size());
294       for (const Tensor& t : values) {
295         // ConcatCPU assumes non-empty inputs.
296         if (t.NumElements() == 0) continue;
297         pieces.emplace_back(
298             new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()})));
299       }
300       Piece out_flat =
301           out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()});
302       ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat);
303     }
304   }
305 };
306 
307 #define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type)            \
308   REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant")                     \
309                               .Device(DEVICE_CPU)                           \
310                               .TypeConstraint<value_type>("Tvalues")        \
311                               .TypeConstraint<split_type>("Tsplits"),       \
312                           RaggedTensorToVariantOp<value_type, split_type>); \
313   REGISTER_KERNEL_BUILDER(                                                  \
314       Name("RaggedTensorToVariantGradient")                                 \
315           .Device(DEVICE_CPU)                                               \
316           .TypeConstraint<value_type>("Tvalues")                            \
317           .TypeConstraint<split_type>("Tsplits"),                           \
318       RaggedTensorToVariantGradientOp<value_type, split_type>);
319 
320 #define REGISTER_KERNELS(value_type)                  \
321   REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
322   REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64_t)
323 TF_CALL_POD_TYPES(REGISTER_KERNELS);
324 TF_CALL_tstring(REGISTER_KERNELS);
325 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
326 TF_CALL_quint16(REGISTER_KERNELS);
327 TF_CALL_qint16(REGISTER_KERNELS);
328 #undef REGISTER_KERNELS
329 #undef REGISTER_KERNELS_WITH_SPLIT_TYPE
330 }  // namespace tensorflow
331