1package xiangshan.cache.wpu 2 3import org.chipsalliance.cde.config.{Field, Parameters} 4import chisel3._ 5import chisel3.util._ 6import utility.XSPerfAccumulate 7import xiangshan.cache.{HasL1CacheParameters, L1CacheParameters} 8import xiangshan.{XSBundle, XSModule} 9 10/* 11// TODO: need to learn the specific grammar 12abstract class WPUBaseModule[T <: Data](implicit P: Parameters) extends XSModule with HasWPUParameters{ 13 def apply[T <: Data] 14 def pred(vaddr: UInt, en: Bool) : T 15 def update(vaddr: UInt, data: T ,en: Bool) 16} 17*/ 18 19case object WPUParamsKey extends Field[WPUParameters] 20case class WPUParameters 21( 22 enWPU: Boolean = true, 23 algoName: String = "mru", 24 enCfPred: Boolean = false, 25 isICache: Boolean = false, 26 // how to impelement a extend inlcude hasL1Cache and L2 Cache 27) 28 29trait HasWPUParameters extends HasL1CacheParameters{ 30 def AlgoWPUMap(wpuParam: WPUParameters, nPorts: Int): BaseWPU = { 31 wpuParam.algoName.toLowerCase match { 32 case "mru" => Module(new MruWPU(wpuParam, nPorts)) 33 case "mmru" => Module(new MmruWPU(wpuParam, nPorts)) 34 case "utag" => Module(new UtagWPU(wpuParam, nPorts)) 35 case t => throw new IllegalArgumentException(s"unknown WPU Algorithm $t") 36 } 37 } 38} 39 40abstract class BaseWPUBundle(implicit P: Parameters) extends XSBundle 41abstract class WPUModule(implicit P: Parameters) extends XSModule with HasWPUParameters 42 43class BaseWpuUpdateBundle(nWays: Int)(implicit p: Parameters) extends BaseWPUBundle{ 44 val en = Bool() 45 val vaddr = UInt(VAddrBits.W) 46 val way_en = UInt(nWays.W) 47} 48 49class LookupWpuUpdateBundle(nWays: Int)(implicit p: Parameters) extends BaseWpuUpdateBundle(nWays){ 50 val pred_way_en = UInt(nWays.W) 51} 52 53class BaseWpuPredictIO(nWays: Int)(implicit p: Parameters) extends BaseWPUBundle{ 54 val en = Input(Bool()) 55 val vaddr = Input(UInt(VAddrBits.W)) 56 val way_en = Output(UInt(nWays.W)) 57} 58 59class WPUBaseIO(portNum: Int, nWays: Int)(implicit p:Parameters) extends BaseWPUBundle { 60 val predVec = Vec(portNum, new BaseWpuPredictIO(nWays)) 61 val updLookup = Input(Vec(portNum, new LookupWpuUpdateBundle(nWays))) 62 val updReplaycarry = Input(Vec(portNum, new BaseWpuUpdateBundle(nWays))) 63 val updTagwrite = Input(Vec(portNum, new BaseWpuUpdateBundle(nWays))) 64} 65 66abstract class BaseWPU(wpuParam: WPUParameters, nPorts: Int)(implicit p:Parameters) extends WPUModule { 67 val cacheParams: L1CacheParameters = if (wpuParam.isICache) icacheParameters else dcacheParameters 68 69 val setSize = nSets 70 val nTagIdx = nWays 71 // auxiliary 1 bit is used to judge whether cache miss 72 val auxWayBits = wayBits + 1 73 val TagIdxBits = log2Up(nTagIdx) 74 val utagBits = 8 75 76 val io = IO(new WPUBaseIO(nPorts, nWays)) 77 78 def get_wpu_idx(addr: UInt): UInt = { 79 addr(untagBits - 1, blockOffBits) 80 } 81} 82 83class MruWPU(wpuParam: WPUParameters, nPorts: Int)(implicit p:Parameters) extends BaseWPU(wpuParam, nPorts){ 84 println(" WpuType: MruWPU") 85 val predict_regs = RegInit(VecInit(Seq.fill(setSize)(0.U(wayBits.W)))) 86 87 def write(upd: BaseWpuUpdateBundle): Unit = { 88 when(upd.en) { 89 val upd_setIdx = get_wpu_idx(upd.vaddr) 90 predict_regs(upd_setIdx) := OHToUInt(upd.way_en) 91 } 92 } 93 94 def predict(pred: BaseWpuPredictIO): Unit = { 95 val predSetIdx = get_wpu_idx(pred.vaddr) 96 when(pred.en) { 97 pred.way_en := UIntToOH(predict_regs(predSetIdx)) 98 }.otherwise { 99 pred.way_en := 0.U(nWays.W) 100 } 101 } 102 103 for(i <- 0 until nPorts){ 104 predict(io.predVec(i)) 105 write(io.updLookup(i)) 106 write(io.updReplaycarry(i)) 107 write(io.updTagwrite(i)) 108 } 109 110} 111 112class MmruWPU(wpuParam: WPUParameters, nPorts: Int)(implicit p:Parameters) extends BaseWPU(wpuParam, nPorts){ 113 println(" WpuType: MmruWPU") 114 val predict_regs = RegInit(VecInit(Seq.fill(setSize)(VecInit(Seq.fill(nTagIdx)(0.U(auxWayBits.W)))))) 115 116 def write(upd: BaseWpuUpdateBundle): Unit = { 117 when(upd.en) { 118 val updSetIdx = get_wpu_idx(upd.vaddr) 119 val updTagIdx = get_vir_tag(upd.vaddr) 120 predict_regs(updSetIdx)(updTagIdx) := OHToUInt(upd.way_en) 121 } 122 } 123 124 def predict(pred: BaseWpuPredictIO): Unit = { 125 val predSetIdx = get_wpu_idx(pred.vaddr) 126 val predTagIdx = get_vir_tag(pred.vaddr) 127 when(pred.en) { 128 //UIntToOH(8.U(4.W))=100000000.U(16.W)=00000000.U(8.W) 129 //UIntToOH(8.U(4.W), 8)=00000001.U(8.W) 130 pred.way_en := UIntToOH(predict_regs(predSetIdx)(predTagIdx)) 131 }.otherwise { 132 pred.way_en := 0.U(nWays.W) 133 } 134 } 135 136 for(i <- 0 until nPorts){ 137 predict(io.predVec(i)) 138 write(io.updLookup(i)) 139 write(io.updReplaycarry(i)) 140 write(io.updTagwrite(i)) 141 } 142 143} 144 145class UtagWPU(wpuParam: WPUParameters, nPorts: Int)(implicit p:Parameters) extends BaseWPU(wpuParam, nPorts){ 146 println(" WpuType: UtagWPU") 147 val utag_regs = RegInit(VecInit(Seq.fill(setSize)(VecInit(Seq.fill(nWays)(0.U(utagBits.W)))))) 148 val valid_regs = RegInit(VecInit(Seq.fill(setSize)(VecInit(Seq.fill(nWays)(false.B))))) 149 150 def get_hash_utag(addr: UInt): UInt = { 151 val utagQuotient = vtagBits / utagBits 152 val utagRemainder = vtagBits % utagBits 153 val vtag = get_vir_tag(addr) 154 155 /* old */ 156 vtag(utagBits * 2 - 1, utagBits) ^ vtag(utagBits - 1, 0) 157 158 /* new */ 159 // val tmp = vtag(utagQuotient * utagBits - 1, 0).asTypeOf(Vec(utagQuotient, UInt(utagBits.W))) 160 // val res1 = tmp.reduce(_ ^ _) 161 // val res2 = Wire(UInt(utagRemainder.W)) 162 // if(utagRemainder!=0){ 163 // res2 := res1(utagRemainder - 1, 0) ^ vtag(vtagBits - 1, utagBits * utagQuotient) 164 // Cat(res1(utagBits - 1, utagRemainder), res2) 165 // }else{ 166 // res1 167 // } 168 } 169 170 def write_utag(upd: BaseWpuUpdateBundle): Unit = { 171 when(upd.en){ 172 val upd_setIdx = get_wpu_idx(upd.vaddr) 173 val upd_utag = get_hash_utag(upd.vaddr) 174 val upd_way = OHToUInt(upd.way_en) 175 utag_regs(upd_setIdx)(upd_way) := upd_utag 176 valid_regs(upd_setIdx)(upd_way) := true.B 177 } 178 } 179 180 def unvalid_utag(upd: LookupWpuUpdateBundle): Unit = { 181 when(upd.en){ 182 val upd_setIdx = get_wpu_idx(upd.vaddr) 183 val upd_way = OHToUInt(upd.pred_way_en) 184 valid_regs(upd_setIdx)(upd_way) := false.B 185 } 186 } 187 188 def predict(pred: BaseWpuPredictIO): Unit = { 189 val req_setIdx = get_wpu_idx(pred.vaddr) 190 val req_utag = get_hash_utag(pred.vaddr) 191 val pred_way_en = Wire(UInt(nWays.W)) 192 when(pred.en) { 193 pred_way_en := VecInit((0 until nWays).map(i => 194 req_utag === utag_regs(req_setIdx)(i) && valid_regs(req_setIdx)(i) 195 )).asUInt 196 }.otherwise { 197 pred_way_en := 0.U(nWays.W) 198 } 199 // avoid hash conflict 200 pred.way_en := UIntToOH(OHToUInt(pred_way_en)) 201 } 202 203 val hash_conflict = Wire(Vec(nPorts, Bool())) 204 for(i <- 0 until nPorts){ 205 predict(io.predVec(i)) 206 val real_way_en = io.updLookup(i).way_en 207 val pred_way_en = io.updLookup(i).pred_way_en 208 val pred_miss = io.updLookup(i).en && !pred_way_en.orR 209 val real_miss = io.updLookup(i).en && !real_way_en.orR 210 val way_match = io.updLookup(i).en && pred_way_en === real_way_en 211 212 hash_conflict(i) := !pred_miss && !way_match 213 // look up: vtag miss but tag hit 214 when(pred_miss && !real_miss) { 215 write_utag(io.updLookup(i)) 216 } 217 // look up: vtag hit but other tag hit ==> unvalid pred way; write real way 218 when(!pred_miss && !way_match) { 219 unvalid_utag(io.updLookup(i)) 220 } 221 when(!pred_miss && !real_miss && !way_match) { 222 write_utag(io.updLookup(i)) 223 } 224 // replay carry 225 write_utag(io.updReplaycarry(i)) 226 // tag write 227 write_utag(io.updTagwrite(i)) 228 } 229 230 XSPerfAccumulate("utag_hash_conflict", PopCount(hash_conflict)) 231}