1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/SegmentReduce.h>
3
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/NumericUtils.h>
7 #include <ATen/TensorOperators.h>
8 #include <c10/util/irange.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_segment_reduce_backward_native.h>
15 #include <ATen/ops/all.h>
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/segment_reduce_native.h>
18 #include <ATen/ops/zeros.h>
19 #endif
20
21 namespace at::native {
22
23 DEFINE_DISPATCH(_segment_reduce_lengths_stub);
24 DEFINE_DISPATCH(_segment_reduce_offsets_stub);
25 DEFINE_DISPATCH(_segment_reduce_lengths_backward_stub);
26 DEFINE_DISPATCH(_segment_reduce_offsets_backward_stub);
27
28 namespace {
29
30 template <typename T, bool is_offsets_like=false>
_segment_reduce_lengths_cpu_kernel1(ReductionType reduction,const Tensor & data,const T * lengths_data,int64_t axis,const std::optional<Scalar> & initial,Tensor & output,int64_t segment_count,int64_t lengths_stride_axis)31 void _segment_reduce_lengths_cpu_kernel1(
32 ReductionType reduction,
33 const Tensor& data,
34 const T* lengths_data,
35 int64_t axis,
36 const std::optional<Scalar>& initial,
37 Tensor& output,
38 int64_t segment_count,
39 int64_t lengths_stride_axis) {
40 // outer_offset is the size of the outer dimensions of output (before axis)
41 // inner_offset is the size of the inner dimensions of output (after axis)
42 int64_t outer_offset = 1, inner_offset = 1;
43 for (int64_t d = 0; d < axis; d++)
44 outer_offset *= output.size(d);
45 for (int64_t d = axis + 1; d < output.dim(); d++)
46 inner_offset *= output.size(d);
47 int64_t lengths_size_axis = is_offsets_like ? segment_count + 1 : segment_count;
48 auto data_stride_axis = data.stride(axis);
49 auto data_size_axis = data.size(axis);
50 auto output_stride_axis = output.stride(axis);
51 auto output_size_axis = output.size(axis);
52 AT_DISPATCH_FLOATING_TYPES_AND2(
53 kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", [&]() {
54 auto* output_data = output.data_ptr<scalar_t>();
55 const auto* values_data = data.const_data_ptr<scalar_t>();
56 for (const auto outer_idx : c10::irange(outer_offset)) {
57 int64_t segment_start, segment_length;
58 int64_t segment_end = is_offsets_like ?
59 lengths_data[outer_idx * lengths_stride_axis * lengths_size_axis] :
60 0;
61 for (const auto dim_idx : c10::irange(segment_count)) {
62 segment_start = segment_end;
63 auto lengths_idx = outer_idx * lengths_stride_axis * lengths_size_axis + dim_idx;
64 if (is_offsets_like) {
65 segment_end = lengths_data[lengths_idx + 1];
66 segment_length = segment_end - segment_start;
67 } else {
68 segment_length = lengths_data[lengths_idx];
69 segment_end += segment_length;
70 }
71 for (const auto inner_idx : c10::irange(inner_offset)) {
72 // ===== step1: initialize starting value
73 scalar_t initial_value;
74 if (initial.has_value()) {
75 initial_value = initial.value().to<scalar_t>();
76 } else if (reduction == ReductionType::MAX) {
77 initial_value = -std::numeric_limits<scalar_t>::infinity();
78 } else if (
79 reduction == ReductionType::MEAN ||
80 reduction == ReductionType::SUM) {
81 initial_value = 0;
82 } else if (reduction == ReductionType::MIN) {
83 initial_value = std::numeric_limits<scalar_t>::infinity();
84 } else if (reduction == ReductionType::PROD) {
85 initial_value = 1;
86 }
87
88 // ===== step2: apply reduction
89 for (const auto j : c10::irange(segment_start, segment_end)) {
90 int64_t data_index = outer_idx * data_stride_axis * data_size_axis
91 + j * data_stride_axis + inner_idx;
92 const auto val = values_data[data_index];
93 if (reduction == ReductionType::MAX) {
94 initial_value = at::_isnan(val)
95 ? val
96 : std::max<scalar_t>(initial_value, val);
97 } else if (
98 reduction == ReductionType::MEAN ||
99 reduction == ReductionType::SUM) {
100 initial_value = initial_value + val;
101 } else if (reduction == ReductionType::MIN) {
102 initial_value = at::_isnan(val)
103 ? val
104 : std::min<scalar_t>(initial_value, val);
105 } else if (reduction == ReductionType::PROD) {
106 initial_value = initial_value * val;
107 }
108 }
109
110 // ===== step3: finalize reduction
111 TORCH_CHECK(segment_length >= 0);
112
113 if (segment_length == 0 && !initial.has_value() &&
114 reduction == ReductionType::MEAN) {
115 initial_value = static_cast<scalar_t>(NAN);
116 } else if (
117 reduction == ReductionType::MEAN &&
118 segment_length > 0 && !at::_isnan(initial_value)) {
119 initial_value = initial_value / segment_length;
120 }
121 int64_t output_index = outer_idx * output_stride_axis * output_size_axis
122 + dim_idx * output_stride_axis + inner_idx;
123 output_data[output_index] = initial_value;
124 }
125 }
126 }
127 });
128 }
129
_segment_reduce_lengths_cpu_kernel(ReductionType reduction,const Tensor & data,const Tensor & lengths,int64_t axis,const std::optional<Scalar> & initial)130 Tensor _segment_reduce_lengths_cpu_kernel(
131 ReductionType reduction,
132 const Tensor& data,
133 const Tensor& lengths,
134 int64_t axis,
135 const std::optional<Scalar>& initial) {
136 // data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
137 TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous.");
138 TORCH_CHECK(lengths.is_contiguous(), "Expected lengths to be contiguous.");
139 // reduction axis should always be the last dimension of lengths
140 axis = lengths.dim() - 1;
141 int64_t segment_count = lengths.size(axis);
142 int64_t lengths_stride_axis = lengths.stride(axis);
143 auto output_shape = data.sizes().vec();
144 output_shape[axis] = segment_count;
145 auto output = at::empty(output_shape, data.options());
146
147 AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "_segment_reduce_lengths_cpu_kernel1", [&]() {
148 const auto* lengths_data = lengths.const_data_ptr<index_t>();
149 _segment_reduce_lengths_cpu_kernel1(
150 reduction, data, lengths_data, axis, initial, output, segment_count, lengths_stride_axis);
151 });
152
153 return output;
154 }
155
_segment_reduce_offsets_cpu_kernel(ReductionType reduction,const Tensor & data,const Tensor & offsets,int64_t axis,const std::optional<Scalar> & initial)156 Tensor _segment_reduce_offsets_cpu_kernel(
157 ReductionType reduction,
158 const Tensor& data,
159 const Tensor& offsets,
160 int64_t axis,
161 const std::optional<Scalar>& initial) {
162 // data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
163 TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous.");
164 TORCH_CHECK(offsets.is_contiguous(), "Expected offsets to be contiguous.");
165 // reduction axis should always be the last dimension of lengths
166 axis = offsets.dim() - 1;
167 int64_t segment_count = offsets.size(axis) - 1;
168 int64_t offsets_stride_axis = offsets.stride(axis);
169 auto output_shape = data.sizes().vec();
170 output_shape[axis] = segment_count;
171 auto output = at::empty(output_shape, data.options());
172
173 AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_segment_reduce_offsets_cpu_kernel1", [&]() {
174 const auto* offsets_data = offsets.const_data_ptr<index_t>();
175 _segment_reduce_lengths_cpu_kernel1<index_t, /*is_offsets_like=*/true>(
176 reduction, data, offsets_data, axis, initial, output, segment_count, offsets_stride_axis);
177 });
178
179 return output;
180 }
181
182 template <typename T, bool is_offsets_like = false>
_segment_reduce_cpu_lengths_backward_kernel1(const Tensor & grad_contig,const Tensor & output_contig,const Tensor & data_contig,ReductionType reduction,const T * lengths_data,int64_t axis,const std::optional<Scalar> & initial,Tensor & grad_input,int64_t segment_count,int64_t lengths_stride_axis)183 void _segment_reduce_cpu_lengths_backward_kernel1(
184 const Tensor& grad_contig,
185 const Tensor& output_contig,
186 const Tensor& data_contig,
187 ReductionType reduction,
188 const T* lengths_data,
189 int64_t axis,
190 const std::optional<Scalar>& initial,
191 Tensor& grad_input,
192 int64_t segment_count,
193 int64_t lengths_stride_axis) {
194 // outer_offset is the size of the outer dimensions of output (before axis)
195 // inner_offset is the size of the inner dimensions of output (after axis)
196 int64_t outer_offset = 1, inner_offset = 1;
197 for (int64_t d = 0; d < axis; d++)
198 outer_offset *= output_contig.size(d);
199 for (int64_t d = axis + 1; d < output_contig.dim(); d++)
200 inner_offset *= output_contig.size(d);
201 int64_t lengths_size_axis = is_offsets_like ? segment_count + 1 : segment_count;
202 auto data_stride_axis = data_contig.stride(axis);
203 auto data_size_axis = data_contig.size(axis);
204 auto output_stride_axis = output_contig.stride(axis);
205 auto output_size_axis = output_contig.size(axis);
206 // TODO: Switch to TensorIterator for better maintainablility and
207 // readability
208 AT_DISPATCH_FLOATING_TYPES_AND2(
209 kBFloat16,
210 kHalf,
211 data_contig.scalar_type(),
212 "_segment_reduce_cpu",
213 [&]() {
214 auto* output_data = output_contig.const_data_ptr<scalar_t>();
215 auto* grad_data = grad_contig.const_data_ptr<scalar_t>();
216 auto* grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
217 const auto* values_data = data_contig.const_data_ptr<scalar_t>();
218 // Used to calculate exclusive prod
219 scalar_t initial_prod_value;
220 if (reduction == ReductionType::PROD) {
221 if (initial.has_value()) {
222 initial_prod_value = initial.value().to<scalar_t>();
223 } else {
224 initial_prod_value = 1;
225 }
226 }
227
228 for (const auto outer_idx : c10::irange(outer_offset)) {
229 // int64_t lengths_cum_sum = 0;
230 int64_t segment_start, segment_length;
231 int64_t segment_end = is_offsets_like ?
232 lengths_data[outer_idx * lengths_stride_axis * lengths_size_axis] :
233 0;
234 for (const auto dim_idx : c10::irange(segment_count)) {
235 // int64_t segment_length = lengths_data[outer_idx * lengths_stride_axis * segment_count + dim_idx];
236 segment_start = segment_end;
237 auto lengths_idx = outer_idx * lengths_stride_axis * lengths_size_axis + dim_idx;
238 if (is_offsets_like) {
239 segment_end = lengths_data[lengths_idx + 1];
240 segment_length = segment_end - segment_start;
241 } else {
242 segment_length = lengths_data[lengths_idx];
243 segment_end += segment_length;
244 }
245 if (segment_length == 0) {
246 continue;
247 }
248 for (const auto inner_idx : c10::irange(inner_offset)) {
249 int64_t output_index = outer_idx * output_stride_axis * output_size_axis
250 + dim_idx * output_stride_axis + inner_idx;
251 if (reduction == ReductionType::MAX ||
252 reduction == ReductionType::MIN) {
253 int64_t counter = 0;
254 for (const auto j : c10::irange(segment_start, segment_end)) {
255 int64_t data_index = outer_idx * data_stride_axis * data_size_axis
256 + j * data_stride_axis + inner_idx;
257 if (at::_isnan(values_data[data_index]) ||
258 values_data[data_index] == output_data[output_index]) {
259 grad_input_data[data_index] = grad_data[output_index];
260 counter++;
261 }
262 }
263 // Average gradient based on number of maximum elements in
264 // the segment
265 if (counter < 2) {
266 continue;
267 }
268 for (const auto j : c10::irange(segment_start, segment_end)) {
269 int64_t data_index = outer_idx * data_stride_axis * data_size_axis
270 + j * data_stride_axis + inner_idx;
271 if (grad_input_data[data_index] > 0) {
272 grad_input_data[data_index] =
273 grad_input_data[data_index] / counter;
274 }
275 }
276 } else if (reduction == ReductionType::MEAN) {
277 auto grad_val = grad_data[output_index] / segment_length;
278 for (const auto j : c10::irange(segment_start, segment_end)) {
279 int64_t data_index = outer_idx * data_stride_axis * data_size_axis
280 + j * data_stride_axis + inner_idx;
281 grad_input_data[data_index] = grad_val;
282 }
283 } else if (reduction == ReductionType::SUM) {
284 const auto& grad_val = grad_data[output_index];
285 for (const auto j : c10::irange(segment_start, segment_end)) {
286 int64_t data_index = outer_idx * data_stride_axis * data_size_axis
287 + j * data_stride_axis + inner_idx;
288 grad_input_data[data_index] = grad_val;
289 }
290 } else if (reduction == ReductionType::PROD) {
291 const auto& grad_val = grad_data[output_index] * output_data[output_index];
292 for (const auto j : c10::irange(segment_start, segment_end)) {
293 int64_t data_index = outer_idx * data_stride_axis * data_size_axis
294 + j * data_stride_axis + inner_idx;
295 if (at::_isnan(values_data[data_index]) ||
296 values_data[data_index] == 0) {
297 // explicitly compute exclusive prod
298 scalar_t exclusive_prod = initial_prod_value;
299 int64_t idx;
300 for (const auto k : c10::irange(segment_start, segment_end)) {
301 if (k != j) {
302 idx = outer_idx * data_stride_axis * data_size_axis
303 + k * data_stride_axis + inner_idx;
304 exclusive_prod *= values_data[idx];
305 }
306 }
307 grad_input_data[data_index] = grad_data[output_index] * exclusive_prod;
308 } else {
309 grad_input_data[data_index] = grad_val / values_data[data_index];
310 }
311 }
312 }
313 }
314 }
315 }
316 });
317 }
318
_segment_reduce_cpu_lengths_backward_kernel(const Tensor & grad_contig,const Tensor & output_contig,const Tensor & data_contig,ReductionType reduction,const Tensor & lengths_contig,int64_t axis,const std::optional<Scalar> & initial)319 Tensor _segment_reduce_cpu_lengths_backward_kernel(
320 const Tensor& grad_contig,
321 const Tensor& output_contig,
322 const Tensor& data_contig,
323 ReductionType reduction,
324 const Tensor& lengths_contig,
325 int64_t axis,
326 const std::optional<Scalar>& initial) {
327 axis = lengths_contig.dim() - 1;
328 int64_t segment_count = lengths_contig.size(axis);
329 int64_t lengths_stride_axis = lengths_contig.stride(axis);
330 auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
331
332 AT_DISPATCH_INDEX_TYPES(
333 lengths_contig.scalar_type(), "_segment_reduce_cpu_lengths_backward_kernel1", [&] {
334 const auto* lengths_data = lengths_contig.const_data_ptr<index_t>();
335 _segment_reduce_cpu_lengths_backward_kernel1(
336 grad_contig,
337 output_contig,
338 data_contig,
339 reduction,
340 lengths_data,
341 axis,
342 initial,
343 grad_input,
344 segment_count,
345 lengths_stride_axis);
346 });
347
348 return grad_input;
349 }
350
351
_segment_reduce_cpu_offsets_backward_kernel(const Tensor & grad_contig,const Tensor & output_contig,const Tensor & data_contig,ReductionType reduction,const Tensor & offsets_contig,int64_t axis,const std::optional<Scalar> & initial)352 Tensor _segment_reduce_cpu_offsets_backward_kernel(
353 const Tensor& grad_contig,
354 const Tensor& output_contig,
355 const Tensor& data_contig,
356 ReductionType reduction,
357 const Tensor& offsets_contig,
358 int64_t axis,
359 const std::optional<Scalar>& initial) {
360 axis = offsets_contig.dim() - 1;
361 int64_t segment_count = offsets_contig.size(axis) - 1;
362 int64_t offsets_stride_axis = offsets_contig.stride(axis);
363 auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
364
365 AT_DISPATCH_INDEX_TYPES(
366 offsets_contig.scalar_type(), "_segment_reduce_cpu_offsets_backward_kernel1", [&] {
367 const auto* offsets_data = offsets_contig.const_data_ptr<index_t>();
368 _segment_reduce_cpu_lengths_backward_kernel1<index_t, /*is_offsets_like=*/true>(
369 grad_contig,
370 output_contig,
371 data_contig,
372 reduction,
373 offsets_data,
374 axis,
375 initial,
376 grad_input,
377 segment_count,
378 offsets_stride_axis);
379 });
380
381 return grad_input;
382 }
383
384 } // namespace
385
segment_reduce_kernel(const Tensor & data,c10::string_view reduce,const std::optional<Tensor> & lengths,const std::optional<Tensor> & indices,const std::optional<Tensor> & offsets,int64_t axis,bool unsafe,const std::optional<Scalar> & initial)386 Tensor segment_reduce_kernel(
387 const Tensor& data,
388 c10::string_view reduce,
389 const std::optional<Tensor>& lengths,
390 const std::optional<Tensor>& indices,
391 const std::optional<Tensor>& offsets,
392 int64_t axis,
393 bool unsafe,
394 const std::optional<Scalar>& initial) {
395 axis = maybe_wrap_dim(axis, data.ndimension());
396 TORCH_CHECK(data.numel() >= 0);
397
398 // check that one of lengths or offsets is defined
399 auto lengths_has_value = lengths.has_value();
400 auto offsets_has_value = offsets.has_value();
401 TORCH_CHECK(
402 !indices.has_value(),
403 "segment_reduce(): indices based reduction is not supported yet.");
404 TORCH_CHECK(
405 lengths_has_value || offsets_has_value,
406 "segment_reduce(): Either lengths or offsets must be defined.")
407
408 auto reduction = get_reduction_enum(reduce);
409 const auto data_contig = data.contiguous();
410
411 if (offsets_has_value) {
412 const auto& offsets_value = offsets.value();
413
414 // offsets related checks
415 TORCH_CHECK(data.get_device() == offsets_value.get_device());
416 TORCH_CHECK(data.dim() >= offsets_value.dim());
417 TORCH_CHECK(axis == offsets_value.dim() - 1,
418 "segment_reduce(): Expected axis to be the last dimension of offsets but got ", axis, ".");
419
420 // TODO: add checks when !unsafe
421
422 const auto offsets_contig = offsets_value.contiguous();
423
424 return _segment_reduce_offsets_stub(
425 data_contig.device().type(),
426 reduction,
427 data_contig,
428 offsets_contig,
429 axis,
430 initial);
431
432 } else {
433 const auto& lengths_value = lengths.value();
434
435 // length related checks
436 TORCH_CHECK(data.get_device() == lengths_value.get_device());
437 TORCH_CHECK(data.dim() >= lengths_value.dim());
438 TORCH_CHECK(axis == lengths_value.dim() - 1,
439 "segment_reduce(): Expected axis to be the last dimension of lengths but got ", axis, ".");
440
441 if (!unsafe) {
442 auto min_length = lengths_value.min().item<int64_t>();
443 TORCH_CHECK((min_length >= 0), "lengths contains negative value!");
444 TORCH_CHECK(all(lengths_value.sum({-1}) == data.size(axis)).item<bool>(),
445 "segment_reduce(): Expected all rows of lengths along axis ",
446 "to sum to data.size(lengths.dim()-1) when !unsafe.");
447 }
448
449 const auto lengths_contig = lengths_value.contiguous();
450
451 return _segment_reduce_lengths_stub(
452 data_contig.device().type(),
453 reduction,
454 data_contig,
455 lengths_contig,
456 axis,
457 initial);
458 }
459 }
460
461 REGISTER_ARCH_DISPATCH(
462 _segment_reduce_lengths_stub,
463 DEFAULT,
464 &_segment_reduce_lengths_cpu_kernel);
465 REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
466 REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
467 REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
468 REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
469
470 // offsets dispatches
471 REGISTER_ARCH_DISPATCH(
472 _segment_reduce_offsets_stub,
473 DEFAULT,
474 &_segment_reduce_offsets_cpu_kernel);
475 REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
476 REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
477 REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
478 REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
479
480 // Currently some computation is being duplicated across forward and backward.
481 // TODO: Cache indices in forward pass to re-use in backward
_segment_reduce_backward_kernel(const Tensor & grad,const Tensor & output,const Tensor & data,c10::string_view reduce,const std::optional<Tensor> & lengths,const std::optional<Tensor> & offsets,int64_t axis,const std::optional<Scalar> & initial)482 Tensor _segment_reduce_backward_kernel(
483 const Tensor& grad,
484 const Tensor& output,
485 const Tensor& data,
486 c10::string_view reduce,
487 const std::optional<Tensor>& lengths,
488 const std::optional<Tensor>& offsets,
489 int64_t axis,
490 const std::optional<Scalar>& initial) {
491 axis = maybe_wrap_dim(axis, data.ndimension());
492 // check that one of lengths or offsets is defined
493 // codegen for derivatives.yaml passes an undefined Tensor for None rather than a std::optional
494 // so checking .has_value() doesn't work unlike in the forward pass
495 auto lengths_has_value = lengths.has_value() && lengths.value().defined();
496 auto offsets_has_value = offsets.has_value() && offsets.value().defined();
497 TORCH_CHECK(
498 lengths_has_value || offsets_has_value,
499 "segment_reduce(): Either lengths or offsets must be defined.");
500
501 const auto grad_contig = grad.contiguous();
502 const auto output_contig = output.contiguous();
503 const auto data_contig = data.contiguous();
504 auto reduction = get_reduction_enum(reduce);
505
506 if (offsets_has_value) {
507 const auto& offsets_value = offsets.value();
508 const auto offsets_contig = offsets_value.contiguous();
509 return _segment_reduce_offsets_backward_stub(
510 grad_contig.device().type(),
511 grad_contig,
512 output_contig,
513 data_contig,
514 reduction,
515 offsets_contig,
516 axis,
517 initial);
518 } else {
519 const auto& lengths_value = lengths.value();
520 const auto lengths_contig = lengths_value.contiguous();
521 return _segment_reduce_lengths_backward_stub(
522 grad_contig.device().type(),
523 grad_contig,
524 output_contig,
525 data_contig,
526 reduction,
527 lengths_contig,
528 axis,
529 initial);
530 }
531 }
532
533 REGISTER_ARCH_DISPATCH(
534 _segment_reduce_lengths_backward_stub,
535 DEFAULT,
536 &_segment_reduce_cpu_lengths_backward_kernel);
537 REGISTER_AVX512_DISPATCH(
538 _segment_reduce_lengths_backward_stub,
539 &_segment_reduce_cpu_lengths_backward_kernel);
540 REGISTER_AVX2_DISPATCH(
541 _segment_reduce_lengths_backward_stub,
542 &_segment_reduce_cpu_lengths_backward_kernel);
543 REGISTER_VSX_DISPATCH(
544 _segment_reduce_lengths_backward_stub,
545 &_segment_reduce_cpu_lengths_backward_kernel);
546 REGISTER_ZVECTOR_DISPATCH(
547 _segment_reduce_lengths_backward_stub,
548 &_segment_reduce_cpu_lengths_backward_kernel);
549
550 REGISTER_ARCH_DISPATCH(
551 _segment_reduce_offsets_backward_stub,
552 DEFAULT,
553 &_segment_reduce_cpu_offsets_backward_kernel);
554 REGISTER_AVX512_DISPATCH(
555 _segment_reduce_offsets_backward_stub,
556 &_segment_reduce_cpu_offsets_backward_kernel);
557 REGISTER_AVX2_DISPATCH(
558 _segment_reduce_offsets_backward_stub,
559 &_segment_reduce_cpu_offsets_backward_kernel);
560 REGISTER_VSX_DISPATCH(
561 _segment_reduce_offsets_backward_stub,
562 &_segment_reduce_cpu_offsets_backward_kernel);
563 REGISTER_ZVECTOR_DISPATCH(
564 _segment_reduce_offsets_backward_stub,
565 &_segment_reduce_cpu_offsets_backward_kernel);
566
567 } // namespace at::native
568