xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/IndexKernel.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/native/DispatchStub.h>
3 #include <c10/util/ArrayRef.h>
4 
5 namespace at {
6 class Tensor;
7 class TensorBase;
8 struct TensorIterator;
9 struct TensorIteratorBase;
10 }
11 
12 namespace c10 {
13 class Scalar;
14 }
15 
16 namespace at::native {
17 
18 using index_fn = void(*)(TensorIteratorBase &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
19 using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source);
20 using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride);
21 using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
22 using put_fn = void(*)(TensorIterator & iter, const TensorBase& self, const bool accumulate);
23 using take_fn = void(*)(TensorIterator & iter, const TensorBase& input);
24 using flip_fn = void(*)(TensorIterator &, const bool);
25 using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar);
26 using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride);
27 using masked_scatter_fn = void(*)(TensorIterator &, const TensorBase &);
28 
29 DECLARE_DISPATCH(index_fn, index_stub);
30 DECLARE_DISPATCH(index_fill_fn, index_fill_stub);
31 DECLARE_DISPATCH(index_copy_fn, index_copy_stub);
32 DECLARE_DISPATCH(index_put_fn, index_put_stub);
33 DECLARE_DISPATCH(put_fn, put_stub);
34 DECLARE_DISPATCH(take_fn, take_stub);
35 DECLARE_DISPATCH(flip_fn, flip_stub);
36 DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub);
37 DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub);
38 DECLARE_DISPATCH(masked_select_fn, masked_select_stub);
39 DECLARE_DISPATCH(masked_scatter_fn, masked_scatter_stub);
40 
41 } // namespace at::native
42