1/*************************************************************************************** 2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences 3* Copyright (c) 2020-2021 Peng Cheng Laboratory 4* 5* XiangShan is licensed under Mulan PSL v2. 6* You can use this software according to the terms and conditions of the Mulan PSL v2. 7* You may obtain a copy of Mulan PSL v2 at: 8* http://license.coscl.org.cn/MulanPSL2 9* 10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 13* 14* See the Mulan PSL v2 for more details. 15***************************************************************************************/ 16 17package xiangshan.backend.fu 18 19import org.chipsalliance.cde.config.Parameters 20import chisel3._ 21import chisel3.util._ 22import utility.{LookupTreeDefault, ParallelMux, ParallelXOR, SignExt, ZeroExt} 23import utility.{XSDebug, XSError} 24import xiangshan._ 25import xiangshan.backend.fu.util._ 26 27class CountModule(implicit p: Parameters) extends XSModule { 28 val io = IO(new Bundle() { 29 val src = Input(UInt(XLEN.W)) 30 val func = Input(UInt()) 31 val regEnable = Input(Bool()) 32 val out = Output(UInt(XLEN.W)) 33 }) 34 35 def encode(bits: UInt): UInt = { 36 LookupTreeDefault(bits, 0.U, List(0.U -> 2.U(2.W), 1.U -> 1.U(2.W))) 37 } 38 def clzi(msb: Int, left: UInt, right: UInt): UInt = { 39 Mux(left(msb), 40 Cat(left(msb) && right(msb), !right(msb), if(msb==1)right(0) else right(msb-1, 0)), 41 left) 42 } 43 44 // stage 0 45 val c0 = Wire(Vec(32, UInt(2.W))) 46 val c1 = Wire(Vec(16, UInt(3.W))) 47 val countSrc = Mux(io.func(1), Reverse(io.src), io.src) 48 49 for(i <- 0 until 32){ c0(i) := encode(countSrc(2*i+1, 2*i)) } 50 for(i <- 0 until 16){ c1(i) := clzi(1, c0(i*2+1), c0(i*2)) } 51 52 // pipeline registers 53 val funcReg = RegEnable(io.func, io.regEnable) 54 val c2 = Reg(Vec(8, UInt(4.W))) 55 val cpopTmp = Reg(Vec(4, UInt(5.W))) 56 when (io.regEnable) { 57 for (i <- 0 until 8) { 58 c2(i) := clzi(2, c1(i*2+1), c1(i*2)) 59 } 60 for (i <- 0 until 4) { 61 cpopTmp(i) := PopCount(io.src(i*16+15, i*16)) 62 } 63 } 64 65 // stage 1 66 val c3 = Wire(Vec(4, UInt(5.W))) 67 val c4 = Wire(Vec(2, UInt(6.W))) 68 69 for(i <- 0 until 4){ c3(i) := clzi(3, c2(i*2+1), c2(i*2)) } 70 for(i <- 0 until 2){ c4(i) := clzi(4, c3(i*2+1), c3(i*2)) } 71 val zeroRes = clzi(5, c4(1), c4(0)) 72 val zeroWRes = Mux(funcReg(1), c4(1), c4(0)) 73 74 val cpopLo32 = cpopTmp(0) +& cpopTmp(1) 75 val cpopHi32 = cpopTmp(2) +& cpopTmp(3) 76 77 val cpopRes = cpopLo32 +& cpopHi32 78 val cpopWRes = cpopLo32 79 80 io.out := Mux(funcReg(2), Mux(funcReg(0), cpopWRes, cpopRes), Mux(funcReg(0), zeroWRes, zeroRes)) 81} 82 83class ClmulModule(implicit p: Parameters) extends XSModule { 84 val io = IO(new Bundle() { 85 val src = Vec(2, Input(UInt(XLEN.W))) 86 val func = Input(UInt()) 87 val regEnable = Input(Bool()) 88 val out = Output(UInt(XLEN.W)) 89 }) 90 91 // stage 0 92 val (src1, src2) = (io.src(0), io.src(1)) 93 94 val mul0 = Wire(Vec(64, UInt(128.W))) 95 val mul1 = Wire(Vec(32, UInt(128.W))) 96 val mul2 = Wire(Vec(16, UInt(128.W))) 97 98 (0 until XLEN) map { i => 99 mul0(i) := Mux(src1(i), if(i==0) src2 else Cat(src2, 0.U(i.W)), 0.U) 100 } 101 (0 until 32) map { i => mul1(i) := mul0(i*2) ^ mul0(i*2+1)} 102 (0 until 16) map { i => mul2(i) := mul1(i*2) ^ mul1(i*2+1)} 103 104 // pipeline registers 105 val funcReg = RegEnable(io.func, io.regEnable) 106 val mul3 = Reg(Vec(8, UInt(128.W))) 107 when (io.regEnable) { 108 (0 until 8) map { i => mul3(i) := mul2(i*2) ^ mul2(i*2+1)} 109 } 110 111 // stage 1 112 val res = ParallelXOR(mul3) 113 114 val clmul = res(63,0) 115 val clmulh = res(127,64) 116 val clmulr = res(126,63) 117 118 io.out := LookupTreeDefault(funcReg, clmul, List( 119 BKUOpType.clmul -> clmul, 120 BKUOpType.clmulh -> clmulh, 121 BKUOpType.clmulr -> clmulr 122 )) 123} 124 125class MiscModule(implicit p: Parameters) extends XSModule { 126 val io = IO(new Bundle() { 127 val src = Vec(2, Input(UInt(XLEN.W))) 128 val func = Input(UInt()) 129 val regEnable = Input(Bool()) 130 val out = Output(UInt(XLEN.W)) 131 }) 132 133 val (src1, src2) = (io.src(0), io.src(1)) 134 135 def xpermLUT(table: UInt, idx: UInt, width: Int) : UInt = { 136 // ParallelMux((0 until XLEN/width).map( i => i.U -> table(i)).map( x => (x._1 === idx, x._2))) 137 LookupTreeDefault(idx, 0.U(width.W), (0 until XLEN/width).map( i => i.U -> table(i*width+width-1, i*width))) 138 } 139 140 val xpermnVec = Wire(Vec(16, UInt(4.W))) 141 (0 until 16).map( i => xpermnVec(i) := xpermLUT(src1, src2(i*4+3, i*4), 4)) 142 val xpermn = Cat(xpermnVec.reverse) 143 144 val xpermbVec = Wire(Vec(8, UInt(8.W))) 145 (0 until 8).map( i => xpermbVec(i) := Mux(src2(i*8+7, i*8+3).orR, 0.U, xpermLUT(src1, src2(i*8+2, i*8), 8))) 146 val xpermb = Cat(xpermbVec.reverse) 147 148 io.out := RegEnable(Mux(io.func(0), xpermb, xpermn), io.regEnable) 149} 150 151class HashModule(implicit p: Parameters) extends XSModule { 152 val io = IO(new Bundle() { 153 val src = Input(UInt(XLEN.W)) 154 val func = Input(UInt()) 155 val regEnable = Input(Bool()) 156 val out = Output(UInt(XLEN.W)) 157 }) 158 159 val src1 = io.src 160 161 val sha256sum0 = ROR32(src1, 2) ^ ROR32(src1, 13) ^ ROR32(src1, 22) 162 val sha256sum1 = ROR32(src1, 6) ^ ROR32(src1, 11) ^ ROR32(src1, 25) 163 val sha256sig0 = ROR32(src1, 7) ^ ROR32(src1, 18) ^ SHR32(src1, 3) 164 val sha256sig1 = ROR32(src1, 17) ^ ROR32(src1, 19) ^ SHR32(src1, 10) 165 val sha512sum0 = ROR64(src1, 28) ^ ROR64(src1, 34) ^ ROR64(src1, 39) 166 val sha512sum1 = ROR64(src1, 14) ^ ROR64(src1, 18) ^ ROR64(src1, 41) 167 val sha512sig0 = ROR64(src1, 1) ^ ROR64(src1, 8) ^ SHR64(src1, 7) 168 val sha512sig1 = ROR64(src1, 19) ^ ROR64(src1, 61) ^ SHR64(src1, 6) 169 val sm3p0 = ROR32(src1, 23) ^ ROR32(src1, 15) ^ src1 170 val sm3p1 = ROR32(src1, 9) ^ ROR32(src1, 17) ^ src1 171 172 val shaSource = VecInit(Seq( 173 SignExt(sha256sum0(31,0), XLEN), 174 SignExt(sha256sum1(31,0), XLEN), 175 SignExt(sha256sig0(31,0), XLEN), 176 SignExt(sha256sig1(31,0), XLEN), 177 sha512sum0, 178 sha512sum1, 179 sha512sig0, 180 sha512sig1 181 )) 182 val sha = shaSource(io.func(2,0)) 183 val sm3 = Mux(io.func(0), SignExt(sm3p1(31,0), XLEN), SignExt(sm3p0(31,0), XLEN)) 184 185 io.out := RegEnable(Mux(io.func(3), sm3, sha), io.regEnable) 186} 187 188class BlockCipherModule(implicit p: Parameters) extends XSModule { 189 val io = IO(new Bundle() { 190 val src = Vec(2, Input(UInt(XLEN.W))) 191 val func = Input(UInt()) 192 val regEnable = Input(Bool()) 193 val out = Output(UInt(XLEN.W)) 194 }) 195 196 val (src1, src2, func, funcReg) = (io.src(0), io.src(1), io.func, RegEnable(io.func, io.regEnable)) 197 198 val src1Bytes = VecInit((0 until 8).map(i => src1(i*8+7, i*8))) 199 val src2Bytes = VecInit((0 until 8).map(i => src2(i*8+7, i*8))) 200 201 // AES 202 val aesSboxIn = ForwardShiftRows(src1Bytes, src2Bytes) 203 val aesSboxMid = Reg(Vec(8, Vec(18, Bool()))) 204 val aesSboxOut = Wire(Vec(8, UInt(8.W))) 205 206 val iaesSboxIn = InverseShiftRows(src1Bytes, src2Bytes) 207 val iaesSboxMid = Reg(Vec(8, Vec(18, Bool()))) 208 val iaesSboxOut = Wire(Vec(8, UInt(8.W))) 209 210 aesSboxOut.zip(aesSboxMid).zip(aesSboxIn)foreach { case ((out, mid), in) => 211 when (io.regEnable) { 212 mid := SboxInv(SboxAesTop(in)) 213 } 214 out := SboxAesOut(mid) 215 } 216 217 iaesSboxOut.zip(iaesSboxMid).zip(iaesSboxIn)foreach { case ((out, mid), in) => 218 when (io.regEnable) { 219 mid := SboxInv(SboxIaesTop(in)) 220 } 221 out := SboxIaesOut(mid) 222 } 223 224 val aes64es = aesSboxOut.asUInt 225 val aes64ds = iaesSboxOut.asUInt 226 227 val imMinIn = RegEnable(src1Bytes, io.regEnable) 228 229 val aes64esm = Cat(MixFwd(Seq(aesSboxOut(4), aesSboxOut(5), aesSboxOut(6), aesSboxOut(7))), 230 MixFwd(Seq(aesSboxOut(0), aesSboxOut(1), aesSboxOut(2), aesSboxOut(3)))) 231 val aes64dsm = Cat(MixInv(Seq(iaesSboxOut(4), iaesSboxOut(5), iaesSboxOut(6), iaesSboxOut(7))), 232 MixInv(Seq(iaesSboxOut(0), iaesSboxOut(1), iaesSboxOut(2), iaesSboxOut(3)))) 233 val aes64im = Cat(MixInv(Seq(imMinIn(4), imMinIn(5), imMinIn(6), imMinIn(7))), 234 MixInv(Seq(imMinIn(0), imMinIn(1), imMinIn(2), imMinIn(3)))) 235 236 237 val rcon = WireInit(VecInit(Seq("h01".U, "h02".U, "h04".U, "h08".U, 238 "h10".U, "h20".U, "h40".U, "h80".U, 239 "h1b".U, "h36".U, "h00".U))) 240 241 val ksSboxIn = Wire(Vec(4, UInt(8.W))) 242 val ksSboxTop = Reg(Vec(4, Vec(21, Bool()))) 243 val ksSboxOut = Wire(Vec(4, UInt(8.W))) 244 ksSboxIn(0) := Mux(src2(3,0) === "ha".U, src1Bytes(4), src1Bytes(5)) 245 ksSboxIn(1) := Mux(src2(3,0) === "ha".U, src1Bytes(5), src1Bytes(6)) 246 ksSboxIn(2) := Mux(src2(3,0) === "ha".U, src1Bytes(6), src1Bytes(7)) 247 ksSboxIn(3) := Mux(src2(3,0) === "ha".U, src1Bytes(7), src1Bytes(4)) 248 ksSboxOut.zip(ksSboxTop).zip(ksSboxIn).foreach{ case ((out, top), in) => 249 when (io.regEnable) { 250 top := SboxAesTop(in) 251 } 252 out := SboxAesOut(SboxInv(top)) 253 } 254 255 val ks1Idx = RegEnable(src2(3,0), io.regEnable) 256 val aes64ks1i = Cat(ksSboxOut.asUInt ^ rcon(ks1Idx), ksSboxOut.asUInt ^ rcon(ks1Idx)) 257 258 val aes64ks2Temp = src1(63,32) ^ src2(31,0) 259 val aes64ks2 = RegEnable(Cat(aes64ks2Temp ^ src2(63,32), aes64ks2Temp), io.regEnable) 260 261 val aesResult = LookupTreeDefault(funcReg, aes64es, List( 262 BKUOpType.aes64es -> aes64es, 263 BKUOpType.aes64esm -> aes64esm, 264 BKUOpType.aes64ds -> aes64ds, 265 BKUOpType.aes64dsm -> aes64dsm, 266 BKUOpType.aes64im -> aes64im, 267 BKUOpType.aes64ks1i -> aes64ks1i, 268 BKUOpType.aes64ks2 -> aes64ks2 269 )) 270 271 // SM4 272 val sm4SboxIn = src2Bytes(func(1,0)) 273 val sm4SboxTop = Reg(Vec(21, Bool())) 274 when (io.regEnable) { 275 sm4SboxTop := SboxSm4Top(sm4SboxIn) 276 } 277 val sm4SboxOut = SboxSm4Out(SboxInv(sm4SboxTop)) 278 279 val sm4ed = sm4SboxOut ^ (sm4SboxOut<<8) ^ (sm4SboxOut<<2) ^ (sm4SboxOut<<18) ^ ((sm4SboxOut&"h3f".U)<<26) ^ ((sm4SboxOut&"hc0".U)<<10) 280 val sm4ks = sm4SboxOut ^ ((sm4SboxOut&"h07".U)<<29) ^ ((sm4SboxOut&"hfe".U)<<7) ^ ((sm4SboxOut&"h01".U)<<23) ^ ((sm4SboxOut&"hf8".U)<<13) 281 val sm4Source = VecInit(Seq( 282 sm4ed(31,0), 283 Cat(sm4ed(23,0), sm4ed(31,24)), 284 Cat(sm4ed(15,0), sm4ed(31,16)), 285 Cat(sm4ed( 7,0), sm4ed(31,8)), 286 sm4ks(31,0), 287 Cat(sm4ks(23,0), sm4ks(31,24)), 288 Cat(sm4ks(15,0), sm4ks(31,16)), 289 Cat(sm4ks( 7,0), sm4ks(31,8)) 290 )) 291 val sm4Result = SignExt((sm4Source(funcReg(2,0)) ^ RegEnable(src1(31,0), io.regEnable))(31,0), XLEN) 292 293 io.out := Mux(funcReg(3), sm4Result, aesResult) 294} 295 296class CryptoModule(implicit p: Parameters) extends XSModule { 297 val io = IO(new Bundle() { 298 val src = Vec(2, Input(UInt(XLEN.W))) 299 val func = Input(UInt()) 300 val regEnable = Input(Bool()) 301 val out = Output(UInt(XLEN.W)) 302 }) 303 304 val (src1, src2, func) = (io.src(0), io.src(1), io.func) 305 val funcReg = RegEnable(func, io.regEnable) 306 307 val hashModule = Module(new HashModule) 308 hashModule.io.src := src1 309 hashModule.io.func := func 310 hashModule.io.regEnable := io.regEnable 311 312 val blockCipherModule = Module(new BlockCipherModule) 313 blockCipherModule.io.src(0) := src1 314 blockCipherModule.io.src(1) := src2 315 blockCipherModule.io.func := func 316 blockCipherModule.io.regEnable := io.regEnable 317 318 io.out := Mux(funcReg(4), hashModule.io.out, blockCipherModule.io.out) 319} 320 321class Bku(cfg: FuConfig)(implicit p: Parameters) extends FuncUnit(cfg) with HasPipelineReg { 322 323 override def latency = 2 324 325 val (src1, src2, func) = ( 326 io.in.bits.data.src(0), 327 io.in.bits.data.src(1), 328 io.in.bits.ctrl.fuOpType 329 ) 330 331 val countModule = Module(new CountModule) 332 countModule.io.src := src1 333 countModule.io.func := func 334 countModule.io.regEnable := regEnable(1) 335 336 val clmulModule = Module(new ClmulModule) 337 clmulModule.io.src(0) := src1 338 clmulModule.io.src(1) := src2 339 clmulModule.io.func := func 340 clmulModule.io.regEnable := regEnable(1) 341 342 val miscModule = Module(new MiscModule) 343 miscModule.io.src(0) := src1 344 miscModule.io.src(1) := src2 345 miscModule.io.func := func 346 miscModule.io.regEnable := regEnable(1) 347 348 val cryptoModule = Module(new CryptoModule) 349 cryptoModule.io.src(0) := src1 350 cryptoModule.io.src(1) := src2 351 cryptoModule.io.func := func 352 cryptoModule.io.regEnable := regEnable(1) 353 354 355 // CountModule, ClmulModule, MiscModule, and CryptoModule have a latency of 1 cycle 356 val funcReg = RegEnable(func, io.in.fire) 357 val result = Mux(funcReg(5), cryptoModule.io.out, 358 Mux(funcReg(3), countModule.io.out, 359 Mux(funcReg(2),miscModule.io.out, clmulModule.io.out))) 360 361 io.out.bits.res.data := RegEnable(result, regEnable(2)) 362 // connectNonPipedCtrlSingal 363} 364