xref: /XiangShan/src/main/scala/xiangshan/mem/prefetch/L1StridePrefetcher.scala (revision 99ce5576f0ecce1b5045b7bc0dbbb2debd934fbb)
1/***************************************************************************************
2* Copyright (c) 2024 Beijing Institute of Open Source Chip (BOSC)
3* Copyright (c) 2020-2024 Institute of Computing Technology, Chinese Academy of Sciences
4* Copyright (c) 2020-2021 Peng Cheng Laboratory
5*
6* XiangShan is licensed under Mulan PSL v2.
7* You can use this software according to the terms and conditions of the Mulan PSL v2.
8* You may obtain a copy of Mulan PSL v2 at:
9*          http://license.coscl.org.cn/MulanPSL2
10*
11* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
12* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
13* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
14*
15* See the Mulan PSL v2 for more details.
16*
17*
18* Acknowledgement
19*
20* This implementation is inspired by several key papers:
21* [1] Jean-Loup Baer, and Tien-Fu Chen. "[An effective on-chip preloading scheme to reduce data access penalty.]
22* (https://doi.org/10.1145/125826.125932)" ACM/IEEE Conference on Supercomputing. 1991.
23***************************************************************************************/
24
25package xiangshan.mem.prefetch
26
27import org.chipsalliance.cde.config.Parameters
28import chisel3._
29import chisel3.util._
30import utils._
31import utility._
32import xiangshan._
33import xiangshan.mem.L1PrefetchReq
34import xiangshan.mem.Bundles.LsPrefetchTrainBundle
35import xiangshan.mem.trace._
36import xiangshan.cache.HasDCacheParameters
37import xiangshan.cache.mmu._
38import scala.collection.SeqLike
39
40trait HasStridePrefetchHelper extends HasL1PrefetchHelper {
41  val STRIDE_FILTER_SIZE = 6
42  val STRIDE_ENTRY_NUM = 10
43  val STRIDE_BITS = 10 + BLOCK_OFFSET
44  val STRIDE_VADDR_BITS = 10 + BLOCK_OFFSET
45  val STRIDE_CONF_BITS = 2
46
47  // detail control
48  val ALWAYS_UPDATE_PRE_VADDR = true
49  val AGGRESIVE_POLICY = false // if true, prefetch degree is greater than 1, 1 otherwise
50  val STRIDE_LOOK_AHEAD_BLOCKS = 2 // aggressive degree
51  val LOOK_UP_STREAM = false // if true, avoid collision with stream
52
53  val STRIDE_WIDTH_BLOCKS = if(AGGRESIVE_POLICY) STRIDE_LOOK_AHEAD_BLOCKS else 1
54
55  def MAX_CONF = (1 << STRIDE_CONF_BITS) - 1
56}
57
58class StrideMetaBundle(implicit p: Parameters) extends XSBundle with HasStridePrefetchHelper {
59  val pre_vaddr = UInt(STRIDE_VADDR_BITS.W)
60  val stride = UInt(STRIDE_BITS.W)
61  val confidence = UInt(STRIDE_CONF_BITS.W)
62  val hash_pc = UInt(HASH_TAG_WIDTH.W)
63
64  def reset(index: Int) = {
65    pre_vaddr := 0.U
66    stride := 0.U
67    confidence := 0.U
68    hash_pc := index.U
69  }
70
71  def tag_match(valid1: Bool, valid2: Bool, new_hash_pc: UInt): Bool = {
72    valid1 && valid2 && hash_pc === new_hash_pc
73  }
74
75  def alloc(vaddr: UInt, alloc_hash_pc: UInt) = {
76    pre_vaddr := vaddr(STRIDE_VADDR_BITS - 1, 0)
77    stride := 0.U
78    confidence := 0.U
79    hash_pc := alloc_hash_pc
80  }
81
82  def update(vaddr: UInt, always_update_pre_vaddr: Bool) = {
83    val new_vaddr = vaddr(STRIDE_VADDR_BITS - 1, 0)
84    val new_stride = new_vaddr - pre_vaddr
85    val new_stride_blk = block_addr(new_stride)
86    // NOTE: for now, disable negtive stride
87    val stride_valid = new_stride_blk =/= 0.U && new_stride_blk =/= 1.U && new_stride(STRIDE_VADDR_BITS - 1) === 0.U
88    val stride_match = new_stride === stride
89    val low_confidence = confidence <= 1.U
90    val can_send_pf = stride_valid && stride_match && confidence === MAX_CONF.U
91
92    when(stride_valid) {
93      when(stride_match) {
94        confidence := Mux(confidence === MAX_CONF.U, confidence, confidence + 1.U)
95      }.otherwise {
96        confidence := Mux(confidence === 0.U, confidence, confidence - 1.U)
97        when(low_confidence) {
98          stride := new_stride
99        }
100      }
101      pre_vaddr := new_vaddr
102    }
103    when(always_update_pre_vaddr) {
104      pre_vaddr := new_vaddr
105    }
106
107    (can_send_pf, new_stride)
108  }
109
110}
111
112class StrideMetaArray(implicit p: Parameters) extends XSModule with HasStridePrefetchHelper {
113  val io = IO(new XSBundle {
114    val enable = Input(Bool())
115    // TODO: flush all entry when process changing happens, or disable stream prefetch for a while
116    val flush = Input(Bool())
117    val dynamic_depth = Input(UInt(32.W)) // TODO: enable dynamic stride depth
118    val train_req = Flipped(DecoupledIO(new PrefetchReqBundle))
119    val l1_prefetch_req = ValidIO(new StreamPrefetchReqBundle)
120    val l2_l3_prefetch_req = ValidIO(new StreamPrefetchReqBundle)
121    // query Stream component to see if a stream pattern has already been detected
122    val stream_lookup_req  = ValidIO(new PrefetchReqBundle)
123    val stream_lookup_resp = Input(Bool())
124  })
125
126  val array = Reg(Vec(STRIDE_ENTRY_NUM, new StrideMetaBundle))
127  val valids = RegInit(VecInit(Seq.fill(STRIDE_ENTRY_NUM)(false.B)))
128
129  def reset_array(i: Int): Unit = {
130    valids(i) := false.B
131    //only need to rest control signals for firendly area
132    // array(i).reset(i)
133  }
134
135  val replacement = ReplacementPolicy.fromString("plru", STRIDE_ENTRY_NUM)
136
137  // s0: hash pc -> cam all entries
138  val s0_can_accept = Wire(Bool())
139  val s0_valid = io.train_req.fire
140  val s0_vaddr = io.train_req.bits.vaddr
141  val s0_pc = io.train_req.bits.pc
142  val s0_pc_hash = pc_hash_tag(s0_pc)
143  val s0_pc_match_vec = VecInit(array zip valids map { case (e, v) => e.tag_match(v, s0_valid, s0_pc_hash) }).asUInt
144  val s0_hit = s0_pc_match_vec.orR
145  val s0_index = Mux(s0_hit, OHToUInt(s0_pc_match_vec), replacement.way)
146  io.train_req.ready := s0_can_accept
147  io.stream_lookup_req.valid := s0_valid
148  io.stream_lookup_req.bits  := io.train_req.bits
149
150  when(s0_valid) {
151    replacement.access(s0_index)
152  }
153
154  assert(PopCount(s0_pc_match_vec) <= 1.U)
155  XSPerfAccumulate("s0_valid", s0_valid)
156  XSPerfAccumulate("s0_hit", s0_valid && s0_hit)
157  XSPerfAccumulate("s0_miss", s0_valid && !s0_hit)
158
159  // s1: alloc or update
160  val s1_valid = GatedValidRegNext(s0_valid)
161  val s1_index = RegEnable(s0_index, s0_valid)
162  val s1_pc_hash = RegEnable(s0_pc_hash, s0_valid)
163  val s1_vaddr = RegEnable(s0_vaddr, s0_valid)
164  val s1_hit = RegEnable(s0_hit, s0_valid)
165  val s1_alloc = s1_valid && !s1_hit
166  val s1_update = s1_valid && s1_hit
167  val s1_stride = array(s1_index).stride
168  val s1_new_stride = WireInit(0.U(STRIDE_BITS.W))
169  val s1_can_send_pf = WireInit(false.B)
170  s0_can_accept := !(s1_valid && s1_pc_hash === s0_pc_hash)
171
172  val always_update = Constantin.createRecord(s"always_update${p(XSCoreParamsKey).HartId}", initValue = ALWAYS_UPDATE_PRE_VADDR)
173
174  when(s1_alloc) {
175    valids(s1_index) := true.B
176    array(s1_index).alloc(
177      vaddr = s1_vaddr,
178      alloc_hash_pc = s1_pc_hash
179    )
180  }.elsewhen(s1_update) {
181    val res = array(s1_index).update(s1_vaddr, always_update)
182    s1_can_send_pf := res._1
183    s1_new_stride := res._2
184  }
185
186  val l1_stride_ratio_const = Constantin.createRecord(s"l1_stride_ratio${p(XSCoreParamsKey).HartId}", initValue = 2)
187  val l1_stride_ratio = l1_stride_ratio_const(3, 0)
188  val l2_stride_ratio_const = Constantin.createRecord(s"l2_stride_ratio${p(XSCoreParamsKey).HartId}", initValue = 5)
189  val l2_stride_ratio = l2_stride_ratio_const(3, 0)
190  // s2: calculate L1 & L2 pf addr
191  val s2_valid = GatedValidRegNext(s1_valid && s1_can_send_pf)
192  val s2_vaddr = RegEnable(s1_vaddr, s1_valid && s1_can_send_pf)
193  val s2_stride = RegEnable(s1_stride, s1_valid && s1_can_send_pf)
194  val s2_l1_depth = s2_stride << l1_stride_ratio
195  val s2_l1_pf_vaddr = (s2_vaddr + s2_l1_depth)(VAddrBits - 1, 0)
196  val s2_l2_depth = s2_stride << l2_stride_ratio
197  val s2_l2_pf_vaddr = (s2_vaddr + s2_l2_depth)(VAddrBits - 1, 0)
198  val s2_l1_pf_req_bits = (new StreamPrefetchReqBundle).getStreamPrefetchReqBundle(
199    valid = s2_valid,
200    vaddr = s2_l1_pf_vaddr,
201    width = STRIDE_WIDTH_BLOCKS,
202    decr_mode = false.B,
203    sink = SINK_L1,
204    source = L1_HW_PREFETCH_STRIDE,
205    // TODO: add stride debug db, not useful for now
206    t_pc = 0xdeadbeefL.U,
207    t_va = 0xdeadbeefL.U
208    )
209  val s2_l2_pf_req_bits = (new StreamPrefetchReqBundle).getStreamPrefetchReqBundle(
210    valid = s2_valid,
211    vaddr = s2_l2_pf_vaddr,
212    width = STRIDE_WIDTH_BLOCKS,
213    decr_mode = false.B,
214    sink = SINK_L2,
215    source = L1_HW_PREFETCH_STRIDE,
216    // TODO: add stride debug db, not useful for now
217    t_pc = 0xdeadbeefL.U,
218    t_va = 0xdeadbeefL.U
219    )
220
221  // s3: send l1 pf out
222  val s3_valid = if (LOOK_UP_STREAM) GatedValidRegNext(s2_valid) && !io.stream_lookup_resp else GatedValidRegNext(s2_valid)
223  val s3_l1_pf_req_bits = RegEnable(s2_l1_pf_req_bits, s2_valid)
224  val s3_l2_pf_req_bits = RegEnable(s2_l2_pf_req_bits, s2_valid)
225
226  // s4: send l2 pf out
227  val s4_valid = GatedValidRegNext(s3_valid)
228  val s4_l2_pf_req_bits = RegEnable(s3_l2_pf_req_bits, s3_valid)
229
230  io.l1_prefetch_req.valid := s3_valid
231  io.l1_prefetch_req.bits := s3_l1_pf_req_bits
232  io.l2_l3_prefetch_req.valid := s4_valid
233  io.l2_l3_prefetch_req.bits := s4_l2_pf_req_bits
234
235  XSPerfAccumulate("pf_valid", PopCount(Seq(io.l1_prefetch_req.valid, io.l2_l3_prefetch_req.valid)))
236  XSPerfAccumulate("l1_pf_valid", s3_valid)
237  XSPerfAccumulate("l2_pf_valid", s4_valid)
238  XSPerfAccumulate("detect_stream", io.stream_lookup_resp)
239  XSPerfHistogram("high_conf_num", PopCount(VecInit(array.map(_.confidence === MAX_CONF.U))).asUInt, true.B, 0, STRIDE_ENTRY_NUM, 1)
240  for(i <- 0 until STRIDE_ENTRY_NUM) {
241    XSPerfAccumulate(s"entry_${i}_update", i.U === s1_index && s1_update)
242    for(j <- 0 until 4) {
243      XSPerfAccumulate(s"entry_${i}_disturb_${j}", i.U === s1_index && s1_update &&
244                                                   j.U === s1_new_stride &&
245                                                   array(s1_index).confidence === MAX_CONF.U &&
246                                                   array(s1_index).stride =/= s1_new_stride
247      )
248    }
249  }
250
251  for(i <- 0 until STRIDE_ENTRY_NUM) {
252    when(GatedValidRegNext(io.flush)) {
253      reset_array(i)
254    }
255  }
256}
257