xref: /XiangShan/src/main/scala/xiangshan/frontend/RAS.scala (revision d0a8077aab2f476131c2b61d66fc00c69da8d153)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.frontend
18
19import chipsalliance.rocketchip.config.Parameters
20import chisel3._
21import chisel3.experimental.chiselName
22import chisel3.util._
23import utils._
24import xiangshan._
25
26class RASEntry()(implicit p: Parameters) extends XSBundle {
27    val retAddr = UInt(VAddrBits.W)
28    val ctr = UInt(8.W) // layer of nested call functions
29}
30
31@chiselName
32class RAS(implicit p: Parameters) extends BasePredictor {
33  object RASEntry {
34    def apply(retAddr: UInt, ctr: UInt): RASEntry = {
35      val e = Wire(new RASEntry)
36      e.retAddr := retAddr
37      e.ctr := ctr
38      e
39    }
40  }
41
42  @chiselName
43  class RASStack(val rasSize: Int) extends XSModule {
44    val io = IO(new Bundle {
45      val push_valid = Input(Bool())
46      val pop_valid = Input(Bool())
47      val spec_new_addr = Input(UInt(VAddrBits.W))
48
49      val recover_sp = Input(UInt(log2Up(rasSize).W))
50      val recover_top = Input(new RASEntry)
51      val recover_valid = Input(Bool())
52      val recover_push = Input(Bool())
53      val recover_pop = Input(Bool())
54      val recover_new_addr = Input(UInt(VAddrBits.W))
55
56      val sp = Output(UInt(log2Up(rasSize).W))
57      val top = Output(new RASEntry)
58    })
59
60    val debugIO = IO(new Bundle{
61        val spec_push_entry = Output(new RASEntry)
62        val spec_alloc_new = Output(Bool())
63        val recover_push_entry = Output(new RASEntry)
64        val recover_alloc_new = Output(Bool())
65        val sp = Output(UInt(log2Up(rasSize).W))
66        val topRegister = Output(new RASEntry)
67        val out_mem = Output(Vec(RasSize, new RASEntry))
68    })
69
70    val stack = Mem(RasSize, new RASEntry)
71    val sp = RegInit(0.U(log2Up(rasSize).W))
72    val top = RegInit(RASEntry(0x80000000L.U, 0.U))
73    val topPtr = RegInit(0.U(log2Up(rasSize).W))
74
75    val wen = WireInit(false.B)
76    val write_bypass_entry = Reg(new RASEntry())
77    val write_bypass_ptr = Reg(UInt(log2Up(rasSize).W))
78    val write_bypass_valid = Reg(Bool())
79    when (wen) {
80      write_bypass_valid := true.B
81    }.elsewhen (write_bypass_valid) {
82      write_bypass_valid := false.B
83    }
84
85    when (write_bypass_valid) {
86      stack(write_bypass_ptr) := write_bypass_entry
87    }
88
89    def ptrInc(ptr: UInt) = Mux(ptr === (rasSize-1).U, 0.U, ptr + 1.U)
90    def ptrDec(ptr: UInt) = Mux(ptr === 0.U, (rasSize-1).U, ptr - 1.U)
91
92    val spec_alloc_new = io.spec_new_addr =/= top.retAddr || top.ctr.andR
93    val recover_alloc_new = io.recover_new_addr =/= io.recover_top.retAddr || io.recover_top.ctr.andR
94
95    // TODO: fix overflow and underflow bugs
96    def update(recover: Bool)(do_push: Bool, do_pop: Bool, do_alloc_new: Bool,
97                              do_sp: UInt, do_top_ptr: UInt, do_new_addr: UInt,
98                              do_top: RASEntry) = {
99      when (do_push) {
100        when (do_alloc_new) {
101          sp     := ptrInc(do_sp)
102          topPtr := do_sp
103          top.retAddr := do_new_addr
104          top.ctr := 0.U
105          // write bypass
106          wen := true.B
107          write_bypass_entry := RASEntry(do_new_addr, 0.U)
108          write_bypass_ptr := do_sp
109        }.otherwise {
110          when (recover) {
111            sp := do_sp
112            topPtr := do_top_ptr
113            top.retAddr := do_top.retAddr
114          }
115          top.ctr := do_top.ctr + 1.U
116          // write bypass
117          wen := true.B
118          write_bypass_entry := RASEntry(do_new_addr, do_top.ctr + 1.U)
119          write_bypass_ptr := do_top_ptr
120        }
121      }.elsewhen (do_pop) {
122        when (do_top.ctr === 0.U) {
123          sp     := ptrDec(do_sp)
124          topPtr := ptrDec(do_top_ptr)
125          // read bypass
126          top :=
127            Mux(ptrDec(do_top_ptr) === write_bypass_ptr && write_bypass_valid,
128              write_bypass_entry,
129              stack.read(ptrDec(do_top_ptr))
130            )
131        }.otherwise {
132          when (recover) {
133            sp := do_sp
134            topPtr := do_top_ptr
135            top.retAddr := do_top.retAddr
136          }
137          top.ctr := do_top.ctr - 1.U
138          // write bypass
139          wen := true.B
140          write_bypass_entry := RASEntry(do_top.retAddr, do_top.ctr - 1.U)
141          write_bypass_ptr := do_top_ptr
142        }
143      }.otherwise {
144        when (recover) {
145          sp := do_sp
146          topPtr := do_top_ptr
147          top := do_top
148          // write bypass
149          wen := true.B
150          write_bypass_entry := do_top
151          write_bypass_ptr := do_top_ptr
152        }
153      }
154    }
155
156
157    update(io.recover_valid)(
158      Mux(io.recover_valid, io.recover_push,     io.push_valid),
159      Mux(io.recover_valid, io.recover_pop,      io.pop_valid),
160      Mux(io.recover_valid, recover_alloc_new,   spec_alloc_new),
161      Mux(io.recover_valid, io.recover_sp,       sp),
162      Mux(io.recover_valid, io.recover_sp - 1.U, topPtr),
163      Mux(io.recover_valid, io.recover_new_addr, io.spec_new_addr),
164      Mux(io.recover_valid, io.recover_top,      top))
165
166    io.sp := sp
167    io.top := top
168
169    val resetIdx = RegInit(0.U(log2Ceil(RasSize).W))
170    val do_reset = RegInit(true.B)
171    when (do_reset) {
172      stack.write(resetIdx, RASEntry(0x80000000L.U, 0.U))
173    }
174    resetIdx := resetIdx + do_reset
175    when (resetIdx === (RasSize-1).U) {
176      do_reset := false.B
177    }
178
179    debugIO.spec_push_entry := RASEntry(io.spec_new_addr, Mux(spec_alloc_new, 1.U, top.ctr + 1.U))
180    debugIO.spec_alloc_new := spec_alloc_new
181    debugIO.recover_push_entry := RASEntry(io.recover_new_addr, Mux(recover_alloc_new, 1.U, io.recover_top.ctr + 1.U))
182    debugIO.recover_alloc_new := recover_alloc_new
183    debugIO.sp := sp
184    debugIO.topRegister := top
185    for (i <- 0 until RasSize) {
186        debugIO.out_mem(i) := Mux(i.U === write_bypass_ptr && write_bypass_valid, write_bypass_entry, stack.read(i.U))
187    }
188  }
189
190  val spec = Module(new RASStack(RasSize))
191  val spec_ras = spec.io
192  val spec_top_addr = spec_ras.top.retAddr
193
194
195  val s2_spec_push = WireInit(false.B)
196  val s2_spec_pop = WireInit(false.B)
197  val s2_full_pred = io.in.bits.resp_in(0).s2.full_pred
198  // when last inst is an rvi call, fall through address would be set to the middle of it, so an addition is needed
199  val s2_spec_new_addr = s2_full_pred.fallThroughAddr + Mux(s2_full_pred.last_may_be_rvi_call, 2.U, 0.U)
200  spec_ras.push_valid := s2_spec_push
201  spec_ras.pop_valid  := s2_spec_pop
202  spec_ras.spec_new_addr := s2_spec_new_addr
203
204  // confirm that the call/ret is the taken cfi
205  s2_spec_push := io.s2_fire && s2_full_pred.hit_taken_on_call && !io.s3_redirect
206  s2_spec_pop  := io.s2_fire && s2_full_pred.hit_taken_on_ret  && !io.s3_redirect
207
208  val s2_jalr_target = io.out.resp.s2.full_pred.jalr_target
209  val s2_last_target_in = s2_full_pred.targets.last
210  val s2_last_target_out = io.out.resp.s2.full_pred.targets.last
211  val s2_is_jalr = s2_full_pred.is_jalr
212  val s2_is_ret = s2_full_pred.is_ret
213  // assert(is_jalr && is_ret || !is_ret)
214  when(s2_is_ret && io.ctrl.ras_enable) {
215    s2_jalr_target := spec_top_addr
216    // FIXME: should use s1 globally
217  }
218  s2_last_target_out := Mux(s2_is_jalr, s2_jalr_target, s2_last_target_in)
219
220  val s3_top = RegEnable(spec_ras.top, io.s2_fire)
221  val s3_sp = RegEnable(spec_ras.sp, io.s2_fire)
222  val s3_spec_new_addr = RegEnable(s2_spec_new_addr, io.s2_fire)
223
224  val s3_jalr_target = io.out.resp.s3.full_pred.jalr_target
225  val s3_last_target_in = io.in.bits.resp_in(0).s3.full_pred.targets.last
226  val s3_last_target_out = io.out.resp.s3.full_pred.targets.last
227  val s3_is_jalr = io.in.bits.resp_in(0).s3.full_pred.is_jalr
228  val s3_is_ret = io.in.bits.resp_in(0).s3.full_pred.is_ret
229  // assert(is_jalr && is_ret || !is_ret)
230  when(s3_is_ret && io.ctrl.ras_enable) {
231    s3_jalr_target := s3_top.retAddr
232    // FIXME: should use s1 globally
233  }
234  s3_last_target_out := Mux(s3_is_jalr, s3_jalr_target, s3_last_target_in)
235
236  val s3_pushed_in_s2 = RegEnable(s2_spec_push, io.s2_fire)
237  val s3_popped_in_s2 = RegEnable(s2_spec_pop,  io.s2_fire)
238  val s3_push = io.in.bits.resp_in(0).s3.full_pred.hit_taken_on_call
239  val s3_pop  = io.in.bits.resp_in(0).s3.full_pred.hit_taken_on_ret
240
241  val s3_recover = io.s3_fire && (s3_pushed_in_s2 =/= s3_push || s3_popped_in_s2 =/= s3_pop)
242  io.out.resp.s3.rasSp  := s3_sp
243  io.out.resp.s3.rasTop := s3_top
244
245
246  val redirect = RegNext(io.redirect)
247  val do_recover = redirect.valid || s3_recover
248  val recover_cfi = redirect.bits.cfiUpdate
249
250  val retMissPred  = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isRet
251  val callMissPred = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isCall
252  // when we mispredict a call, we must redo a push operation
253  // similarly, when we mispredict a return, we should redo a pop
254  spec_ras.recover_valid := do_recover
255  spec_ras.recover_push := Mux(redirect.valid, callMissPred, s3_push)
256  spec_ras.recover_pop  := Mux(redirect.valid, retMissPred, s3_pop)
257
258  spec_ras.recover_sp  := Mux(redirect.valid, recover_cfi.rasSp, s3_sp)
259  spec_ras.recover_top := Mux(redirect.valid, recover_cfi.rasEntry, s3_top)
260  spec_ras.recover_new_addr := Mux(redirect.valid, recover_cfi.pc + Mux(recover_cfi.pd.isRVC, 2.U, 4.U), s3_spec_new_addr)
261
262
263  XSPerfAccumulate("ras_s3_recover", s3_recover)
264  XSPerfAccumulate("ras_redirect_recover", redirect.valid)
265  XSPerfAccumulate("ras_s3_and_redirect_recover_at_the_same_time", s3_recover && redirect.valid)
266  // TODO: back-up stack for ras
267  // use checkpoint to recover RAS
268
269  val spec_debug = spec.debugIO
270  XSDebug("----------------RAS----------------\n")
271  XSDebug(" TopRegister: 0x%x   %d \n",spec_debug.topRegister.retAddr,spec_debug.topRegister.ctr)
272  XSDebug("  index       addr           ctr \n")
273  for(i <- 0 until RasSize){
274      XSDebug("  (%d)   0x%x      %d",i.U,spec_debug.out_mem(i).retAddr,spec_debug.out_mem(i).ctr)
275      when(i.U === spec_debug.sp){XSDebug(false,true.B,"   <----sp")}
276      XSDebug(false,true.B,"\n")
277  }
278  XSDebug(s2_spec_push, "s2_spec_push  inAddr: 0x%x  inCtr: %d |  allocNewEntry:%d |   sp:%d \n",
279  s2_spec_new_addr,spec_debug.spec_push_entry.ctr,spec_debug.spec_alloc_new,spec_debug.sp.asUInt)
280  XSDebug(s2_spec_pop, "s2_spec_pop  outAddr: 0x%x \n",io.out.resp.s2.getTarget)
281  val s3_recover_entry = spec_debug.recover_push_entry
282  XSDebug(s3_recover && s3_push, "s3_recover_push  inAddr: 0x%x  inCtr: %d |  allocNewEntry:%d |   sp:%d \n",
283    s3_recover_entry.retAddr, s3_recover_entry.ctr, spec_debug.recover_alloc_new, s3_sp.asUInt)
284  XSDebug(s3_recover && s3_pop, "s3_recover_pop  outAddr: 0x%x \n",io.out.resp.s3.getTarget)
285  val redirectUpdate = redirect.bits.cfiUpdate
286  XSDebug(do_recover && callMissPred, "redirect_recover_push\n")
287  XSDebug(do_recover && retMissPred, "redirect_recover_pop\n")
288  XSDebug(do_recover, "redirect_recover(SP:%d retAddr:%x ctr:%d) \n",
289      redirectUpdate.rasSp,redirectUpdate.rasEntry.retAddr,redirectUpdate.rasEntry.ctr)
290
291  generatePerfEvent()
292}
293