xref: /XiangShan/src/main/scala/xiangshan/backend/fu/PMP.scala (revision 8882eb685de93177da606ee717b5ec8e459a768a)
1/***************************************************************************************
2* Copyright (c) 2024 Beijing Institute of Open Source Chip (BOSC)
3* Copyright (c) 2020-2024 Institute of Computing Technology, Chinese Academy of Sciences
4* Copyright (c) 2020-2021 Peng Cheng Laboratory
5*
6* XiangShan is licensed under Mulan PSL v2.
7* You can use this software according to the terms and conditions of the Mulan PSL v2.
8* You may obtain a copy of Mulan PSL v2 at:
9*          http://license.coscl.org.cn/MulanPSL2
10*
11* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
12* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
13* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
14*
15* See the Mulan PSL v2 for more details.
16***************************************************************************************/
17
18// See LICENSE.SiFive for license details.
19
20package xiangshan.backend.fu
21
22import org.chipsalliance.cde.config.Parameters
23import chisel3._
24import chisel3.util._
25import utility.MaskedRegMap.WritableMask
26import xiangshan._
27import xiangshan.backend.fu.util.HasCSRConst
28import utils._
29import utility._
30import xiangshan.cache.mmu.{TlbCmd, TlbExceptionBundle}
31import freechips.rocketchip.rocket.CSRs
32
33trait PMPConst extends HasPMParameters {
34  val PMPOffBits = 2 // minimal 4bytes
35  val CoarserGrain: Boolean = PlatformGrain > PMPOffBits
36}
37
38abstract class PMPBundle(implicit val p: Parameters) extends Bundle with PMPConst
39abstract class PMPModule(implicit val p: Parameters) extends Module with PMPConst
40abstract class PMPXSModule(implicit p: Parameters) extends XSModule with PMPConst
41
42class PMPConfig(implicit p: Parameters) extends PMPBundle {
43  val l = Bool()
44  val c = Bool() // res(1), unuse in pmp
45  val atomic = Bool() // res(0), unuse in pmp
46  val a = UInt(2.W)
47  val x = Bool()
48  val w = Bool()
49  val r = Bool()
50
51  def res: UInt = Cat(c, atomic) // in pmp, unused
52  def off = a === 0.U
53  def tor = a === 1.U
54  def na4 = { if (CoarserGrain) false.B else a === 2.U }
55  def napot = { if (CoarserGrain) a(1).asBool else a === 3.U }
56  def off_tor = !a(1)
57  def na4_napot = a(1)
58
59  def locked = l
60  def addr_locked: Bool = locked
61  def addr_locked(next: PMPConfig): Bool = locked || (next.locked && next.tor)
62}
63
64object PMPConfigUInt {
65  def apply(
66    l: Boolean = false,
67    c: Boolean = false,
68    atomic: Boolean = false,
69    a: Int = 0,
70    x: Boolean = false,
71    w: Boolean = false,
72    r: Boolean = false)(implicit p: Parameters): UInt = {
73    var config = 0
74    if (l) { config += (1 << 7) }
75    if (c) { config += (1 << 6) }
76    if (atomic) { config += (1 << 5) }
77    if (a > 0) { config += (a << 3) }
78    if (x) { config += (1 << 2) }
79    if (w) { config += (1 << 1) }
80    if (r) { config += (1 << 0) }
81    config.U(8.W)
82  }
83}
84trait PMPReadWriteMethodBare extends PMPConst {
85  def match_mask(cfg: PMPConfig, paddr: UInt) = {
86    val match_mask_c_addr = Cat(paddr, cfg.a(0)) | (((1 << PlatformGrain) - 1) >> PMPOffBits).U((paddr.getWidth + 1).W)
87    Cat(match_mask_c_addr & ~(match_mask_c_addr + 1.U), ((1 << PMPOffBits) - 1).U(PMPOffBits.W))
88  }
89
90  def write_cfg_vec(mask: Vec[UInt], addr: Vec[UInt], index: Int, oldcfg: UInt)(cfgs: UInt): UInt = {
91    val cfgVec = Wire(Vec(cfgs.getWidth/8, new PMPConfig))
92    for (i <- cfgVec.indices) {
93      val cfg_w_m_tmp = cfgs((i+1)*8-1, i*8).asUInt.asTypeOf(new PMPConfig)
94      val cfg_old_tmp = oldcfg((i+1)*8-1, i*8).asUInt.asTypeOf(new PMPConfig)
95      cfgVec(i) := cfg_old_tmp
96      when (!cfg_old_tmp.l) {
97        cfgVec(i) := cfg_w_m_tmp
98        cfgVec(i).w := cfg_w_m_tmp.w && cfg_w_m_tmp.r
99        if (CoarserGrain) { cfgVec(i).a := Cat(cfg_w_m_tmp.a(1), cfg_w_m_tmp.a.orR) }
100        when (cfgVec(i).na4_napot) {
101          mask(index + i) := match_mask(cfgVec(i), addr(index + i))
102        }
103      }
104    }
105    cfgVec.asUInt
106  }
107
108  def read_addr(cfg: PMPConfig)(addr: UInt): UInt = {
109    val G = PlatformGrain - PMPOffBits
110    require(G >= 0)
111    if (G == 0) {
112      addr
113    } else if (G >= 2) {
114      Mux(cfg.na4_napot, set_low_bits(addr, G-1), clear_low_bits(addr, G))
115    } else { // G is 1
116      Mux(cfg.off_tor, clear_low_bits(addr, G), addr)
117    }
118  }
119
120  def write_addr(next: PMPConfig, mask: UInt)(paddr: UInt, cfg: PMPConfig, addr: UInt): UInt = {
121    val locked = cfg.addr_locked(next)
122    mask := Mux(!locked, match_mask(cfg, paddr), mask)
123    Mux(!locked, paddr, addr)
124  }
125
126  def set_low_bits(data: UInt, num: Int): UInt = {
127    require(num >= 0)
128    data | ((1 << num)-1).U
129  }
130
131  /** mask the data's low num bits (lsb) */
132  def clear_low_bits(data: UInt, num: Int): UInt = {
133    require(num >= 0)
134    // use Cat instead of & with mask to avoid "Signal Width" problem
135    if (num == 0) { data }
136    else { Cat(data(data.getWidth-1, num), 0.U(num.W)) }
137  }
138}
139
140trait PMPReadWriteMethod extends PMPReadWriteMethodBare  { this: PMPBase =>
141  def write_cfg_vec(oldcfg: UInt)(cfgs: UInt): UInt = {
142    val cfgVec = Wire(Vec(cfgs.getWidth/8, new PMPConfig))
143    for (i <- cfgVec.indices) {
144      val cfg_w_tmp = cfgs((i+1)*8-1, i*8).asUInt.asTypeOf(new PMPConfig)
145      val cfg_old_tmp = oldcfg((i+1)*8-1, i*8).asUInt.asTypeOf(new PMPConfig)
146      cfgVec(i) := cfg_old_tmp
147      when (!cfg_old_tmp.l) {
148        cfgVec(i) := cfg_w_tmp
149        cfgVec(i).w := cfg_w_tmp.w && cfg_w_tmp.r
150        if (CoarserGrain) { cfgVec(i).a := Cat(cfg_w_tmp.a(1), cfg_w_tmp.a.orR) }
151      }
152    }
153    cfgVec.asUInt
154  }
155
156  /** In general, the PMP grain is 2**{G+2} bytes. when G >= 1, na4 is not selectable.
157   * When G >= 2 and cfg.a(1) is set(then the mode is napot), the bits addr(G-2, 0) read as zeros.
158   * When G >= 1 and cfg.a(1) is clear(the mode is off or tor), the addr(G-1, 0) read as zeros.
159   * The low OffBits is dropped
160   */
161  def read_addr(): UInt = {
162    read_addr(cfg)(addr)
163  }
164
165  /** addr for inside addr, drop OffBits with.
166   * compare_addr for inside addr for comparing.
167   * paddr for outside addr.
168   */
169  def write_addr(next: PMPConfig)(paddr: UInt): UInt = {
170    Mux(!cfg.addr_locked(next), paddr, addr)
171  }
172  def write_addr(paddr: UInt): UInt = {
173    Mux(!cfg.addr_locked, paddr, addr)
174  }
175}
176
177/** PMPBase for CSR unit
178  * with only read and write logic
179  */
180class PMPBase(implicit p: Parameters) extends PMPBundle with PMPReadWriteMethod {
181  val cfg = new PMPConfig
182  val addr = UInt((PMPAddrBits - PMPOffBits).W)
183
184  def gen(cfg: PMPConfig, addr: UInt) = {
185    require(addr.getWidth == this.addr.getWidth)
186    this.cfg := cfg
187    this.addr := addr
188  }
189}
190
191trait PMPMatchMethod extends PMPConst { this: PMPEntry =>
192  /** compare_addr is used to compare with input addr */
193  def compare_addr: UInt = ((addr << PMPOffBits) & ~(((1 << PlatformGrain) - 1).U(PMPAddrBits.W))).asUInt
194
195  /** size and maxSize are all log2 Size
196   * for dtlb, the maxSize is bPMXLEN which is 8
197   * for itlb and ptw, the maxSize is log2(512) ?
198   * but we may only need the 64 bytes? how to prevent the bugs?
199   * TODO: handle the special case that itlb & ptw & dcache access wider size than PMXLEN
200   */
201  def is_match(paddr: UInt, lgSize: UInt, lgMaxSize: Int, last_pmp: PMPEntry): Bool = {
202    Mux(cfg.na4_napot, napotMatch(paddr, lgSize, lgMaxSize),
203      Mux(cfg.tor, torMatch(paddr, lgSize, lgMaxSize, last_pmp), false.B))
204  }
205
206  /** generate match mask to help match in napot mode */
207  def match_mask(paddr: UInt): UInt = {
208    match_mask(cfg, paddr)
209  }
210
211  def boundMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int): Bool = {
212    if (lgMaxSize <= PlatformGrain) {
213      (paddr < compare_addr)
214    } else {
215      val highLess = (paddr >> lgMaxSize) < (compare_addr >> lgMaxSize)
216      val highEqual = (paddr >> lgMaxSize) === (compare_addr >> lgMaxSize)
217      val lowLess = (paddr(lgMaxSize-1, 0) | OneHot.UIntToOH1(lgSize, lgMaxSize))  < compare_addr(lgMaxSize-1, 0)
218      highLess || (highEqual && lowLess)
219    }
220  }
221
222  def lowerBoundMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int): Bool = {
223    !boundMatch(paddr, lgSize, lgMaxSize)
224  }
225
226  def higherBoundMatch(paddr: UInt, lgMaxSize: Int) = {
227    boundMatch(paddr, 0.U, lgMaxSize)
228  }
229
230  def torMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int, last_pmp: PMPEntry): Bool = {
231    last_pmp.lowerBoundMatch(paddr, lgSize, lgMaxSize) && higherBoundMatch(paddr, lgMaxSize)
232  }
233
234  def unmaskEqual(a: UInt, b: UInt, m: UInt) = {
235    (a & ~m) === (b & ~m)
236  }
237
238  def napotMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int) = {
239    if (lgMaxSize <= PlatformGrain) {
240      unmaskEqual(paddr, compare_addr, mask)
241    } else {
242      val lowMask = mask | OneHot.UIntToOH1(lgSize, lgMaxSize)
243      val highMatch = unmaskEqual(paddr >> lgMaxSize, compare_addr >> lgMaxSize, mask >> lgMaxSize)
244      val lowMatch = unmaskEqual(paddr(lgMaxSize-1, 0), compare_addr(lgMaxSize-1, 0), lowMask(lgMaxSize-1, 0))
245      highMatch && lowMatch
246    }
247  }
248
249  def aligned(paddr: UInt, lgSize: UInt, lgMaxSize: Int, last: PMPEntry) = {
250    if (lgMaxSize <= PlatformGrain) {
251      true.B
252    } else {
253      val lowBitsMask = OneHot.UIntToOH1(lgSize, lgMaxSize)
254      val lowerBound = ((paddr >> lgMaxSize) === (last.compare_addr >> lgMaxSize)) &&
255        ((~paddr(lgMaxSize-1, 0) & last.compare_addr(lgMaxSize-1, 0)) =/= 0.U)
256      val upperBound = ((paddr >> lgMaxSize) === (compare_addr >> lgMaxSize)) &&
257        ((compare_addr(lgMaxSize-1, 0) & (paddr(lgMaxSize-1, 0) | lowBitsMask)) =/= 0.U)
258      val torAligned = !(lowerBound || upperBound)
259      val napotAligned = (lowBitsMask & ~mask(lgMaxSize-1, 0)) === 0.U
260      Mux(cfg.na4_napot, napotAligned, torAligned)
261    }
262  }
263}
264
265/** PMPEntry for outside pmp copies
266  * with one more elements mask to help napot match
267  * TODO: make mask an element, not an method, for timing opt
268  */
269class PMPEntry(implicit p: Parameters) extends PMPBase with PMPMatchMethod {
270  val mask = UInt(PMPAddrBits.W) // help to match in napot
271
272  def write_addr(next: PMPConfig, mask: UInt)(paddr: UInt) = {
273    mask := Mux(!cfg.addr_locked(next), match_mask(paddr), mask)
274    Mux(!cfg.addr_locked(next), paddr, addr)
275  }
276
277  def write_addr(mask: UInt)(paddr: UInt) = {
278    mask := Mux(!cfg.addr_locked, match_mask(paddr), mask)
279    Mux(!cfg.addr_locked, paddr, addr)
280  }
281
282  def gen(cfg: PMPConfig, addr: UInt, mask: UInt) = {
283    require(addr.getWidth == this.addr.getWidth)
284    this.cfg := cfg
285    this.addr := addr
286    this.mask := mask
287  }
288}
289
290trait PMPMethod extends PMPConst {
291  def pmp_init() : (Vec[UInt], Vec[UInt], Vec[UInt])= {
292    val cfg = WireInit(0.U.asTypeOf(Vec(NumPMP/8, UInt(PMXLEN.W))))
293    // val addr = Wire(Vec(NumPMP, UInt((PMPAddrBits-PMPOffBits).W)))
294    // val mask = Wire(Vec(NumPMP, UInt(PMPAddrBits.W)))
295    // INFO: these CSRs could be uninitialized, but for difftesting with NEMU, we opt to initialize them.
296    val addr = WireInit(0.U.asTypeOf(Vec(NumPMP, UInt((PMPAddrBits-PMPOffBits).W))))
297    val mask = WireInit(0.U.asTypeOf(Vec(NumPMP, UInt(PMPAddrBits.W))))
298    (cfg, addr, mask)
299  }
300
301  def pmp_gen_mapping
302  (
303    init: () => (Vec[UInt], Vec[UInt], Vec[UInt]),
304    num: Int = 16,
305    cfgBase: Int,
306    addrBase: Int,
307    entries: Vec[PMPEntry]
308  ) = {
309    val pmpCfgPerCSR = PMXLEN / new PMPConfig().getWidth
310    def pmpCfgIndex(i: Int) = (PMXLEN / 32) * (i / pmpCfgPerCSR)
311    val init_value = init()
312    /** to fit MaskedRegMap's write, declare cfgs as Merged CSRs and split them into each pmp */
313    val cfgMerged = RegInit(init_value._1) //(Vec(num / pmpCfgPerCSR, UInt(PMXLEN.W))) // RegInit(VecInit(Seq.fill(num / pmpCfgPerCSR)(0.U(PMXLEN.W))))
314    val cfgs = WireInit(cfgMerged).asTypeOf(Vec(num, new PMPConfig()))
315    val addr = RegInit(init_value._2) // (Vec(num, UInt((PMPAddrBits-PMPOffBits).W)))
316    val mask = RegInit(init_value._3) // (Vec(num, UInt(PMPAddrBits.W)))
317
318    for (i <- entries.indices) {
319      entries(i).gen(cfgs(i), addr(i), mask(i))
320    }
321
322    val cfg_mapping = (0 until num by pmpCfgPerCSR).map(i => {Map(
323      MaskedRegMap(
324        addr = cfgBase + pmpCfgIndex(i),
325        reg = cfgMerged(i/pmpCfgPerCSR),
326        wmask = WritableMask,
327        wfn = new PMPBase().write_cfg_vec(mask, addr, i, cfgMerged(i/pmpCfgPerCSR))
328      ))
329    }).fold(Map())((a, b) => a ++ b) // ugly code, hit me if u have better codes
330
331    val addr_mapping = (0 until num).map(i => {Map(
332      MaskedRegMap(
333        addr = addrBase + i,
334        reg = addr(i),
335        wmask = WritableMask,
336        wfn = { if (i != num-1) entries(i).write_addr(entries(i+1).cfg, mask(i)) else entries(i).write_addr(mask(i)) },
337        rmask = WritableMask,
338        rfn = new PMPBase().read_addr(entries(i).cfg)
339      ))
340    }).fold(Map())((a, b) => a ++ b) // ugly code, hit me if u have better codes.
341
342    cfg_mapping ++ addr_mapping
343  }
344}
345
346class PMP(implicit p: Parameters) extends PMPXSModule with HasXSParameter with PMPMethod with PMAMethod with HasCSRConst {
347  val io = IO(new Bundle {
348    val distribute_csr = Flipped(new DistributedCSRIO())
349    val pmp = Output(Vec(NumPMP, new PMPEntry()))
350    val pma = Output(Vec(NumPMA, new PMPEntry()))
351  })
352
353  val w = io.distribute_csr.w
354
355  val pmp = Wire(Vec(NumPMP, new PMPEntry()))
356  val pma = Wire(Vec(NumPMA, new PMPEntry()))
357
358  val pmpMapping = pmp_gen_mapping(pmp_init, NumPMP, CSRs.pmpcfg0, CSRs.pmpaddr0, pmp)
359  val pmaMapping = pmp_gen_mapping(pma_init, NumPMA, PmacfgBase, PmaaddrBase, pma)
360  val mapping = pmpMapping ++ pmaMapping
361
362  val rdata = Wire(UInt(PMXLEN.W))
363  MaskedRegMap.generate(mapping, w.bits.addr, rdata, w.valid, w.bits.data)
364
365  io.pmp := pmp
366  io.pma := pma
367}
368
369class PMPReqBundle(lgMaxSize: Int = 3)(implicit p: Parameters) extends PMPBundle {
370  val addr = Output(UInt(PMPAddrBits.W))
371  val size = Output(UInt(log2Ceil(lgMaxSize+1).W))
372  val cmd = Output(TlbCmd())
373
374  def apply(addr: UInt, size: UInt, cmd: UInt): Unit = {
375    this.addr := addr
376    this.size := size
377    this.cmd := cmd
378  }
379
380  def apply(addr: UInt): Unit = { // req minimal permission and req align size
381    apply(addr, lgMaxSize.U, TlbCmd.read)
382  }
383
384}
385
386class PMPRespBundle(implicit p: Parameters) extends PMPBundle {
387  val ld = Output(Bool())
388  val st = Output(Bool())
389  val instr = Output(Bool())
390  val mmio = Output(Bool())
391  val atomic = Output(Bool())
392
393  def |(resp: PMPRespBundle): PMPRespBundle = {
394    val res = Wire(new PMPRespBundle())
395    res.ld := this.ld || resp.ld
396    res.st := this.st || resp.st
397    res.instr := this.instr || resp.instr
398    res.mmio := this.mmio || resp.mmio
399    res.atomic := this.atomic || resp.atomic
400    res
401  }
402}
403
404trait PMPCheckMethod extends PMPConst {
405  def pmp_check(cmd: UInt, cfg: PMPConfig) = {
406    val resp = Wire(new PMPRespBundle)
407    resp.ld := TlbCmd.isRead(cmd) && !TlbCmd.isAmo(cmd) && !cfg.r
408    resp.st := (TlbCmd.isWrite(cmd) || TlbCmd.isAmo(cmd)) && !cfg.w
409    resp.instr := TlbCmd.isExec(cmd) && !cfg.x
410    resp.mmio := false.B
411    resp.atomic := false.B
412    resp
413  }
414
415  def pmp_match_res(leaveHitMux: Boolean = false, valid: Bool = true.B)(
416    addr: UInt,
417    size: UInt,
418    pmpEntries: Vec[PMPEntry],
419    mode: UInt,
420    lgMaxSize: Int
421  ) = {
422    val num = pmpEntries.size
423    require(num == NumPMP)
424
425    val passThrough = if (pmpEntries.isEmpty) true.B else (mode > 1.U)
426    val pmpDefault = WireInit(0.U.asTypeOf(new PMPEntry()))
427    pmpDefault.cfg.r := passThrough
428    pmpDefault.cfg.w := passThrough
429    pmpDefault.cfg.x := passThrough
430
431    val match_vec = Wire(Vec(num+1, Bool()))
432    val cfg_vec = Wire(Vec(num+1, new PMPEntry()))
433
434    pmpEntries.zip(pmpDefault +: pmpEntries.take(num-1)).zipWithIndex.foreach{ case ((pmp, last_pmp), i) =>
435      val is_match = pmp.is_match(addr, size, lgMaxSize, last_pmp)
436      val ignore = passThrough && !pmp.cfg.l
437      val aligned = pmp.aligned(addr, size, lgMaxSize, last_pmp)
438
439      val cur = WireInit(pmp)
440      cur.cfg.r := aligned && (pmp.cfg.r || ignore)
441      cur.cfg.w := aligned && (pmp.cfg.w || ignore)
442      cur.cfg.x := aligned && (pmp.cfg.x || ignore)
443
444//      Mux(is_match, cur, prev)
445      match_vec(i) := is_match
446      cfg_vec(i) := cur
447    }
448
449    // default value
450    match_vec(num) := true.B
451    cfg_vec(num) := pmpDefault
452
453    if (leaveHitMux) {
454      ParallelPriorityMux(match_vec.map(RegEnable(_, false.B, valid)), RegEnable(cfg_vec, valid))
455    } else {
456      ParallelPriorityMux(match_vec, cfg_vec)
457    }
458  }
459}
460
461class PMPCheckerEnv(implicit p: Parameters) extends PMPBundle {
462  val cmode = Bool()
463  val mode = UInt(2.W)
464  val pmp = Vec(NumPMP, new PMPEntry())
465  val pma = Vec(NumPMA, new PMPEntry())
466
467  def apply(cmode: Bool, mode: UInt, pmp: Vec[PMPEntry], pma: Vec[PMPEntry]): Unit = {
468    this.cmode := cmode
469    this.mode := mode
470    this.pmp := pmp
471    this.pma := pma
472  }
473
474  def apply(mode: UInt, pmp: Vec[PMPEntry], pma: Vec[PMPEntry]): Unit = {
475    this.cmode := true.B
476    this.mode := mode
477    this.pmp := pmp
478    this.pma := pma
479  }
480}
481
482class PMPCheckIO(lgMaxSize: Int)(implicit p: Parameters) extends PMPBundle {
483  val check_env = Input(new PMPCheckerEnv())
484  val req = Flipped(Valid(new PMPReqBundle(lgMaxSize))) // usage: assign the valid to fire signal
485  val resp = new PMPRespBundle()
486
487  def apply(cmode: Bool, mode: UInt, pmp: Vec[PMPEntry], pma: Vec[PMPEntry], req: Valid[PMPReqBundle]) = {
488    check_env.apply(cmode, mode, pmp, pma)
489    this.req := req
490    resp
491  }
492
493  def apply(mode: UInt, pmp: Vec[PMPEntry], pma: Vec[PMPEntry], req: Valid[PMPReqBundle]) = {
494    check_env.apply(mode, pmp, pma)
495    this.req := req
496    resp
497  }
498
499  def req_apply(valid: Bool, addr: UInt): Unit = {
500    this.req.valid := valid
501    this.req.bits.apply(addr)
502  }
503
504  def apply(mode: UInt, pmp: Vec[PMPEntry], pma: Vec[PMPEntry], valid: Bool, addr: UInt) = {
505    check_env.apply(mode, pmp, pma)
506    req_apply(valid, addr)
507    resp
508  }
509}
510
511class PMPCheckv2IO(lgMaxSize: Int)(implicit p: Parameters) extends PMPBundle {
512  val check_env = Input(new PMPCheckerEnv())
513  val req = Flipped(Valid(new PMPReqBundle(lgMaxSize))) // usage: assign the valid to fire signal
514  val resp = Output(new PMPConfig())
515
516  def apply(cmode: Bool, mode: UInt, pmp: Vec[PMPEntry], pma: Vec[PMPEntry], valid: Bool, addr: UInt) = {
517    check_env.apply(cmode, mode, pmp, pma)
518    req_apply(valid, addr)
519    resp
520  }
521
522  def apply(mode: UInt, pmp: Vec[PMPEntry], pma: Vec[PMPEntry], req: Valid[PMPReqBundle]) = {
523    check_env.apply(mode, pmp, pma)
524    this.req := req
525    resp
526  }
527
528  def req_apply(valid: Bool, addr: UInt): Unit = {
529    this.req.valid := valid
530    this.req.bits.apply(addr)
531  }
532
533  def apply(mode: UInt, pmp: Vec[PMPEntry], pma: Vec[PMPEntry], valid: Bool, addr: UInt) = {
534    check_env.apply(mode, pmp, pma)
535    req_apply(valid, addr)
536    resp
537  }
538}
539
540class PMPChecker
541(
542  lgMaxSize: Int = 3,
543  sameCycle: Boolean = false,
544  leaveHitMux: Boolean = false,
545  pmpUsed: Boolean = true
546)(implicit p: Parameters) extends PMPModule
547  with PMPCheckMethod
548  with PMACheckMethod
549{
550  require(!(leaveHitMux && sameCycle))
551  val io = IO(new PMPCheckIO(lgMaxSize))
552
553  val req = io.req.bits
554
555  /* The KeyIDBits is used for memary encrypt, as part of address MSB,
556   * so (PMPKeyIDBits > 0) usually set with HasMEMencryption = true.
557   *
558   * Example:
559   * PAddrBits=48 & PMPKeyIDBits=5
560   * [47,46,45,44,43, 42,41,.......,1,0]
561   * {----KeyID----} {----RealPAddr----}
562   *
563   * The nonzero keyID is binding with Enclave/CVM(cmode=1) to select different memary encrypt key,
564   * and the OS/VMM/APP/VM(cmode=0) can only use zero as KeyID.
565   *
566   * So only the RealPAddr need PMP&PMA check.
567   */
568
569  val res_pmp = pmp_match_res(leaveHitMux, io.req.valid)(req.addr(PMPAddrBits-PMPKeyIDBits-1, 0), req.size, io.check_env.pmp, io.check_env.mode, lgMaxSize)
570  val res_pma = pma_match_res(leaveHitMux, io.req.valid)(req.addr(PMPAddrBits-PMPKeyIDBits-1, 0), req.size, io.check_env.pma, io.check_env.mode, lgMaxSize)
571
572  val resp_pmp = pmp_check(req.cmd, res_pmp.cfg)
573  val resp_pma = pma_check(req.cmd, res_pma.cfg)
574
575  def keyid_check(leaveHitMux: Boolean = false, valid: Bool = true.B, addr: UInt) = {
576    val resp = Wire(new PMPRespBundle)
577    val keyid_nz = if (PMPKeyIDBits > 0) addr(PMPAddrBits-1, PMPAddrBits-PMPKeyIDBits) =/= 0.U else false.B
578    resp.ld := keyid_nz && !io.check_env.cmode && (io.check_env.mode < 3.U)
579    resp.st := keyid_nz && !io.check_env.cmode && (io.check_env.mode < 3.U)
580    resp.instr := keyid_nz && !io.check_env.cmode && (io.check_env.mode < 3.U)
581    resp.mmio := false.B
582    resp.atomic := false.B
583    if (leaveHitMux) {
584      RegEnable(resp, valid)
585    } else {
586      resp
587    }
588  }
589
590  val resp_keyid = keyid_check(leaveHitMux, io.req.valid, req.addr)
591
592  val resp = if (pmpUsed) (resp_pmp | resp_pma | resp_keyid) else (resp_pma | resp_keyid)
593
594  if (sameCycle || leaveHitMux) {
595    io.resp := resp
596  } else {
597    io.resp := RegEnable(resp, io.req.valid)
598  }
599}
600
601/* get config with check */
602class PMPCheckerv2
603(
604  lgMaxSize: Int = 3,
605  sameCycle: Boolean = false,
606  leaveHitMux: Boolean = false
607)(implicit p: Parameters) extends PMPModule
608  with PMPCheckMethod
609  with PMACheckMethod
610{
611  require(!(leaveHitMux && sameCycle))
612  val io = IO(new PMPCheckv2IO(lgMaxSize))
613
614  val req = io.req.bits
615
616  val res_pmp = pmp_match_res(leaveHitMux, io.req.valid)(req.addr, req.size, io.check_env.pmp, io.check_env.mode, lgMaxSize)
617  val res_pma = pma_match_res(leaveHitMux, io.req.valid)(req.addr, req.size, io.check_env.pma, io.check_env.mode, lgMaxSize)
618
619  val resp = and(res_pmp, res_pma)
620
621  if (sameCycle || leaveHitMux) {
622    io.resp := resp
623  } else {
624    io.resp := RegEnable(resp, io.req.valid)
625  }
626
627  def and(pmp: PMPEntry, pma: PMPEntry): PMPConfig = {
628    val tmp_res = Wire(new PMPConfig)
629    tmp_res.l := DontCare
630    tmp_res.a := DontCare
631    tmp_res.r := pmp.cfg.r && pma.cfg.r
632    tmp_res.w := pmp.cfg.w && pma.cfg.w
633    tmp_res.x := pmp.cfg.x && pma.cfg.x
634    tmp_res.c := pma.cfg.c
635    tmp_res.atomic := pma.cfg.atomic
636    tmp_res
637  }
638}
639