xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/SegmentReduce.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/SegmentReduce.h>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/NumericUtils.h>
7 #include <ATen/TensorOperators.h>
8 #include <c10/util/irange.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_segment_reduce_backward_native.h>
15 #include <ATen/ops/all.h>
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/segment_reduce_native.h>
18 #include <ATen/ops/zeros.h>
19 #endif
20 
21 namespace at::native {
22 
23 DEFINE_DISPATCH(_segment_reduce_lengths_stub);
24 DEFINE_DISPATCH(_segment_reduce_offsets_stub);
25 DEFINE_DISPATCH(_segment_reduce_lengths_backward_stub);
26 DEFINE_DISPATCH(_segment_reduce_offsets_backward_stub);
27 
28 namespace {
29 
30 template <typename T, bool is_offsets_like=false>
_segment_reduce_lengths_cpu_kernel1(ReductionType reduction,const Tensor & data,const T * lengths_data,int64_t axis,const std::optional<Scalar> & initial,Tensor & output,int64_t segment_count,int64_t lengths_stride_axis)31 void _segment_reduce_lengths_cpu_kernel1(
32     ReductionType reduction,
33     const Tensor& data,
34     const T* lengths_data,
35     int64_t axis,
36     const std::optional<Scalar>& initial,
37     Tensor& output,
38     int64_t segment_count,
39     int64_t lengths_stride_axis) {
40   // outer_offset is the size of the outer dimensions of output (before axis)
41   // inner_offset is the size of the inner dimensions of output (after axis)
42   int64_t outer_offset = 1, inner_offset = 1;
43   for (int64_t d = 0; d < axis; d++)
44       outer_offset *= output.size(d);
45   for (int64_t d = axis + 1; d < output.dim(); d++)
46       inner_offset *= output.size(d);
47   int64_t lengths_size_axis = is_offsets_like ? segment_count + 1 : segment_count;
48   auto data_stride_axis = data.stride(axis);
49   auto data_size_axis = data.size(axis);
50   auto output_stride_axis = output.stride(axis);
51   auto output_size_axis = output.size(axis);
52   AT_DISPATCH_FLOATING_TYPES_AND2(
53       kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", [&]() {
54         auto* output_data = output.data_ptr<scalar_t>();
55         const auto* values_data = data.const_data_ptr<scalar_t>();
56         for (const auto outer_idx : c10::irange(outer_offset)) {
57           int64_t segment_start, segment_length;
58           int64_t segment_end = is_offsets_like ?
59                                 lengths_data[outer_idx * lengths_stride_axis * lengths_size_axis] :
60                                 0;
61           for (const auto dim_idx : c10::irange(segment_count)) {
62             segment_start = segment_end;
63             auto lengths_idx = outer_idx * lengths_stride_axis * lengths_size_axis + dim_idx;
64             if (is_offsets_like) {
65               segment_end = lengths_data[lengths_idx + 1];
66               segment_length = segment_end - segment_start;
67             } else {
68               segment_length = lengths_data[lengths_idx];
69               segment_end += segment_length;
70             }
71             for (const auto inner_idx : c10::irange(inner_offset)) {
72               // ===== step1: initialize starting value
73               scalar_t initial_value;
74               if (initial.has_value()) {
75                 initial_value = initial.value().to<scalar_t>();
76               } else if (reduction == ReductionType::MAX) {
77                 initial_value = -std::numeric_limits<scalar_t>::infinity();
78               } else if (
79                   reduction == ReductionType::MEAN ||
80                   reduction == ReductionType::SUM) {
81                 initial_value = 0;
82               } else if (reduction == ReductionType::MIN) {
83                 initial_value = std::numeric_limits<scalar_t>::infinity();
84               } else if (reduction == ReductionType::PROD) {
85                 initial_value = 1;
86               }
87 
88               // ===== step2: apply reduction
89               for (const auto j : c10::irange(segment_start, segment_end)) {
90                 int64_t data_index = outer_idx * data_stride_axis * data_size_axis
91                                      + j * data_stride_axis + inner_idx;
92                 const auto val = values_data[data_index];
93                 if (reduction == ReductionType::MAX) {
94                   initial_value = at::_isnan(val)
95                       ? val
96                       : std::max<scalar_t>(initial_value, val);
97                 } else if (
98                     reduction == ReductionType::MEAN ||
99                     reduction == ReductionType::SUM) {
100                   initial_value = initial_value + val;
101                 } else if (reduction == ReductionType::MIN) {
102                   initial_value = at::_isnan(val)
103                       ? val
104                       : std::min<scalar_t>(initial_value, val);
105                 } else if (reduction == ReductionType::PROD) {
106                   initial_value = initial_value * val;
107                 }
108               }
109 
110               // ===== step3: finalize reduction
111               TORCH_CHECK(segment_length >= 0);
112 
113               if (segment_length == 0 && !initial.has_value() &&
114                   reduction == ReductionType::MEAN) {
115                 initial_value = static_cast<scalar_t>(NAN);
116               } else if (
117                   reduction == ReductionType::MEAN &&
118                   segment_length > 0 && !at::_isnan(initial_value)) {
119                 initial_value = initial_value / segment_length;
120               }
121               int64_t output_index = outer_idx * output_stride_axis * output_size_axis
122                                      + dim_idx * output_stride_axis + inner_idx;
123               output_data[output_index] = initial_value;
124             }
125           }
126         }
127       });
128 }
129 
_segment_reduce_lengths_cpu_kernel(ReductionType reduction,const Tensor & data,const Tensor & lengths,int64_t axis,const std::optional<Scalar> & initial)130 Tensor _segment_reduce_lengths_cpu_kernel(
131     ReductionType reduction,
132     const Tensor& data,
133     const Tensor& lengths,
134     int64_t axis,
135     const std::optional<Scalar>& initial) {
136   // data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
137   TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous.");
138   TORCH_CHECK(lengths.is_contiguous(), "Expected lengths to be contiguous.");
139   // reduction axis should always be the last dimension of lengths
140   axis = lengths.dim() - 1;
141   int64_t segment_count = lengths.size(axis);
142   int64_t lengths_stride_axis = lengths.stride(axis);
143   auto output_shape = data.sizes().vec();
144   output_shape[axis] = segment_count;
145   auto output = at::empty(output_shape, data.options());
146 
147   AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "_segment_reduce_lengths_cpu_kernel1", [&]() {
148     const auto* lengths_data = lengths.const_data_ptr<index_t>();
149     _segment_reduce_lengths_cpu_kernel1(
150         reduction, data, lengths_data, axis, initial, output, segment_count, lengths_stride_axis);
151   });
152 
153   return output;
154 }
155 
_segment_reduce_offsets_cpu_kernel(ReductionType reduction,const Tensor & data,const Tensor & offsets,int64_t axis,const std::optional<Scalar> & initial)156 Tensor _segment_reduce_offsets_cpu_kernel(
157     ReductionType reduction,
158     const Tensor& data,
159     const Tensor& offsets,
160     int64_t axis,
161     const std::optional<Scalar>& initial) {
162   // data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
163   TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous.");
164   TORCH_CHECK(offsets.is_contiguous(), "Expected offsets to be contiguous.");
165   // reduction axis should always be the last dimension of lengths
166   axis = offsets.dim() - 1;
167   int64_t segment_count = offsets.size(axis) - 1;
168   int64_t offsets_stride_axis = offsets.stride(axis);
169   auto output_shape = data.sizes().vec();
170   output_shape[axis] = segment_count;
171   auto output = at::empty(output_shape, data.options());
172 
173   AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_segment_reduce_offsets_cpu_kernel1", [&]() {
174     const auto* offsets_data = offsets.const_data_ptr<index_t>();
175     _segment_reduce_lengths_cpu_kernel1<index_t, /*is_offsets_like=*/true>(
176         reduction, data, offsets_data, axis, initial, output, segment_count, offsets_stride_axis);
177   });
178 
179   return output;
180 }
181 
182 template <typename T, bool is_offsets_like = false>
_segment_reduce_cpu_lengths_backward_kernel1(const Tensor & grad_contig,const Tensor & output_contig,const Tensor & data_contig,ReductionType reduction,const T * lengths_data,int64_t axis,const std::optional<Scalar> & initial,Tensor & grad_input,int64_t segment_count,int64_t lengths_stride_axis)183 void _segment_reduce_cpu_lengths_backward_kernel1(
184     const Tensor& grad_contig,
185     const Tensor& output_contig,
186     const Tensor& data_contig,
187     ReductionType reduction,
188     const T* lengths_data,
189     int64_t axis,
190     const std::optional<Scalar>& initial,
191     Tensor& grad_input,
192     int64_t segment_count,
193     int64_t lengths_stride_axis) {
194   // outer_offset is the size of the outer dimensions of output (before axis)
195   // inner_offset is the size of the inner dimensions of output (after axis)
196   int64_t outer_offset = 1, inner_offset = 1;
197   for (int64_t d = 0; d < axis; d++)
198       outer_offset *= output_contig.size(d);
199   for (int64_t d = axis + 1; d < output_contig.dim(); d++)
200       inner_offset *= output_contig.size(d);
201   int64_t lengths_size_axis = is_offsets_like ? segment_count + 1 : segment_count;
202   auto data_stride_axis = data_contig.stride(axis);
203   auto data_size_axis = data_contig.size(axis);
204   auto output_stride_axis = output_contig.stride(axis);
205   auto output_size_axis = output_contig.size(axis);
206   // TODO: Switch to TensorIterator for better maintainablility and
207   // readability
208   AT_DISPATCH_FLOATING_TYPES_AND2(
209       kBFloat16,
210       kHalf,
211       data_contig.scalar_type(),
212       "_segment_reduce_cpu",
213       [&]() {
214         auto* output_data = output_contig.const_data_ptr<scalar_t>();
215         auto* grad_data = grad_contig.const_data_ptr<scalar_t>();
216         auto* grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
217         const auto* values_data = data_contig.const_data_ptr<scalar_t>();
218         // Used to calculate exclusive prod
219         scalar_t initial_prod_value;
220         if (reduction == ReductionType::PROD) {
221           if (initial.has_value()) {
222             initial_prod_value = initial.value().to<scalar_t>();
223           } else {
224             initial_prod_value = 1;
225           }
226         }
227 
228         for (const auto outer_idx : c10::irange(outer_offset)) {
229           // int64_t lengths_cum_sum = 0;
230           int64_t segment_start, segment_length;
231           int64_t segment_end = is_offsets_like ?
232                                 lengths_data[outer_idx * lengths_stride_axis * lengths_size_axis] :
233                                 0;
234           for (const auto dim_idx : c10::irange(segment_count)) {
235             // int64_t segment_length = lengths_data[outer_idx * lengths_stride_axis * segment_count + dim_idx];
236             segment_start = segment_end;
237             auto lengths_idx = outer_idx * lengths_stride_axis * lengths_size_axis + dim_idx;
238             if (is_offsets_like) {
239               segment_end = lengths_data[lengths_idx + 1];
240               segment_length = segment_end - segment_start;
241             } else {
242               segment_length = lengths_data[lengths_idx];
243               segment_end += segment_length;
244             }
245             if (segment_length == 0) {
246               continue;
247             }
248             for (const auto inner_idx : c10::irange(inner_offset)) {
249               int64_t output_index = outer_idx * output_stride_axis * output_size_axis
250                                      + dim_idx * output_stride_axis + inner_idx;
251               if (reduction == ReductionType::MAX ||
252                   reduction == ReductionType::MIN) {
253                 int64_t counter = 0;
254                 for (const auto j : c10::irange(segment_start, segment_end)) {
255                   int64_t data_index = outer_idx * data_stride_axis * data_size_axis
256                                        + j * data_stride_axis + inner_idx;
257                   if (at::_isnan(values_data[data_index]) ||
258                       values_data[data_index] == output_data[output_index]) {
259                     grad_input_data[data_index] = grad_data[output_index];
260                     counter++;
261                   }
262                 }
263                 // Average gradient based on number of maximum elements in
264                 // the segment
265                 if (counter < 2) {
266                   continue;
267                 }
268                 for (const auto j : c10::irange(segment_start, segment_end)) {
269                   int64_t data_index = outer_idx * data_stride_axis * data_size_axis
270                                        + j * data_stride_axis + inner_idx;
271                   if (grad_input_data[data_index] > 0) {
272                     grad_input_data[data_index] =
273                         grad_input_data[data_index] / counter;
274                   }
275                 }
276               } else if (reduction == ReductionType::MEAN) {
277                 auto grad_val = grad_data[output_index] / segment_length;
278                 for (const auto j : c10::irange(segment_start, segment_end)) {
279                   int64_t data_index = outer_idx * data_stride_axis * data_size_axis
280                                        + j * data_stride_axis + inner_idx;
281                   grad_input_data[data_index] = grad_val;
282                 }
283               } else if (reduction == ReductionType::SUM) {
284                 const auto& grad_val = grad_data[output_index];
285                 for (const auto j : c10::irange(segment_start, segment_end)) {
286                   int64_t data_index = outer_idx * data_stride_axis * data_size_axis
287                                        + j * data_stride_axis + inner_idx;
288                   grad_input_data[data_index] = grad_val;
289                 }
290               } else if (reduction == ReductionType::PROD) {
291                 const auto& grad_val = grad_data[output_index] * output_data[output_index];
292                 for (const auto j : c10::irange(segment_start, segment_end)) {
293                   int64_t data_index = outer_idx * data_stride_axis * data_size_axis
294                                        + j * data_stride_axis + inner_idx;
295                   if (at::_isnan(values_data[data_index]) ||
296                       values_data[data_index] == 0) {
297                     // explicitly compute exclusive prod
298                     scalar_t exclusive_prod = initial_prod_value;
299                     int64_t idx;
300                     for (const auto k : c10::irange(segment_start, segment_end)) {
301                       if (k != j) {
302                         idx = outer_idx * data_stride_axis * data_size_axis
303                               + k * data_stride_axis + inner_idx;
304                         exclusive_prod *= values_data[idx];
305                       }
306                     }
307                     grad_input_data[data_index] = grad_data[output_index] * exclusive_prod;
308                   } else {
309                     grad_input_data[data_index] = grad_val / values_data[data_index];
310                   }
311                 }
312               }
313             }
314           }
315         }
316       });
317 }
318 
_segment_reduce_cpu_lengths_backward_kernel(const Tensor & grad_contig,const Tensor & output_contig,const Tensor & data_contig,ReductionType reduction,const Tensor & lengths_contig,int64_t axis,const std::optional<Scalar> & initial)319 Tensor _segment_reduce_cpu_lengths_backward_kernel(
320     const Tensor& grad_contig,
321     const Tensor& output_contig,
322     const Tensor& data_contig,
323     ReductionType reduction,
324     const Tensor& lengths_contig,
325     int64_t axis,
326     const std::optional<Scalar>& initial) {
327   axis = lengths_contig.dim() - 1;
328   int64_t segment_count = lengths_contig.size(axis);
329   int64_t lengths_stride_axis = lengths_contig.stride(axis);
330   auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
331 
332   AT_DISPATCH_INDEX_TYPES(
333       lengths_contig.scalar_type(), "_segment_reduce_cpu_lengths_backward_kernel1", [&] {
334         const auto* lengths_data = lengths_contig.const_data_ptr<index_t>();
335         _segment_reduce_cpu_lengths_backward_kernel1(
336             grad_contig,
337             output_contig,
338             data_contig,
339             reduction,
340             lengths_data,
341             axis,
342             initial,
343             grad_input,
344             segment_count,
345             lengths_stride_axis);
346       });
347 
348   return grad_input;
349 }
350 
351 
_segment_reduce_cpu_offsets_backward_kernel(const Tensor & grad_contig,const Tensor & output_contig,const Tensor & data_contig,ReductionType reduction,const Tensor & offsets_contig,int64_t axis,const std::optional<Scalar> & initial)352 Tensor _segment_reduce_cpu_offsets_backward_kernel(
353     const Tensor& grad_contig,
354     const Tensor& output_contig,
355     const Tensor& data_contig,
356     ReductionType reduction,
357     const Tensor& offsets_contig,
358     int64_t axis,
359     const std::optional<Scalar>& initial) {
360   axis = offsets_contig.dim() - 1;
361   int64_t segment_count = offsets_contig.size(axis) - 1;
362   int64_t offsets_stride_axis = offsets_contig.stride(axis);
363   auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
364 
365   AT_DISPATCH_INDEX_TYPES(
366       offsets_contig.scalar_type(), "_segment_reduce_cpu_offsets_backward_kernel1", [&] {
367         const auto* offsets_data = offsets_contig.const_data_ptr<index_t>();
368         _segment_reduce_cpu_lengths_backward_kernel1<index_t, /*is_offsets_like=*/true>(
369             grad_contig,
370             output_contig,
371             data_contig,
372             reduction,
373             offsets_data,
374             axis,
375             initial,
376             grad_input,
377             segment_count,
378             offsets_stride_axis);
379       });
380 
381   return grad_input;
382 }
383 
384 } // namespace
385 
segment_reduce_kernel(const Tensor & data,c10::string_view reduce,const std::optional<Tensor> & lengths,const std::optional<Tensor> & indices,const std::optional<Tensor> & offsets,int64_t axis,bool unsafe,const std::optional<Scalar> & initial)386 Tensor segment_reduce_kernel(
387     const Tensor& data,
388     c10::string_view reduce,
389     const std::optional<Tensor>& lengths,
390     const std::optional<Tensor>& indices,
391     const std::optional<Tensor>& offsets,
392     int64_t axis,
393     bool unsafe,
394     const std::optional<Scalar>& initial) {
395   axis = maybe_wrap_dim(axis, data.ndimension());
396   TORCH_CHECK(data.numel() >= 0);
397 
398   // check that one of lengths or offsets is defined
399   auto lengths_has_value = lengths.has_value();
400   auto offsets_has_value = offsets.has_value();
401   TORCH_CHECK(
402     !indices.has_value(),
403     "segment_reduce(): indices based reduction is not supported yet.");
404   TORCH_CHECK(
405       lengths_has_value || offsets_has_value,
406       "segment_reduce(): Either lengths or offsets must be defined.")
407 
408   auto reduction = get_reduction_enum(reduce);
409   const auto data_contig = data.contiguous();
410 
411   if (offsets_has_value) {
412     const auto& offsets_value = offsets.value();
413 
414     // offsets related checks
415     TORCH_CHECK(data.get_device() == offsets_value.get_device());
416     TORCH_CHECK(data.dim() >= offsets_value.dim());
417     TORCH_CHECK(axis == offsets_value.dim() - 1,
418                 "segment_reduce(): Expected axis to be the last dimension of offsets but got ", axis, ".");
419 
420     // TODO: add checks when !unsafe
421 
422     const auto offsets_contig = offsets_value.contiguous();
423 
424     return _segment_reduce_offsets_stub(
425       data_contig.device().type(),
426       reduction,
427       data_contig,
428       offsets_contig,
429       axis,
430       initial);
431 
432   } else {
433     const auto& lengths_value = lengths.value();
434 
435     // length related checks
436     TORCH_CHECK(data.get_device() == lengths_value.get_device());
437     TORCH_CHECK(data.dim() >= lengths_value.dim());
438     TORCH_CHECK(axis == lengths_value.dim() - 1,
439                 "segment_reduce(): Expected axis to be the last dimension of lengths but got ", axis, ".");
440 
441     if (!unsafe) {
442       auto min_length = lengths_value.min().item<int64_t>();
443       TORCH_CHECK((min_length >= 0), "lengths contains negative value!");
444       TORCH_CHECK(all(lengths_value.sum({-1}) == data.size(axis)).item<bool>(),
445                   "segment_reduce(): Expected all rows of lengths along axis ",
446                   "to sum to data.size(lengths.dim()-1) when !unsafe.");
447     }
448 
449     const auto lengths_contig = lengths_value.contiguous();
450 
451     return _segment_reduce_lengths_stub(
452       data_contig.device().type(),
453       reduction,
454       data_contig,
455       lengths_contig,
456       axis,
457       initial);
458   }
459 }
460 
461 REGISTER_ARCH_DISPATCH(
462     _segment_reduce_lengths_stub,
463     DEFAULT,
464     &_segment_reduce_lengths_cpu_kernel);
465 REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
466 REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
467 REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
468 REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
469 
470 // offsets dispatches
471 REGISTER_ARCH_DISPATCH(
472     _segment_reduce_offsets_stub,
473     DEFAULT,
474     &_segment_reduce_offsets_cpu_kernel);
475 REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
476 REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
477 REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
478 REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
479 
480 // Currently some computation is being duplicated across forward and backward.
481 // TODO: Cache indices in forward pass to re-use in backward
_segment_reduce_backward_kernel(const Tensor & grad,const Tensor & output,const Tensor & data,c10::string_view reduce,const std::optional<Tensor> & lengths,const std::optional<Tensor> & offsets,int64_t axis,const std::optional<Scalar> & initial)482 Tensor _segment_reduce_backward_kernel(
483     const Tensor& grad,
484     const Tensor& output,
485     const Tensor& data,
486     c10::string_view reduce,
487     const std::optional<Tensor>& lengths,
488     const std::optional<Tensor>& offsets,
489     int64_t axis,
490     const std::optional<Scalar>& initial) {
491   axis = maybe_wrap_dim(axis, data.ndimension());
492   // check that one of lengths or offsets is defined
493   // codegen for derivatives.yaml passes an undefined Tensor for None rather than a std::optional
494   // so checking .has_value() doesn't work unlike in the forward pass
495   auto lengths_has_value = lengths.has_value() && lengths.value().defined();
496   auto offsets_has_value = offsets.has_value() && offsets.value().defined();
497   TORCH_CHECK(
498       lengths_has_value ||  offsets_has_value,
499       "segment_reduce(): Either lengths or offsets must be defined.");
500 
501   const auto grad_contig = grad.contiguous();
502   const auto output_contig = output.contiguous();
503   const auto data_contig = data.contiguous();
504   auto reduction = get_reduction_enum(reduce);
505 
506   if (offsets_has_value) {
507     const auto& offsets_value = offsets.value();
508     const auto offsets_contig = offsets_value.contiguous();
509     return _segment_reduce_offsets_backward_stub(
510       grad_contig.device().type(),
511       grad_contig,
512       output_contig,
513       data_contig,
514       reduction,
515       offsets_contig,
516       axis,
517       initial);
518   } else {
519     const auto& lengths_value = lengths.value();
520     const auto lengths_contig = lengths_value.contiguous();
521     return _segment_reduce_lengths_backward_stub(
522       grad_contig.device().type(),
523       grad_contig,
524       output_contig,
525       data_contig,
526       reduction,
527       lengths_contig,
528       axis,
529       initial);
530   }
531 }
532 
533 REGISTER_ARCH_DISPATCH(
534     _segment_reduce_lengths_backward_stub,
535     DEFAULT,
536     &_segment_reduce_cpu_lengths_backward_kernel);
537 REGISTER_AVX512_DISPATCH(
538     _segment_reduce_lengths_backward_stub,
539     &_segment_reduce_cpu_lengths_backward_kernel);
540 REGISTER_AVX2_DISPATCH(
541     _segment_reduce_lengths_backward_stub,
542     &_segment_reduce_cpu_lengths_backward_kernel);
543 REGISTER_VSX_DISPATCH(
544     _segment_reduce_lengths_backward_stub,
545     &_segment_reduce_cpu_lengths_backward_kernel);
546 REGISTER_ZVECTOR_DISPATCH(
547     _segment_reduce_lengths_backward_stub,
548     &_segment_reduce_cpu_lengths_backward_kernel);
549 
550 REGISTER_ARCH_DISPATCH(
551     _segment_reduce_offsets_backward_stub,
552     DEFAULT,
553     &_segment_reduce_cpu_offsets_backward_kernel);
554 REGISTER_AVX512_DISPATCH(
555     _segment_reduce_offsets_backward_stub,
556     &_segment_reduce_cpu_offsets_backward_kernel);
557 REGISTER_AVX2_DISPATCH(
558     _segment_reduce_offsets_backward_stub,
559     &_segment_reduce_cpu_offsets_backward_kernel);
560 REGISTER_VSX_DISPATCH(
561     _segment_reduce_offsets_backward_stub,
562     &_segment_reduce_cpu_offsets_backward_kernel);
563 REGISTER_ZVECTOR_DISPATCH(
564     _segment_reduce_offsets_backward_stub,
565     &_segment_reduce_cpu_offsets_backward_kernel);
566 
567 } // namespace at::native
568