1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
3  *reserved. SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice,
9  *this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29  *POSSIBILITY OF SUCH DAMAGE.
30  *
31  **************************************************************************************************/
32 /*! \file
33     \brief Instantiates the right WarpIterator to read from shared memory
34     The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading
35         data dumped with `B2bGemm::accumToSmem`.
36 */
37 
38 #pragma once
39 
40 #include <cutlass/cutlass.h>
41 #include <cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h>
42 #include <cutlass/platform/platform.h>
43 
44 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h>
45 
46 namespace cutlass {
47 namespace gemm {
48 namespace threadblock {
49 
50 template <
51     typename WarpShape,
52     typename InstructionShape,
53     typename RegularWarpIterator,
54     typename Policy,
55     typename Enable = void>
56 struct DefaultWarpIteratorAFromSharedMemory {};
57 
58 // TensorOp - Ampere half
59 template <typename RegularWarpIterator, typename Policy, int kInstrK>
60 struct DefaultWarpIteratorAFromSharedMemory<
61     cutlass::gemm::GemmShape<32, 32, 32>,
62     cutlass::gemm::GemmShape<16, 8, kInstrK>,
63     RegularWarpIterator,
64     Policy,
65     typename platform::enable_if<(
66         sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
67         Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
68   using OpDelta = typename Policy::Operator::Policy::OpDelta;
69   using WarpShape = cutlass::MatrixShape<32, 32>;
70   using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>;
71 
72   using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
73       cutlass::gemm::Operand::kA,
74       typename RegularWarpIterator::Element,
75       cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>>;
76 };
77 
78 // TensorOp - Ampere f32
79 template <typename WarpShape, typename RegularWarpIterator, typename Policy>
80 struct DefaultWarpIteratorAFromSharedMemory<
81     WarpShape,
82     cutlass::gemm::GemmShape<16, 8, 8>,
83     RegularWarpIterator,
84     Policy,
85     typename platform::enable_if<(
86         sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
87         Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
88   using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
89   static constexpr auto kWarpSize = 32;
90   using OpDelta = typename Policy::Operator::Policy::OpDelta;
91 
92   using WarpIterator =
93       cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
94           cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>,
95           cutlass::gemm::Operand::kA,
96           typename RegularWarpIterator::Element,
97           cutlass::layout::RowMajor,
98           cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>,
99           OpDelta::kRow,
100           kWarpSize>;
101 };
102 
103 // TensorOp - Volta
104 template <typename WarpShape, typename RegularWarpIterator, typename Policy>
105 struct DefaultWarpIteratorAFromSharedMemory<
106     WarpShape,
107     cutlass::gemm::GemmShape<16, 16, 4>,
108     RegularWarpIterator,
109     Policy> {
110   using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>;
111   static constexpr auto kWarpSize = 32;
112   using OpDelta = typename Policy::Operator::Policy::OpDelta;
113 
114   using WarpIterator =
115       cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<
116           cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM,
117                                         // WarpShape::kK>,
118           cutlass::gemm::Operand::kA,
119           typename RegularWarpIterator::Element,
120           cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>,
121           cutlass::MatrixShape<16, 4>,
122           OpDelta::kRow,
123           kWarpSize>;
124 };
125 
126 // Simt
127 template <typename WarpShape, typename RegularWarpIterator, typename Policy>
128 struct DefaultWarpIteratorAFromSharedMemory<
129     WarpShape,
130     cutlass::gemm::GemmShape<1, 1, 1>,
131     RegularWarpIterator,
132     Policy> {
133   using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
134   static constexpr auto kWarpSize = 32;
135 
136   // We just use the same iterator, as we reproduced the same shared-memory
137   // schema. Just modify it to handle non-complete tiles.
138   using WarpIterator = RegularWarpIterator;
139 };
140 
141 } // namespace threadblock
142 } // namespace gemm
143 } // namespace cutlass
144