1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18 #include "tensorflow/core/util/ragged_to_dense_util.h"
19
20 namespace tensorflow {
21
22 using errors::InvalidArgument;
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26
27 namespace {
ValidateRowPartitionTypesAndShapes(const std::vector<RowPartitionType> & row_partition_types,InferenceContext * c)28 tensorflow::Status ValidateRowPartitionTypesAndShapes(
29 const std::vector<RowPartitionType>& row_partition_types,
30 InferenceContext* c) {
31 // Note: the allowed types may be extended in the future.
32 for (RowPartitionType row_partition_type : row_partition_types) {
33 switch (row_partition_type) {
34 case RowPartitionType::FIRST_DIM_SIZE:
35 case RowPartitionType::VALUE_ROWIDS:
36 case RowPartitionType::ROW_SPLITS:
37 break;
38 default:
39 return InvalidArgument("Unsupported partition type: ",
40 RowPartitionTypeToString(row_partition_type));
41 }
42 }
43
44 if (row_partition_types.empty()) {
45 return InvalidArgument("Partition info types should not be empty");
46 }
47 for (int i = 1; i < row_partition_types.size(); ++i) {
48 if (row_partition_types[i] == RowPartitionType::FIRST_DIM_SIZE) {
49 return InvalidArgument("FIRST_DIM_SIZE must be first");
50 }
51 }
52 if (row_partition_types[0] == RowPartitionType::FIRST_DIM_SIZE &&
53 (row_partition_types.size() < 2 ||
54 row_partition_types[1] != RowPartitionType::VALUE_ROWIDS)) {
55 return InvalidArgument("FIRST_DIM_SIZE must be followed by VALUE_ROWIDS");
56 }
57 if (row_partition_types[0] == RowPartitionType::VALUE_ROWIDS) {
58 return InvalidArgument("VALUE_ROWIDS cannot be first");
59 }
60
61 int num_row_partition_tensors;
62 TF_RETURN_IF_ERROR(
63 c->GetAttr("num_row_partition_tensors", &num_row_partition_tensors));
64 if (num_row_partition_tensors != row_partition_types.size()) {
65 return InvalidArgument(
66 "Number of row partition tensors (", num_row_partition_tensors,
67 ") does not equal the number of row partition types(",
68 row_partition_types.size(), ").");
69 }
70
71 for (int i = 0; i < num_row_partition_tensors; ++i) {
72 TensorShapeProto partition_shape;
73 c->ShapeHandleToProto(c->input(3 + i), &partition_shape);
74 if (partition_shape.unknown_rank()) {
75 continue;
76 }
77 if (row_partition_types[i] == RowPartitionType::FIRST_DIM_SIZE) {
78 if (partition_shape.dim_size() != 0) {
79 return InvalidArgument("FIRST_DIM_SIZE must be a scalar.");
80 }
81 } else {
82 if (partition_shape.dim_size() != 1) {
83 return InvalidArgument("Row partition must be a vector.");
84 }
85 }
86 }
87 return OkStatus();
88 }
89
90 } // namespace
91
92 Status RaggedTensorToSparseShapeFn(InferenceContext* c);
93 Status RaggedTensorToVariantShapeFn(InferenceContext* c);
94 Status RaggedTensorFromVariantShapeFn(InferenceContext* c);
95 Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c);
96 Status RaggedTensorToTensorShapeFn(InferenceContext* c);
97
98 //==============================================================================
99 // Registered Ops
100 //==============================================================================
101
102 REGISTER_OP("RaggedTensorToSparse")
103 .Input("rt_nested_splits: RAGGED_RANK * Tsplits")
104 .Input("rt_dense_values: T")
105 .Output("sparse_indices: int64")
106 .Output("sparse_values: T")
107 .Output("sparse_dense_shape: int64")
108 .Attr("RAGGED_RANK: int >= 1")
109 .Attr("T: type")
110 .Attr("Tsplits: {int32, int64} = DT_INT64")
111 .SetShapeFn(RaggedTensorToSparseShapeFn);
112
113 REGISTER_OP("RaggedTensorToVariant")
114 .Input("rt_nested_splits: RAGGED_RANK * Tsplits")
115 .Input("rt_dense_values: Tvalues")
116 .Output("encoded_ragged: variant")
117 .Attr("RAGGED_RANK: int >= 0")
118 .Attr("Tvalues: type")
119 .Attr("Tsplits: {int32, int64} = DT_INT64")
120 .Attr("batched_input: bool")
121 .SetTypeConstructor(full_type::Unary(TFT_RAGGED, "Tvalues"))
122 .SetShapeFn(RaggedTensorToVariantShapeFn);
123
124 REGISTER_OP("RaggedTensorFromVariant")
125 .Input("encoded_ragged: variant")
126 .Output("output_nested_splits: output_ragged_rank * Tsplits")
127 .Output("output_dense_values: Tvalues")
128 .Attr("input_ragged_rank: int >= -1")
129 .Attr("output_ragged_rank: int >= 0")
130 .Attr("Tvalues: type")
131 .Attr("Tsplits: {int32, int64} = DT_INT64")
132 .SetShapeFn(RaggedTensorFromVariantShapeFn);
133
134 REGISTER_OP("RaggedTensorToVariantGradient")
135 .Input("encoded_ragged_grad: variant")
136 .Input("row_splits: Tsplits")
137 .Input("dense_values_shape: int32")
138 .Output("dense_values_grad: Tvalues")
139 .Attr("Tvalues: type")
140 .Attr("Tsplits: {int32, int64} = DT_INT64")
141 .SetShapeFn(RaggedTensorToVariantGradientShapeFn);
142
143 REGISTER_OP("RaggedTensorToTensor")
144 .Attr("T: type")
145 .Attr("Tindex: {int64, int32}")
146 .Attr("Tshape: {int64, int32}")
147 .Attr("num_row_partition_tensors: int")
148 .Attr("row_partition_types: list(string)")
149 .Input("shape: Tshape")
150 .Input("values: T")
151 .Input("default_value: T")
152 .Input("row_partition_tensors: num_row_partition_tensors * Tindex")
153 .Output("result: T")
154 .SetShapeFn(RaggedTensorToTensorShapeFn);
155
156 //==============================================================================
157 // Shape Functions
158 //==============================================================================
159
RaggedTensorToSparseShapeFn(InferenceContext * c)160 Status RaggedTensorToSparseShapeFn(InferenceContext* c) {
161 int64_t num_splits;
162 TF_RETURN_IF_ERROR(c->GetAttr<int64_t>("RAGGED_RANK", &num_splits));
163 // TODO(b/112274756): Allow ragged_rank to be 0.
164 if (num_splits < 1) {
165 return errors::InvalidArgument("Requires RAGGED_RANK>0");
166 }
167 ShapeHandle rt_dense_values = c->input(num_splits);
168 TF_RETURN_IF_ERROR(c->WithRankAtLeast(rt_dense_values, 1, &rt_dense_values));
169
170 // Check that all rt_nested_splits have rank 1.
171 for (int64_t i = 0; i < num_splits; ++i) {
172 ShapeHandle splits = c->input(i);
173 TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
174 }
175
176 DimensionHandle dense_dims =
177 c->RankKnown(rt_dense_values)
178 ? c->MakeDim(c->Rank(rt_dense_values) + num_splits)
179 : c->UnknownDim();
180 DimensionHandle num_values = c->NumElements(rt_dense_values);
181
182 c->set_output(0, c->Matrix(num_values, dense_dims)); // indices
183 c->set_output(1, c->Vector(num_values)); // values
184 c->set_output(2, c->Vector(dense_dims)); // dense_shape
185
186 return OkStatus();
187 }
188
RaggedTensorToVariantShapeFn(InferenceContext * c)189 Status RaggedTensorToVariantShapeFn(InferenceContext* c) {
190 int64_t num_splits;
191 TF_RETURN_IF_ERROR(c->GetAttr<int64_t>("RAGGED_RANK", &num_splits));
192 bool batched;
193 TF_RETURN_IF_ERROR(c->GetAttr<bool>("batched_input", &batched));
194 shape_inference::ShapeHandle rt_dense_values = c->input(num_splits);
195 TF_RETURN_IF_ERROR(c->WithRankAtLeast(rt_dense_values, 1, &rt_dense_values));
196 for (int64_t i = 0; i < num_splits; ++i) {
197 shape_inference::ShapeHandle splits = c->input(i);
198 TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
199 }
200 if (batched) {
201 auto num_first_splits = c->Dim(c->input(0), 0);
202 shape_inference::DimensionHandle num_rows;
203 TF_RETURN_IF_ERROR(c->Subtract(num_first_splits, 1, &num_rows));
204 c->set_output(0, c->Vector(num_rows));
205 } else {
206 c->set_output(0, c->Scalar());
207 }
208 return OkStatus();
209 }
210
RaggedTensorToVariantGradientShapeFn(InferenceContext * c)211 Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) {
212 ShapeHandle shape;
213 TF_RETURN_IF_ERROR(
214 c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape));
215 c->set_output(0, shape);
216 return OkStatus();
217 }
218
RaggedTensorFromVariantShapeFn(InferenceContext * c)219 Status RaggedTensorFromVariantShapeFn(InferenceContext* c) {
220 int64_t input_ragged_rank;
221 TF_RETURN_IF_ERROR(
222 c->GetAttr<int64_t>("input_ragged_rank", &input_ragged_rank));
223 int64_t output_ragged_rank;
224 TF_RETURN_IF_ERROR(
225 c->GetAttr<int64_t>("output_ragged_rank", &output_ragged_rank));
226 shape_inference::ShapeHandle encoded_ragged = c->input(0);
227 if (c->RankKnown(encoded_ragged) && input_ragged_rank >= 0) {
228 shape_inference::ShapeHandle unused;
229 TF_RETURN_IF_ERROR(c->WithRank(
230 encoded_ragged, output_ragged_rank - input_ragged_rank, &unused));
231 }
232 for (int64_t i = 0; i < output_ragged_rank; i++) {
233 c->set_output(i, c->UnknownShapeOfRank(1));
234 }
235 c->set_output(output_ragged_rank, c->UnknownShape());
236 return OkStatus();
237 }
238
RaggedTensorToTensorShapeFn(InferenceContext * c)239 tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c) {
240 TensorShapeProto shape;
241 {
242 ShapeHandle shape_handle;
243 TF_RETURN_IF_ERROR(
244 c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &shape_handle));
245 c->ShapeHandleToProto(shape_handle, &shape);
246 }
247
248 std::vector<RowPartitionType> row_partition_types;
249 TF_RETURN_IF_ERROR(GetRowPartitionTypes(c, &row_partition_types));
250 int ragged_rank = GetRaggedRank(row_partition_types);
251 TF_RETURN_IF_ERROR(
252 ValidateRowPartitionTypesAndShapes(row_partition_types, c));
253
254 TensorShapeProto value_shape;
255 c->ShapeHandleToProto(c->input(1), &value_shape);
256
257 TensorShapeProto default_value_shape;
258 c->ShapeHandleToProto(c->input(2), &default_value_shape);
259
260 TF_RETURN_IF_ERROR(
261 ValidateDefaultValueShape(default_value_shape, value_shape));
262
263 // TODO(martinz): Theoretically, we could check the first dimension of
264 // value_shape against the first dimension of the last row_partition_tensor
265 // assuming it is a VALUE_ROWIDS type.
266 // TODO(martinz): Although we normally don't know the first dimension of the
267 // output, we could infer it from the first dimension of the first
268 // row_partition_tensor if it is ROW_SPLITS type.
269 // TODO(martinz): If the shape is provided, but the value_shape has missing
270 // dimensions, we can check the default_value_shape against the shape.
271 TensorShapeProto output_shape;
272 TF_RETURN_IF_ERROR(CombineRaggedTensorToTensorShapes(
273 ragged_rank, shape, value_shape, &output_shape));
274
275 ShapeHandle output_shape_handle;
276 TF_RETURN_IF_ERROR(
277 c->MakeShapeFromShapeProto(output_shape, &output_shape_handle));
278 c->set_output(0, output_shape_handle);
279 return OkStatus();
280 }
281
282 } // namespace tensorflow
283