xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/SoftmaxKernel.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/DispatchStub.h>
4 #include <cstdint>
5 
6 namespace at {
7 class Tensor;
8 
9 namespace native {
10 
11 using forward_fn = void (*)(const Tensor&, const Tensor&);
12 using backward_fn = void(*)(const Tensor &, const Tensor &, const Tensor&);
13 
14 DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel);
15 DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel);
16 DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel);
17 DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel);
18 
19 using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t);
20 using backward_fn_with_dim =
21     void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t);
22 
23 DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel);
24 DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel);
25 DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel);
26 DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel);
27 }
28 }
29