xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/ragged_array_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18 
19 namespace tensorflow {
20 
21 using shape_inference::DimensionHandle;
22 using shape_inference::InferenceContext;
23 using shape_inference::ShapeHandle;
24 
25 Status RaggedGatherShapeFn(InferenceContext* c);
26 
27 //==============================================================================
28 // Registered Ops
29 //==============================================================================
30 
31 REGISTER_OP("RaggedGather")
32     .Input("params_nested_splits: PARAMS_RAGGED_RANK * Tsplits")
33     .Input("params_dense_values: Tvalues")
34     .Input("indices: Tindices")
35     .Output("output_nested_splits: OUTPUT_RAGGED_RANK * Tsplits")
36     .Output("output_dense_values: Tvalues")
37     .Attr("Tvalues: type")
38     .Attr("Tindices: {int32, int64}")
39     .Attr("Tsplits: {int32, int64} = DT_INT64")
40     .Attr("PARAMS_RAGGED_RANK: int >= 1")
41     .Attr("OUTPUT_RAGGED_RANK: int >= 0")
42     .SetShapeFn(RaggedGatherShapeFn);
43 
44 REGISTER_OP("RaggedCross")
45     .Input("ragged_values: ragged_values_types")
46     .Input("ragged_row_splits: ragged_splits_types")
47     .Input("sparse_indices: Nsparse * int64")
48     .Input("sparse_values: sparse_values_types")
49     .Input("sparse_shape: Nsparse * int64")
50     .Input("dense_inputs: dense_types")
51     .Output("output_values: out_values_type")
52     .Output("output_row_splits: out_row_splits_type")
53     .Attr("Nsparse: int >= 0")
54     .Attr("input_order: string")
55     .Attr("hashed_output: bool")
56     .Attr("num_buckets: int >= 0")
57     .Attr("hash_key: int")
58     .Attr("ragged_values_types: list({int64, string}) >= 0")
59     .Attr("ragged_splits_types: list({int32, int64}) >= 0")
60     .Attr("sparse_values_types: list({int64, string}) >= 0")
61     .Attr("dense_types: list({int64, string}) >= 0")
62     .Attr("out_values_type: {int64, string}")
63     .Attr("out_row_splits_type: {int32, int64}")
__anon1fd7ceb10102(shape_inference::InferenceContext* c) 64     .SetShapeFn([](shape_inference::InferenceContext* c) {
65       std::vector<DataType> ragged_values_types;
66       std::vector<DataType> ragged_splits_types;
67       std::vector<DataType> sparse_values_types;
68       std::vector<DataType> dense_types;
69 
70       TF_RETURN_IF_ERROR(
71           c->GetAttr("ragged_values_types", &ragged_values_types));
72       TF_RETURN_IF_ERROR(
73           c->GetAttr("ragged_splits_types", &ragged_splits_types));
74       TF_RETURN_IF_ERROR(c->GetAttr("dense_types", &dense_types));
75       TF_RETURN_IF_ERROR(
76           c->GetAttr("sparse_values_types", &sparse_values_types));
77 
78       int num_ragged = ragged_values_types.size();
79       if (num_ragged != ragged_splits_types.size()) {
80         return errors::InvalidArgument(
81             "ragged values and splits must have the same length.");
82       }
83 
84       int num_sparse;
85       TF_RETURN_IF_ERROR(c->GetAttr("Nsparse", &num_sparse));
86       if (num_sparse != sparse_values_types.size()) {
87         return errors::InvalidArgument(
88             "sparse indices and values must have the same length");
89       }
90 
91       ShapeHandle out_values = c->UnknownShapeOfRank(1);
92       ShapeHandle out_splits = c->UnknownShapeOfRank(1);
93 
94       // Merge the shapes of row_splits from ragged inputs.  (This is one plus
95       // the batch size.)
96       int ragged_splits_start = num_ragged;
97       for (int i = 0; i < ragged_splits_types.size(); ++i) {
98         ShapeHandle row_splits = c->input(i + ragged_splits_start);
99         if (!c->Merge(out_splits, row_splits, &out_splits).ok()) {
100           return errors::InvalidArgument(
101               "inputs must all have the same batch dimension size.");
102         }
103       }
104 
105       // Merge the batch size of each dense input into out_splits.
106       int dense_start = num_ragged * 2 + num_sparse * 3;
107       for (int i = 0; i < dense_types.size(); ++i) {
108         ShapeHandle dense_input = c->input(i + dense_start);
109         int32 rank = c->Rank(dense_input);
110         if (rank == InferenceContext::kUnknownRank) {
111           continue;
112         } else if (rank != 2) {
113           return errors::InvalidArgument(
114               "tf.ragged.cross only supports inputs with rank=2");
115         }
116         int64_t batch_size = c->Value(c->Dim(dense_input, 0));
117         if (batch_size != InferenceContext::kUnknownDim) {
118           ShapeHandle row_splits = c->Vector(batch_size + 1);
119           if (!c->Merge(out_splits, row_splits, &out_splits).ok()) {
120             return errors::InvalidArgument(
121                 "inputs must all have the same batch dimension size.");
122           }
123         }
124       }
125 
126       c->set_output(0, out_values);
127       c->set_output(1, out_splits);
128       return OkStatus();
129     });
130 
131 //==============================================================================
132 // Shape Functions
133 //==============================================================================
134 
RaggedGatherShapeFn(InferenceContext * c)135 Status RaggedGatherShapeFn(InferenceContext* c) {
136   int num_splits;
137   int64_t PARAMS_RAGGED_RANK;
138   TF_RETURN_IF_ERROR(
139       c->GetAttr<int64_t>("PARAMS_RAGGED_RANK", &PARAMS_RAGGED_RANK));
140   TF_RETURN_IF_ERROR(c->GetAttr<int>("OUTPUT_RAGGED_RANK", &num_splits));
141 
142   // Check rank of `indices`.
143   ShapeHandle indices = c->input(PARAMS_RAGGED_RANK + 1);
144   TF_RETURN_IF_ERROR(
145       c->WithRank(indices, num_splits - PARAMS_RAGGED_RANK + 1, &indices));
146 
147   // Check that all params_nested_splits have rank 1.
148   for (int64_t i = 0; i < PARAMS_RAGGED_RANK; ++i) {
149     ShapeHandle splits = c->input(i);
150     TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
151   }
152 
153   // Check that `params_dense_values` has rank>=1.
154   ShapeHandle params_dense_values = c->input(PARAMS_RAGGED_RANK);
155   TF_RETURN_IF_ERROR(
156       c->WithRankAtLeast(params_dense_values, 1, &params_dense_values));
157 
158   // Set the rank for the `splits` outputs.
159   for (int i = 0; i < num_splits; ++i) {
160     c->set_output(i, c->UnknownShapeOfRank(1));
161   }
162 
163   // Calculate the `values` shape.
164   ShapeHandle value = c->UnknownShape();
165   ShapeHandle values = c->UnknownShape();
166   TF_RETURN_IF_ERROR(c->Subshape(params_dense_values, 1, &value));
167   TF_RETURN_IF_ERROR(c->Concatenate(c->UnknownShapeOfRank(1), value, &values));
168   c->set_output(num_splits, values);
169 
170   return OkStatus();
171 }
172 
173 }  // namespace tensorflow
174