xref: /XiangShan/src/main/scala/xiangshan/frontend/SC.scala (revision 3e98122d7c1c9f544f8596a195d04abe375fd70c)
1package xiangshan.frontend
2
3import chisel3._
4import chisel3.util._
5import xiangshan._
6import utils._
7import chisel3.experimental.chiselName
8
9import scala.math.min
10
11class SCReq extends TageReq
12
13class SCResp(val ctrBits: Int = 6) extends TageBundle {
14  val ctr = Vec(2, SInt(ctrBits.W))
15}
16
17class SCUpdate(val ctrBits: Int = 6) extends TageBundle {
18  val pc = UInt(VAddrBits.W)
19  val fetchIdx = UInt(log2Up(TageBanks).W)
20  val hist = UInt(HistoryLength.W)
21  val mask = Vec(TageBanks, Bool())
22  val oldCtr = SInt(ctrBits.W)
23  val tagePred = Bool()
24  val taken = Bool()
25}
26
27class SCTableIO extends TageBundle {
28  val req = Input(Valid(new SCReq))
29  val resp = Output(Vec(TageBanks, new SCResp))
30  val update = Input(new SCUpdate)
31}
32
33abstract class BaseSCTable(val r: Int = 1024, val cb: Int = 6, val h: Int = 0) extends TageModule {
34  val io = IO(new SCTableIO)
35  def getCenteredValue(ctr: SInt): SInt = (ctr << 1).asSInt + 1.S
36}
37
38class FakeSCTable extends BaseSCTable {
39  io.resp := 0.U.asTypeOf(Vec(TageBanks, new SCResp))
40}
41
42@chiselName
43class SCTable(val nRows: Int, val ctrBits: Int, val histLen: Int) extends BaseSCTable(nRows, ctrBits, histLen) {
44
45  val table = List.fill(TageBanks) {
46    List.fill(2) {
47      Module(new SRAMTemplate(SInt(ctrBits.W), set=nRows, shouldReset=false, holdRead=true, singlePort=false))
48    }
49  }
50
51  def compute_folded_hist(hist: UInt, l: Int) = {
52    if (histLen > 0) {
53      val nChunks = (histLen + l - 1) / l
54      val hist_chunks = (0 until nChunks) map {i =>
55        hist(min((i+1)*l, histLen)-1, i*l)
56      }
57      hist_chunks.reduce(_^_)
58    }
59    else 0.U
60  }
61
62  def getIdx(hist: UInt, pc: UInt) = {
63    (compute_folded_hist(hist, log2Ceil(nRows)) ^ (pc >> instOffsetBits.U))(log2Ceil(nRows)-1,0)
64  }
65
66  def ctrUpdate(ctr: SInt, cond: Bool): SInt = signedSatUpdate(ctr, ctrBits, cond)
67
68  val doing_reset = RegInit(true.B)
69  val reset_idx = RegInit(0.U(log2Ceil(nRows).W))
70  reset_idx := reset_idx + doing_reset
71  when (reset_idx === (nRows-1).U) { doing_reset := false.B }
72
73  val idx = getIdx(io.req.bits.hist, io.req.bits.pc)
74  val idxLatch = RegEnable(idx, enable=io.req.valid)
75
76  val table_r = WireInit(0.U.asTypeOf(Vec(TageBanks,Vec(2, SInt(ctrBits.W)))))
77
78  val baseBank = io.req.bits.pc(log2Up(TageBanks), instOffsetBits)
79  val baseBankLatch = RegEnable(baseBank, enable=io.req.valid)
80
81  val bankIdxInOrder = VecInit((0 until TageBanks).map(b => (baseBankLatch +& b.U)(log2Up(TageBanks)-1, 0)))
82  val realMask = circularShiftLeft(io.req.bits.mask, TageBanks, baseBank)
83  val maskLatch = RegEnable(io.req.bits.mask, enable=io.req.valid)
84
85  val update_idx = getIdx(io.update.hist, io.update.pc - (io.update.fetchIdx << instOffsetBits))
86  val update_wdata = ctrUpdate(io.update.oldCtr, io.update.taken)
87
88
89  for (b <- 0 until TageBanks) {
90    for (i <- 0 to 1) {
91      table(b)(i).reset := reset.asBool
92      table(b)(i).io.r.req.valid := io.req.valid && realMask(b)
93      table(b)(i).io.r.req.bits.setIdx := idx
94
95      table_r(b)(i) := table(b)(i).io.r.resp.data(0)
96
97      table(b)(i).io.w.req.valid := (io.update.mask(b) && i.U === io.update.tagePred.asUInt) || doing_reset
98      table(b)(i).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, update_idx)
99      table(b)(i).io.w.req.bits.data := Mux(doing_reset, 0.S, update_wdata)
100    }
101
102  }
103
104  (0 until TageBanks).map(b => {
105    io.resp(b).ctr := table_r(bankIdxInOrder(b))
106  })
107
108  if (BPUDebug && debug) {
109    val u = io.update
110    val b = PriorityEncoder(u.mask)
111    XSDebug(io.req.valid, p"scTableReq: pc=0x${io.req.bits.pc}%x, idx=${idx}%d, hist=${io.req.bits.hist}%x, baseBank=${baseBank}%d, mask=${io.req.bits.mask}%b, realMask=${realMask}%b\n")
112    for (i <- 0 until TageBanks) {
113      XSDebug(RegNext(io.req.valid), p"scTableResp[${i.U}]: idx=${idxLatch}%d, ctr:${io.resp(i).ctr}\n")
114    }
115    XSDebug(io.update.mask.reduce(_||_), p"update Table: pc:${u.pc}%x, fetchIdx:${u.fetchIdx}%d, hist:${u.hist}%x, bank:${b}%d, tageTaken:${u.tagePred}%d, taken:${u.taken}%d, oldCtr:${u.oldCtr}%d\n")
116  }
117
118}
119
120class SCThreshold(val ctrBits: Int = 5) extends TageBundle {
121  val ctr = UInt(ctrBits.W)
122  def satPos(ctr: UInt = this.ctr) = ctr === ((1.U << ctrBits) - 1.U)
123  def satNeg(ctr: UInt = this.ctr) = ctr === 0.U
124  def neutralVal = (1.U << (ctrBits - 1))
125  val thres = UInt(5.W)
126  def minThres = 5.U
127  def maxThres = 31.U
128  def update(cause: Bool): SCThreshold = {
129    val res = Wire(new SCThreshold(this.ctrBits))
130    val newCtr = satUpdate(this.ctr, this.ctrBits, cause)
131    val newThres = Mux(res.satPos(newCtr), this.thres + 1.U,
132                      Mux(res.satNeg(newCtr), this.thres - 1.U,
133                      this.thres))
134    res.thres := newThres
135    res.ctr := Mux(res.satPos(newCtr) || res.satNeg(newCtr), res.neutralVal, newCtr)
136    // XSDebug(true.B, p"scThres Update: cause${cause} newCtr ${newCtr} newThres ${newThres}\n")
137    res
138  }
139}
140
141object SCThreshold {
142  def apply(bits: Int) = {
143    val t = Wire(new SCThreshold(ctrBits=bits))
144    t.ctr := t.neutralVal
145    t.thres := t.minThres
146    t
147  }
148}