1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdint>
16 #include <utility>
17 #include <vector>
18
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/variant.h"
24 #include "tensorflow/core/framework/variant_encode_decode.h"
25 #include "tensorflow/core/framework/variant_op_registry.h"
26 #include "tensorflow/core/kernels/concat_lib.h"
27 #include "tensorflow/core/kernels/ragged_tensor_variant.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/core/util/tensor_ops_util.h"
32
33 namespace tensorflow {
34 namespace {
35
36 template <typename VALUE_TYPE>
UnbatchDenseZerothDim(const RaggedTensorVariant & batched_ragged,std::vector<RaggedTensorVariant> * ragged_components)37 Status UnbatchDenseZerothDim(
38 const RaggedTensorVariant& batched_ragged,
39 std::vector<RaggedTensorVariant>* ragged_components) {
40 Tensor batched_values = batched_ragged.values();
41 TensorShape values_shape = batched_values.shape();
42 if (values_shape.dims() < 1) {
43 return errors::InvalidArgument("Can't unbatch rank-0 tensor.");
44 }
45 auto num_components = values_shape.dim_size(0);
46 values_shape.RemoveDim(0);
47 auto num_values = values_shape.num_elements();
48
49 ragged_components->resize(num_components);
50 const auto& batched_flat = batched_values.flat<VALUE_TYPE>();
51
52 for (auto i = decltype(num_components){}; i < num_components; i++) {
53 (*ragged_components)[i].set_values(
54 Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
55 auto ragged_component_values_flat =
56 (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
57 for (auto j = decltype(num_values){}; j < num_values; j++) {
58 ragged_component_values_flat(j) = batched_flat(j + i * num_values);
59 }
60 }
61
62 return OkStatus();
63 }
64
65 template <typename VALUE_TYPE, typename SPLIT_TYPE>
UnbatchRaggedZerothDim(const RaggedTensorVariant & batched_ragged,std::vector<RaggedTensorVariant> * ragged_components)66 Status UnbatchRaggedZerothDim(
67 const RaggedTensorVariant& batched_ragged,
68 std::vector<RaggedTensorVariant>* ragged_components) {
69 // Set up the component Ragged Tensors.
70 int ragged_rank = batched_ragged.ragged_rank();
71 if (ragged_rank == 0) {
72 return UnbatchDenseZerothDim<VALUE_TYPE>(batched_ragged, ragged_components);
73 }
74
75 auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>();
76 auto num_components = batched_splits_top_vec.size() - 1;
77
78 if (num_components < 0) {
79 return errors::Internal("Invalid split argument.");
80 }
81
82 int num_splits = ragged_rank - 1;
83 ragged_components->resize(num_components);
84 for (RaggedTensorVariant& ragged_component : *ragged_components) {
85 ragged_component.mutable_nested_splits()->reserve(num_splits);
86 }
87 const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>();
88 auto num_inner_elems = batched_ragged.values().NumElements();
89 if (batched_ragged.values().dim_size(0) > 1) {
90 num_inner_elems /= batched_ragged.values().dim_size(0);
91 }
92 TensorShape values_shape = batched_ragged.values().shape();
93
94 // Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]]
95 if (num_splits == 0) {
96 for (auto i = decltype(num_components){}; i < num_components; i++) {
97 auto start = batched_splits_top_vec(i);
98 auto limit = batched_splits_top_vec(i + 1);
99 auto num_values = limit - start;
100 values_shape.set_dim(0, num_values);
101 (*ragged_components)[i].set_values(
102 Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
103 auto ragged_component_values_flat =
104 (*ragged_components)[i].mutable_values()->template flat<VALUE_TYPE>();
105 for (auto j = decltype(num_values * num_inner_elems){};
106 j < num_values * num_inner_elems; j++) {
107 ragged_component_values_flat(j) =
108 batched_flat(j + start * num_inner_elems);
109 }
110 }
111 return OkStatus();
112 }
113
114 // Unbatch nested splits.
115 std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec;
116 batched_splits_vec.reserve(ragged_rank);
117 for (int i = 0; i < ragged_rank; i++) {
118 batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>());
119 }
120 std::vector<SPLIT_TYPE> index(num_splits, 1);
121 std::vector<SPLIT_TYPE> ragged_component_values_size(num_components, 0);
122 for (auto i = decltype(num_components){}; i < num_components; i++) {
123 std::vector<typename TTypes<SPLIT_TYPE>::Vec> ragged_component_splits_vec;
124 ragged_component_splits_vec.reserve(num_splits);
125 SPLIT_TYPE split_size = -1;
126 for (int j = 0; j < num_splits; j++) {
127 if (j == 0) {
128 split_size =
129 batched_splits_top_vec(i + 1) - batched_splits_top_vec(i) + 1;
130 } else {
131 // Update split size based on previous split.
132 SPLIT_TYPE last_index = ragged_component_splits_vec[j - 1].size() - 1;
133 split_size = ragged_component_splits_vec[j - 1](last_index) + 1;
134 }
135 (*ragged_components)[i].append_splits(
136 Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
137 ragged_component_splits_vec.push_back((*ragged_components)[i]
138 .mutable_splits(j)
139 ->template vec<SPLIT_TYPE>());
140 SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1);
141 ragged_component_splits_vec[j](0) = 0;
142 for (SPLIT_TYPE k = 1; k < split_size; k++, index[j]++) {
143 ragged_component_splits_vec[j](k) =
144 batched_splits_vec[j + 1](index[j]) - last_split_value;
145 }
146 }
147 SPLIT_TYPE last_split_size =
148 ragged_component_splits_vec[num_splits - 1].size();
149 ragged_component_values_size[i] =
150 ragged_component_splits_vec[num_splits - 1](last_split_size - 1);
151 }
152
153 // Unbatch values.
154 int64_t value_index = 0;
155 for (auto i = decltype(num_components){}; i < num_components; i++) {
156 SPLIT_TYPE num_values = ragged_component_values_size[i];
157 values_shape.set_dim(0, num_values);
158 (*ragged_components)[i].set_values(
159 Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
160 auto ragged_component_values_flat =
161 (*ragged_components)[i].mutable_values()->template flat<VALUE_TYPE>();
162 for (int64_t j = 0; j < num_values * num_inner_elems; j++, value_index++) {
163 ragged_component_values_flat(j) = batched_flat(value_index);
164 }
165 }
166
167 return OkStatus();
168 }
169 } // namespace
170
171 template <typename VALUE_TYPE, typename SPLIT_TYPE>
172 class RaggedTensorToVariantOp : public OpKernel {
173 public:
RaggedTensorToVariantOp(OpKernelConstruction * context)174 explicit RaggedTensorToVariantOp(OpKernelConstruction* context)
175 : OpKernel(context) {
176 OP_REQUIRES_OK(context, context->GetAttr("batched_input", &batched_input_));
177 }
178
Compute(OpKernelContext * context)179 void Compute(OpKernelContext* context) override {
180 // Read ragged_splits inputs.
181 OpInputList ragged_nested_splits_in;
182 OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
183 &ragged_nested_splits_in));
184 const int ragged_nested_splits_len = ragged_nested_splits_in.size();
185 RaggedTensorVariant batched_ragged_input;
186 // Read ragged_values input.
187 batched_ragged_input.set_values(context->input(ragged_nested_splits_len));
188 batched_ragged_input.mutable_nested_splits()->reserve(
189 ragged_nested_splits_len);
190 for (int i = 0; i < ragged_nested_splits_len; i++) {
191 OP_REQUIRES(context, ragged_nested_splits_in[i].dims() == 1,
192 errors::InvalidArgument("Requires nested_row_splits[", i, "]",
193 " to be rank 1 but is rank ",
194 ragged_nested_splits_in[i].dims()));
195 batched_ragged_input.append_splits(ragged_nested_splits_in[i]);
196 }
197
198 if (!batched_input_) {
199 // Encode as a Scalar Variant Tensor.
200 Tensor* encoded_scalar;
201 OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
202 &encoded_scalar));
203 encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input);
204 return;
205 }
206
207 // Checked here instead of at input in case batched_input_ is false
208 OP_REQUIRES(context, ragged_nested_splits_len > 0,
209 errors::InvalidArgument(
210 "rt_nested_splits must be a list of one or more, but "
211 "received rt_nested_splits of length 0."));
212
213 // Unbatch the Ragged Tensor and encode the components.
214 std::vector<RaggedTensorVariant> unbatched_ragged_input;
215 auto batched_splits_top_vec =
216 batched_ragged_input.splits(0).vec<SPLIT_TYPE>();
217 int num_components = batched_splits_top_vec.size() - 1;
218 OP_REQUIRES(context, num_components >= 0,
219 errors::Internal("Invalid split argument."));
220 OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>(
221 batched_ragged_input, &unbatched_ragged_input));
222
223 // Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor.
224 Tensor* encoded_vector;
225
226 // output_size will be used for calling TensorShape(int64_t ...). We
227 // cannot use `auto` type here, or there will be a narrowing error.
228 int64_t output_size = unbatched_ragged_input.size();
229 OP_REQUIRES_OK(context,
230 context->allocate_output(0, TensorShape({output_size}),
231 &encoded_vector));
232 auto encoded_vector_t = encoded_vector->vec<Variant>();
233 for (auto i = decltype(output_size){}; i < output_size; i++) {
234 encoded_vector_t(i) = unbatched_ragged_input[i];
235 }
236 }
237
238 private:
239 bool batched_input_;
240 };
241
242 template <typename VALUE_TYPE, typename SPLIT_TYPE>
243 class RaggedTensorToVariantGradientOp : public OpKernel {
244 public:
245 using OpKernel::OpKernel;
246
Compute(OpKernelContext * context)247 void Compute(OpKernelContext* context) override {
248 // Read inputs.
249 Tensor encoded_variant = context->input(0);
250 Tensor row_splits = context->input(1);
251 auto flat_row_splits = row_splits.flat<SPLIT_TYPE>();
252 TensorShape dense_values_shape;
253 OP_REQUIRES_OK(context,
254 TensorShapeUtils::MakeShape(context->input(2).vec<int32>(),
255 &dense_values_shape));
256
257 const auto& flat_variants = encoded_variant.flat<Variant>();
258
259 // Get a Tensor containing the flat_values for each variant.
260 std::vector<Tensor> values;
261 for (int i = 0; i < flat_variants.size(); ++i) {
262 if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) {
263 values.push_back(encoded->values());
264 } else {
265 // Missing value: this happens if only some of the variant values
266 // generated by ragged_tensor_to_variant impacted the value that we're
267 // calculating the gradient for. In this case, we will see a
268 // default-constructed variant; so treat it as a zero tensor with the
269 // appropriate shape.
270 const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
271 auto piece_size = flat_row_splits(i + 1) - flat_row_splits(i);
272 TensorShape zeros_shape = dense_values_shape;
273 zeros_shape.set_dim(0, piece_size);
274 Tensor zero(value_dtype, zeros_shape);
275 zero.flat<VALUE_TYPE>().setZero();
276 values.push_back(zero);
277 }
278 }
279
280 if (values.size() == 1) {
281 // Just one flat_value tensor: return as-is.
282 context->set_output(0, values[0]);
283 } else {
284 Tensor* out = nullptr;
285 OP_REQUIRES_OK(context,
286 context->allocate_output(0, dense_values_shape, &out));
287 // ConcatCPU assumes non-empty output.
288 if (dense_values_shape.num_elements() == 0) return;
289 // Multiple flat_values tensors: concatenate them together.
290 using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix;
291 using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix;
292 std::vector<std::unique_ptr<ConstPiece>> pieces;
293 pieces.reserve(values.size());
294 for (const Tensor& t : values) {
295 // ConcatCPU assumes non-empty inputs.
296 if (t.NumElements() == 0) continue;
297 pieces.emplace_back(
298 new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()})));
299 }
300 Piece out_flat =
301 out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()});
302 ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat);
303 }
304 }
305 };
306
307 #define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
308 REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \
309 .Device(DEVICE_CPU) \
310 .TypeConstraint<value_type>("Tvalues") \
311 .TypeConstraint<split_type>("Tsplits"), \
312 RaggedTensorToVariantOp<value_type, split_type>); \
313 REGISTER_KERNEL_BUILDER( \
314 Name("RaggedTensorToVariantGradient") \
315 .Device(DEVICE_CPU) \
316 .TypeConstraint<value_type>("Tvalues") \
317 .TypeConstraint<split_type>("Tsplits"), \
318 RaggedTensorToVariantGradientOp<value_type, split_type>);
319
320 #define REGISTER_KERNELS(value_type) \
321 REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
322 REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64_t)
323 TF_CALL_POD_TYPES(REGISTER_KERNELS);
324 TF_CALL_tstring(REGISTER_KERNELS);
325 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
326 TF_CALL_quint16(REGISTER_KERNELS);
327 TF_CALL_qint16(REGISTER_KERNELS);
328 #undef REGISTER_KERNELS
329 #undef REGISTER_KERNELS_WITH_SPLIT_TYPE
330 } // namespace tensorflow
331