xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <utility>
16 #include <vector>
17 
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/variant.h"
22 #include "tensorflow/core/framework/variant_encode_decode.h"
23 #include "tensorflow/core/kernels/ragged_tensor_variant.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 
27 namespace tensorflow {
28 namespace {
29 
30 /* Extracts the components of the variant-encoded tensor `encoded_variant`
31  * into a flat vector of `RaggedTensorVariant` objects. */
RaggedComponentsFromVariant(const Tensor & encoded_variant,int input_ragged_rank,int output_ragged_rank,DataType value_dtype,DataType split_dtype,std::vector<RaggedTensorVariant> * decoded_ragged)32 Status RaggedComponentsFromVariant(
33     const Tensor& encoded_variant, int input_ragged_rank,
34     int output_ragged_rank, DataType value_dtype, DataType split_dtype,
35     std::vector<RaggedTensorVariant>* decoded_ragged) {
36   const auto& flat_variants = encoded_variant.flat<Variant>();
37   decoded_ragged->reserve(flat_variants.size());
38 
39   for (int i = 0; i < flat_variants.size(); i++) {
40     const auto& flat_variant = flat_variants(i);
41     const RaggedTensorVariant* decoded =
42         flat_variant.get<RaggedTensorVariant>();
43     if (decoded == nullptr) {
44       return errors::InvalidArgument(
45           "Input Variant element at index ", i,
46           " doesn't hold a RaggedTensorVariant: ", flat_variant.DebugString());
47     }
48     decoded_ragged->push_back(*decoded);
49     decoded = &decoded_ragged->back();
50     // Check ragged rank & types
51     if (decoded->ragged_rank() != input_ragged_rank) {
52       return errors::InvalidArgument(
53           "Encoded input RaggedTensorVariant has ragged_rank=",
54           decoded->ragged_rank(), ".  Expected ragged_rank=", input_ragged_rank,
55           ".");
56     }
57     if (decoded->values().dtype() != value_dtype) {
58       return errors::InvalidArgument(
59           "Expected values Tensor dtype: ", DataTypeString(value_dtype),
60           ", found: ", DataTypeString(decoded->values().dtype()));
61     }
62     if (decoded->values().dims() < 1 && output_ragged_rank != 0) {
63       return errors::InvalidArgument(
64           "Ragged values must have rank >= 1; encoded scalar element at index ",
65           i, " has values Tensor: ", decoded->values().DebugString());
66     }
67     for (const auto& splits : decoded->nested_splits()) {
68       if (splits.dtype() != split_dtype) {
69         return errors::InvalidArgument(
70             "Expected row_splits Tensor dtype: ", DataTypeString(split_dtype),
71             ", found: ", DataTypeString(splits.dtype()));
72       }
73       if (splits.dims() != 1) {
74         return errors::InvalidArgument(
75             "Ragged splits must have rank 1; encoded scalar element at index ",
76             i, " has splits Tensor ", splits.DebugString());
77       }
78     }
79   }
80   return OkStatus();
81 }
82 
83 /* Takes a set of RaggedTensorVariants for non-ragged tensors, stacks
84  * their flat_values, and sets output_ragged's flat_values to that stacked
85  * value.  I.e.:
86  *
87  * output_ragged.values = stack([c.values for c in ragged_components])
88  *
89  * Requires that elements of `ragged_components` have no splits.
90  *
91  * This should only be used when input_ragged_rank=0 and output_ragged_rank=0.
92  */
93 template <typename VALUE_TYPE>
StackNonRaggedTensors(const std::vector<RaggedTensorVariant> & ragged_components,RaggedTensorVariant * output_ragged)94 Status StackNonRaggedTensors(
95     const std::vector<RaggedTensorVariant>& ragged_components,
96     RaggedTensorVariant* output_ragged) {
97   if (ragged_components.empty()) {
98     output_ragged->set_values(Tensor(DataTypeToEnum<VALUE_TYPE>::value, {0}));
99     return Status::OK();
100   }
101 
102   TensorShape component_values_shape = ragged_components[0].values().shape();
103   TensorShape result_shape = component_values_shape;
104   result_shape.InsertDim(0, ragged_components.size());
105 
106   output_ragged->set_values(
107       Tensor(DataTypeToEnum<VALUE_TYPE>::value, result_shape));
108   auto output_values_flat = output_ragged->mutable_values()->flat<VALUE_TYPE>();
109   int values_index = 0;
110   for (int i = 0; i < ragged_components.size(); i++) {
111     auto& component_values = ragged_components[i].values();
112     if (component_values.shape() != component_values_shape) {
113       return errors::InvalidArgument(
114           "All flat_values must have compatible shapes.  Shape at index 0: ",
115           component_values_shape, ".  Shape at index ", i, ": ",
116           component_values.shape());
117     }
118     auto component_values_flat = component_values.flat<VALUE_TYPE>();
119     for (int j = 0; j < component_values_flat.size(); j++) {
120       output_values_flat(values_index++) = component_values_flat(j);
121     }
122   }
123   return Status::OK();
124 }
125 
126 template <typename VALUE_TYPE, typename SPLIT_TYPE>
NestedStackRaggedTensors(const std::vector<RaggedTensorVariant> & ragged_components,const std::vector<int> & nested_dim_sizes,const int input_ragged_rank,const int output_ragged_rank,RaggedTensorVariant * output_ragged)127 Status NestedStackRaggedTensors(
128     const std::vector<RaggedTensorVariant>& ragged_components,
129     const std::vector<int>& nested_dim_sizes, const int input_ragged_rank,
130     const int output_ragged_rank, RaggedTensorVariant* output_ragged) {
131   output_ragged->mutable_nested_splits()->reserve(output_ragged_rank);
132   const int dims = nested_dim_sizes.size();
133 
134   if (output_ragged_rank == 0) {
135     if (input_ragged_rank > 0) {
136       return errors::InvalidArgument(
137           "Expected input_ragged_rank=0 if output_ragged_rank==0.  "
138           "Got input_ragged_rank=",
139           input_ragged_rank);
140     }
141     return StackNonRaggedTensors<VALUE_TYPE>(ragged_components, output_ragged);
142   }
143 
144   // Populate first `dims - 1` splits.
145   for (int i = 0; i < dims - 1; i++) {
146     int dims_splits_size = nested_dim_sizes[i] + 1;
147     output_ragged->append_splits(Tensor(DataTypeToEnum<SPLIT_TYPE>::value,
148                                         TensorShape({dims_splits_size})));
149     auto splits_vec = output_ragged->mutable_splits(i)->vec<SPLIT_TYPE>();
150     int split_diff = nested_dim_sizes[i + 1];
151     for (int j = 0; j < dims_splits_size; j++) {
152       splits_vec(j) = j * split_diff;
153     }
154   }
155 
156   // Populate `dims`-th split.
157   int splits_size = ragged_components.size() + 1;
158   output_ragged->append_splits(
159       Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({splits_size})));
160   auto dims_splits_vec =
161       output_ragged->mutable_splits(dims - 1)->vec<SPLIT_TYPE>();
162   dims_splits_vec(0) = 0;
163   for (int i = 0; i < ragged_components.size(); i++) {
164     int split_val = ragged_components[i].values().shape().dim_size(0);
165     if (input_ragged_rank != 0 && ragged_components[i].ragged_rank() > 0) {
166       split_val = ragged_components[i].splits(0).NumElements() - 1;
167     }
168     dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val;
169   }
170 
171   // Populate last `input_ragged_rank` splits.
172   for (int i = 0; i < input_ragged_rank; i++) {
173     int split_index = dims + i;
174     int split_size = 1;
175     for (int j = 0; j < ragged_components.size(); j++) {
176       if (!ragged_components[j].nested_splits().empty()) {
177         split_size += ragged_components[j].splits(i).NumElements() - 1;
178       }
179     }
180     output_ragged->append_splits(
181         Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
182     auto splits_vec =
183         output_ragged->mutable_splits(split_index)->vec<SPLIT_TYPE>();
184     splits_vec(0) = 0;
185     SPLIT_TYPE last_split_value = 0;
186     int index = 1;
187     for (int j = 0; j < ragged_components.size(); j++) {
188       if (ragged_components[j].nested_splits().empty()) {
189         // Corner case: empty row. e.g [ [[x], [x]], [] ]
190         continue;
191       }
192       auto component_splits_vec =
193           ragged_components[j].splits(i).vec<SPLIT_TYPE>();
194       for (int k = 1; k < component_splits_vec.size(); k++, index++) {
195         splits_vec(index) = component_splits_vec(k) + last_split_value;
196       }
197       last_split_value = splits_vec(index - 1);
198     }
199   }
200 
201   // If the variant tensor input is empty, then we have no way to determine
202   // the correct shape for the dense_values.  (It must have rank>=1, and its
203   // outer dimension must be 0, but we don't know its shape beyond that.)
204   // For now, we just use a shape of `[0]` in this case.
205   // TODO(edloper): Update this op with an attribute containing information
206   // about dense_values shape.  If it's `None`, then we'll probably still have
207   // to use shape=[0] here, but if we have more info, then we can use it.
208   // E.g., in map_fn, we may have shape info from the RaggedTensorSpec.
209   TensorShape component_values_shape;
210   if (ragged_components.empty()) {
211     component_values_shape = TensorShape({0});
212   } else {
213     component_values_shape = ragged_components[0].values().shape();
214   }
215 
216   // Populate values.
217   int values_size = component_values_shape.dim_size(0);
218   for (int i = 1; i < ragged_components.size(); i++) {
219     if (ragged_components[i].values().dims() != component_values_shape.dims()) {
220       return errors::InvalidArgument(
221           "Rank of values must match for all "
222           "components; values shape at index 0: ",
223           component_values_shape.DebugString(), ", values shape at index ", i,
224           ": ", ragged_components[i].values().shape().DebugString());
225     }
226     values_size += ragged_components[i].values().shape().dim_size(0);
227   }
228   component_values_shape.set_dim(0, values_size);
229   output_ragged->set_values(
230       Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape));
231   auto output_values_flat =
232       output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>();
233   int values_index = 0;
234 
235   TensorShape expected_value_shape = component_values_shape;
236   expected_value_shape.RemoveDim(0);
237 
238   for (int i = 0; i < ragged_components.size(); i++) {
239     // Check that the flat_values tensor shape is compatible.
240     TensorShape value_shape = ragged_components[i].values().shape();
241     value_shape.RemoveDim(0);
242     if (value_shape != expected_value_shape) {
243       return errors::InvalidArgument(
244           "All flat_values must have compatible shapes.  Shape at index 0: ",
245           expected_value_shape, ".  Shape at index ", i, ": ", value_shape,
246           ".  If you are using tf.map_fn, then you may need to specify an "
247           "explicit fn_output_signature with appropriate ragged_rank, and/or "
248           "convert output tensors to RaggedTensors.");
249     }
250 
251     auto component_values_flat =
252         ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>();
253     int num_inner_elements = ragged_components[i].values().NumElements();
254     if (ragged_components[i].values().dim_size(0) > 0) {
255       num_inner_elements /= ragged_components[i].values().dim_size(0);
256     }
257     for (int j = 0; j < ragged_components[i].values().dim_size(0);
258          j++, values_index++) {
259       for (int k = 0; k < num_inner_elements; k++) {
260         output_values_flat(values_index, k) = component_values_flat(j, k);
261       }
262     }
263   }
264   return OkStatus();
265 }
266 }  // namespace
267 
268 template <typename VALUE_TYPE, typename SPLIT_TYPE>
269 class RaggedTensorFromVariantOp : public OpKernel {
270  public:
RaggedTensorFromVariantOp(OpKernelConstruction * context)271   explicit RaggedTensorFromVariantOp(OpKernelConstruction* context)
272       : OpKernel(context) {
273     OP_REQUIRES_OK(context, context->GetAttr("input_ragged_rank",
274                                              &input_ragged_rank_attr_));
275     OP_REQUIRES_OK(
276         context, context->GetAttr("output_ragged_rank", &output_ragged_rank_));
277   }
278 
Compute(OpKernelContext * context)279   void Compute(OpKernelContext* context) override {
280     // Read input Tensor.
281     const Tensor& encoded_variant = context->input(0);
282     auto input_ragged_rank_ = input_ragged_rank_attr_;
283 
284     if (input_ragged_rank_ == -1) {  // Infer input_ragged_rank_.
285       input_ragged_rank_ = output_ragged_rank_ - encoded_variant.dims();
286       if (output_ragged_rank_ == 0 && input_ragged_rank_ < 0) {
287         input_ragged_rank_ = 0;
288       }
289       OP_REQUIRES(context, input_ragged_rank_ >= 0,
290                   errors::InvalidArgument(
291                       "Inferred input_ragged_rank (output_ragged_rank - "
292                       "encoded_variant.dims()) must be >= 0, found "
293                       "output_ragged_rank: ",
294                       output_ragged_rank_,
295                       ", encoded_variant.dims(): ", encoded_variant.dims(),
296                       ", inferred input_ragged_rank: ", input_ragged_rank_));
297     }
298     OP_REQUIRES(
299         context,
300         (output_ragged_rank_ == 0 && input_ragged_rank_ == 0) ||
301             (output_ragged_rank_ ==
302              encoded_variant.dims() + input_ragged_rank_),
303         errors::InvalidArgument(
304             "output_ragged_rank must be equal to input_ragged_rank + "
305             "encoded_ragged.dims(); output_ragged_rank: ",
306             output_ragged_rank_, ", input_ragged_rank: ", input_ragged_rank_,
307             ", encoded_variant.dims(): ", encoded_variant.dims(), "."));
308 
309     // Decode all variants.
310     const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
311     const auto split_dtype = DataTypeToEnum<SPLIT_TYPE>::v();
312     std::vector<RaggedTensorVariant> decoded_components;
313     OP_REQUIRES_OK(context,
314                    RaggedComponentsFromVariant(
315                        encoded_variant, input_ragged_rank_, output_ragged_rank_,
316                        value_dtype, split_dtype, &decoded_components));
317 
318     // Corner case: input is a scalar.
319     if (encoded_variant.dims() == 0) {
320       ReturnRaggedTensor(context, decoded_components[0]);
321       return;
322     }
323 
324     // Nested-Stack Ragged components into a batched RaggedTensor.
325     std::vector<int> encoded_dim_sizes(encoded_variant.dims(), 0);
326     for (int i = 0; i < encoded_variant.dims(); i++) {
327       encoded_dim_sizes[i] = encoded_variant.dim_size(i);
328     }
329     RaggedTensorVariant output_ragged;
330     OP_REQUIRES_OK(
331         context, NestedStackRaggedTensors<VALUE_TYPE, SPLIT_TYPE>(
332                      decoded_components, encoded_dim_sizes, input_ragged_rank_,
333                      output_ragged_rank_, &output_ragged));
334 
335     // Set output.
336     ReturnRaggedTensor(context, output_ragged);
337   }
338 
339  private:
340   int input_ragged_rank_attr_;
341   int output_ragged_rank_;
342 
ReturnRaggedTensor(OpKernelContext * context,const RaggedTensorVariant & ragged_tensor)343   void ReturnRaggedTensor(OpKernelContext* context,
344                           const RaggedTensorVariant& ragged_tensor) {
345     int ragged_rank = ragged_tensor.ragged_rank();
346     OpOutputList splits_out;
347     OP_REQUIRES_OK(context,
348                    context->output_list("output_nested_splits", &splits_out));
349     for (int i = 0; i < ragged_rank; i++) {
350       splits_out.set(i, ragged_tensor.splits(i));
351     }
352     context->set_output(ragged_rank, ragged_tensor.values());
353   }
354 };
355 
356 #define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type)      \
357   REGISTER_KERNEL_BUILDER(Name("RaggedTensorFromVariant")             \
358                               .Device(DEVICE_CPU)                     \
359                               .TypeConstraint<value_type>("Tvalues")  \
360                               .TypeConstraint<split_type>("Tsplits"), \
361                           RaggedTensorFromVariantOp<value_type, split_type>);
362 #define REGISTER_KERNELS(value_type)                  \
363   REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
364   REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64_t)
365 TF_CALL_POD_TYPES(REGISTER_KERNELS);
366 TF_CALL_tstring(REGISTER_KERNELS);
367 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
368 TF_CALL_quint16(REGISTER_KERNELS);
369 TF_CALL_qint16(REGISTER_KERNELS);
370 #undef REGISTER_KERNELS
371 #undef REGISTER_KERNELS_WITH_SPLIT_TYPE
372 }  // namespace tensorflow
373