1 /******************************************************************************
2 * Copyright (c) 2024, Tri Dao.
3 ******************************************************************************/
4
5 #pragma once
6
7 #include <cmath>
8
9 #include <cute/tensor.hpp>
10
11 #include <cutlass/numeric_types.h>
12
13 #include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
14 #include <ATen/native/transformers/cuda/flash_attn/utils.h>
15
16 namespace pytorch_flash {
17
18 using namespace cute;
19
20 #define UNFUSE_FMA
21 ////////////////////////////////////////////////////////////////////////////////////////////////////
22
23 template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
thread_reduce_(Tensor<Engine0,Layout0> const & tensor,Tensor<Engine1,Layout1> & summary,Operator & op)24 __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
25 static_assert(Layout0::rank == 2, "Only support 2D Tensor");
26 static_assert(Layout1::rank == 1, "Only support 1D Tensor");
27 CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
28 #pragma unroll
29 for (int mi = 0; mi < size<0>(tensor); mi++) {
30 summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
31 #pragma unroll
32 for (int ni = 1; ni < size<1>(tensor); ni++) {
33 summary(mi) = op(summary(mi), tensor(mi, ni));
34 }
35 }
36 }
37
38 template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
quad_allreduce_(Tensor<Engine0,Layout0> & dst,Tensor<Engine1,Layout1> & src,Operator & op)39 __device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
40 CUTE_STATIC_ASSERT_V(size(dst) == size(src));
41 #pragma unroll
42 for (int i = 0; i < size(dst); i++){
43 dst(i) = Allreduce<4>::run(src(i), op);
44 }
45 }
46
47 template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
reduce_(Tensor<Engine0,Layout0> const & tensor,Tensor<Engine1,Layout1> & summary,Operator & op)48 __device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
49 thread_reduce_<zero_init>(tensor, summary, op);
50 quad_allreduce_(summary, summary, op);
51 }
52
53 template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
reduce_max(Tensor<Engine0,Layout0> const & tensor,Tensor<Engine1,Layout1> & max)54 __device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
55 MaxOp<float> max_op;
56 reduce_<zero_init>(tensor, max, max_op);
57 }
58
59 template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
reduce_sum(Tensor<Engine0,Layout0> const & tensor,Tensor<Engine1,Layout1> & sum)60 __device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
61 SumOp<float> sum_op;
62 thread_reduce_<zero_init>(tensor, sum, sum_op);
63 }
64
65 // Apply the exp to all the elements.
66 template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
scale_apply_exp2(Tensor<Engine0,Layout0> & tensor,Tensor<Engine1,Layout1> const & max,const float scale)67 __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
68 static_assert(Layout0::rank == 2, "Only support 2D Tensor");
69 static_assert(Layout1::rank == 1, "Only support 1D Tensor");
70 CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
71 #pragma unroll
72 for (int mi = 0; mi < size<0>(tensor); ++mi) {
73 // If max is -inf, then all elements must have been -inf (possibly due to masking).
74 // We don't want (-inf - (-inf)) since that would give NaN.
75 // If we don't have float around M_LOG2E the multiplication is done in fp64.
76 const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
77 #pragma unroll
78 for (int ni = 0; ni < size<1>(tensor); ++ni) {
79 // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
80 // max * log_2(e)) This allows the compiler to use the ffma
81 // instruction instead of fadd and fmul separately.
82 #ifdef UNFUSE_FMA
83 tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
84 #else
85 tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
86 #endif
87 }
88 }
89 }
90
91 // Apply the exp to all the elements.
92 template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
max_scale_exp2_sum(Tensor<Engine0,Layout0> & tensor,Tensor<Engine1,Layout1> & max,Tensor<Engine1,Layout1> & sum,const float scale)93 __forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
94 static_assert(Layout0::rank == 2, "Only support 2D Tensor");
95 static_assert(Layout1::rank == 1, "Only support 1D Tensor");
96 CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
97 #pragma unroll
98 for (int mi = 0; mi < size<0>(tensor); ++mi) {
99 MaxOp<float> max_op;
100 max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
101 #pragma unroll
102 for (int ni = 1; ni < size<1>(tensor); ni++) {
103 max(mi) = max_op(max(mi), tensor(mi, ni));
104 }
105 max(mi) = Allreduce<4>::run(max(mi), max_op);
106 // If max is -inf, then all elements must have been -inf (possibly due to masking).
107 // We don't want (-inf - (-inf)) since that would give NaN.
108 const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
109 sum(mi) = 0;
110 #pragma unroll
111 for (int ni = 0; ni < size<1>(tensor); ++ni) {
112 // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
113 // max * log_2(e)) This allows the compiler to use the ffma
114 // instruction instead of fadd and fmul separately.
115 tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
116 sum(mi) += tensor(mi, ni);
117 }
118 SumOp<float> sum_op;
119 sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
120 }
121 }
122
123 ////////////////////////////////////////////////////////////////////////////////////////////////////
124
125 template <int kNRows>
126 struct Softmax {
127
128 using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
129 TensorT row_max, row_sum;
130
SoftmaxSoftmax131 __forceinline__ __device__ Softmax() {};
132
133 template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
softmax_rescale_oSoftmax134 __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
135 // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
136 Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout()));
137 static_assert(decltype(size<0>(scores))::value == kNRows);
138 if (Is_first) {
139 pytorch_flash::template reduce_max</*zero_init=*/true>(scores, row_max);
140 pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
141 pytorch_flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
142 } else {
143 Tensor scores_max_prev = make_fragment_like(row_max);
144 cute::copy(row_max, scores_max_prev);
145 pytorch_flash::template reduce_max</*zero_init=*/false>(scores, row_max);
146 // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
147 Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
148 static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
149 #pragma unroll
150 for (int mi = 0; mi < size(row_max); ++mi) {
151 float scores_max_cur = !Check_inf
152 ? row_max(mi)
153 : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
154 float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
155 row_sum(mi) *= scores_scale;
156 #pragma unroll
157 for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
158 }
159 pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
160 // We don't do the reduce across threads here since we don't need to use the row_sum.
161 // We do that reduce at the end when we need to normalize the softmax.
162 pytorch_flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
163 }
164 };
165
166 template<bool Is_dropout=false, bool Split=false, typename Tensor0>
167 __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
168 SumOp<float> sum_op;
169 quad_allreduce_(row_sum, row_sum, sum_op);
170 TensorT lse = make_fragment_like(row_sum);
171 Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
172 static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
173 #pragma unroll
174 for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
175 float sum = row_sum(mi);
176 float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
177 lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
178 float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
179 #pragma unroll
180 for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
181 }
182 return lse;
183 };
184 };
185
186 } // namespace pytorch_flash
187