1 #pragma once 2 3 #include <ATen/native/DispatchStub.h> 4 #include <cstdint> 5 6 namespace at { 7 class Tensor; 8 9 namespace native { 10 11 using forward_fn = void (*)(const Tensor&, const Tensor&); 12 using backward_fn = void(*)(const Tensor &, const Tensor &, const Tensor&); 13 14 DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel); 15 DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel); 16 DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel); 17 DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel); 18 19 using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t); 20 using backward_fn_with_dim = 21 void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t); 22 23 DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel); 24 DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel); 25 DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel); 26 DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel); 27 } 28 } 29