1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3  * SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice, this
9  * list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *
30  **************************************************************************************************/
31 /*! \file
32     \brief Templates implementing warp-level matrix multiply-accumulate operations targeting
33       Tensor Cores.
34 */
35 
36 #pragma once
37 
38 #include <cutlass/array.h>
39 #include <cutlass/cutlass.h>
40 #include <cutlass/platform/platform.h>
41 
42 #include <cutlass/matrix_shape.h>
43 #include <cutlass/numeric_conversion.h>
44 #include <cutlass/numeric_types.h>
45 
46 #include <cutlass/arch/memory_sm75.h>
47 #include <cutlass/arch/mma_sm75.h>
48 #include <cutlass/arch/mma_sm80.h>
49 
50 #include <cutlass/gemm/gemm.h>
51 #include <cutlass/gemm/warp/mma.h>
52 
53 #include <cutlass/gemm/warp/mma_tensor_op_policy.h>
54 
55 #include <cutlass/gemm/warp/mma_tensor_op_tile_iterator.h>
56 #include <cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h>
57 
58 /////////////////////////////////////////////////////////////////////////////////////////////////
59 
60 namespace cutlass {
61 namespace gemm {
62 namespace warp {
63 
64 /////////////////////////////////////////////////////////////////////////////////////////////////
65 /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
66 template<
67     /// Size of the Gemm problem - concept: gemm::GemmShape<>
68     typename Shape_,
69     /// Data type of A elements
70     typename ElementA_,
71     /// Layout of A matrix (concept: MatrixLayout)
72     typename LayoutA_,
73     /// Data type of B elements
74     typename ElementB_,
75     /// Layout of B matrix (concept: MatrixLayout)
76     typename LayoutB_,
77     /// Element type of C matrix
78     typename ElementC_,
79     /// Layout of C matrix (concept: MatrixLayout)
80     typename LayoutC_,
81     /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
82     typename Policy_,
83     /// Instruction shape to override shared memory iterators with
84     typename SharedMemoryInstructionShape_,
85     /// Number of partitions along K dimension
86     int PartitionsK_ = 1,
87     /// Store the accumulators in row major or column major.  Row major is used
88     /// when output layout is interleaved.
89     bool AccumulatorsInRowMajor = false,
90     /// Used for partial specialization
91     typename Enable = bool>
92 class MmaTensorOpComputeBWithF16 {
93 public:
94     /// Shape of warp-level matrix operation (concept: GemmShape)
95     using Shape = Shape_;
96 
97     /// Data type of multiplicand A
98     using ElementA = ElementA_;
99 
100     /// Layout of multiplicand A
101     using LayoutA = LayoutA_;
102 
103     /// Data type of multiplicand B
104     using ElementB = ElementB_;
105 
106     /// Layout of multiplicand B
107     using LayoutB = LayoutB_;
108 
109     /// Data type of accumulator matrix C
110     using ElementC = ElementC_;
111 
112     /// Layout of accumulator matrix C
113     using LayoutC = LayoutC_;
114 
115     /// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
116     using Policy = Policy_;
117 
118     /// Underlying matrix multiply operator (concept: arch::Mma)
119     using ArchMmaOperator = typename Policy::Operator;
120 
121     /// Indicates math operator
122     using MathOperator = typename ArchMmaOperator::Operator;
123 
124     /// Architecture tag from underlying instruction
125     using ArchTag = typename ArchMmaOperator::ArchTag;
126     static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
127                    && platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
128                       || (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
129                           && platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
130                           && ArchTag::kMinComputeCapability >= 80),
131                   "MmaTensorOpCvtBToA only supports underlying HMMA");
132 
133     static_assert(platform::is_same<ElementA, half_t>::value
134                       || (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80),
135                   "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+");
136 
137     /// Indicates class of matrix operator
138     using OperatorClass = arch::OpClassTensorOp;
139 
140     /// Shape of underlying instruction
141     using InstructionShape = typename ArchMmaOperator::Shape;
142 
143     /// Instruction shape to override shared memory iterators with
144     using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
145 
146     static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM,
147                   "M dimension of compute instruction must match load");
148     static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN,
149                   "N dimension of compute instruction must match load");
150 
151     static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
152 
153     static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
154 
155     /// Complex transform on A operand
156     static ComplexTransform const kTransformA = ComplexTransform::kNone;
157 
158     /// Complex transform on B operand
159     static ComplexTransform const kTransformB = ComplexTransform::kNone;
160 
161     /// Number of threads participating in warp-level matrix product
162     static int const kThreadCount = 32;
163 
164     /// Number of partitions along K dimension
165     static int const kPartitionsK = PartitionsK_;
166 
167 public:
168     /// Iterates over the A operand in memory
169     using IteratorA = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>,
170                                                           Operand::kA,
171                                                           ElementA,
172                                                           LayoutA,
173                                                           MatrixShape<InstructionShape::kM, InstructionShape::kK>,
174                                                           Policy::OpDelta::kRow,
175                                                           kThreadCount,
176                                                           kPartitionsK>;
177 
178     /// Storage for A tile
179     using FragmentA = typename IteratorA::Fragment;
180 
181     /// Storage for transformed A tile
182     using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
183 
184     /// Iterates over the B operand in memory
185     using IteratorB =
186         MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>,
187                                             Operand::kB,
188                                             ElementB,
189                                             LayoutB,
190                                             MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>,
191                                             Policy::OpDelta::kRow,
192                                             kThreadCount,
193                                             kPartitionsK>;
194 
195     /// Storage for B tile
196     using FragmentB = typename IteratorB::Fragment;
197 
198     /// Storage for transformed B tile
199     using TransformedFragmentB = Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>;
200 
201     /// Iterates over the C operand in memory
202     using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>,
203                                                          ElementC,
204                                                          LayoutC,
205                                                          typename ArchMmaOperator::Shape,
206                                                          typename Policy::OpDelta>;
207 
208     /// Storage for C tile
209     using FragmentC = typename IteratorC::Fragment;
210 
211     /// Number of mma operations performed
212     using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
213                                       (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
214 
215 public:
216     /// Underlying matrix multiply operator (concept: arch::Mma)
217     ArchMmaOperator mma;
218 
219 public:
220     //
221     // Methods
222     //
223 
224     /// Ctor
225     CUTLASS_DEVICE
MmaTensorOpComputeBWithF16()226     MmaTensorOpComputeBWithF16() {}
227 
228     /// Performs a warp-level matrix multiply-accumulate operation
229     CUTLASS_DEVICE
operator()230     void operator()(FragmentC&                  D,
231                     TransformedFragmentA const& A,
232                     TransformedFragmentB const& B,
233                     FragmentC const&            C,
234                     const int                   warp_tileB_k_offset) const
235     {
236 
237         using MmaOperandA = typename ArchMmaOperator::FragmentA;
238         using MmaOperandB = typename ArchMmaOperator::FragmentB;
239         using MmaOperandC = typename ArchMmaOperator::FragmentC;
240 
241         static_assert(
242             TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn,
243             "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B");
244 
245         D = C;
246 
247         MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
248         MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
249         MmaOperandC*       ptr_D = reinterpret_cast<MmaOperandC*>(&D);
250 
251 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
252         // Serpentine visitation order maximizing reuse of Rb
253         CUTLASS_PRAGMA_UNROLL
254         for (int n = 0; n < MmaIterations::kColumn; ++n) {
255 
256             CUTLASS_PRAGMA_UNROLL
257             for (int m = 0; m < MmaIterations::kRow; ++m) {
258 
259                 int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
260 
261                 int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n;
262                 if (AccumulatorsInRowMajor) {  // matrix B is reordered
263                     mma(ptr_D[n + m_serpentine * MmaIterations::kColumn],
264                         ptr_A[m_serpentine],
265                         ptr_B[n_offsetB],
266                         ptr_D[n + m_serpentine * MmaIterations::kColumn]);
267                 }
268                 else {
269                     mma(ptr_D[m_serpentine + n * MmaIterations::kRow],
270                         ptr_A[m_serpentine],
271                         ptr_B[n_offsetB],
272                         ptr_D[m_serpentine + n * MmaIterations::kRow]);
273                 }
274             }
275         }
276 #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
277         // Serpentine visitation order maximizing reuse of Ra
278         CUTLASS_PRAGMA_UNROLL
279         for (int m = 0; m < MmaIterations::kRow; ++m) {
280 
281             CUTLASS_PRAGMA_UNROLL
282             for (int n = 0; n < MmaIterations::kColumn; ++n) {
283 
284                 int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
285 
286                 int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine;
287                 if (AccumulatorsInRowMajor) {  // matrix B is reordered
288                     mma(ptr_D[n_serpentine + m * MmaIterations::kColumn],
289                         ptr_A[m],
290                         ptr_B[n_serpentine_offsetB],
291                         ptr_D[n_serpentine + m * MmaIterations::kColumn]);
292                 }
293                 else {
294                     mma(ptr_D[m + n_serpentine * MmaIterations::kRow],
295                         ptr_A[m],
296                         ptr_B[n_serpentine_offsetB],
297                         ptr_D[m + n_serpentine * MmaIterations::kRow]);
298                 }
299             }
300         }
301 #else
302         assert(0);
303 #endif
304     }
305 };
306 
307 /////////////////////////////////////////////////////////////////////////////////////////////////
308 
309 }  // namespace warp
310 }  // namespace gemm
311 }  // namespace cutlass
312 
313 /////////////////////////////////////////////////////////////////////////////////////////////////
314