1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 #pragma once
9 
10 #include <cutlass/cutlass.h>
11 #include <cutlass/aligned_buffer.h>
12 #include <cutlass/array.h>
13 #include <cutlass/layout/matrix.h>
14 #include <cutlass/layout/pitch_linear.h>
15 #include <cutlass/numeric_types.h>
16 #include <cutlass/transform/pitch_linear_thread_map.h>
17 #include <cutlass/transform/threadblock/predicated_tile_iterator.h>
18 #include <cutlass/transform/threadblock/regular_tile_iterator.h>
19 
20 template <
21     typename scalar_t, // scalar type
22     typename ThreadblockTileShape, // size of tile to load
23     int Threads, // number of participating threads
24     int ElementsPerAccess> // thread access width in elements
25 class TileSmemLoader {
26  public:
27   using SmemTile =
28       cutlass::AlignedBuffer<scalar_t, ThreadblockTileShape::kCount>;
29 
30   using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
31       cutlass::layout::PitchLinearShape<
32           ThreadblockTileShape::kColumn, // contiguous
33           ThreadblockTileShape::kRow>, // strided
34       Threads, // Threads
35       ElementsPerAccess>; // ElementsPerAccess
36 
37   using GmemTileIterator =
38       cutlass::transform::threadblock::PredicatedTileIterator<
39           ThreadblockTileShape, // Shape
40           scalar_t, // Element
41           cutlass::layout::RowMajor, // Layout
42           0, // AdvanceRank
43           ThreadMap>; // ThreadMap
44 
45   using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator<
46       ThreadblockTileShape, // Shape
47       scalar_t, // Element
48       cutlass::layout::RowMajor, // Layout
49       0, // AdvanceRank
50       ThreadMap>; // ThreadMap
51 
52   using Fragment = typename GmemTileIterator::Fragment;
53 
54   /// load a tile from global memory into shared memory
55   CUTLASS_DEVICE
load(GmemTileIterator tile_load_iter,SmemTileIterator tile_store_iter)56   static void load(
57       GmemTileIterator tile_load_iter,
58       SmemTileIterator tile_store_iter) {
59     Fragment tb_frag;
60     tb_frag.clear();
61     tile_load_iter.load(tb_frag);
62     tile_store_iter.store(tb_frag);
63 
64     __syncthreads();
65   }
66 };
67