xref: /XiangShan/src/main/scala/xiangshan/frontend/RAS.scala (revision 8088cde17e46914b7de1cfc4b49b14cb20840c05)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.frontend
18
19import chipsalliance.rocketchip.config.Parameters
20import chisel3._
21import chisel3.experimental.chiselName
22import chisel3.util._
23import utils._
24import utility._
25import xiangshan._
26
27class RASEntry()(implicit p: Parameters) extends XSBundle {
28    val retAddr = UInt(VAddrBits.W)
29    val ctr = UInt(8.W) // layer of nested call functions
30}
31
32@chiselName
33class RAS(implicit p: Parameters) extends BasePredictor {
34  object RASEntry {
35    def apply(retAddr: UInt, ctr: UInt): RASEntry = {
36      val e = Wire(new RASEntry)
37      e.retAddr := retAddr
38      e.ctr := ctr
39      e
40    }
41  }
42
43  @chiselName
44  class RASStack(val rasSize: Int) extends XSModule {
45    val io = IO(new Bundle {
46      val push_valid = Input(Bool())
47      val pop_valid = Input(Bool())
48      val spec_new_addr = Input(UInt(VAddrBits.W))
49
50      val recover_sp = Input(UInt(log2Up(rasSize).W))
51      val recover_top = Input(new RASEntry)
52      val recover_valid = Input(Bool())
53      val recover_push = Input(Bool())
54      val recover_pop = Input(Bool())
55      val recover_new_addr = Input(UInt(VAddrBits.W))
56
57      val sp = Output(UInt(log2Up(rasSize).W))
58      val top = Output(new RASEntry)
59    })
60
61    val debugIO = IO(new Bundle{
62        val spec_push_entry = Output(new RASEntry)
63        val spec_alloc_new = Output(Bool())
64        val recover_push_entry = Output(new RASEntry)
65        val recover_alloc_new = Output(Bool())
66        val sp = Output(UInt(log2Up(rasSize).W))
67        val topRegister = Output(new RASEntry)
68        val out_mem = Output(Vec(RasSize, new RASEntry))
69    })
70
71    val stack = Mem(RasSize, new RASEntry)
72    val sp = RegInit(0.U(log2Up(rasSize).W))
73    val top = RegInit(0.U.asTypeOf(new RASEntry()))
74    val topPtr = RegInit(0.U(log2Up(rasSize).W))
75
76    val wen = WireInit(false.B)
77    val write_bypass_entry = RegInit(0.U.asTypeOf(new RASEntry()))
78    val write_bypass_ptr = RegInit(0.U(log2Up(rasSize).W))
79    val write_bypass_valid = RegInit(false.B)
80    when (wen) {
81      write_bypass_valid := true.B
82    }.elsewhen (write_bypass_valid) {
83      write_bypass_valid := false.B
84    }
85
86    when (write_bypass_valid) {
87      stack(write_bypass_ptr) := write_bypass_entry
88    }
89
90    def ptrInc(ptr: UInt) = Mux(ptr === (rasSize-1).U, 0.U, ptr + 1.U)
91    def ptrDec(ptr: UInt) = Mux(ptr === 0.U, (rasSize-1).U, ptr - 1.U)
92
93    val spec_alloc_new = io.spec_new_addr =/= top.retAddr || top.ctr.andR
94    val recover_alloc_new = io.recover_new_addr =/= io.recover_top.retAddr || io.recover_top.ctr.andR
95
96    // TODO: fix overflow and underflow bugs
97    def update(recover: Bool)(do_push: Bool, do_pop: Bool, do_alloc_new: Bool,
98                              do_sp: UInt, do_top_ptr: UInt, do_new_addr: UInt,
99                              do_top: RASEntry) = {
100      when (do_push) {
101        when (do_alloc_new) {
102          sp     := ptrInc(do_sp)
103          topPtr := do_sp
104          top.retAddr := do_new_addr
105          top.ctr := 0.U
106          // write bypass
107          wen := true.B
108          write_bypass_entry := RASEntry(do_new_addr, 0.U)
109          write_bypass_ptr := do_sp
110        }.otherwise {
111          when (recover) {
112            sp := do_sp
113            topPtr := do_top_ptr
114            top.retAddr := do_top.retAddr
115          }
116          top.ctr := do_top.ctr + 1.U
117          // write bypass
118          wen := true.B
119          write_bypass_entry := RASEntry(do_new_addr, do_top.ctr + 1.U)
120          write_bypass_ptr := do_top_ptr
121        }
122      }.elsewhen (do_pop) {
123        when (do_top.ctr === 0.U) {
124          sp     := ptrDec(do_sp)
125          topPtr := ptrDec(do_top_ptr)
126          // read bypass
127          top :=
128            Mux(ptrDec(do_top_ptr) === write_bypass_ptr && write_bypass_valid,
129              write_bypass_entry,
130              stack.read(ptrDec(do_top_ptr))
131            )
132        }.otherwise {
133          when (recover) {
134            sp := do_sp
135            topPtr := do_top_ptr
136            top.retAddr := do_top.retAddr
137          }
138          top.ctr := do_top.ctr - 1.U
139          // write bypass
140          wen := true.B
141          write_bypass_entry := RASEntry(do_top.retAddr, do_top.ctr - 1.U)
142          write_bypass_ptr := do_top_ptr
143        }
144      }.otherwise {
145        when (recover) {
146          sp := do_sp
147          topPtr := do_top_ptr
148          top := do_top
149          // write bypass
150          wen := true.B
151          write_bypass_entry := do_top
152          write_bypass_ptr := do_top_ptr
153        }
154      }
155    }
156
157
158    update(io.recover_valid)(
159      Mux(io.recover_valid, io.recover_push,     io.push_valid),
160      Mux(io.recover_valid, io.recover_pop,      io.pop_valid),
161      Mux(io.recover_valid, recover_alloc_new,   spec_alloc_new),
162      Mux(io.recover_valid, io.recover_sp,       sp),
163      Mux(io.recover_valid, io.recover_sp - 1.U, topPtr),
164      Mux(io.recover_valid, io.recover_new_addr, io.spec_new_addr),
165      Mux(io.recover_valid, io.recover_top,      top))
166
167    io.sp := sp
168    io.top := top
169
170    val resetIdx = RegInit(0.U(log2Ceil(RasSize).W))
171    val do_reset = RegInit(true.B)
172    when (do_reset) {
173      stack.write(resetIdx, RASEntry(0x80000000L.U, 0.U))
174    }
175    resetIdx := resetIdx + do_reset
176    when (resetIdx === (RasSize-1).U) {
177      do_reset := false.B
178    }
179
180    debugIO.spec_push_entry := RASEntry(io.spec_new_addr, Mux(spec_alloc_new, 1.U, top.ctr + 1.U))
181    debugIO.spec_alloc_new := spec_alloc_new
182    debugIO.recover_push_entry := RASEntry(io.recover_new_addr, Mux(recover_alloc_new, 1.U, io.recover_top.ctr + 1.U))
183    debugIO.recover_alloc_new := recover_alloc_new
184    debugIO.sp := sp
185    debugIO.topRegister := top
186    for (i <- 0 until RasSize) {
187        debugIO.out_mem(i) := Mux(i.U === write_bypass_ptr && write_bypass_valid, write_bypass_entry, stack.read(i.U))
188    }
189  }
190
191  val spec = Module(new RASStack(RasSize))
192  val spec_ras = spec.io
193  val spec_top_addr = spec_ras.top.retAddr
194
195
196  val s2_spec_push = WireInit(false.B)
197  val s2_spec_pop = WireInit(false.B)
198  val s2_full_pred = io.in.bits.resp_in(0).s2.full_pred
199  // when last inst is an rvi call, fall through address would be set to the middle of it, so an addition is needed
200  val s2_spec_new_addr = s2_full_pred.fallThroughAddr + Mux(s2_full_pred.last_may_be_rvi_call, 2.U, 0.U)
201  spec_ras.push_valid := s2_spec_push
202  spec_ras.pop_valid  := s2_spec_pop
203  spec_ras.spec_new_addr := s2_spec_new_addr
204
205  // confirm that the call/ret is the taken cfi
206  s2_spec_push := io.s2_fire && s2_full_pred.hit_taken_on_call && !io.s3_redirect
207  s2_spec_pop  := io.s2_fire && s2_full_pred.hit_taken_on_ret  && !io.s3_redirect
208
209  val s2_jalr_target = io.out.s2.full_pred.jalr_target
210  val s2_last_target_in = s2_full_pred.targets.last
211  val s2_last_target_out = io.out.s2.full_pred.targets.last
212  val s2_is_jalr = s2_full_pred.is_jalr
213  val s2_is_ret = s2_full_pred.is_ret
214  // assert(is_jalr && is_ret || !is_ret)
215  when(s2_is_ret && io.ctrl.ras_enable) {
216    s2_jalr_target := spec_top_addr
217    // FIXME: should use s1 globally
218  }
219  s2_last_target_out := Mux(s2_is_jalr, s2_jalr_target, s2_last_target_in)
220
221  val s3_top = RegEnable(spec_ras.top, io.s2_fire)
222  val s3_sp = RegEnable(spec_ras.sp, io.s2_fire)
223  val s3_spec_new_addr = RegEnable(s2_spec_new_addr, io.s2_fire)
224
225  val s3_jalr_target = io.out.s3.full_pred.jalr_target
226  val s3_last_target_in = io.in.bits.resp_in(0).s3.full_pred.targets.last
227  val s3_last_target_out = io.out.s3.full_pred.targets.last
228  val s3_is_jalr = io.in.bits.resp_in(0).s3.full_pred.is_jalr
229  val s3_is_ret = io.in.bits.resp_in(0).s3.full_pred.is_ret
230  // assert(is_jalr && is_ret || !is_ret)
231  when(s3_is_ret && io.ctrl.ras_enable) {
232    s3_jalr_target := s3_top.retAddr
233    // FIXME: should use s1 globally
234  }
235  s3_last_target_out := Mux(s3_is_jalr, s3_jalr_target, s3_last_target_in)
236
237  val s3_pushed_in_s2 = RegEnable(s2_spec_push, io.s2_fire)
238  val s3_popped_in_s2 = RegEnable(s2_spec_pop,  io.s2_fire)
239  val s3_push = io.in.bits.resp_in(0).s3.full_pred.hit_taken_on_call
240  val s3_pop  = io.in.bits.resp_in(0).s3.full_pred.hit_taken_on_ret
241
242  val s3_recover = io.s3_fire && (s3_pushed_in_s2 =/= s3_push || s3_popped_in_s2 =/= s3_pop)
243  io.out.last_stage_spec_info.rasSp  := s3_sp
244  io.out.last_stage_spec_info.rasTop := s3_top
245
246
247  val redirect = RegNext(io.redirect)
248  val do_recover = redirect.valid || s3_recover
249  val recover_cfi = redirect.bits.cfiUpdate
250
251  val retMissPred  = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isRet
252  val callMissPred = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isCall
253  // when we mispredict a call, we must redo a push operation
254  // similarly, when we mispredict a return, we should redo a pop
255  spec_ras.recover_valid := do_recover
256  spec_ras.recover_push := Mux(redirect.valid, callMissPred, s3_push)
257  spec_ras.recover_pop  := Mux(redirect.valid, retMissPred, s3_pop)
258
259  spec_ras.recover_sp  := Mux(redirect.valid, recover_cfi.rasSp, s3_sp)
260  spec_ras.recover_top := Mux(redirect.valid, recover_cfi.rasEntry, s3_top)
261  spec_ras.recover_new_addr := Mux(redirect.valid, recover_cfi.pc + Mux(recover_cfi.pd.isRVC, 2.U, 4.U), s3_spec_new_addr)
262
263
264  XSPerfAccumulate("ras_s3_recover", s3_recover)
265  XSPerfAccumulate("ras_redirect_recover", redirect.valid)
266  XSPerfAccumulate("ras_s3_and_redirect_recover_at_the_same_time", s3_recover && redirect.valid)
267  // TODO: back-up stack for ras
268  // use checkpoint to recover RAS
269
270  val spec_debug = spec.debugIO
271  XSDebug("----------------RAS----------------\n")
272  XSDebug(" TopRegister: 0x%x   %d \n",spec_debug.topRegister.retAddr,spec_debug.topRegister.ctr)
273  XSDebug("  index       addr           ctr \n")
274  for(i <- 0 until RasSize){
275      XSDebug("  (%d)   0x%x      %d",i.U,spec_debug.out_mem(i).retAddr,spec_debug.out_mem(i).ctr)
276      when(i.U === spec_debug.sp){XSDebug(false,true.B,"   <----sp")}
277      XSDebug(false,true.B,"\n")
278  }
279  XSDebug(s2_spec_push, "s2_spec_push  inAddr: 0x%x  inCtr: %d |  allocNewEntry:%d |   sp:%d \n",
280  s2_spec_new_addr,spec_debug.spec_push_entry.ctr,spec_debug.spec_alloc_new,spec_debug.sp.asUInt)
281  XSDebug(s2_spec_pop, "s2_spec_pop  outAddr: 0x%x \n",io.out.s2.getTarget)
282  val s3_recover_entry = spec_debug.recover_push_entry
283  XSDebug(s3_recover && s3_push, "s3_recover_push  inAddr: 0x%x  inCtr: %d |  allocNewEntry:%d |   sp:%d \n",
284    s3_recover_entry.retAddr, s3_recover_entry.ctr, spec_debug.recover_alloc_new, s3_sp.asUInt)
285  XSDebug(s3_recover && s3_pop, "s3_recover_pop  outAddr: 0x%x \n",io.out.s3.getTarget)
286  val redirectUpdate = redirect.bits.cfiUpdate
287  XSDebug(do_recover && callMissPred, "redirect_recover_push\n")
288  XSDebug(do_recover && retMissPred, "redirect_recover_pop\n")
289  XSDebug(do_recover, "redirect_recover(SP:%d retAddr:%x ctr:%d) \n",
290      redirectUpdate.rasSp,redirectUpdate.rasEntry.retAddr,redirectUpdate.rasEntry.ctr)
291
292  generatePerfEvent()
293}
294