1 /*! \file
2   \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
3 
4   The epilogue rearranges the result of a matrix product through shared memory
5   to match canonical tensor layouts in global memory. Epilogues support
6   conversion and reduction operations.
7 
8   This is a copy of cutlass/epilogue/threadblock/epilogue.h that can
9   handle "row_id" as a first argument, as uses it to get the corresponding
10   `m_prime` / `s_prime` to rescale the output.
11 */
12 
13 #pragma once
14 
15 #if defined(__CUDACC_RTC__)
16 #include <cuda/std/cassert>
17 #else
18 #include <cassert>
19 #endif
20 
21 #include <cutlass/aligned_buffer.h>
22 #include <cutlass/array.h>
23 #include <cutlass/cutlass.h>
24 #include <cutlass/functional.h>
25 #include <cutlass/layout/tensor.h>
26 #include <cutlass/layout/vector.h>
27 #include <cutlass/numeric_types.h>
28 #include <cutlass/tensor_coord.h>
29 
30 #include <cutlass/gemm/gemm.h>
31 
32 #include <cutlass/transform/pitch_linear_thread_map.h>
33 #include <cutlass/transform/threadblock/regular_tile_iterator.h>
34 
35 #include <cutlass/epilogue/threadblock/epilogue_base.h>
36 #include <cutlass/epilogue/threadblock/predicated_tile_iterator.h>
37 #include <cutlass/numeric_types.h>
38 
39 #include <cutlass/array.h>
40 #include <cutlass/cutlass.h>
41 #include <cutlass/epilogue/thread/scale_type.h>
42 #include <cutlass/functional.h>
43 #include <cutlass/numeric_conversion.h>
44 #include <cutlass/numeric_types.h>
45 
46 #include <ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h>
47 
48 /////////////////////////////////////////////////////////////////////////////////////////////////
49 
50 namespace cutlass {
51 namespace epilogue {
52 namespace thread {
53 
54 /////////////////////////////////////////////////////////////////////////////////////////////////
55 
56 /// Applies a linear combination operator to an array of elements.
57 // output <- alpha * accumulator + beta * source
58 //   with:
59 //     alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise)
60 //     beta = alpha / m_prime (renormalize the output when the max changes)
61 //     source is the current output
62 template <
63     typename ElementOutput_, ///< Data type used to store tensors
64     typename ElementSource_, //< Data type for source (usually matches
65                              //`ElementOutput`)
66     int Count, ///< Number of elements computed per operation.
67                ///< Usually it is 128/sizeof_bits<ElementOutput_>,
68                ///< but we use 64 or 32 sometimes when there are not enough data
69                ///< to store
70     typename ElementAccumulator_, ///< Accumulator data type
71     typename ElementCompute_, ///< Data type used to compute linear combination
72     bool isFirst,
73     bool isLast,
74     typename FragmentAlphaBeta_,
75     FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
76 class MemoryEfficientAttentionNormalize {
77  public:
78   using ElementOutput = ElementOutput_;
79   using ElementSource = ElementSource_;
80   using ElementAccumulator = ElementAccumulator_;
81   using ElementCompute = ElementCompute_;
82 
83   static int const kCount = Count;
84 
85   using FragmentOutput = Array<ElementOutput, kCount>;
86   using FragmentSource = Array<ElementSource, kCount>;
87   using FragmentAccumulator = Array<ElementAccumulator, kCount>;
88   using ComputeFragment = Array<ElementCompute, kCount>;
89   using FragmentAlphaBeta = FragmentAlphaBeta_;
90 
91   static FloatRoundStyle const kRound = Round;
92 
93  private:
94   //
95   // Data members
96   //
97 
98   FragmentAlphaBeta const& s_prime_;
99   FragmentAlphaBeta const& m_prime_;
100 
101  public:
102   /// Constructs the function object, possibly loading from pointers in host
103   /// memory
104   CUTLASS_HOST_DEVICE
MemoryEfficientAttentionNormalize(FragmentAlphaBeta const & s_prime,FragmentAlphaBeta const & m_prime)105   MemoryEfficientAttentionNormalize(
106       FragmentAlphaBeta const& s_prime,
107       FragmentAlphaBeta const& m_prime)
108       : s_prime_(s_prime), m_prime_(m_prime) {}
109 
110   /// Returns true if source is needed
111   CUTLASS_HOST_DEVICE
is_source_needed()112   bool is_source_needed() const {
113     return !isFirst;
114   }
115 
116   /// Functionally required for serial reduction in the epilogue
117   CUTLASS_HOST_DEVICE
set_k_partition(int k_partition,int k_partition_count)118   void set_k_partition(int k_partition, int k_partition_count) {}
119 
120   /// Computes linear scaling: D = alpha * accumulator + beta * source
121   CUTLASS_HOST_DEVICE
operator()122   FragmentOutput operator()(
123       int row,
124       FragmentAccumulator const& accumulator,
125       FragmentSource const& source) const {
126     assert(!isFirst);
127 
128     // Convert source to interal compute numeric type
129     NumericArrayConverter<ElementCompute, ElementSource, kCount, Round>
130         source_converter;
131     NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
132         accumulator_converter;
133 
134     // Convert to destination numeric type
135     NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
136         destination_converter;
137 
138     ComputeFragment converted_source = source_converter(source);
139     ComputeFragment converted_accumulator = accumulator_converter(accumulator);
140 
141     // Perform binary operations
142     ComputeFragment intermediate;
143 
144     multiplies<ComputeFragment> mul_add_source;
145     multiply_add<ComputeFragment> mul_add_accumulator;
146 
147     // Row sums for full masked out rows are 0, we set them to 1
148     // In order to avoid NaNs in the output and instead sem them to 0.
149     ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row];
150     ElementCompute alpha = isLast ? (1 / denom) : 1;
151     ElementCompute beta = alpha * m_prime_[row];
152 
153     intermediate = mul_add_source(beta, converted_source); // X =  beta * C
154 
155     intermediate = mul_add_accumulator(
156         alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
157 
158     return destination_converter(intermediate);
159   }
160 
161   /// Computes linear scaling: D = alpha * accumulator
162   CUTLASS_HOST_DEVICE
operator()163   FragmentOutput operator()(int row, FragmentAccumulator const& accumulator)
164       const {
165     assert(isFirst);
166 
167     // Convert source to interal compute numeric type
168     NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
169         accumulator_converter;
170 
171     // Convert to destination numeric type
172     NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
173         destination_converter;
174 
175     ComputeFragment converted_accumulator = accumulator_converter(accumulator);
176 
177     ComputeFragment intermediate;
178     multiplies<ComputeFragment> mul_accumulator;
179 
180     // Row sums for full masked out rows are 0, we set them to 1
181     // In order to avoid NaNs in the output and instead sem them to 0.
182     ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row];
183     ElementCompute alpha = isLast ? (1 / denom) : 1;
184 
185     intermediate = mul_accumulator(
186         alpha, converted_accumulator); // X =  alpha * C + uniform
187 
188     return destination_converter(intermediate);
189   }
190 };
191 
192 } // namespace thread
193 
194 namespace threadblock {
195 template <
196     typename EO,
197     typename ES,
198     int Count,
199     typename EA,
200     typename EC,
201     bool F,
202     bool L,
203     typename FAB,
204     FloatRoundStyle R>
205 struct ApplyEpilogueOp<thread::MemoryEfficientAttentionNormalize<
206     EO,
207     ES,
208     Count,
209     EA,
210     EC,
211     F,
212     L,
213     FAB,
214     R>> {
215   using Op = thread::
216       MemoryEfficientAttentionNormalize<EO, ES, Count, EA, EC, F, L, FAB, R>;
217   static CUTLASS_DEVICE typename Op::FragmentOutput apply(
218       Op const& output_op,
219       int row_id,
220       typename Op::FragmentAccumulator const& accum,
221       typename Op::FragmentSource const& source) {
222     return output_op(row_id, accum, source);
223   }
224   static CUTLASS_DEVICE typename Op::FragmentOutput apply(
225       Op const& output_op,
226       int row_id,
227       typename Op::FragmentAccumulator const& accum) {
228     return output_op(row_id, accum);
229   }
230 };
231 
232 /////////////////////////////////////////////////////////////////////////////////////////////////
233 
234 } // namespace threadblock
235 } // namespace epilogue
236 } // namespace cutlass
237 
238 /////////////////////////////////////////////////////////////////////////////////////////////////
239