1 /*! \file 2 \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. 3 4 The epilogue rearranges the result of a matrix product through shared memory 5 to match canonical tensor layouts in global memory. Epilogues support 6 conversion and reduction operations. 7 8 This is a copy of cutlass/epilogue/threadblock/epilogue.h that can 9 handle "row_id" as a first argument, as uses it to get the corresponding 10 `m_prime` / `s_prime` to rescale the output. 11 */ 12 13 #pragma once 14 15 #if defined(__CUDACC_RTC__) 16 #include <cuda/std/cassert> 17 #else 18 #include <cassert> 19 #endif 20 21 #include <cutlass/aligned_buffer.h> 22 #include <cutlass/array.h> 23 #include <cutlass/cutlass.h> 24 #include <cutlass/functional.h> 25 #include <cutlass/layout/tensor.h> 26 #include <cutlass/layout/vector.h> 27 #include <cutlass/numeric_types.h> 28 #include <cutlass/tensor_coord.h> 29 30 #include <cutlass/gemm/gemm.h> 31 32 #include <cutlass/transform/pitch_linear_thread_map.h> 33 #include <cutlass/transform/threadblock/regular_tile_iterator.h> 34 35 #include <cutlass/epilogue/threadblock/epilogue_base.h> 36 #include <cutlass/epilogue/threadblock/predicated_tile_iterator.h> 37 #include <cutlass/numeric_types.h> 38 39 #include <cutlass/array.h> 40 #include <cutlass/cutlass.h> 41 #include <cutlass/epilogue/thread/scale_type.h> 42 #include <cutlass/functional.h> 43 #include <cutlass/numeric_conversion.h> 44 #include <cutlass/numeric_types.h> 45 46 #include <ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h> 47 48 ///////////////////////////////////////////////////////////////////////////////////////////////// 49 50 namespace cutlass { 51 namespace epilogue { 52 namespace thread { 53 54 ///////////////////////////////////////////////////////////////////////////////////////////////// 55 56 /// Applies a linear combination operator to an array of elements. 57 // output <- alpha * accumulator + beta * source 58 // with: 59 // alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) 60 // beta = alpha / m_prime (renormalize the output when the max changes) 61 // source is the current output 62 template < 63 typename ElementOutput_, ///< Data type used to store tensors 64 typename ElementSource_, //< Data type for source (usually matches 65 //`ElementOutput`) 66 int Count, ///< Number of elements computed per operation. 67 ///< Usually it is 128/sizeof_bits<ElementOutput_>, 68 ///< but we use 64 or 32 sometimes when there are not enough data 69 ///< to store 70 typename ElementAccumulator_, ///< Accumulator data type 71 typename ElementCompute_, ///< Data type used to compute linear combination 72 bool isFirst, 73 bool isLast, 74 typename FragmentAlphaBeta_, 75 FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> 76 class MemoryEfficientAttentionNormalize { 77 public: 78 using ElementOutput = ElementOutput_; 79 using ElementSource = ElementSource_; 80 using ElementAccumulator = ElementAccumulator_; 81 using ElementCompute = ElementCompute_; 82 83 static int const kCount = Count; 84 85 using FragmentOutput = Array<ElementOutput, kCount>; 86 using FragmentSource = Array<ElementSource, kCount>; 87 using FragmentAccumulator = Array<ElementAccumulator, kCount>; 88 using ComputeFragment = Array<ElementCompute, kCount>; 89 using FragmentAlphaBeta = FragmentAlphaBeta_; 90 91 static FloatRoundStyle const kRound = Round; 92 93 private: 94 // 95 // Data members 96 // 97 98 FragmentAlphaBeta const& s_prime_; 99 FragmentAlphaBeta const& m_prime_; 100 101 public: 102 /// Constructs the function object, possibly loading from pointers in host 103 /// memory 104 CUTLASS_HOST_DEVICE MemoryEfficientAttentionNormalize(FragmentAlphaBeta const & s_prime,FragmentAlphaBeta const & m_prime)105 MemoryEfficientAttentionNormalize( 106 FragmentAlphaBeta const& s_prime, 107 FragmentAlphaBeta const& m_prime) 108 : s_prime_(s_prime), m_prime_(m_prime) {} 109 110 /// Returns true if source is needed 111 CUTLASS_HOST_DEVICE is_source_needed()112 bool is_source_needed() const { 113 return !isFirst; 114 } 115 116 /// Functionally required for serial reduction in the epilogue 117 CUTLASS_HOST_DEVICE set_k_partition(int k_partition,int k_partition_count)118 void set_k_partition(int k_partition, int k_partition_count) {} 119 120 /// Computes linear scaling: D = alpha * accumulator + beta * source 121 CUTLASS_HOST_DEVICE operator()122 FragmentOutput operator()( 123 int row, 124 FragmentAccumulator const& accumulator, 125 FragmentSource const& source) const { 126 assert(!isFirst); 127 128 // Convert source to interal compute numeric type 129 NumericArrayConverter<ElementCompute, ElementSource, kCount, Round> 130 source_converter; 131 NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> 132 accumulator_converter; 133 134 // Convert to destination numeric type 135 NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> 136 destination_converter; 137 138 ComputeFragment converted_source = source_converter(source); 139 ComputeFragment converted_accumulator = accumulator_converter(accumulator); 140 141 // Perform binary operations 142 ComputeFragment intermediate; 143 144 multiplies<ComputeFragment> mul_add_source; 145 multiply_add<ComputeFragment> mul_add_accumulator; 146 147 // Row sums for full masked out rows are 0, we set them to 1 148 // In order to avoid NaNs in the output and instead sem them to 0. 149 ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row]; 150 ElementCompute alpha = isLast ? (1 / denom) : 1; 151 ElementCompute beta = alpha * m_prime_[row]; 152 153 intermediate = mul_add_source(beta, converted_source); // X = beta * C 154 155 intermediate = mul_add_accumulator( 156 alpha, converted_accumulator, intermediate); // D = alpha * Accum + X 157 158 return destination_converter(intermediate); 159 } 160 161 /// Computes linear scaling: D = alpha * accumulator 162 CUTLASS_HOST_DEVICE operator()163 FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) 164 const { 165 assert(isFirst); 166 167 // Convert source to interal compute numeric type 168 NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> 169 accumulator_converter; 170 171 // Convert to destination numeric type 172 NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> 173 destination_converter; 174 175 ComputeFragment converted_accumulator = accumulator_converter(accumulator); 176 177 ComputeFragment intermediate; 178 multiplies<ComputeFragment> mul_accumulator; 179 180 // Row sums for full masked out rows are 0, we set them to 1 181 // In order to avoid NaNs in the output and instead sem them to 0. 182 ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row]; 183 ElementCompute alpha = isLast ? (1 / denom) : 1; 184 185 intermediate = mul_accumulator( 186 alpha, converted_accumulator); // X = alpha * C + uniform 187 188 return destination_converter(intermediate); 189 } 190 }; 191 192 } // namespace thread 193 194 namespace threadblock { 195 template < 196 typename EO, 197 typename ES, 198 int Count, 199 typename EA, 200 typename EC, 201 bool F, 202 bool L, 203 typename FAB, 204 FloatRoundStyle R> 205 struct ApplyEpilogueOp<thread::MemoryEfficientAttentionNormalize< 206 EO, 207 ES, 208 Count, 209 EA, 210 EC, 211 F, 212 L, 213 FAB, 214 R>> { 215 using Op = thread:: 216 MemoryEfficientAttentionNormalize<EO, ES, Count, EA, EC, F, L, FAB, R>; 217 static CUTLASS_DEVICE typename Op::FragmentOutput apply( 218 Op const& output_op, 219 int row_id, 220 typename Op::FragmentAccumulator const& accum, 221 typename Op::FragmentSource const& source) { 222 return output_op(row_id, accum, source); 223 } 224 static CUTLASS_DEVICE typename Op::FragmentOutput apply( 225 Op const& output_op, 226 int row_id, 227 typename Op::FragmentAccumulator const& accum) { 228 return output_op(row_id, accum); 229 } 230 }; 231 232 ///////////////////////////////////////////////////////////////////////////////////////////////// 233 234 } // namespace threadblock 235 } // namespace epilogue 236 } // namespace cutlass 237 238 ///////////////////////////////////////////////////////////////////////////////////////////////// 239