1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 #pragma once 9 10 #include <cutlass/cutlass.h> 11 #include <cutlass/aligned_buffer.h> 12 #include <cutlass/array.h> 13 #include <cutlass/layout/matrix.h> 14 #include <cutlass/layout/pitch_linear.h> 15 #include <cutlass/numeric_types.h> 16 #include <cutlass/transform/pitch_linear_thread_map.h> 17 #include <cutlass/transform/threadblock/predicated_tile_iterator.h> 18 #include <cutlass/transform/threadblock/regular_tile_iterator.h> 19 20 template < 21 typename scalar_t, // scalar type 22 typename ThreadblockTileShape, // size of tile to load 23 int Threads, // number of participating threads 24 int ElementsPerAccess> // thread access width in elements 25 class TileSmemLoader { 26 public: 27 using SmemTile = 28 cutlass::AlignedBuffer<scalar_t, ThreadblockTileShape::kCount>; 29 30 using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< 31 cutlass::layout::PitchLinearShape< 32 ThreadblockTileShape::kColumn, // contiguous 33 ThreadblockTileShape::kRow>, // strided 34 Threads, // Threads 35 ElementsPerAccess>; // ElementsPerAccess 36 37 using GmemTileIterator = 38 cutlass::transform::threadblock::PredicatedTileIterator< 39 ThreadblockTileShape, // Shape 40 scalar_t, // Element 41 cutlass::layout::RowMajor, // Layout 42 0, // AdvanceRank 43 ThreadMap>; // ThreadMap 44 45 using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator< 46 ThreadblockTileShape, // Shape 47 scalar_t, // Element 48 cutlass::layout::RowMajor, // Layout 49 0, // AdvanceRank 50 ThreadMap>; // ThreadMap 51 52 using Fragment = typename GmemTileIterator::Fragment; 53 54 /// load a tile from global memory into shared memory 55 CUTLASS_DEVICE load(GmemTileIterator tile_load_iter,SmemTileIterator tile_store_iter)56 static void load( 57 GmemTileIterator tile_load_iter, 58 SmemTileIterator tile_store_iter) { 59 Fragment tb_frag; 60 tb_frag.clear(); 61 tile_load_iter.load(tb_frag); 62 tile_store_iter.store(tb_frag); 63 64 __syncthreads(); 65 } 66 }; 67