1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 * SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, this 9 * list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 * 30 **************************************************************************************************/ 31 /*! \file 32 \brief Templates implementing warp-level matrix multiply-accumulate operations targeting 33 Tensor Cores. 34 */ 35 36 #pragma once 37 38 #include <cutlass/array.h> 39 #include <cutlass/cutlass.h> 40 #include <cutlass/platform/platform.h> 41 42 #include <cutlass/matrix_shape.h> 43 #include <cutlass/numeric_conversion.h> 44 #include <cutlass/numeric_types.h> 45 46 #include <cutlass/arch/memory_sm75.h> 47 #include <cutlass/arch/mma_sm75.h> 48 #include <cutlass/arch/mma_sm80.h> 49 50 #include <cutlass/gemm/gemm.h> 51 #include <cutlass/gemm/warp/mma.h> 52 53 #include <cutlass/gemm/warp/mma_tensor_op_policy.h> 54 55 #include <cutlass/gemm/warp/mma_tensor_op_tile_iterator.h> 56 #include <cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h> 57 58 ///////////////////////////////////////////////////////////////////////////////////////////////// 59 60 namespace cutlass { 61 namespace gemm { 62 namespace warp { 63 64 ///////////////////////////////////////////////////////////////////////////////////////////////// 65 /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. 66 template< 67 /// Size of the Gemm problem - concept: gemm::GemmShape<> 68 typename Shape_, 69 /// Data type of A elements 70 typename ElementA_, 71 /// Layout of A matrix (concept: MatrixLayout) 72 typename LayoutA_, 73 /// Data type of B elements 74 typename ElementB_, 75 /// Layout of B matrix (concept: MatrixLayout) 76 typename LayoutB_, 77 /// Element type of C matrix 78 typename ElementC_, 79 /// Layout of C matrix (concept: MatrixLayout) 80 typename LayoutC_, 81 /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) 82 typename Policy_, 83 /// Instruction shape to override shared memory iterators with 84 typename SharedMemoryInstructionShape_, 85 /// Number of partitions along K dimension 86 int PartitionsK_ = 1, 87 /// Store the accumulators in row major or column major. Row major is used 88 /// when output layout is interleaved. 89 bool AccumulatorsInRowMajor = false, 90 /// Used for partial specialization 91 typename Enable = bool> 92 class MmaTensorOpComputeBWithF16 { 93 public: 94 /// Shape of warp-level matrix operation (concept: GemmShape) 95 using Shape = Shape_; 96 97 /// Data type of multiplicand A 98 using ElementA = ElementA_; 99 100 /// Layout of multiplicand A 101 using LayoutA = LayoutA_; 102 103 /// Data type of multiplicand B 104 using ElementB = ElementB_; 105 106 /// Layout of multiplicand B 107 using LayoutB = LayoutB_; 108 109 /// Data type of accumulator matrix C 110 using ElementC = ElementC_; 111 112 /// Layout of accumulator matrix C 113 using LayoutC = LayoutC_; 114 115 /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) 116 using Policy = Policy_; 117 118 /// Underlying matrix multiply operator (concept: arch::Mma) 119 using ArchMmaOperator = typename Policy::Operator; 120 121 /// Indicates math operator 122 using MathOperator = typename ArchMmaOperator::Operator; 123 124 /// Architecture tag from underlying instruction 125 using ArchTag = typename ArchMmaOperator::ArchTag; 126 static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value 127 && platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value) 128 || (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value 129 && platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value 130 && ArchTag::kMinComputeCapability >= 80), 131 "MmaTensorOpCvtBToA only supports underlying HMMA"); 132 133 static_assert(platform::is_same<ElementA, half_t>::value 134 || (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80), 135 "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); 136 137 /// Indicates class of matrix operator 138 using OperatorClass = arch::OpClassTensorOp; 139 140 /// Shape of underlying instruction 141 using InstructionShape = typename ArchMmaOperator::Shape; 142 143 /// Instruction shape to override shared memory iterators with 144 using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; 145 146 static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, 147 "M dimension of compute instruction must match load"); 148 static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, 149 "N dimension of compute instruction must match load"); 150 151 static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; 152 153 static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); 154 155 /// Complex transform on A operand 156 static ComplexTransform const kTransformA = ComplexTransform::kNone; 157 158 /// Complex transform on B operand 159 static ComplexTransform const kTransformB = ComplexTransform::kNone; 160 161 /// Number of threads participating in warp-level matrix product 162 static int const kThreadCount = 32; 163 164 /// Number of partitions along K dimension 165 static int const kPartitionsK = PartitionsK_; 166 167 public: 168 /// Iterates over the A operand in memory 169 using IteratorA = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>, 170 Operand::kA, 171 ElementA, 172 LayoutA, 173 MatrixShape<InstructionShape::kM, InstructionShape::kK>, 174 Policy::OpDelta::kRow, 175 kThreadCount, 176 kPartitionsK>; 177 178 /// Storage for A tile 179 using FragmentA = typename IteratorA::Fragment; 180 181 /// Storage for transformed A tile 182 using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>; 183 184 /// Iterates over the B operand in memory 185 using IteratorB = 186 MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>, 187 Operand::kB, 188 ElementB, 189 LayoutB, 190 MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>, 191 Policy::OpDelta::kRow, 192 kThreadCount, 193 kPartitionsK>; 194 195 /// Storage for B tile 196 using FragmentB = typename IteratorB::Fragment; 197 198 /// Storage for transformed B tile 199 using TransformedFragmentB = Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>; 200 201 /// Iterates over the C operand in memory 202 using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>, 203 ElementC, 204 LayoutC, 205 typename ArchMmaOperator::Shape, 206 typename Policy::OpDelta>; 207 208 /// Storage for C tile 209 using FragmentC = typename IteratorC::Fragment; 210 211 /// Number of mma operations performed 212 using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, 213 (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; 214 215 public: 216 /// Underlying matrix multiply operator (concept: arch::Mma) 217 ArchMmaOperator mma; 218 219 public: 220 // 221 // Methods 222 // 223 224 /// Ctor 225 CUTLASS_DEVICE MmaTensorOpComputeBWithF16()226 MmaTensorOpComputeBWithF16() {} 227 228 /// Performs a warp-level matrix multiply-accumulate operation 229 CUTLASS_DEVICE operator()230 void operator()(FragmentC& D, 231 TransformedFragmentA const& A, 232 TransformedFragmentB const& B, 233 FragmentC const& C, 234 const int warp_tileB_k_offset) const 235 { 236 237 using MmaOperandA = typename ArchMmaOperator::FragmentA; 238 using MmaOperandB = typename ArchMmaOperator::FragmentB; 239 using MmaOperandC = typename ArchMmaOperator::FragmentC; 240 241 static_assert( 242 TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, 243 "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B"); 244 245 D = C; 246 247 MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A); 248 MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B); 249 MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D); 250 251 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) 252 // Serpentine visitation order maximizing reuse of Rb 253 CUTLASS_PRAGMA_UNROLL 254 for (int n = 0; n < MmaIterations::kColumn; ++n) { 255 256 CUTLASS_PRAGMA_UNROLL 257 for (int m = 0; m < MmaIterations::kRow; ++m) { 258 259 int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); 260 261 int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; 262 if (AccumulatorsInRowMajor) { // matrix B is reordered 263 mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], 264 ptr_A[m_serpentine], 265 ptr_B[n_offsetB], 266 ptr_D[n + m_serpentine * MmaIterations::kColumn]); 267 } 268 else { 269 mma(ptr_D[m_serpentine + n * MmaIterations::kRow], 270 ptr_A[m_serpentine], 271 ptr_B[n_offsetB], 272 ptr_D[m_serpentine + n * MmaIterations::kRow]); 273 } 274 } 275 } 276 #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) 277 // Serpentine visitation order maximizing reuse of Ra 278 CUTLASS_PRAGMA_UNROLL 279 for (int m = 0; m < MmaIterations::kRow; ++m) { 280 281 CUTLASS_PRAGMA_UNROLL 282 for (int n = 0; n < MmaIterations::kColumn; ++n) { 283 284 int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); 285 286 int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; 287 if (AccumulatorsInRowMajor) { // matrix B is reordered 288 mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], 289 ptr_A[m], 290 ptr_B[n_serpentine_offsetB], 291 ptr_D[n_serpentine + m * MmaIterations::kColumn]); 292 } 293 else { 294 mma(ptr_D[m + n_serpentine * MmaIterations::kRow], 295 ptr_A[m], 296 ptr_B[n_serpentine_offsetB], 297 ptr_D[m + n_serpentine * MmaIterations::kRow]); 298 } 299 } 300 } 301 #else 302 assert(0); 303 #endif 304 } 305 }; 306 307 ///////////////////////////////////////////////////////////////////////////////////////////////// 308 309 } // namespace warp 310 } // namespace gemm 311 } // namespace cutlass 312 313 ///////////////////////////////////////////////////////////////////////////////////////////////// 314