xref: /XiangShan/src/main/scala/xiangshan/cache/mmu/BitmapCheck.scala (revision 8882eb685de93177da606ee717b5ec8e459a768a)
1/***************************************************************************************
2* Copyright (c) 2024-2025 Institute of Information Engineering, Chinese Academy of Sciences
3*
4* XiangShan is licensed under Mulan PSL v2.
5* You can use this software according to the terms and conditions of the Mulan PSL v2.
6* You may obtain a copy of Mulan PSL v2 at:
7*          http://license.coscl.org.cn/MulanPSL2
8*
9* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
10* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
11* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
12*
13* See the Mulan PSL v2 for more details.
14***************************************************************************************/
15
16package xiangshan.cache.mmu
17
18import org.chipsalliance.cde.config.Parameters
19import chisel3._
20import chisel3.util._
21import xiangshan._
22import xiangshan.cache.{HasDCacheParameters, MemoryOpConstants}
23import utils._
24import utility._
25import freechips.rocketchip.diplomacy.{LazyModule, LazyModuleImp}
26import freechips.rocketchip.tilelink._
27import xiangshan.backend.fu.{PMPReqBundle, PMPRespBundle}
28
29class bitmapReqBundle(implicit p: Parameters) extends XSBundle with HasPtwConst {
30    val bmppn = UInt(ppnLen.W)
31    val id = UInt(log2Up(l2tlbParams.llptwsize+2).W)
32    val vpn = UInt(vpnLen.W)
33    val level = UInt(log2Up(Level).W)
34    val way_info = UInt(l2tlbParams.l0nWays.W)
35    val hptw_bypassed = Bool()
36}
37
38class bitmapRespBundle(implicit p: Parameters) extends XSBundle with HasPtwConst {
39    val cf = Bool()
40    val cfs = Vec(tlbcontiguous,Bool())
41    val id = UInt(log2Up(l2tlbParams.llptwsize+2).W)
42}
43
44class bitmapEntry(implicit p: Parameters) extends XSBundle with HasPtwConst {
45  val ppn = UInt(ppnLen.W)
46  val vpn = UInt(vpnLen.W)
47  val id = UInt(bMemID.W)
48  val wait_id = UInt(log2Up(l2tlbParams.llptwsize+2).W)
49  // bitmap check faild? : 0 success, 1 faild
50  val cf = Bool()
51  val hit = Bool()
52  val cfs = Vec(tlbcontiguous,Bool())
53  val level = UInt(log2Up(Level).W)
54  val way_info = UInt(l2tlbParams.l0nWays.W)
55  val hptw_bypassed = Bool()
56  val data = UInt(XLEN.W)
57}
58
59class bitmapIO(implicit p: Parameters) extends MMUIOBaseBundle with HasPtwConst {
60  val mem = new Bundle {
61    val req = DecoupledIO(new L2TlbMemReqBundle())
62    val resp = Flipped(DecoupledIO(new Bundle {
63      val id = Output(UInt(bMemID.W))
64      val value = Output(UInt(blockBits.W))
65    }))
66    val req_mask = Input(Vec(l2tlbParams.llptwsize+2, Bool()))
67  }
68  val req = Flipped(DecoupledIO(new bitmapReqBundle()))
69  val resp = DecoupledIO(new bitmapRespBundle())
70
71  val pmp = new Bundle {
72    val req = ValidIO(new PMPReqBundle())
73    val resp = Flipped(new PMPRespBundle())
74  }
75
76  val wakeup = ValidIO(new Bundle {
77    val setIndex = UInt(PtwL0SetIdxLen.W)
78    val tag = UInt(SPTagLen.W)
79    val isSp = Bool()
80    val way_info = UInt(l2tlbParams.l0nWays.W)
81    val pte_index = UInt(sectortlbwidth.W)
82    val check_success = Bool()
83  })
84
85  // bitmap cache req/resp and refill port
86  val cache = new Bundle {
87    val req = DecoupledIO(new bitmapCacheReqBundle())
88    val resp = Flipped(DecoupledIO(new bitmapCacheRespBundle()))
89  }
90  val refill = Output(ValidIO(new Bundle {
91    val tag = UInt(ppnLen.W)
92    val data = UInt(XLEN.W)
93  }))
94}
95
96class Bitmap(implicit p: Parameters) extends XSModule with HasPtwConst {
97  def getBitmapAddr(ppn: UInt): UInt = {
98    val effective_ppn = ppn(ppnLen-KeyIDBits-1, 0)
99    bitmap_base + (effective_ppn >> log2Ceil(XLEN) << log2Ceil(8))
100  }
101
102  val io = IO(new bitmapIO)
103
104  val csr = io.csr
105  val sfence = io.sfence
106  val flush = sfence.valid || csr.satp.changed || csr.vsatp.changed || csr.hgatp.changed
107  val bitmap_base = csr.mbmc.BMA << 6
108
109  val entries = Reg(Vec(l2tlbParams.llptwsize+2, new bitmapEntry()))
110  // add pmp check
111  val state_idle :: state_addr_check :: state_cache_req :: state_cache_resp  ::state_mem_req :: state_mem_waiting :: state_mem_out :: Nil = Enum(7)
112  val state = RegInit(VecInit(Seq.fill(l2tlbParams.llptwsize+2)(state_idle)))
113
114  val is_emptys = state.map(_ === state_idle)
115  val is_cache_req = state.map (_ === state_cache_req)
116  val is_cache_resp = state.map (_ === state_cache_resp)
117  val is_mems = state.map(_ === state_mem_req)
118  val is_waiting = state.map(_ === state_mem_waiting)
119  val is_having = state.map(_ === state_mem_out)
120
121  val full = !ParallelOR(is_emptys).asBool
122  val waiting = ParallelOR(is_waiting).asBool
123  val enq_ptr = ParallelPriorityEncoder(is_emptys)
124
125  val mem_ptr = ParallelPriorityEncoder(is_having)
126  val mem_arb = Module(new RRArbiter(new bitmapEntry(), l2tlbParams.llptwsize+2))
127
128  val bitmapdata = Wire(Vec(blockBits / XLEN, UInt(XLEN.W)))
129  if (HasBitmapCheckDefault) {
130    for (i <- 0 until blockBits / XLEN) {
131      bitmapdata(i) := 0.U
132    }
133  } else {
134    bitmapdata := io.mem.resp.bits.value.asTypeOf(Vec(blockBits / XLEN, UInt(XLEN.W)))
135  }
136
137  for (i <- 0 until l2tlbParams.llptwsize+2) {
138    mem_arb.io.in(i).bits := entries(i)
139    mem_arb.io.in(i).valid := is_mems(i) && !io.mem.req_mask(i)
140  }
141
142  val cache_req_arb = Module(new Arbiter(new bitmapCacheReqBundle(),l2tlbParams.llptwsize + 2))
143  for (i <- 0 until l2tlbParams.llptwsize+2) {
144    cache_req_arb.io.in(i).valid := is_cache_req(i)
145    cache_req_arb.io.in(i).bits.tag := entries(i).ppn
146    cache_req_arb.io.in(i).bits.order := i.U;
147  }
148
149  val dup_vec = state.indices.map(i =>
150    dupBitmapPPN(io.req.bits.bmppn, entries(i).ppn)
151  )
152  val dup_req_fire = mem_arb.io.out.fire && dupBitmapPPN(io.req.bits.bmppn, mem_arb.io.out.bits.ppn)
153  val dup_vec_wait = dup_vec.zip(is_waiting).map{case (d, w) => d && w}
154  val dup_wait_resp = io.mem.resp.fire && VecInit(dup_vec_wait)(io.mem.resp.bits.id - (l2tlbParams.llptwsize + 2).U)
155  val wait_id = Mux(dup_req_fire, mem_arb.io.chosen, ParallelMux(dup_vec_wait zip entries.map(_.wait_id)))
156
157  val to_wait = Cat(dup_vec_wait).orR || dup_req_fire
158  val to_mem_out = dup_wait_resp
159
160  val enq_state_normal = MuxCase(state_addr_check, Seq(
161    to_mem_out -> state_mem_out,
162    to_wait -> state_mem_waiting
163  ))
164  val enq_state =  enq_state_normal
165  val enq_ptr_reg = RegNext(enq_ptr)
166
167  val need_addr_check = RegNext(enq_state === state_addr_check && io.req.fire && !flush)
168
169  io.pmp.req.valid := need_addr_check
170  io.pmp.req.bits.addr := RegEnable(getBitmapAddr(io.req.bits.bmppn),io.req.fire)
171  io.pmp.req.bits.cmd := TlbCmd.read
172  io.pmp.req.bits.size := 3.U
173  val pmp_resp_valid = io.pmp.req.valid
174
175  when (io.req.fire) {
176    state(enq_ptr) := enq_state
177    entries(enq_ptr).ppn := io.req.bits.bmppn
178    entries(enq_ptr).vpn := io.req.bits.vpn
179    entries(enq_ptr).id := io.req.bits.id
180    entries(enq_ptr).wait_id := Mux(to_wait, wait_id, enq_ptr)
181    entries(enq_ptr).cf := false.B
182    for (i <- 0 until tlbcontiguous) {
183      entries(enq_ptr).cfs(i) := false.B
184    }
185    entries(enq_ptr).hit := to_wait
186    entries(enq_ptr).level := io.req.bits.level
187    entries(enq_ptr).way_info := io.req.bits.way_info
188    entries(enq_ptr).hptw_bypassed := io.req.bits.hptw_bypassed
189  }
190
191  // when pmp check failed, use cf bit represent
192  when (pmp_resp_valid) {
193    val ptr = enq_ptr_reg
194    val accessFault = io.pmp.resp.ld || io.pmp.resp.mmio
195    entries(ptr).cf := accessFault
196    for (i <- 0 until tlbcontiguous) {
197      entries(ptr).cfs(i) := accessFault
198    }
199    // firstly req bitmap cache
200    state(ptr) := Mux(accessFault, state_mem_out, state_cache_req)
201  }
202
203  val cache_wait = ParallelOR(is_cache_resp).asBool
204  io.cache.resp.ready := !flush && cache_wait
205
206  val hit = WireInit(false.B)
207  io.cache.req.valid := cache_req_arb.io.out.valid && !flush
208  io.cache.req.bits.tag := cache_req_arb.io.out.bits.tag
209  io.cache.req.bits.order := cache_req_arb.io.out.bits.order
210  cache_req_arb.io.out.ready := io.cache.req.ready
211
212
213  when (cache_req_arb.io.out.fire) {
214    for (i <- state.indices) {
215      when (state(i) === state_cache_req && cache_req_arb.io.chosen === i.U) {
216        state(i) := state_cache_resp
217      }
218    }
219  }
220
221  when (io.cache.resp.fire) {
222    for (i <- state.indices) {
223      val cm_dup_vec = state.indices.map(j =>
224        dupBitmapPPN(entries(i).ppn, entries(j).ppn)
225      )
226      val cm_dup_req_fire = mem_arb.io.out.fire && dupBitmapPPN(entries(i).ppn, mem_arb.io.out.bits.ppn)
227      val cm_dup_vec_wait = cm_dup_vec.zip(is_waiting).map{case (d, w) => d && w}
228      val cm_dup_wait_resp = io.mem.resp.fire && VecInit(cm_dup_vec_wait)(io.mem.resp.bits.id - (l2tlbParams.llptwsize + 2).U)
229      val cm_wait_id = Mux(cm_dup_req_fire, mem_arb.io.chosen, ParallelMux(cm_dup_vec_wait zip entries.map(_.wait_id)))
230      val cm_to_wait = Cat(cm_dup_vec_wait).orR || cm_dup_req_fire
231      val cm_to_mem_out = cm_dup_wait_resp
232      val cm_next_state_normal = MuxCase(state_mem_req, Seq(
233        cm_to_mem_out -> state_mem_out,
234        cm_to_wait -> state_mem_waiting
235      ))
236      when (state(i) === state_cache_resp && io.cache.resp.bits.order === i.U) {
237          hit := io.cache.resp.bits.hit
238          when (hit) {
239            entries(i).cf := io.cache.resp.bits.cfs(entries(i).ppn(5,0))
240            entries(i).hit := true.B
241            entries(i).cfs := io.cache.resp.bits.cfs
242            state(i) := state_mem_out
243          } .otherwise {
244            state(i) := cm_next_state_normal
245            entries(i).wait_id := Mux(cm_to_wait, cm_wait_id, entries(i).wait_id)
246            entries(i).hit := cm_to_wait
247          }
248      }
249    }
250  }
251
252  when (mem_arb.io.out.fire) {
253    for (i <- state.indices) {
254      when (state(i) === state_mem_req && dupBitmapPPN(entries(i).ppn, mem_arb.io.out.bits.ppn)) {
255        state(i) := state_mem_waiting
256        entries(i).wait_id := mem_arb.io.chosen
257      }
258    }
259  }
260
261  when (io.mem.resp.fire) {
262    state.indices.map{i =>
263      when (state(i) === state_mem_waiting && io.mem.resp.bits.id === entries(i).wait_id + (l2tlbParams.llptwsize + 2).U) {
264        state(i) := state_mem_out
265        val index = getBitmapAddr(entries(i).ppn)(log2Up(l2tlbParams.blockBytes)-1, log2Up(XLEN/8))
266        entries(i).data := bitmapdata(index)
267        entries(i).cf := bitmapdata(index)(entries(i).ppn(5,0))
268        val ppnPart = entries(i).ppn(5,3)
269        val start = (ppnPart << 3.U)
270        val end = start + 7.U
271        val mask = (1.U << 8) - 1.U
272        val selectedBits = (bitmapdata(index) >> start) & mask
273        for (j <- 0 until tlbcontiguous) {
274          entries(i).cfs(j) := selectedBits(j)
275        }
276      }
277    }
278  }
279
280  when (io.resp.fire) {
281    state(mem_ptr) := state_idle
282  }
283
284  when (flush) {
285    state.map(_ := state_idle)
286  }
287
288  io.req.ready := !full
289
290  io.resp.valid := ParallelOR(is_having).asBool
291  // if cache hit, resp the cache's resp
292  io.resp.bits.cf := entries(mem_ptr).cf
293  io.resp.bits.cfs := entries(mem_ptr).cfs
294  io.resp.bits.id := entries(mem_ptr).id
295
296  io.mem.req.valid := mem_arb.io.out.valid && !flush
297  io.mem.req.bits.addr := getBitmapAddr(mem_arb.io.out.bits.ppn)
298  io.mem.req.bits.id := mem_arb.io.chosen + (l2tlbParams.llptwsize + 2).U
299  mem_arb.io.out.ready := io.mem.req.ready
300
301  io.mem.resp.ready := waiting
302
303  io.mem.req.bits.hptw_bypassed := false.B
304
305  io.wakeup.valid := io.resp.valid && !entries(mem_ptr).hptw_bypassed
306  io.wakeup.bits.setIndex := genPtwL0SetIdx(entries(mem_ptr).vpn)
307  io.wakeup.bits.tag := entries(mem_ptr).vpn(vpnLen - 1, vpnLen - SPTagLen)
308  io.wakeup.bits.isSp := entries(mem_ptr).level =/= 0.U
309  io.wakeup.bits.way_info := entries(mem_ptr).way_info
310  io.wakeup.bits.pte_index := entries(mem_ptr).vpn(sectortlbwidth - 1, 0)
311  io.wakeup.bits.check_success := !entries(mem_ptr).cf
312
313  // when don't hit, refill the data to bitmap cache
314  io.refill.valid := io.resp.valid && !entries(mem_ptr).hit
315  io.refill.bits.tag := entries(mem_ptr).ppn
316  io.refill.bits.data := entries(mem_ptr).data
317
318  XSPerfAccumulate("bitmap_req", io.req.fire)
319  XSPerfAccumulate("bitmap_mem_req", io.mem.req.fire)
320}
321
322// add bitmap cache
323class bitmapCacheReqBundle(implicit p: Parameters) extends PtwBundle{
324  val order = UInt((l2tlbParams.llptwsize + 2).W)
325  val tag = UInt(ppnLen.W)
326}
327class bitmapCacheRespBundle(implicit p: Parameters) extends PtwBundle{
328  val hit = Bool()
329  val cfs = Vec(tlbcontiguous,Bool())
330  val order = UInt((l2tlbParams.llptwsize + 2).W)
331  def apply(hit : Bool, cfs : Vec[Bool], order : UInt) = {
332    this.hit := hit
333    this.cfs := cfs
334    this.order := order
335  }
336}
337class bitmapCacheIO(implicit p: Parameters) extends MMUIOBaseBundle with HasPtwConst {
338  val req = Flipped(DecoupledIO(new bitmapCacheReqBundle()))
339  val resp = DecoupledIO(new bitmapCacheRespBundle())
340  val refill = Flipped(ValidIO(new Bundle {
341    val tag = UInt(ppnLen.W)
342    val data = UInt(XLEN.W)
343  }))
344}
345class bitmapCacheEntry(implicit p: Parameters) extends PtwBundle{
346  val tag = UInt((ppnLen-log2Ceil(XLEN)).W)
347  val data = UInt(XLEN.W) // store 64bits in one entry
348  val valid = Bool()
349  def hit(tag : UInt) = {
350    (this.tag === tag(ppnLen-1,log2Ceil(XLEN))) && this.valid === 1.B
351  }
352  def refill(tag : UInt,data : UInt,valid : Bool) = {
353    this.tag := tag(ppnLen-1,log2Ceil(XLEN))
354    this.data := data
355    this.valid := valid
356  }
357}
358
359class BitmapCache(implicit p: Parameters) extends XSModule with HasPtwConst {
360  val io = IO(new bitmapCacheIO)
361
362  val csr = io.csr
363  val sfence = io.sfence
364  val flush = sfence.valid || csr.satp.changed || csr.vsatp.changed || csr.hgatp.changed
365  val bitmap_cache_clear = csr.mbmc.BCLEAR
366
367  val bitmapCachesize = 16
368  val bitmapcache = Reg(Vec(bitmapCachesize,new bitmapCacheEntry()))
369  val bitmapReplace = ReplacementPolicy.fromString(l2tlbParams.l3Replacer, bitmapCachesize)
370
371  // -----
372  // -S0--
373  // -----
374  val addr_search = io.req.bits.tag
375  val hitVecT = bitmapcache.map(_.hit(addr_search))
376
377  // -----
378  // -S1--
379  // -----
380  val index = RegEnable(addr_search(log2Up(XLEN)-1,0), io.req.fire)
381  val order = RegEnable(io.req.bits.order, io.req.fire)
382  val hitVec = RegEnable(VecInit(hitVecT), io.req.fire)
383  val CacheData = RegEnable(ParallelPriorityMux(hitVecT zip bitmapcache.map(_.data)), io.req.fire)
384  val cfs = Wire(Vec(tlbcontiguous, Bool()))
385
386  val start = (index(5, 3) << 3.U)
387  val end = start + 7.U
388  val mask = (1.U << 8) - 1.U
389  val cfsdata = (CacheData >> start) & mask
390  for (i <- 0 until tlbcontiguous) {
391    cfs(i) := cfsdata(i)
392  }
393  val hit = ParallelOR(hitVec)
394
395  val resp_res = Wire(new bitmapCacheRespBundle())
396  resp_res.apply(hit,cfs,order)
397
398  val resp_valid_reg = RegInit(false.B)
399  when (flush) {
400    resp_valid_reg := false.B
401  } .elsewhen(io.req.fire) {
402    resp_valid_reg := true.B
403  } .elsewhen(io.resp.fire) {
404    resp_valid_reg := false.B
405  } .otherwise {
406    resp_valid_reg := resp_valid_reg
407  }
408
409  io.req.ready := !resp_valid_reg || io.resp.fire
410  io.resp.valid := resp_valid_reg
411  io.resp.bits := resp_res
412
413  when (!flush && hit && io.resp.fire) {
414    bitmapReplace.access(OHToUInt(hitVec))
415  }
416
417  // -----
418  // refill
419  // -----
420  val rf_addr = io.refill.bits.tag
421  val rf_data = io.refill.bits.data
422  val rf_vd = io.refill.valid
423  when (!flush && rf_vd) {
424    val refillindex = bitmapReplace.way
425    dontTouch(refillindex)
426    bitmapcache(refillindex).refill(rf_addr,rf_data,true.B)
427    bitmapReplace.access(refillindex)
428  }
429  when (bitmap_cache_clear === 1.U) {
430    bitmapcache.foreach(_.valid := false.B)
431  }
432
433  XSPerfAccumulate("bitmap_cache_resp", io.resp.fire)
434  XSPerfAccumulate("bitmap_cache_resp_miss", io.resp.fire && !io.resp.bits.hit)
435}
436