xref: /XiangShan/src/main/scala/xiangshan/cache/wpu/WPU.scala (revision bb2f3f51dd67f6e16e0cc1ffe43368c9fc7e4aef)
1package xiangshan.cache.wpu
2
3import org.chipsalliance.cde.config.{Field, Parameters}
4import chisel3._
5import chisel3.util._
6import utility.XSPerfAccumulate
7import xiangshan.cache.{HasL1CacheParameters, L1CacheParameters}
8import xiangshan.{XSBundle, XSModule}
9
10/*
11// TODO: need to learn the specific grammar
12abstract class WPUBaseModule[T <: Data](implicit P: Parameters) extends XSModule with HasWPUParameters{
13  def apply[T <: Data]
14  def pred(vaddr: UInt, en: Bool) : T
15  def update(vaddr: UInt, data: T ,en: Bool)
16}
17*/
18
19case object WPUParamsKey extends Field[WPUParameters]
20case class WPUParameters
21(
22  enWPU: Boolean = true,
23  algoName: String = "mru",
24  enCfPred: Boolean = false,
25  isICache: Boolean = false,
26  // how to impelement a extend inlcude hasL1Cache and L2 Cache
27)
28
29trait HasWPUParameters extends HasL1CacheParameters{
30  def AlgoWPUMap(wpuParam: WPUParameters, nPorts: Int): BaseWPU = {
31    wpuParam.algoName.toLowerCase match {
32      case "mru" => Module(new MruWPU(wpuParam, nPorts))
33      case "mmru" => Module(new MmruWPU(wpuParam, nPorts))
34      case "utag" => Module(new UtagWPU(wpuParam, nPorts))
35      case t => throw new IllegalArgumentException(s"unknown WPU Algorithm $t")
36    }
37  }
38}
39
40abstract class BaseWPUBundle(implicit P: Parameters) extends XSBundle
41abstract class WPUModule(implicit P: Parameters) extends XSModule with HasWPUParameters
42
43class BaseWpuUpdateBundle(nWays: Int)(implicit p: Parameters) extends BaseWPUBundle{
44  val en = Bool()
45  val vaddr = UInt(VAddrBits.W)
46  val way_en = UInt(nWays.W)
47}
48
49class LookupWpuUpdateBundle(nWays: Int)(implicit p: Parameters) extends BaseWpuUpdateBundle(nWays){
50  val pred_way_en = UInt(nWays.W)
51}
52
53class BaseWpuPredictIO(nWays: Int)(implicit p: Parameters) extends BaseWPUBundle{
54  val en = Input(Bool())
55  val vaddr = Input(UInt(VAddrBits.W))
56  val way_en = Output(UInt(nWays.W))
57}
58
59class WPUBaseIO(portNum: Int, nWays: Int)(implicit p:Parameters) extends BaseWPUBundle {
60  val predVec = Vec(portNum, new BaseWpuPredictIO(nWays))
61  val updLookup = Input(Vec(portNum, new LookupWpuUpdateBundle(nWays)))
62  val updReplaycarry = Input(Vec(portNum, new BaseWpuUpdateBundle(nWays)))
63  val updTagwrite = Input(Vec(portNum, new BaseWpuUpdateBundle(nWays)))
64}
65
66abstract class BaseWPU(wpuParam: WPUParameters, nPorts: Int)(implicit p:Parameters) extends WPUModule {
67  val cacheParams: L1CacheParameters = if (wpuParam.isICache) icacheParameters else dcacheParameters
68
69  val setSize = nSets
70  val nTagIdx = nWays
71  // auxiliary 1 bit is used to judge whether cache miss
72  val auxWayBits = wayBits + 1
73  val TagIdxBits = log2Up(nTagIdx)
74  val utagBits = 8
75
76  val io = IO(new WPUBaseIO(nPorts, nWays))
77
78  def get_wpu_idx(addr: UInt): UInt = {
79    addr(untagBits - 1, blockOffBits)
80  }
81}
82
83class MruWPU(wpuParam: WPUParameters, nPorts: Int)(implicit p:Parameters) extends BaseWPU(wpuParam, nPorts){
84  println("  WpuType: MruWPU")
85  val predict_regs = RegInit(VecInit(Seq.fill(setSize)(0.U(wayBits.W))))
86
87  def write(upd: BaseWpuUpdateBundle): Unit = {
88    when(upd.en) {
89      val upd_setIdx = get_wpu_idx(upd.vaddr)
90      predict_regs(upd_setIdx) := OHToUInt(upd.way_en)
91    }
92  }
93
94  def predict(pred: BaseWpuPredictIO): Unit = {
95    val predSetIdx = get_wpu_idx(pred.vaddr)
96    when(pred.en) {
97      pred.way_en := UIntToOH(predict_regs(predSetIdx))
98    }.otherwise {
99      pred.way_en := 0.U(nWays.W)
100    }
101  }
102
103  for(i <- 0 until nPorts){
104    predict(io.predVec(i))
105    write(io.updLookup(i))
106    write(io.updReplaycarry(i))
107    write(io.updTagwrite(i))
108  }
109
110}
111
112class MmruWPU(wpuParam: WPUParameters, nPorts: Int)(implicit p:Parameters) extends BaseWPU(wpuParam, nPorts){
113  println("  WpuType: MmruWPU")
114  val predict_regs = RegInit(VecInit(Seq.fill(setSize)(VecInit(Seq.fill(nTagIdx)(0.U(auxWayBits.W))))))
115
116  def write(upd: BaseWpuUpdateBundle): Unit = {
117    when(upd.en) {
118      val updSetIdx = get_wpu_idx(upd.vaddr)
119      val updTagIdx = get_vir_tag(upd.vaddr)
120      predict_regs(updSetIdx)(updTagIdx) := OHToUInt(upd.way_en)
121    }
122  }
123
124  def predict(pred: BaseWpuPredictIO): Unit = {
125    val predSetIdx = get_wpu_idx(pred.vaddr)
126    val predTagIdx = get_vir_tag(pred.vaddr)
127    when(pred.en) {
128      //UIntToOH(8.U(4.W))=100000000.U(16.W)=00000000.U(8.W)
129      //UIntToOH(8.U(4.W), 8)=00000001.U(8.W)
130      pred.way_en := UIntToOH(predict_regs(predSetIdx)(predTagIdx))
131    }.otherwise {
132      pred.way_en := 0.U(nWays.W)
133    }
134  }
135
136  for(i <- 0 until nPorts){
137    predict(io.predVec(i))
138    write(io.updLookup(i))
139    write(io.updReplaycarry(i))
140    write(io.updTagwrite(i))
141  }
142
143}
144
145class UtagWPU(wpuParam: WPUParameters, nPorts: Int)(implicit p:Parameters) extends BaseWPU(wpuParam, nPorts){
146  println("  WpuType: UtagWPU")
147  val utag_regs = RegInit(VecInit(Seq.fill(setSize)(VecInit(Seq.fill(nWays)(0.U(utagBits.W))))))
148  val valid_regs = RegInit(VecInit(Seq.fill(setSize)(VecInit(Seq.fill(nWays)(false.B)))))
149
150  def get_hash_utag(addr: UInt): UInt = {
151    val utagQuotient = vtagBits / utagBits
152    val utagRemainder = vtagBits % utagBits
153    val vtag = get_vir_tag(addr)
154
155    /* old */
156    vtag(utagBits * 2 - 1, utagBits) ^ vtag(utagBits - 1, 0)
157
158    /* new */
159    // val tmp = vtag(utagQuotient * utagBits - 1, 0).asTypeOf(Vec(utagQuotient, UInt(utagBits.W)))
160    // val res1 = tmp.reduce(_ ^ _)
161    // val res2 = Wire(UInt(utagRemainder.W))
162    // if(utagRemainder!=0){
163    //   res2 := res1(utagRemainder - 1, 0) ^ vtag(vtagBits - 1, utagBits * utagQuotient)
164    //   Cat(res1(utagBits - 1, utagRemainder), res2)
165    // }else{
166    //   res1
167    // }
168  }
169
170  def write_utag(upd: BaseWpuUpdateBundle): Unit = {
171    when(upd.en){
172      val upd_setIdx = get_wpu_idx(upd.vaddr)
173      val upd_utag = get_hash_utag(upd.vaddr)
174      val upd_way = OHToUInt(upd.way_en)
175      utag_regs(upd_setIdx)(upd_way) := upd_utag
176      valid_regs(upd_setIdx)(upd_way) := true.B
177    }
178  }
179
180  def unvalid_utag(upd: LookupWpuUpdateBundle): Unit = {
181    when(upd.en){
182      val upd_setIdx = get_wpu_idx(upd.vaddr)
183      val upd_way = OHToUInt(upd.pred_way_en)
184      valid_regs(upd_setIdx)(upd_way) := false.B
185    }
186  }
187
188  def predict(pred: BaseWpuPredictIO): Unit = {
189    val req_setIdx = get_wpu_idx(pred.vaddr)
190    val req_utag = get_hash_utag(pred.vaddr)
191    val pred_way_en = Wire(UInt(nWays.W))
192    when(pred.en) {
193      pred_way_en := VecInit((0 until nWays).map(i =>
194        req_utag === utag_regs(req_setIdx)(i) && valid_regs(req_setIdx)(i)
195      )).asUInt
196    }.otherwise {
197      pred_way_en := 0.U(nWays.W)
198    }
199    // avoid hash conflict
200    pred.way_en := UIntToOH(OHToUInt(pred_way_en))
201  }
202
203  val hash_conflict = Wire(Vec(nPorts, Bool()))
204  for(i <- 0 until nPorts){
205    predict(io.predVec(i))
206    val real_way_en = io.updLookup(i).way_en
207    val pred_way_en = io.updLookup(i).pred_way_en
208    val pred_miss = io.updLookup(i).en && !pred_way_en.orR
209    val real_miss = io.updLookup(i).en && !real_way_en.orR
210    val way_match = io.updLookup(i).en && pred_way_en === real_way_en
211
212    hash_conflict(i) := !pred_miss && !way_match
213      // look up: vtag miss but tag hit
214    when(pred_miss && !real_miss) {
215      write_utag(io.updLookup(i))
216    }
217    // look up: vtag hit but other tag hit ==> unvalid pred way; write real way
218    when(!pred_miss && !way_match) {
219      unvalid_utag(io.updLookup(i))
220    }
221    when(!pred_miss && !real_miss && !way_match) {
222      write_utag(io.updLookup(i))
223    }
224    // replay carry
225    write_utag(io.updReplaycarry(i))
226    // tag write
227    write_utag(io.updTagwrite(i))
228  }
229
230  XSPerfAccumulate("utag_hash_conflict", PopCount(hash_conflict))
231}