xref: /XiangShan/src/main/scala/xiangshan/backend/fu/Bku.scala (revision bb2f3f51dd67f6e16e0cc1ffe43368c9fc7e4aef)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.backend.fu
18
19import org.chipsalliance.cde.config.Parameters
20import chisel3._
21import chisel3.util._
22import utility.{LookupTreeDefault, ParallelMux, ParallelXOR, SignExt, ZeroExt}
23import utility.{XSDebug, XSError}
24import xiangshan._
25import xiangshan.backend.fu.util._
26
27class CountModule(implicit p: Parameters) extends XSModule {
28  val io = IO(new Bundle() {
29    val src = Input(UInt(XLEN.W))
30    val func = Input(UInt())
31    val regEnable = Input(Bool())
32    val out = Output(UInt(XLEN.W))
33  })
34
35  def encode(bits: UInt): UInt = {
36    LookupTreeDefault(bits, 0.U, List(0.U -> 2.U(2.W), 1.U -> 1.U(2.W)))
37  }
38  def clzi(msb: Int, left: UInt, right: UInt): UInt = {
39    Mux(left(msb),
40      Cat(left(msb) && right(msb), !right(msb), if(msb==1)right(0) else right(msb-1, 0)),
41      left)
42  }
43
44  // stage 0
45  val c0 = Wire(Vec(32, UInt(2.W)))
46  val c1 = Wire(Vec(16, UInt(3.W)))
47  val countSrc = Mux(io.func(1), Reverse(io.src), io.src)
48
49  for(i <- 0 until 32){ c0(i) := encode(countSrc(2*i+1, 2*i)) }
50  for(i <- 0 until 16){ c1(i) := clzi(1, c0(i*2+1), c0(i*2)) }
51
52  // pipeline registers
53  val funcReg = RegEnable(io.func, io.regEnable)
54  val c2 = Reg(Vec(8, UInt(4.W)))
55  val cpopTmp = Reg(Vec(4, UInt(5.W)))
56  when (io.regEnable) {
57    for (i <- 0 until 8) {
58      c2(i) := clzi(2, c1(i*2+1), c1(i*2))
59    }
60    for (i <- 0 until 4) {
61      cpopTmp(i) := PopCount(io.src(i*16+15, i*16))
62    }
63  }
64
65  // stage 1
66  val c3 = Wire(Vec(4, UInt(5.W)))
67  val c4 = Wire(Vec(2, UInt(6.W)))
68
69  for(i <- 0 until  4){ c3(i) := clzi(3, c2(i*2+1), c2(i*2)) }
70  for(i <- 0 until  2){ c4(i) := clzi(4, c3(i*2+1), c3(i*2)) }
71  val zeroRes = clzi(5, c4(1), c4(0))
72  val zeroWRes = Mux(funcReg(1), c4(1), c4(0))
73
74  val cpopLo32 = cpopTmp(0) +& cpopTmp(1)
75  val cpopHi32 = cpopTmp(2) +& cpopTmp(3)
76
77  val cpopRes = cpopLo32 +& cpopHi32
78  val cpopWRes = cpopLo32
79
80  io.out := Mux(funcReg(2), Mux(funcReg(0), cpopWRes, cpopRes), Mux(funcReg(0), zeroWRes, zeroRes))
81}
82
83class ClmulModule(implicit p: Parameters) extends XSModule {
84  val io = IO(new Bundle() {
85    val src = Vec(2, Input(UInt(XLEN.W)))
86    val func = Input(UInt())
87    val regEnable = Input(Bool())
88    val out = Output(UInt(XLEN.W))
89  })
90
91  // stage 0
92  val (src1, src2) = (io.src(0), io.src(1))
93
94  val mul0 = Wire(Vec(64, UInt(128.W)))
95  val mul1 = Wire(Vec(32, UInt(128.W)))
96  val mul2 = Wire(Vec(16, UInt(128.W)))
97
98  (0 until XLEN) map { i =>
99    mul0(i) := Mux(src1(i), if(i==0) src2 else Cat(src2, 0.U(i.W)), 0.U)
100  }
101  (0 until 32) map { i => mul1(i) := mul0(i*2) ^ mul0(i*2+1)}
102  (0 until 16) map { i => mul2(i) := mul1(i*2) ^ mul1(i*2+1)}
103
104  // pipeline registers
105  val funcReg = RegEnable(io.func, io.regEnable)
106  val mul3 = Reg(Vec(8, UInt(128.W)))
107  when (io.regEnable) {
108    (0 until 8) map { i => mul3(i) := mul2(i*2) ^ mul2(i*2+1)}
109  }
110
111  // stage 1
112  val res = ParallelXOR(mul3)
113
114  val clmul  = res(63,0)
115  val clmulh = res(127,64)
116  val clmulr = res(126,63)
117
118  io.out := LookupTreeDefault(funcReg, clmul, List(
119    BKUOpType.clmul  -> clmul,
120    BKUOpType.clmulh -> clmulh,
121    BKUOpType.clmulr -> clmulr
122  ))
123}
124
125class MiscModule(implicit p: Parameters) extends XSModule {
126  val io = IO(new Bundle() {
127    val src = Vec(2, Input(UInt(XLEN.W)))
128    val func = Input(UInt())
129    val regEnable = Input(Bool())
130    val out = Output(UInt(XLEN.W))
131  })
132
133  val (src1, src2) = (io.src(0), io.src(1))
134
135  def xpermLUT(table: UInt, idx: UInt, width: Int) : UInt = {
136    // ParallelMux((0 until XLEN/width).map( i => i.U -> table(i)).map( x => (x._1 === idx, x._2)))
137    LookupTreeDefault(idx, 0.U(width.W), (0 until XLEN/width).map( i => i.U -> table(i*width+width-1, i*width)))
138  }
139
140  val xpermnVec = Wire(Vec(16, UInt(4.W)))
141  (0 until 16).map( i => xpermnVec(i) := xpermLUT(src1, src2(i*4+3, i*4), 4))
142  val xpermn = Cat(xpermnVec.reverse)
143
144  val xpermbVec = Wire(Vec(8, UInt(8.W)))
145  (0 until 8).map( i => xpermbVec(i) := Mux(src2(i*8+7, i*8+3).orR, 0.U, xpermLUT(src1, src2(i*8+2, i*8), 8)))
146  val xpermb = Cat(xpermbVec.reverse)
147
148  io.out := RegEnable(Mux(io.func(0), xpermb, xpermn), io.regEnable)
149}
150
151class HashModule(implicit p: Parameters) extends XSModule {
152  val io = IO(new Bundle() {
153    val src = Input(UInt(XLEN.W))
154    val func = Input(UInt())
155    val regEnable = Input(Bool())
156    val out = Output(UInt(XLEN.W))
157  })
158
159  val src1 = io.src
160
161  val sha256sum0 = ROR32(src1, 2)  ^ ROR32(src1, 13) ^ ROR32(src1, 22)
162  val sha256sum1 = ROR32(src1, 6)  ^ ROR32(src1, 11) ^ ROR32(src1, 25)
163  val sha256sig0 = ROR32(src1, 7)  ^ ROR32(src1, 18) ^ SHR32(src1, 3)
164  val sha256sig1 = ROR32(src1, 17) ^ ROR32(src1, 19) ^ SHR32(src1, 10)
165  val sha512sum0 = ROR64(src1, 28) ^ ROR64(src1, 34) ^ ROR64(src1, 39)
166  val sha512sum1 = ROR64(src1, 14) ^ ROR64(src1, 18) ^ ROR64(src1, 41)
167  val sha512sig0 = ROR64(src1, 1)  ^ ROR64(src1, 8)  ^ SHR64(src1, 7)
168  val sha512sig1 = ROR64(src1, 19) ^ ROR64(src1, 61) ^ SHR64(src1, 6)
169  val sm3p0      = ROR32(src1, 23) ^ ROR32(src1, 15) ^ src1
170  val sm3p1      = ROR32(src1, 9)  ^ ROR32(src1, 17) ^ src1
171
172  val shaSource = VecInit(Seq(
173    SignExt(sha256sum0(31,0), XLEN),
174    SignExt(sha256sum1(31,0), XLEN),
175    SignExt(sha256sig0(31,0), XLEN),
176    SignExt(sha256sig1(31,0), XLEN),
177    sha512sum0,
178    sha512sum1,
179    sha512sig0,
180    sha512sig1
181  ))
182  val sha = shaSource(io.func(2,0))
183  val sm3 = Mux(io.func(0), SignExt(sm3p1(31,0), XLEN), SignExt(sm3p0(31,0), XLEN))
184
185  io.out := RegEnable(Mux(io.func(3), sm3, sha), io.regEnable)
186}
187
188class BlockCipherModule(implicit p: Parameters) extends XSModule {
189  val io = IO(new Bundle() {
190    val src = Vec(2, Input(UInt(XLEN.W)))
191    val func = Input(UInt())
192    val regEnable = Input(Bool())
193    val out = Output(UInt(XLEN.W))
194  })
195
196  val (src1, src2, func, funcReg) = (io.src(0), io.src(1), io.func, RegEnable(io.func, io.regEnable))
197
198  val src1Bytes = VecInit((0 until 8).map(i => src1(i*8+7, i*8)))
199  val src2Bytes = VecInit((0 until 8).map(i => src2(i*8+7, i*8)))
200
201  // AES
202  val aesSboxIn  = ForwardShiftRows(src1Bytes, src2Bytes)
203  val aesSboxMid  = Reg(Vec(8, Vec(18, Bool())))
204  val aesSboxOut  = Wire(Vec(8, UInt(8.W)))
205
206  val iaesSboxIn = InverseShiftRows(src1Bytes, src2Bytes)
207  val iaesSboxMid  = Reg(Vec(8, Vec(18, Bool())))
208  val iaesSboxOut = Wire(Vec(8, UInt(8.W)))
209
210  aesSboxOut.zip(aesSboxMid).zip(aesSboxIn)foreach { case ((out, mid), in) =>
211    when (io.regEnable) {
212      mid := SboxInv(SboxAesTop(in))
213    }
214    out := SboxAesOut(mid)
215  }
216
217  iaesSboxOut.zip(iaesSboxMid).zip(iaesSboxIn)foreach { case ((out, mid), in) =>
218    when (io.regEnable) {
219      mid := SboxInv(SboxIaesTop(in))
220    }
221    out := SboxIaesOut(mid)
222  }
223
224  val aes64es = aesSboxOut.asUInt
225  val aes64ds = iaesSboxOut.asUInt
226
227  val imMinIn  = RegEnable(src1Bytes, io.regEnable)
228
229  val aes64esm = Cat(MixFwd(Seq(aesSboxOut(4), aesSboxOut(5), aesSboxOut(6), aesSboxOut(7))),
230                     MixFwd(Seq(aesSboxOut(0), aesSboxOut(1), aesSboxOut(2), aesSboxOut(3))))
231  val aes64dsm = Cat(MixInv(Seq(iaesSboxOut(4), iaesSboxOut(5), iaesSboxOut(6), iaesSboxOut(7))),
232                     MixInv(Seq(iaesSboxOut(0), iaesSboxOut(1), iaesSboxOut(2), iaesSboxOut(3))))
233  val aes64im  = Cat(MixInv(Seq(imMinIn(4), imMinIn(5), imMinIn(6), imMinIn(7))),
234                     MixInv(Seq(imMinIn(0), imMinIn(1), imMinIn(2), imMinIn(3))))
235
236
237  val rcon = WireInit(VecInit(Seq("h01".U, "h02".U, "h04".U, "h08".U,
238                                  "h10".U, "h20".U, "h40".U, "h80".U,
239                                  "h1b".U, "h36".U, "h00".U)))
240
241  val ksSboxIn  = Wire(Vec(4, UInt(8.W)))
242  val ksSboxTop = Reg(Vec(4, Vec(21, Bool())))
243  val ksSboxOut = Wire(Vec(4, UInt(8.W)))
244  ksSboxIn(0) := Mux(src2(3,0) === "ha".U, src1Bytes(4), src1Bytes(5))
245  ksSboxIn(1) := Mux(src2(3,0) === "ha".U, src1Bytes(5), src1Bytes(6))
246  ksSboxIn(2) := Mux(src2(3,0) === "ha".U, src1Bytes(6), src1Bytes(7))
247  ksSboxIn(3) := Mux(src2(3,0) === "ha".U, src1Bytes(7), src1Bytes(4))
248  ksSboxOut.zip(ksSboxTop).zip(ksSboxIn).foreach{ case ((out, top), in) =>
249    when (io.regEnable) {
250      top := SboxAesTop(in)
251    }
252    out := SboxAesOut(SboxInv(top))
253    }
254
255  val ks1Idx = RegEnable(src2(3,0), io.regEnable)
256  val aes64ks1i = Cat(ksSboxOut.asUInt ^ rcon(ks1Idx), ksSboxOut.asUInt ^ rcon(ks1Idx))
257
258  val aes64ks2Temp = src1(63,32) ^ src2(31,0)
259  val aes64ks2 = RegEnable(Cat(aes64ks2Temp ^ src2(63,32), aes64ks2Temp), io.regEnable)
260
261  val aesResult = LookupTreeDefault(funcReg, aes64es, List(
262    BKUOpType.aes64es   -> aes64es,
263    BKUOpType.aes64esm  -> aes64esm,
264    BKUOpType.aes64ds   -> aes64ds,
265    BKUOpType.aes64dsm  -> aes64dsm,
266    BKUOpType.aes64im   -> aes64im,
267    BKUOpType.aes64ks1i -> aes64ks1i,
268    BKUOpType.aes64ks2  -> aes64ks2
269  ))
270
271  // SM4
272  val sm4SboxIn  = src2Bytes(func(1,0))
273  val sm4SboxTop = Reg(Vec(21, Bool()))
274  when (io.regEnable) {
275    sm4SboxTop := SboxSm4Top(sm4SboxIn)
276  }
277  val sm4SboxOut = SboxSm4Out(SboxInv(sm4SboxTop))
278
279  val sm4ed = sm4SboxOut ^ (sm4SboxOut<<8) ^ (sm4SboxOut<<2) ^ (sm4SboxOut<<18) ^ ((sm4SboxOut&"h3f".U)<<26) ^ ((sm4SboxOut&"hc0".U)<<10)
280  val sm4ks = sm4SboxOut ^ ((sm4SboxOut&"h07".U)<<29) ^ ((sm4SboxOut&"hfe".U)<<7) ^ ((sm4SboxOut&"h01".U)<<23) ^ ((sm4SboxOut&"hf8".U)<<13)
281  val sm4Source = VecInit(Seq(
282    sm4ed(31,0),
283    Cat(sm4ed(23,0), sm4ed(31,24)),
284    Cat(sm4ed(15,0), sm4ed(31,16)),
285    Cat(sm4ed( 7,0), sm4ed(31,8)),
286    sm4ks(31,0),
287    Cat(sm4ks(23,0), sm4ks(31,24)),
288    Cat(sm4ks(15,0), sm4ks(31,16)),
289    Cat(sm4ks( 7,0), sm4ks(31,8))
290  ))
291  val sm4Result = SignExt((sm4Source(funcReg(2,0)) ^ RegEnable(src1(31,0), io.regEnable))(31,0), XLEN)
292
293  io.out := Mux(funcReg(3), sm4Result, aesResult)
294}
295
296class CryptoModule(implicit p: Parameters) extends XSModule {
297  val io = IO(new Bundle() {
298    val src = Vec(2, Input(UInt(XLEN.W)))
299    val func = Input(UInt())
300    val regEnable = Input(Bool())
301    val out = Output(UInt(XLEN.W))
302  })
303
304  val (src1, src2, func) = (io.src(0), io.src(1), io.func)
305  val funcReg = RegEnable(func, io.regEnable)
306
307  val hashModule = Module(new HashModule)
308  hashModule.io.src := src1
309  hashModule.io.func := func
310  hashModule.io.regEnable := io.regEnable
311
312  val blockCipherModule = Module(new BlockCipherModule)
313  blockCipherModule.io.src(0) := src1
314  blockCipherModule.io.src(1) := src2
315  blockCipherModule.io.func := func
316  blockCipherModule.io.regEnable := io.regEnable
317
318  io.out := Mux(funcReg(4), hashModule.io.out, blockCipherModule.io.out)
319}
320
321class Bku(cfg: FuConfig)(implicit p: Parameters) extends FuncUnit(cfg) with HasPipelineReg {
322
323  override def latency = 2
324
325  val (src1, src2, func) = (
326    io.in.bits.data.src(0),
327    io.in.bits.data.src(1),
328    io.in.bits.ctrl.fuOpType
329  )
330
331  val countModule = Module(new CountModule)
332  countModule.io.src := src1
333  countModule.io.func := func
334  countModule.io.regEnable := regEnable(1)
335
336  val clmulModule = Module(new ClmulModule)
337  clmulModule.io.src(0) := src1
338  clmulModule.io.src(1) := src2
339  clmulModule.io.func := func
340  clmulModule.io.regEnable := regEnable(1)
341
342  val miscModule = Module(new MiscModule)
343  miscModule.io.src(0) := src1
344  miscModule.io.src(1) := src2
345  miscModule.io.func := func
346  miscModule.io.regEnable := regEnable(1)
347
348  val cryptoModule = Module(new CryptoModule)
349  cryptoModule.io.src(0) := src1
350  cryptoModule.io.src(1) := src2
351  cryptoModule.io.func := func
352  cryptoModule.io.regEnable := regEnable(1)
353
354
355  // CountModule, ClmulModule, MiscModule, and CryptoModule have a latency of 1 cycle
356  val funcReg = RegEnable(func, io.in.fire)
357  val result = Mux(funcReg(5), cryptoModule.io.out,
358                  Mux(funcReg(3), countModule.io.out,
359                      Mux(funcReg(2),miscModule.io.out, clmulModule.io.out)))
360
361  io.out.bits.res.data := RegEnable(result, regEnable(2))
362  // connectNonPipedCtrlSingal
363}
364