xref: /XiangShan/src/main/scala/xiangshan/frontend/PreDecode.scala (revision 708ceed4afe43fb0ea3a52407e46b2794c573634)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.frontend
18
19import chipsalliance.rocketchip.config.Parameters
20import freechips.rocketchip.rocket.{RVCDecoder, ExpandedInstruction}
21import chisel3.{util, _}
22import chisel3.util._
23import utils._
24import xiangshan._
25import xiangshan.backend.decode.isa.predecode.PreDecodeInst
26import xiangshan.cache._
27
28trait HasPdConst extends HasXSParameter with HasICacheParameters with HasIFUConst{
29  def isRVC(inst: UInt) = (inst(1,0) =/= 3.U)
30  def isLink(reg:UInt) = reg === 1.U || reg === 5.U
31  def brInfo(instr: UInt) = {
32    val brType::Nil = ListLookup(instr, List(BrType.notCFI), PreDecodeInst.brTable)
33    val rd = Mux(isRVC(instr), instr(12), instr(11,7))
34    val rs = Mux(isRVC(instr), Mux(brType === BrType.jal, 0.U, instr(11, 7)), instr(19, 15))
35    val isCall = (brType === BrType.jal && !isRVC(instr) || brType === BrType.jalr) && isLink(rd) // Only for RV64
36    val isRet = brType === BrType.jalr && isLink(rs) && !isCall
37    List(brType, isCall, isRet)
38  }
39  def jal_offset(inst: UInt, rvc: Bool): UInt = {
40    val rvc_offset = Cat(inst(12), inst(8), inst(10, 9), inst(6), inst(7), inst(2), inst(11), inst(5, 3), 0.U(1.W))
41    val rvi_offset = Cat(inst(31), inst(19, 12), inst(20), inst(30, 21), 0.U(1.W))
42    val max_width = rvi_offset.getWidth
43    SignExt(Mux(rvc, SignExt(rvc_offset, max_width), SignExt(rvi_offset, max_width)), XLEN)
44  }
45  def br_offset(inst: UInt, rvc: Bool): UInt = {
46    val rvc_offset = Cat(inst(12), inst(6, 5), inst(2), inst(11, 10), inst(4, 3), 0.U(1.W))
47    val rvi_offset = Cat(inst(31), inst(7), inst(30, 25), inst(11, 8), 0.U(1.W))
48    val max_width = rvi_offset.getWidth
49    SignExt(Mux(rvc, SignExt(rvc_offset, max_width), SignExt(rvi_offset, max_width)), XLEN)
50  }
51  def getBasicBlockIdx( pc: UInt, start:  UInt ): UInt = {
52    val byteOffset = pc - start
53    (byteOffset - instBytes.U)(log2Ceil(PredictWidth),instOffsetBits)
54  }
55
56  def NOP = "h4501".U(16.W)
57}
58
59object BrType {
60  def notCFI   = "b00".U
61  def branch  = "b01".U
62  def jal     = "b10".U
63  def jalr    = "b11".U
64  def apply() = UInt(2.W)
65}
66
67object ExcType {  //TODO:add exctype
68  def notExc = "b000".U
69  def apply() = UInt(3.W)
70}
71
72class PreDecodeInfo extends Bundle {  // 8 bit
73  val valid   = Bool()
74  val isRVC   = Bool()
75  val brType  = UInt(2.W)
76  val isCall  = Bool()
77  val isRet   = Bool()
78  //val excType = UInt(3.W)
79  def isBr    = brType === BrType.branch
80  def isJal   = brType === BrType.jal
81  def isJalr  = brType === BrType.jalr
82  def notCFI  = brType === BrType.notCFI
83}
84
85class PreDecodeResp(implicit p: Parameters) extends XSBundle with HasPdConst {
86  val pc          = Vec(PredictWidth, UInt(VAddrBits.W))
87  val instrs      = Vec(PredictWidth, UInt(32.W))
88  val pd          = Vec(PredictWidth, (new PreDecodeInfo))
89  val takens      = Vec(PredictWidth, Bool())
90  val misOffset    = ValidUndirectioned(UInt(log2Ceil(PredictWidth).W))
91  val cfiOffset    = ValidUndirectioned(UInt(log2Ceil(PredictWidth).W))
92  val target       = UInt(VAddrBits.W)
93  val jalTarget    = UInt(VAddrBits.W)
94  val hasLastHalf  = Bool()
95  val realEndPC    = UInt(VAddrBits.W)
96  val instrRange   = Vec(PredictWidth, Bool())
97  val pageFault    = Vec(PredictWidth, Bool())
98  val accessFault  = Vec(PredictWidth, Bool())
99  val crossPageIPF = Vec(PredictWidth, Bool())
100}
101
102class PreDecode(implicit p: Parameters) extends XSModule with HasPdConst{
103  val io = IO(new Bundle() {
104    val in = Input(new IfuToPreDecode)
105    val out = Output(new PreDecodeResp)
106  })
107
108  val instValid     = io.in.instValid
109  val data          = io.in.data
110  val pcStart       = io.in.startAddr
111  val pcEnd         = io.in.fallThruAddr
112  val pcEndError    = io.in.fallThruError
113  val isDoubleLine  = io.in.isDoubleLine
114  val bbOffset      = io.in.ftqOffset.bits
115  val bbTaken       = io.in.ftqOffset.valid
116  val bbTarget      = io.in.target
117  val oversize      = io.in.oversize
118  val pageFault     = io.in.pageFault
119  val accessFault   = io.in.accessFault
120
121
122  val validStart        = Wire(Vec(PredictWidth, Bool()))
123  dontTouch(validStart)
124  val validEnd          = Wire(Vec(PredictWidth, Bool()))
125  val targets           = Wire(Vec(PredictWidth, UInt(VAddrBits.W)))
126  val misPred           = Wire(Vec(PredictWidth, Bool()))
127  val takens            = Wire(Vec(PredictWidth, Bool()))
128  val falseHit          = Wire(Vec(PredictWidth, Bool()))
129  val instRange         = Wire(Vec(PredictWidth, Bool()))
130  //"real" means signals that are genrated by repaired end pc of this basic block using predecode information
131  val realEndPC         = Wire(UInt(VAddrBits.W))
132  val realHasLastHalf   = Wire(Vec(PredictWidth, Bool()))
133  val realMissPred      = Wire(Vec(PredictWidth, Bool()))
134  val realTakens        = Wire(Vec(PredictWidth, Bool()))
135
136  val rawInsts = if (HasCExtension) VecInit((0 until PredictWidth).map(i => Cat(data(i+1), data(i))))
137                       else         VecInit((0 until PredictWidth).map(i => data(i)))
138
139  val nextLinePC =  align(pcStart, 64) + 64.U
140
141  for (i <- 0 until PredictWidth) {
142    //TODO: Terrible timing for pc comparing
143    val isNextLine      = (io.out.pc(i) > nextLinePC)
144    val nullInstruction = isNextLine && !isDoubleLine
145
146    val hasPageFault   = validStart(i) && ((io.out.pc(i) < nextLinePC && pageFault(0))   || (io.out.pc(i) > nextLinePC && pageFault(1)))
147    val hasAccessFault = validStart(i) && ((io.out.pc(i) < nextLinePC && accessFault(0)) || (io.out.pc(i) > nextLinePC && accessFault(1)))
148    val exception      = hasPageFault || hasAccessFault
149    val inst           = Mux(exception || nullInstruction , NOP, WireInit(rawInsts(i)))
150    val expander       = Module(new RVCExpander)
151
152    val isFirstInBlock = i.U === 0.U
153    val isLastInBlock  = (i == PredictWidth - 1).B
154    val currentPC      = pcStart + (i << 1).U((log2Ceil(PredictWidth)+1).W)
155    val currentIsRVC   = isRVC(inst) && HasCExtension.B
156
157    val lastIsValidEnd =  if (i == 0) { !io.in.lastHalfMatch } else { validEnd(i-1) || isFirstInBlock || !HasCExtension.B }
158
159    validStart(i)   := (lastIsValidEnd || !HasCExtension.B)
160    validEnd(i)     := validStart(i) && currentIsRVC || !validStart(i) || !HasCExtension.B
161
162    val brType::isCall::isRet::Nil = brInfo(inst)
163    val jalOffset = jal_offset(inst, currentIsRVC)
164    val brOffset  = br_offset(inst, currentIsRVC)
165
166    io.out.pd(i).valid         := (lastIsValidEnd || !HasCExtension.B)
167    io.out.pd(i).isRVC         := currentIsRVC
168    io.out.pd(i).brType        := brType
169    io.out.pd(i).isCall        := isCall
170    io.out.pd(i).isRet         := isRet
171    io.out.pc(i)               := currentPC
172    io.out.pageFault(i)        := hasPageFault
173    io.out.accessFault(i)      := hasAccessFault
174    io.out.crossPageIPF(i)     := (io.out.pc(i) === align(realEndPC, 64) - 2.U) && !pageFault(0) && pageFault(1) && !currentIsRVC
175
176    expander.io.in             := inst
177    io.out.instrs(i)           := expander.io.out.bits
178
179    takens(i)    := (validStart(i)  && (bbTaken && bbOffset === i.U && !io.out.pd(i).notCFI || io.out.pd(i).isJal || io.out.pd(i).isRet))
180
181    val jumpTarget      = io.out.pc(i) + Mux(io.out.pd(i).isBr, brOffset, jalOffset)
182    targets(i) := Mux(takens(i), jumpTarget, pcEnd)
183                       //Banch and jal have wrong targets
184    val targetFault    = (validStart(i)  && i.U === bbOffset && bbTaken && (io.out.pd(i).isBr || io.out.pd(i).isJal) && bbTarget =/= targets(i))
185                       //An not-CFI instruction is predicted taken
186    val notCFIFault    = (validStart(i)  && i.U === bbOffset && io.out.pd(i).notCFI && bbTaken)
187                       //A jal instruction is predicted not taken
188    val jalFault       = (validStart(i)  && !bbTaken && io.out.pd(i).isJal)
189                       //A ret instruction is predicted not taken
190    val retFault       = (validStart(i)  && !bbTaken && io.out.pd(i).isRet)
191                       //An invalid instruction is predicted taken
192    val invalidInsFault  = (!validStart(i)  && i.U === bbOffset && bbTaken)
193
194    misPred(i)   := targetFault  || notCFIFault || jalFault || retFault || invalidInsFault || pcEndError
195    falseHit(i)  := invalidInsFault || notCFIFault
196
197    realMissPred(i)     := misPred(i) && instRange(i)
198    realHasLastHalf(i)  := instValid && currentPC === (realEndPC - 2.U) && validStart(i) && instRange(i) && !currentIsRVC
199    realTakens(i)       := takens(i) && instRange(i)
200  }
201
202  val jumpOH                  =  VecInit(io.out.pd.zipWithIndex.map{ case(inst, i) => inst.isJal  && validStart(i) }) //TODO: need jalr?
203  val jumpOffset              =  PriorityEncoder(jumpOH)
204  val rvcOH                   =  VecInit(io.out.pd.map(inst => inst.isRVC))
205  val jumpPC                  =  io.out.pc(jumpOffset)
206  val jumpIsRVC               =  rvcOH(jumpOffset)
207  val jumpNextPC              =  jumpPC + Mux(jumpIsRVC, 2.U, 4.U)
208  val (hasFalseHit, hasJump)  =  (ParallelOR(falseHit), ParallelOR(jumpOH))
209  val endRange                =  ((Fill(PredictWidth, 1.U(1.W)) >> (~getBasicBlockIdx(realEndPC, pcStart))) | (Fill(PredictWidth, oversize)))
210  val takeRange               =  Fill(PredictWidth, !ParallelOR(takens))   | Fill(PredictWidth, 1.U(1.W)) >> (~PriorityEncoder(takens))
211  val fixCross                =  ((pcStart + (FetchWidth * 4).U) > nextLinePC) && !isDoubleLine
212  val boundPC                 =  Mux(fixCross, nextLinePC - 2.U  ,pcStart + (FetchWidth * 4).U)
213
214  instRange               :=  VecInit((0 until PredictWidth).map(i => endRange(i) &&  takeRange(i)))
215  realEndPC               :=  Mux(hasFalseHit, Mux(hasJump && ((jumpNextPC < boundPC) || (jumpNextPC === boundPC) ), jumpNextPC, boundPC), pcEnd)
216
217  val validLastOffset     = Mux(io.out.pd((PredictWidth - 1).U).valid, (PredictWidth - 1).U, (PredictWidth - 2).U)
218  io.out.misOffset.valid  := ParallelOR(realMissPred)
219  io.out.misOffset.bits   := Mux(pcEndError,validLastOffset,PriorityEncoder(realMissPred))
220  io.out.instrRange.zipWithIndex.map{case (bit,i) => bit := instRange(i).asBool()}
221
222  io.out.cfiOffset.valid  := ParallelOR(realTakens)
223  io.out.cfiOffset.bits   := PriorityEncoder(realTakens)
224
225  io.out.target           := Mux(io.out.cfiOffset.valid, targets(io.out.cfiOffset.bits), realEndPC)
226  io.out.takens           := realTakens
227
228  io.out.jalTarget        :=  targets(jumpOffset)
229
230  io.out.hasLastHalf      := realHasLastHalf.reduce(_||_)
231  io.out.realEndPC        := realEndPC
232
233  for (i <- 0 until PredictWidth) {
234    XSDebug(true.B,
235      p"instr ${Hexadecimal(io.out.instrs(i))}, " +
236      p"validStart ${Binary(validStart(i))}, " +
237      p"validEnd ${Binary(validEnd(i))}, " +
238      p"pc ${Hexadecimal(io.out.pc(i))}, " +
239      p"isRVC ${Binary(io.out.pd(i).isRVC)}, " +
240      p"brType ${Binary(io.out.pd(i).brType)}, " +
241      p"isRet ${Binary(io.out.pd(i).isRet)}, " +
242      p"isCall ${Binary(io.out.pd(i).isCall)}\n"
243    )
244  }
245}
246
247class RVCExpander(implicit p: Parameters) extends XSModule {
248  val io = IO(new Bundle {
249    val in = Input(UInt(32.W))
250    val out = Output(new ExpandedInstruction)
251  })
252
253  if (HasCExtension) {
254    io.out := new RVCDecoder(io.in, XLEN).decode
255  } else {
256    io.out := new RVCDecoder(io.in, XLEN).passthrough
257  }
258}
259