xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/SegmentReduce.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/ReductionType.h>
5 #include <c10/core/Scalar.h>
6 #include <optional>
7 
8 namespace at {
9 class Tensor;
10 
11 namespace native {
12 
13 using segment_reduce_lengths_fn = Tensor (*)(
14     ReductionType,
15     const Tensor&,
16     const Tensor&,
17     int64_t,
18     const std::optional<Scalar>&);
19 DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub);
20 
21 using segment_reduce_offsets_fn = Tensor (*)(
22     ReductionType,
23     const Tensor&,
24     const Tensor&,
25     int64_t,
26     const std::optional<Scalar>&);
27 DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub);
28 
29 using segment_reduce_lengths_backward_fn = Tensor (*)(
30     const Tensor&,
31     const Tensor&,
32     const Tensor&,
33     ReductionType,
34     const Tensor&,
35     int64_t,
36     const std::optional<Scalar>&);
37 DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub);
38 
39 using segment_reduce_offsets_backward_fn = Tensor (*)(
40     const Tensor&,
41     const Tensor&,
42     const Tensor&,
43     ReductionType,
44     const Tensor&,
45     int64_t,
46     const std::optional<Scalar>&);
47 DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub);
48 
49 } // namespace native
50 } // namespace at
51