1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 *reserved. SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 *this list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 *POSSIBILITY OF SUCH DAMAGE. 30 * 31 **************************************************************************************************/ 32 /*! \file 33 \brief Instantiates the right WarpIterator to read from shared memory 34 The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading 35 data dumped with `B2bGemm::accumToSmem`. 36 */ 37 38 #pragma once 39 40 #include <cutlass/cutlass.h> 41 #include <cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h> 42 #include <cutlass/platform/platform.h> 43 44 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h> 45 46 namespace cutlass { 47 namespace gemm { 48 namespace threadblock { 49 50 template < 51 typename WarpShape, 52 typename InstructionShape, 53 typename RegularWarpIterator, 54 typename Policy, 55 typename Enable = void> 56 struct DefaultWarpIteratorAFromSharedMemory {}; 57 58 // TensorOp - Ampere half 59 template <typename RegularWarpIterator, typename Policy, int kInstrK> 60 struct DefaultWarpIteratorAFromSharedMemory< 61 cutlass::gemm::GemmShape<32, 32, 32>, 62 cutlass::gemm::GemmShape<16, 8, kInstrK>, 63 RegularWarpIterator, 64 Policy, 65 typename platform::enable_if<( 66 sizeof_bits<typename RegularWarpIterator::Element>::value == 16 && 67 Policy::Operator::Policy::OpDelta::kRow == 1)>::type> { 68 using OpDelta = typename Policy::Operator::Policy::OpDelta; 69 using WarpShape = cutlass::MatrixShape<32, 32>; 70 using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>; 71 72 using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem< 73 cutlass::gemm::Operand::kA, 74 typename RegularWarpIterator::Element, 75 cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>>; 76 }; 77 78 // TensorOp - Ampere f32 79 template <typename WarpShape, typename RegularWarpIterator, typename Policy> 80 struct DefaultWarpIteratorAFromSharedMemory< 81 WarpShape, 82 cutlass::gemm::GemmShape<16, 8, 8>, 83 RegularWarpIterator, 84 Policy, 85 typename platform::enable_if<( 86 sizeof_bits<typename RegularWarpIterator::Element>::value != 16 || 87 Policy::Operator::Policy::OpDelta::kRow != 1)>::type> { 88 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; 89 static constexpr auto kWarpSize = 32; 90 using OpDelta = typename Policy::Operator::Policy::OpDelta; 91 92 using WarpIterator = 93 cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< 94 cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>, 95 cutlass::gemm::Operand::kA, 96 typename RegularWarpIterator::Element, 97 cutlass::layout::RowMajor, 98 cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>, 99 OpDelta::kRow, 100 kWarpSize>; 101 }; 102 103 // TensorOp - Volta 104 template <typename WarpShape, typename RegularWarpIterator, typename Policy> 105 struct DefaultWarpIteratorAFromSharedMemory< 106 WarpShape, 107 cutlass::gemm::GemmShape<16, 16, 4>, 108 RegularWarpIterator, 109 Policy> { 110 using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; 111 static constexpr auto kWarpSize = 32; 112 using OpDelta = typename Policy::Operator::Policy::OpDelta; 113 114 using WarpIterator = 115 cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< 116 cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM, 117 // WarpShape::kK>, 118 cutlass::gemm::Operand::kA, 119 typename RegularWarpIterator::Element, 120 cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, 121 cutlass::MatrixShape<16, 4>, 122 OpDelta::kRow, 123 kWarpSize>; 124 }; 125 126 // Simt 127 template <typename WarpShape, typename RegularWarpIterator, typename Policy> 128 struct DefaultWarpIteratorAFromSharedMemory< 129 WarpShape, 130 cutlass::gemm::GemmShape<1, 1, 1>, 131 RegularWarpIterator, 132 Policy> { 133 using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; 134 static constexpr auto kWarpSize = 32; 135 136 // We just use the same iterator, as we reproduced the same shared-memory 137 // schema. Just modify it to handle non-complete tiles. 138 using WarpIterator = RegularWarpIterator; 139 }; 140 141 } // namespace threadblock 142 } // namespace gemm 143 } // namespace cutlass 144