xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/ragged_conversion_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18 #include "tensorflow/core/util/ragged_to_dense_util.h"
19 
20 namespace tensorflow {
21 
22 using errors::InvalidArgument;
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26 
27 namespace {
ValidateRowPartitionTypesAndShapes(const std::vector<RowPartitionType> & row_partition_types,InferenceContext * c)28 tensorflow::Status ValidateRowPartitionTypesAndShapes(
29     const std::vector<RowPartitionType>& row_partition_types,
30     InferenceContext* c) {
31   // Note: the allowed types may be extended in the future.
32   for (RowPartitionType row_partition_type : row_partition_types) {
33     switch (row_partition_type) {
34       case RowPartitionType::FIRST_DIM_SIZE:
35       case RowPartitionType::VALUE_ROWIDS:
36       case RowPartitionType::ROW_SPLITS:
37         break;
38       default:
39         return InvalidArgument("Unsupported partition type: ",
40                                RowPartitionTypeToString(row_partition_type));
41     }
42   }
43 
44   if (row_partition_types.empty()) {
45     return InvalidArgument("Partition info types should not be empty");
46   }
47   for (int i = 1; i < row_partition_types.size(); ++i) {
48     if (row_partition_types[i] == RowPartitionType::FIRST_DIM_SIZE) {
49       return InvalidArgument("FIRST_DIM_SIZE must be first");
50     }
51   }
52   if (row_partition_types[0] == RowPartitionType::FIRST_DIM_SIZE &&
53       (row_partition_types.size() < 2 ||
54        row_partition_types[1] != RowPartitionType::VALUE_ROWIDS)) {
55     return InvalidArgument("FIRST_DIM_SIZE must be followed by VALUE_ROWIDS");
56   }
57   if (row_partition_types[0] == RowPartitionType::VALUE_ROWIDS) {
58     return InvalidArgument("VALUE_ROWIDS cannot be first");
59   }
60 
61   int num_row_partition_tensors;
62   TF_RETURN_IF_ERROR(
63       c->GetAttr("num_row_partition_tensors", &num_row_partition_tensors));
64   if (num_row_partition_tensors != row_partition_types.size()) {
65     return InvalidArgument(
66         "Number of row partition tensors (", num_row_partition_tensors,
67         ") does not equal the number of row partition types(",
68         row_partition_types.size(), ").");
69   }
70 
71   for (int i = 0; i < num_row_partition_tensors; ++i) {
72     TensorShapeProto partition_shape;
73     c->ShapeHandleToProto(c->input(3 + i), &partition_shape);
74     if (partition_shape.unknown_rank()) {
75       continue;
76     }
77     if (row_partition_types[i] == RowPartitionType::FIRST_DIM_SIZE) {
78       if (partition_shape.dim_size() != 0) {
79         return InvalidArgument("FIRST_DIM_SIZE must be a scalar.");
80       }
81     } else {
82       if (partition_shape.dim_size() != 1) {
83         return InvalidArgument("Row partition must be a vector.");
84       }
85     }
86   }
87   return OkStatus();
88 }
89 
90 }  // namespace
91 
92 Status RaggedTensorToSparseShapeFn(InferenceContext* c);
93 Status RaggedTensorToVariantShapeFn(InferenceContext* c);
94 Status RaggedTensorFromVariantShapeFn(InferenceContext* c);
95 Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c);
96 Status RaggedTensorToTensorShapeFn(InferenceContext* c);
97 
98 //==============================================================================
99 // Registered Ops
100 //==============================================================================
101 
102 REGISTER_OP("RaggedTensorToSparse")
103     .Input("rt_nested_splits: RAGGED_RANK * Tsplits")
104     .Input("rt_dense_values: T")
105     .Output("sparse_indices: int64")
106     .Output("sparse_values: T")
107     .Output("sparse_dense_shape: int64")
108     .Attr("RAGGED_RANK: int >= 1")
109     .Attr("T: type")
110     .Attr("Tsplits: {int32, int64} = DT_INT64")
111     .SetShapeFn(RaggedTensorToSparseShapeFn);
112 
113 REGISTER_OP("RaggedTensorToVariant")
114     .Input("rt_nested_splits: RAGGED_RANK * Tsplits")
115     .Input("rt_dense_values: Tvalues")
116     .Output("encoded_ragged: variant")
117     .Attr("RAGGED_RANK: int >= 0")
118     .Attr("Tvalues: type")
119     .Attr("Tsplits: {int32, int64} = DT_INT64")
120     .Attr("batched_input: bool")
121     .SetTypeConstructor(full_type::Unary(TFT_RAGGED, "Tvalues"))
122     .SetShapeFn(RaggedTensorToVariantShapeFn);
123 
124 REGISTER_OP("RaggedTensorFromVariant")
125     .Input("encoded_ragged: variant")
126     .Output("output_nested_splits: output_ragged_rank * Tsplits")
127     .Output("output_dense_values: Tvalues")
128     .Attr("input_ragged_rank: int >= -1")
129     .Attr("output_ragged_rank: int >= 0")
130     .Attr("Tvalues: type")
131     .Attr("Tsplits: {int32, int64} = DT_INT64")
132     .SetShapeFn(RaggedTensorFromVariantShapeFn);
133 
134 REGISTER_OP("RaggedTensorToVariantGradient")
135     .Input("encoded_ragged_grad: variant")
136     .Input("row_splits: Tsplits")
137     .Input("dense_values_shape: int32")
138     .Output("dense_values_grad: Tvalues")
139     .Attr("Tvalues: type")
140     .Attr("Tsplits: {int32, int64} = DT_INT64")
141     .SetShapeFn(RaggedTensorToVariantGradientShapeFn);
142 
143 REGISTER_OP("RaggedTensorToTensor")
144     .Attr("T: type")
145     .Attr("Tindex: {int64, int32}")
146     .Attr("Tshape: {int64, int32}")
147     .Attr("num_row_partition_tensors: int")
148     .Attr("row_partition_types: list(string)")
149     .Input("shape: Tshape")
150     .Input("values: T")
151     .Input("default_value: T")
152     .Input("row_partition_tensors: num_row_partition_tensors * Tindex")
153     .Output("result: T")
154     .SetShapeFn(RaggedTensorToTensorShapeFn);
155 
156 //==============================================================================
157 // Shape Functions
158 //==============================================================================
159 
RaggedTensorToSparseShapeFn(InferenceContext * c)160 Status RaggedTensorToSparseShapeFn(InferenceContext* c) {
161   int64_t num_splits;
162   TF_RETURN_IF_ERROR(c->GetAttr<int64_t>("RAGGED_RANK", &num_splits));
163   // TODO(b/112274756): Allow ragged_rank to be 0.
164   if (num_splits < 1) {
165     return errors::InvalidArgument("Requires RAGGED_RANK>0");
166   }
167   ShapeHandle rt_dense_values = c->input(num_splits);
168   TF_RETURN_IF_ERROR(c->WithRankAtLeast(rt_dense_values, 1, &rt_dense_values));
169 
170   // Check that all rt_nested_splits have rank 1.
171   for (int64_t i = 0; i < num_splits; ++i) {
172     ShapeHandle splits = c->input(i);
173     TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
174   }
175 
176   DimensionHandle dense_dims =
177       c->RankKnown(rt_dense_values)
178           ? c->MakeDim(c->Rank(rt_dense_values) + num_splits)
179           : c->UnknownDim();
180   DimensionHandle num_values = c->NumElements(rt_dense_values);
181 
182   c->set_output(0, c->Matrix(num_values, dense_dims));  // indices
183   c->set_output(1, c->Vector(num_values));              // values
184   c->set_output(2, c->Vector(dense_dims));              // dense_shape
185 
186   return OkStatus();
187 }
188 
RaggedTensorToVariantShapeFn(InferenceContext * c)189 Status RaggedTensorToVariantShapeFn(InferenceContext* c) {
190   int64_t num_splits;
191   TF_RETURN_IF_ERROR(c->GetAttr<int64_t>("RAGGED_RANK", &num_splits));
192   bool batched;
193   TF_RETURN_IF_ERROR(c->GetAttr<bool>("batched_input", &batched));
194   shape_inference::ShapeHandle rt_dense_values = c->input(num_splits);
195   TF_RETURN_IF_ERROR(c->WithRankAtLeast(rt_dense_values, 1, &rt_dense_values));
196   for (int64_t i = 0; i < num_splits; ++i) {
197     shape_inference::ShapeHandle splits = c->input(i);
198     TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
199   }
200   if (batched) {
201     auto num_first_splits = c->Dim(c->input(0), 0);
202     shape_inference::DimensionHandle num_rows;
203     TF_RETURN_IF_ERROR(c->Subtract(num_first_splits, 1, &num_rows));
204     c->set_output(0, c->Vector(num_rows));
205   } else {
206     c->set_output(0, c->Scalar());
207   }
208   return OkStatus();
209 }
210 
RaggedTensorToVariantGradientShapeFn(InferenceContext * c)211 Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) {
212   ShapeHandle shape;
213   TF_RETURN_IF_ERROR(
214       c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape));
215   c->set_output(0, shape);
216   return OkStatus();
217 }
218 
RaggedTensorFromVariantShapeFn(InferenceContext * c)219 Status RaggedTensorFromVariantShapeFn(InferenceContext* c) {
220   int64_t input_ragged_rank;
221   TF_RETURN_IF_ERROR(
222       c->GetAttr<int64_t>("input_ragged_rank", &input_ragged_rank));
223   int64_t output_ragged_rank;
224   TF_RETURN_IF_ERROR(
225       c->GetAttr<int64_t>("output_ragged_rank", &output_ragged_rank));
226   shape_inference::ShapeHandle encoded_ragged = c->input(0);
227   if (c->RankKnown(encoded_ragged) && input_ragged_rank >= 0) {
228     shape_inference::ShapeHandle unused;
229     TF_RETURN_IF_ERROR(c->WithRank(
230         encoded_ragged, output_ragged_rank - input_ragged_rank, &unused));
231   }
232   for (int64_t i = 0; i < output_ragged_rank; i++) {
233     c->set_output(i, c->UnknownShapeOfRank(1));
234   }
235   c->set_output(output_ragged_rank, c->UnknownShape());
236   return OkStatus();
237 }
238 
RaggedTensorToTensorShapeFn(InferenceContext * c)239 tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c) {
240   TensorShapeProto shape;
241   {
242     ShapeHandle shape_handle;
243     TF_RETURN_IF_ERROR(
244         c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &shape_handle));
245     c->ShapeHandleToProto(shape_handle, &shape);
246   }
247 
248   std::vector<RowPartitionType> row_partition_types;
249   TF_RETURN_IF_ERROR(GetRowPartitionTypes(c, &row_partition_types));
250   int ragged_rank = GetRaggedRank(row_partition_types);
251   TF_RETURN_IF_ERROR(
252       ValidateRowPartitionTypesAndShapes(row_partition_types, c));
253 
254   TensorShapeProto value_shape;
255   c->ShapeHandleToProto(c->input(1), &value_shape);
256 
257   TensorShapeProto default_value_shape;
258   c->ShapeHandleToProto(c->input(2), &default_value_shape);
259 
260   TF_RETURN_IF_ERROR(
261       ValidateDefaultValueShape(default_value_shape, value_shape));
262 
263   // TODO(martinz): Theoretically, we could check the first dimension of
264   // value_shape against the first dimension of the last row_partition_tensor
265   // assuming it is a VALUE_ROWIDS type.
266   // TODO(martinz): Although we normally don't know the first dimension of the
267   // output, we could infer it from the first dimension of the first
268   // row_partition_tensor if it is ROW_SPLITS type.
269   // TODO(martinz): If the shape is provided, but the value_shape has missing
270   // dimensions, we can check the default_value_shape against the shape.
271   TensorShapeProto output_shape;
272   TF_RETURN_IF_ERROR(CombineRaggedTensorToTensorShapes(
273       ragged_rank, shape, value_shape, &output_shape));
274 
275   ShapeHandle output_shape_handle;
276   TF_RETURN_IF_ERROR(
277       c->MakeShapeFromShapeProto(output_shape, &output_shape_handle));
278   c->set_output(0, output_shape_handle);
279   return OkStatus();
280 }
281 
282 }  // namespace tensorflow
283