1 #pragma once 2 #include <ATen/native/DispatchStub.h> 3 #include <c10/util/ArrayRef.h> 4 5 namespace at { 6 class Tensor; 7 class TensorBase; 8 struct TensorIterator; 9 struct TensorIteratorBase; 10 } 11 12 namespace c10 { 13 class Scalar; 14 } 15 16 namespace at::native { 17 18 using index_fn = void(*)(TensorIteratorBase &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides); 19 using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source); 20 using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride); 21 using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate); 22 using put_fn = void(*)(TensorIterator & iter, const TensorBase& self, const bool accumulate); 23 using take_fn = void(*)(TensorIterator & iter, const TensorBase& input); 24 using flip_fn = void(*)(TensorIterator &, const bool); 25 using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar); 26 using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride); 27 using masked_scatter_fn = void(*)(TensorIterator &, const TensorBase &); 28 29 DECLARE_DISPATCH(index_fn, index_stub); 30 DECLARE_DISPATCH(index_fill_fn, index_fill_stub); 31 DECLARE_DISPATCH(index_copy_fn, index_copy_stub); 32 DECLARE_DISPATCH(index_put_fn, index_put_stub); 33 DECLARE_DISPATCH(put_fn, put_stub); 34 DECLARE_DISPATCH(take_fn, take_stub); 35 DECLARE_DISPATCH(flip_fn, flip_stub); 36 DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub); 37 DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub); 38 DECLARE_DISPATCH(masked_select_fn, masked_select_stub); 39 DECLARE_DISPATCH(masked_scatter_fn, masked_scatter_stub); 40 41 } // namespace at::native 42