xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /******************************************************************************
2  * Copyright (c) 2024, Tri Dao.
3  ******************************************************************************/
4 
5 #pragma once
6 
7 #include <cmath>
8 
9 #include <cute/tensor.hpp>
10 
11 #include <cutlass/numeric_types.h>
12 
13 #include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
14 #include <ATen/native/transformers/cuda/flash_attn/utils.h>
15 
16 namespace pytorch_flash {
17 
18 using namespace cute;
19 
20 #define UNFUSE_FMA
21 ////////////////////////////////////////////////////////////////////////////////////////////////////
22 
23 template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
thread_reduce_(Tensor<Engine0,Layout0> const & tensor,Tensor<Engine1,Layout1> & summary,Operator & op)24 __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
25     static_assert(Layout0::rank == 2, "Only support 2D Tensor");
26     static_assert(Layout1::rank == 1, "Only support 1D Tensor");
27     CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
28     #pragma unroll
29     for (int mi = 0; mi < size<0>(tensor); mi++) {
30         summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
31         #pragma unroll
32         for (int ni = 1; ni < size<1>(tensor); ni++) {
33             summary(mi) = op(summary(mi), tensor(mi, ni));
34         }
35     }
36 }
37 
38 template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
quad_allreduce_(Tensor<Engine0,Layout0> & dst,Tensor<Engine1,Layout1> & src,Operator & op)39 __device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
40     CUTE_STATIC_ASSERT_V(size(dst) == size(src));
41     #pragma unroll
42     for (int i = 0; i < size(dst); i++){
43         dst(i) = Allreduce<4>::run(src(i), op);
44     }
45 }
46 
47 template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
reduce_(Tensor<Engine0,Layout0> const & tensor,Tensor<Engine1,Layout1> & summary,Operator & op)48 __device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
49     thread_reduce_<zero_init>(tensor, summary, op);
50     quad_allreduce_(summary, summary, op);
51 }
52 
53 template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
reduce_max(Tensor<Engine0,Layout0> const & tensor,Tensor<Engine1,Layout1> & max)54 __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
55     MaxOp<float> max_op;
56     reduce_<zero_init>(tensor, max, max_op);
57 }
58 
59 template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
reduce_sum(Tensor<Engine0,Layout0> const & tensor,Tensor<Engine1,Layout1> & sum)60 __device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
61     SumOp<float> sum_op;
62     thread_reduce_<zero_init>(tensor, sum, sum_op);
63 }
64 
65 // Apply the exp to all the elements.
66 template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
scale_apply_exp2(Tensor<Engine0,Layout0> & tensor,Tensor<Engine1,Layout1> const & max,const float scale)67 __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
68     static_assert(Layout0::rank == 2, "Only support 2D Tensor");
69     static_assert(Layout1::rank == 1, "Only support 1D Tensor");
70     CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
71     #pragma unroll
72     for (int mi = 0; mi < size<0>(tensor); ++mi) {
73         // If max is -inf, then all elements must have been -inf (possibly due to masking).
74         // We don't want (-inf - (-inf)) since that would give NaN.
75         // If we don't have float around M_LOG2E the multiplication is done in fp64.
76         const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
77         #pragma unroll
78         for (int ni = 0; ni < size<1>(tensor); ++ni)  {
79             // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
80             // max * log_2(e)) This allows the compiler to use the ffma
81             // instruction instead of fadd and fmul separately.
82             #ifdef UNFUSE_FMA
83                 tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
84             #else
85                 tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
86             #endif
87         }
88     }
89 }
90 
91 // Apply the exp to all the elements.
92 template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
max_scale_exp2_sum(Tensor<Engine0,Layout0> & tensor,Tensor<Engine1,Layout1> & max,Tensor<Engine1,Layout1> & sum,const float scale)93 __forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
94     static_assert(Layout0::rank == 2, "Only support 2D Tensor");
95     static_assert(Layout1::rank == 1, "Only support 1D Tensor");
96     CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
97     #pragma unroll
98     for (int mi = 0; mi < size<0>(tensor); ++mi) {
99         MaxOp<float> max_op;
100         max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
101         #pragma unroll
102         for (int ni = 1; ni < size<1>(tensor); ni++) {
103             max(mi) = max_op(max(mi), tensor(mi, ni));
104         }
105         max(mi) = Allreduce<4>::run(max(mi), max_op);
106         // If max is -inf, then all elements must have been -inf (possibly due to masking).
107         // We don't want (-inf - (-inf)) since that would give NaN.
108         const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
109         sum(mi) = 0;
110         #pragma unroll
111         for (int ni = 0; ni < size<1>(tensor); ++ni)  {
112             // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
113             // max * log_2(e)) This allows the compiler to use the ffma
114             // instruction instead of fadd and fmul separately.
115             tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
116             sum(mi) += tensor(mi, ni);
117         }
118         SumOp<float> sum_op;
119         sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
120     }
121 }
122 
123 ////////////////////////////////////////////////////////////////////////////////////////////////////
124 
125 template <int kNRows>
126 struct Softmax {
127 
128     using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
129     TensorT row_max, row_sum;
130 
SoftmaxSoftmax131     __forceinline__ __device__ Softmax() {};
132 
133     template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
softmax_rescale_oSoftmax134     __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
135         // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
136         Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout()));
137         static_assert(decltype(size<0>(scores))::value == kNRows);
138         if (Is_first) {
139             pytorch_flash::template reduce_max</*zero_init=*/true>(scores, row_max);
140             pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
141             pytorch_flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
142         } else {
143             Tensor scores_max_prev = make_fragment_like(row_max);
144             cute::copy(row_max, scores_max_prev);
145             pytorch_flash::template reduce_max</*zero_init=*/false>(scores, row_max);
146             // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
147             Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
148             static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
149             #pragma unroll
150             for (int mi = 0; mi < size(row_max); ++mi) {
151                 float scores_max_cur = !Check_inf
152                     ? row_max(mi)
153                     : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
154                 float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
155                 row_sum(mi) *= scores_scale;
156                 #pragma unroll
157                 for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
158             }
159             pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
160             // We don't do the reduce across threads here since we don't need to use the row_sum.
161             // We do that reduce at the end when we need to normalize the softmax.
162             pytorch_flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
163         }
164     };
165 
166     template<bool Is_dropout=false, bool Split=false, typename Tensor0>
167     __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
168         SumOp<float> sum_op;
169         quad_allreduce_(row_sum, row_sum, sum_op);
170         TensorT lse = make_fragment_like(row_sum);
171         Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
172         static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
173         #pragma unroll
174         for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
175             float sum = row_sum(mi);
176             float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
177             lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
178             float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
179             #pragma unroll
180             for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
181         }
182         return lse;
183     };
184 };
185 
186 }  // namespace pytorch_flash
187