xref: /XiangShan/src/main/scala/xiangshan/backend/datapath/WbArbiter.scala (revision ac4d321d18df4775b9ddda83e77cf526a0b1ca67)
1package xiangshan.backend.datapath
2
3import org.chipsalliance.cde.config.Parameters
4import chisel3._
5import chisel3.util._
6import difftest.{DiffFpWriteback, DiffIntWriteback, DifftestModule}
7import utils.XSError
8import xiangshan.backend.BackendParams
9import xiangshan.backend.Bundles.{ExuOutput, WriteBackBundle}
10import xiangshan.backend.datapath.DataConfig.{IntData, VecData}
11import xiangshan.backend.regfile.RfWritePortWithConfig
12import xiangshan.{Redirect, XSBundle, XSModule}
13
14class WbArbiterDispatcherIO[T <: Data](private val gen: T, n: Int) extends Bundle {
15  val in = Flipped(DecoupledIO(gen))
16
17  val out = Vec(n, DecoupledIO(gen))
18}
19
20class WbArbiterDispatcher[T <: Data](private val gen: T, n: Int, acceptCond: T => (Seq[Bool], Bool))
21                           (implicit p: Parameters)
22  extends Module {
23
24  val io = IO(new WbArbiterDispatcherIO(gen, n))
25
26  private val acceptVec: Vec[Bool] = VecInit(acceptCond(io.in.bits)._1)
27
28  XSError(io.in.valid && PopCount(acceptVec) > 1.U, s"[ExeUnit] accept vec should no more than 1, ${Binary(acceptVec.asUInt)} ")
29
30  io.out.zipWithIndex.foreach { case (out, i) =>
31    out.valid := acceptVec(i) && io.in.valid
32    out.bits := io.in.bits
33  }
34
35  io.in.ready := Cat(io.out.zip(acceptVec).map{ case(out, canAccept) => out.ready && canAccept}).orR || acceptCond(io.in.bits)._2
36}
37
38class WbArbiterIO()(implicit p: Parameters, params: WbArbiterParams) extends XSBundle {
39  val flush = Flipped(ValidIO(new Redirect))
40  val in: MixedVec[DecoupledIO[WriteBackBundle]] = Flipped(params.genInput)
41  val out: MixedVec[ValidIO[WriteBackBundle]] = params.genOutput
42
43  def inGroup: Map[Int, Seq[DecoupledIO[WriteBackBundle]]] = in.groupBy(_.bits.params.port).map(x => (x._1, x._2.sortBy(_.bits.params.priority).toSeq))
44}
45
46class WbArbiter(params: WbArbiterParams)(implicit p: Parameters) extends XSModule {
47  val io = IO(new WbArbiterIO()(p, params))
48
49  private val inGroup: Map[Int, Seq[DecoupledIO[WriteBackBundle]]] = io.inGroup
50
51  private val arbiters: Seq[Option[RealWBArbiter[WriteBackBundle]]] = Seq.tabulate(params.numOut) { x => {
52    if (inGroup.contains(x)) {
53      Some(Module(new RealWBArbiter(new WriteBackBundle(inGroup.values.head.head.bits.params, backendParams), inGroup(x).length)))
54    } else {
55      None
56    }
57  }}
58
59  arbiters.zipWithIndex.foreach { case (arb, i) =>
60    if (arb.nonEmpty) {
61      arb.get.io.in.zip(inGroup(i)).foreach { case (arbIn, wbIn) =>
62        arbIn <> wbIn
63      }
64    }
65  }
66
67  io.out.zip(arbiters).foreach { case (wbOut, arb) =>
68    if (arb.nonEmpty) {
69      val arbOut = arb.get.io.out
70      arbOut.ready := true.B
71      wbOut.valid := arbOut.valid
72      wbOut.bits := arbOut.bits
73    } else {
74      wbOut := 0.U.asTypeOf(wbOut)
75    }
76  }
77
78  def getInOutMap: Map[Int, Int] = {
79    (params.wbCfgs.indices zip params.wbCfgs.map(_.port)).toMap
80  }
81}
82
83class WbDataPathIO()(implicit p: Parameters, params: BackendParams) extends XSBundle {
84  val flush = Flipped(ValidIO(new Redirect()))
85
86  val fromTop = new Bundle {
87    val hartId = Input(UInt(8.W))
88  }
89
90  val fromIntExu: MixedVec[MixedVec[DecoupledIO[ExuOutput]]] = Flipped(params.intSchdParams.get.genExuOutputDecoupledBundle)
91
92  val fromVfExu: MixedVec[MixedVec[DecoupledIO[ExuOutput]]] = Flipped(params.vfSchdParams.get.genExuOutputDecoupledBundle)
93
94  val fromMemExu: MixedVec[MixedVec[DecoupledIO[ExuOutput]]] = Flipped(params.memSchdParams.get.genExuOutputDecoupledBundle)
95
96  val toIntPreg = Flipped(MixedVec(Vec(params.numPregWb(IntData()),
97    new RfWritePortWithConfig(params.intPregParams.dataCfg, params.intPregParams.addrWidth))))
98
99  val toVfPreg = Flipped(MixedVec(Vec(params.numPregWb(VecData()),
100    new RfWritePortWithConfig(params.vfPregParams.dataCfg, params.vfPregParams.addrWidth))))
101
102  val toCtrlBlock = new Bundle {
103    val writeback: MixedVec[ValidIO[ExuOutput]] = params.genWrite2CtrlBundles
104  }
105}
106
107class WbDataPath(params: BackendParams)(implicit p: Parameters) extends XSModule {
108  val io = IO(new WbDataPathIO()(p, params))
109
110  // split
111  val fromExuPre = (io.fromIntExu ++ io.fromVfExu ++ io.fromMemExu).flatten
112  val fromExuVld: Seq[DecoupledIO[ExuOutput]] = fromExuPre.filter(_.bits.params.hasVLoadFu).toSeq
113  require(fromExuVld.size == 1, "vldCnt should be 1")
114  val vldMgu = Module(new VldMergeUnit(fromExuVld.head.bits.params))
115  vldMgu.io.flush := io.flush
116  vldMgu.io.writeback <> fromExuVld.head
117  val wbReplaceVld: Seq[DecoupledIO[ExuOutput]] = fromExuPre.updated(fromExuPre.indexWhere(_.bits.params.hasVLoadFu), vldMgu.io.writebackAfterMerge).toSeq
118  val fromExu: MixedVec[DecoupledIO[ExuOutput]] = MixedVecInit(wbReplaceVld)
119
120  // io.fromExuPre ------------------------------------------------------------> fromExu
121  //               \                                                         /
122  //                -> vldMgu.io.writeback -> vldMgu.io.writebackAfterMerge /
123  (fromExu zip wbReplaceVld).foreach { case (sink, source) => source.ready := sink.ready }
124
125  // alias
126  val intArbiterInputsWireY = fromExu.filter(_.bits.params.writeIntRf)
127  val intArbiterInputsWireN = fromExu.filterNot(_.bits.params.writeIntRf)
128  val intArbiterInputsWire = Wire(chiselTypeOf(fromExu))
129  intArbiterInputsWire.foreach{ x =>
130    val id = x.bits.params.exuIdx
131    val indexY = intArbiterInputsWireY.map(_.bits.params.exuIdx).indexOf(id)
132    val indexN = intArbiterInputsWireN.map(_.bits.params.exuIdx).indexOf(id)
133    if (indexY > -1) intArbiterInputsWire(id) := intArbiterInputsWireY(indexY)
134    else if(indexN > -1) intArbiterInputsWire(id) := intArbiterInputsWireN(indexN)
135    else assert(false, "intArbiterInputsWire not in intArbiterInputsWireY or intArbiterInputsWireN")
136  }
137  val vfArbiterInputsWireY = fromExu.filter(_.bits.params.writeVfRf)
138  val vfArbiterInputsWireN = fromExu.filterNot(_.bits.params.writeVfRf)
139  val vfArbiterInputsWire = WireInit(fromExu)
140  vfArbiterInputsWire.foreach { x =>
141    val id = x.bits.params.exuIdx
142    val indexY = vfArbiterInputsWireY.map(_.bits.params.exuIdx).indexOf(id)
143    val indexN = vfArbiterInputsWireN.map(_.bits.params.exuIdx).indexOf(id)
144    if (indexY > -1) vfArbiterInputsWire(id) := vfArbiterInputsWireY(indexY)
145    else if (indexN > -1) vfArbiterInputsWire(id) := vfArbiterInputsWireN(indexN)
146    else assert(false, "vfArbiterInputsWire not in vfArbiterInputsWireY or vfArbiterInputsWireN")
147  }
148
149  def acceptCond(exuOutput: ExuOutput): (Seq[Bool], Bool) = {
150    val intWen = if(exuOutput.intWen.isDefined) exuOutput.intWen.get else false.B
151    val fpwen  = if(exuOutput.fpWen.isDefined) exuOutput.fpWen.get else false.B
152    val vecWen = if(exuOutput.vecWen.isDefined) exuOutput.vecWen.get else false.B
153    (Seq(intWen, fpwen || vecWen), !intWen && !fpwen && !vecWen)
154  }
155
156  fromExu.zip(intArbiterInputsWire.zip(vfArbiterInputsWire))map{
157    case (exuOut, (intArbiterInput, vfArbiterInput)) =>
158      val regfilesTypeNum = params.pregParams.filterNot(_.isFake).size
159      val in1ToN = Module(new WbArbiterDispatcher(new ExuOutput(exuOut.bits.params), regfilesTypeNum, acceptCond))
160      in1ToN.io.in.valid := exuOut.valid
161      in1ToN.io.in.bits := exuOut.bits
162      exuOut.ready := in1ToN.io.in.ready
163      in1ToN.io.out.zip(MixedVecInit(intArbiterInput, vfArbiterInput)).foreach { case (source, sink) =>
164        sink.valid := source.valid
165        sink.bits := source.bits
166        source.ready := sink.ready
167      }
168  }
169  intArbiterInputsWireN.foreach(_.ready := false.B)
170  vfArbiterInputsWireN.foreach(_.ready := false.B)
171
172  println(s"[WbDataPath] write int preg: " +
173    s"IntExu(${io.fromIntExu.flatten.count(_.bits.params.writeIntRf)}) " +
174    s"VfExu(${io.fromVfExu.flatten.count(_.bits.params.writeIntRf)}) " +
175    s"MemExu(${io.fromMemExu.flatten.count(_.bits.params.writeIntRf)})"
176  )
177  println(s"[WbDataPath] write vf preg: " +
178    s"IntExu(${io.fromIntExu.flatten.count(_.bits.params.writeVfRf)}) " +
179    s"VfExu(${io.fromVfExu.flatten.count(_.bits.params.writeVfRf)}) " +
180    s"MemExu(${io.fromMemExu.flatten.count(_.bits.params.writeVfRf)})"
181  )
182
183  // modules
184  private val intWbArbiter = Module(new WbArbiter(params.getIntWbArbiterParams))
185  private val vfWbArbiter = Module(new WbArbiter(params.getVfWbArbiterParams))
186  println(s"[WbDataPath] int preg write back port num: ${intWbArbiter.io.out.size}, active port: ${intWbArbiter.io.inGroup.keys.toSeq.sorted}")
187  println(s"[WbDataPath] vf preg write back port num: ${vfWbArbiter.io.out.size}, active port: ${vfWbArbiter.io.inGroup.keys.toSeq.sorted}")
188
189  // module assign
190  intWbArbiter.io.flush <> io.flush
191  require(intWbArbiter.io.in.size == intArbiterInputsWireY.size, s"intWbArbiter input size: ${intWbArbiter.io.in.size}, all vf wb size: ${intArbiterInputsWireY.size}")
192  intWbArbiter.io.in.zip(intArbiterInputsWireY).foreach { case (arbiterIn, in) =>
193    arbiterIn.valid := in.valid && in.bits.intWen.get
194    in.ready := arbiterIn.ready
195    arbiterIn.bits.fromExuOutput(in.bits)
196  }
197  private val intWbArbiterOut = intWbArbiter.io.out
198
199  vfWbArbiter.io.flush <> io.flush
200  require(vfWbArbiter.io.in.size == vfArbiterInputsWireY.size, s"vfWbArbiter input size: ${vfWbArbiter.io.in.size}, all vf wb size: ${vfArbiterInputsWireY.size}")
201  vfWbArbiter.io.in.zip(vfArbiterInputsWireY).foreach { case (arbiterIn, in) =>
202    arbiterIn.valid := in.valid && (in.bits.fpWen.getOrElse(false.B) || in.bits.vecWen.getOrElse(false.B))
203    in.ready := arbiterIn.ready
204    arbiterIn.bits.fromExuOutput(in.bits)
205  }
206
207  private val vfWbArbiterOut = vfWbArbiter.io.out
208
209  private val intExuInputs = io.fromIntExu.flatten.toSeq
210  private val intExuWBs = WireInit(MixedVecInit(intExuInputs))
211  private val vfExuInputs = io.fromVfExu.flatten.toSeq
212  private val vfExuWBs = WireInit(MixedVecInit(vfExuInputs))
213  private val memExuInputs = io.fromMemExu.flatten.toSeq
214  private val memExuWBs = WireInit(MixedVecInit(memExuInputs))
215
216  // only fired port can write back to ctrl block
217  (intExuWBs zip intExuInputs).foreach { case (wb, input) => wb.valid := input.fire }
218  (vfExuWBs zip vfExuInputs).foreach { case (wb, input) => wb.valid := input.fire }
219  (memExuWBs zip memExuInputs).foreach { case (wb, input) => wb.valid := input.fire }
220
221  // the ports not writting back pregs are always ready
222  // the ports set highest priority are always ready
223  (fromExu).foreach( x =>
224    if (x.bits.params.hasNoDataWB || x.bits.params.isHighestWBPriority) x.ready := true.B
225  )
226
227  // io assign
228  private val toIntPreg: MixedVec[RfWritePortWithConfig] = MixedVecInit(intWbArbiterOut.map(x => x.bits.asIntRfWriteBundle(x.fire)).toSeq)
229  private val toVfPreg: MixedVec[RfWritePortWithConfig] = MixedVecInit(vfWbArbiterOut.map(x => x.bits.asVfRfWriteBundle(x.fire)).toSeq)
230
231  private val wb2Ctrl = intExuWBs ++ vfExuWBs ++ memExuWBs
232
233  io.toIntPreg := toIntPreg
234  io.toVfPreg := toVfPreg
235  io.toCtrlBlock.writeback.zip(wb2Ctrl).foreach { case (sink, source) =>
236    sink.valid := source.valid
237    sink.bits := source.bits
238    source.ready := true.B
239  }
240
241  if (env.EnableDifftest || env.AlwaysBasicDiff) {
242    intWbArbiterOut.foreach(out => {
243      val difftest = DifftestModule(new DiffIntWriteback(IntPhyRegs))
244      difftest.coreid := io.fromTop.hartId
245      difftest.valid := out.fire && out.bits.rfWen
246      difftest.address := out.bits.pdest
247      difftest.data := out.bits.data
248    })
249  }
250
251  if (env.EnableDifftest || env.AlwaysBasicDiff) {
252    vfWbArbiterOut.foreach(out => {
253      val difftest = DifftestModule(new DiffFpWriteback(VfPhyRegs))
254      difftest.coreid := io.fromTop.hartId
255      difftest.valid := out.fire // all fp instr will write fp rf
256      difftest.address := out.bits.pdest
257      difftest.data := out.bits.data
258    })
259  }
260
261}
262
263
264
265
266