1 #pragma once 2 3 #include <ATen/native/DispatchStub.h> 4 #include <ATen/native/ReductionType.h> 5 #include <c10/core/Scalar.h> 6 #include <optional> 7 8 namespace at { 9 class Tensor; 10 11 namespace native { 12 13 using segment_reduce_lengths_fn = Tensor (*)( 14 ReductionType, 15 const Tensor&, 16 const Tensor&, 17 int64_t, 18 const std::optional<Scalar>&); 19 DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub); 20 21 using segment_reduce_offsets_fn = Tensor (*)( 22 ReductionType, 23 const Tensor&, 24 const Tensor&, 25 int64_t, 26 const std::optional<Scalar>&); 27 DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub); 28 29 using segment_reduce_lengths_backward_fn = Tensor (*)( 30 const Tensor&, 31 const Tensor&, 32 const Tensor&, 33 ReductionType, 34 const Tensor&, 35 int64_t, 36 const std::optional<Scalar>&); 37 DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub); 38 39 using segment_reduce_offsets_backward_fn = Tensor (*)( 40 const Tensor&, 41 const Tensor&, 42 const Tensor&, 43 ReductionType, 44 const Tensor&, 45 int64_t, 46 const std::optional<Scalar>&); 47 DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub); 48 49 } // namespace native 50 } // namespace at 51