xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 #include "tensorflow/core/kernels/segment_reduction_ops_impl.h"
18 
19 namespace tensorflow {
20 namespace internal {
21 // Static routines not in the templated class to reduce code size
ValidateSegmentReduction(OpKernelContext * context,const Tensor & input,const Tensor & segment_ids)22 Status ValidateSegmentReduction(OpKernelContext* context, const Tensor& input,
23                                 const Tensor& segment_ids) {
24   if (!TensorShapeUtils::IsVectorOrHigher(input.shape())) {
25     return errors::InvalidArgument("input must be at least rank 1");
26   }
27   if (!TensorShapeUtils::IsVector(segment_ids.shape())) {
28     return errors::InvalidArgument("segment_ids should be a vector.");
29   }
30   const int64_t num_indices = segment_ids.NumElements();
31   if (num_indices != input.dim_size(0)) {
32     return errors::InvalidArgument(
33         "segment_ids should be the same size as dimension 0 of"
34         " input.");
35   }
36 
37   return OkStatus();
38 }
39 
40 // check routines not in the templated class to reduce code size
ValidateUnsortedSegmentReduction(OpKernel * op_kernel,OpKernelContext * context,const Tensor & data,const Tensor & segment_ids,const Tensor & num_segments)41 Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel,
42                                         OpKernelContext* context,
43                                         const Tensor& data,
44                                         const Tensor& segment_ids,
45                                         const Tensor& num_segments) {
46   if (!TensorShapeUtils::IsScalar(num_segments.shape())) {
47     return errors::InvalidArgument(
48         "num_segments should be a scalar, not shape ",
49         num_segments.shape().DebugString());
50   }
51 
52   if (!TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape())) {
53     return errors::InvalidArgument("data.shape = ", data.shape().DebugString(),
54                                    " does not start with segment_ids.shape = ",
55                                    segment_ids.shape().DebugString());
56   }
57 
58   return OkStatus();
59 }
60 
ValidateSparseSegmentReduction(OpKernelContext * context,const Tensor & input,const Tensor & indices,const Tensor & segment_ids,bool has_num_segments)61 Status ValidateSparseSegmentReduction(OpKernelContext* context,
62                                       const Tensor& input,
63                                       const Tensor& indices,
64                                       const Tensor& segment_ids,
65                                       bool has_num_segments) {
66   if (has_num_segments) {
67     const Tensor& num_segments_t = context->input(3);
68     if (!TensorShapeUtils::IsScalar(num_segments_t.shape())) {
69       return errors::InvalidArgument(
70           "num_segments should be a scalar, not shape ",
71           num_segments_t.shape().DebugString());
72     }
73     int64_t output_rows =
74         internal::SubtleMustCopy(num_segments_t.dtype() == DT_INT32
75                                      ? num_segments_t.scalar<int32>()()
76                                      : num_segments_t.scalar<int64_t>()());
77     if (output_rows < 0) {
78       return errors::InvalidArgument("segment ids must be >= 0");
79     }
80   }
81 
82   if (!TensorShapeUtils::IsVector(indices.shape())) {
83     return errors::InvalidArgument("indices should be a vector.");
84   }
85 
86   if (!TensorShapeUtils::IsVector(segment_ids.shape())) {
87     return errors::InvalidArgument("segment_ids should be a vector.");
88   }
89 
90   const int64_t num_indices = indices.NumElements();
91   if (num_indices != segment_ids.NumElements()) {
92     return errors::InvalidArgument(
93         "segment_ids and indices should have same size.");
94   }
95 
96   if (input.dims() < 1) {
97     return errors::InvalidArgument("Shape must be at least rank 1");
98   }
99 
100   return OkStatus();
101 }
102 
103 }  // namespace internal
104 
105 #define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \
106                                     default_value)                   \
107   REGISTER_KERNEL_BUILDER(                                           \
108       Name(name)                                                     \
109           .Device(DEVICE_CPU)                                        \
110           .TypeConstraint<type>("T")                                 \
111           .TypeConstraint<index_type>("Tindices"),                   \
112       SegmentReductionOp<CPUDevice, type, index_type, functor, default_value>)
113 
114 #define REGISTER_REAL_CPU_KERNELS(type, index_type)                            \
115   REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
116                               type, index_type, 0);                            \
117   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
118       "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
119   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
120       "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1); \
121   REGISTER_CPU_KERNEL_SEGMENT("SegmentMin", Eigen::internal::MinReducer<type>, \
122                               type, index_type, 0);                            \
123   REGISTER_CPU_KERNEL_SEGMENT("SegmentMax", Eigen::internal::MaxReducer<type>, \
124                               type, index_type, 0)
125 
126 #define REGISTER_COMPLEX_CPU_KERNELS(type, index_type)                         \
127   REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
128                               type, index_type, 0);                            \
129   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
130       "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
131   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
132       "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1);
133 
134 #define REGISTER_REAL_CPU_KERNELS_ALL(type) \
135   REGISTER_REAL_CPU_KERNELS(type, int32)
136 
137 #define REGISTER_COMPLEX_CPU_KERNELS_ALL(type) \
138   REGISTER_COMPLEX_CPU_KERNELS(type, int32)
139 
140 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_KERNELS_ALL);
141 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex64);
142 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
143 #undef REGISTER_CPU_KERNEL_SEGMENT
144 #undef REGISTER_REAL_CPU_KERNELS
145 #undef REGISTER_COMPLEX_CPU_KERNELS
146 #undef REGISTER_REAL_CPU_KERNELS_ALL
147 #undef REGISTER_COMPLEX_CPU_KERNELS_ALL
148 
149 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
150 #define REGISTER_GPU_KERNEL_SORTEDSEGMENT(                            \
151     name, type, index_type, initial_value_functor,                    \
152     empty_segment_value_functor, reduction_kernel_functor, is_mean)   \
153   REGISTER_KERNEL_BUILDER(                                            \
154       Name(name)                                                      \
155           .Device(DEVICE_GPU)                                         \
156           .TypeConstraint<type>("T")                                  \
157           .TypeConstraint<index_type>("Tindices"),                    \
158       SegmentReductionGPUOp<                                          \
159           type, index_type,                                           \
160           functor::SegmentReductionFunctor<                           \
161               type, index_type, initial_value_functor,                \
162               empty_segment_value_functor, reduction_kernel_functor>, \
163           is_mean>)
164 
165 #define REGISTER_GPU_SORTED_KERNELS(type, index_type)                         \
166   REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentSum", type, index_type,           \
167                                     functor::Zero<type>, functor::Zero<type>, \
168                                     functor::Sum, /*is_mean=*/false);         \
169   REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentMean", type, index_type,          \
170                                     functor::Zero<type>, functor::Zero<type>, \
171                                     functor::Sum, /*is_mean=*/true);          \
172   REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentProd", type, index_type,          \
173                                     functor::One<type>, functor::One<type>,   \
174                                     functor::Prod, /*is_mean=*/false);        \
175   REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                          \
176       "SegmentMin", type, index_type, functor::Highest<type>,                 \
177       functor::Zero<type>, functor::Min, /*is_mean=*/false);                  \
178   REGISTER_GPU_KERNEL_SORTEDSEGMENT(                                          \
179       "SegmentMax", type, index_type, functor::Lowest<type>,                  \
180       functor::Zero<type>, functor::Max, /*is_mean=*/false);
181 
182 #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
183   REGISTER_GPU_SORTED_KERNELS(type, int32)
184 
185 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
186 #undef REGISTER_GPU_KERNEL_SORTEDSEGMENT
187 #undef REGISTER_GPU_SORTED_KERNELS
188 #undef REGISTER_GPU_SORTED_KERNELS_ALL
189 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
190 
191 }  // namespace tensorflow
192