xref: /XiangShan/src/main/scala/xiangshan/backend/fu/util/CryptoUtils.scala (revision 8891a219bbc84f568e1d134854d8d5ed86d6d560)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.backend.fu.util
18
19import chisel3._
20import chisel3.util._
21import xiangshan._
22import org.chipsalliance.cde.config.Parameters
23
24// 32bits shift right
25object SHR32 {
26  def apply(bits: UInt, shamt: Int) = {
27    require(shamt>0 && shamt<32)
28    if(shamt == 31) Cat(0.U(63.W), bits(31))
29    else            Cat(0.U((32+shamt).W), bits(31,shamt))
30  }
31}
32
33// 64bits shift right
34object SHR64 {
35  def apply(bits: UInt, shamt: Int) = {
36    require(shamt>0 && shamt<64)
37    if(shamt == 63) Cat(bits(62,0), bits(63))
38    else            Cat(0.U(shamt.W), bits(63,shamt))
39  }
40}
41
42// 32bits Rotate shift
43object ROR32 {
44  def apply(bits: UInt, shamt: Int) = {
45    require(shamt>0 && shamt<32)
46    if(shamt == 1)       Cat(0.U(32.W), bits(0), bits(31,1))
47    else if(shamt == 31) Cat(0.U(32.W), bits(30,0), bits(31))
48    else                 Cat(0.U(32.W), bits(shamt-1,0), bits(31,shamt))
49  }
50}
51
52// 64bits Rotate shift
53object ROR64 {
54  def apply(bits: UInt, shamt: Int) = {
55    require(shamt>0 && shamt<64)
56    if(shamt == 1)       Cat(bits(0), bits(63,1))
57    else if(shamt == 63) Cat(bits(62,0), bits(63))
58    else                 Cat(bits(shamt-1,0), bits(63,shamt))
59  }
60}
61
62// AES forward shift rows
63object ForwardShiftRows {
64  def apply(src1: Seq[UInt], src2: Seq[UInt]): Seq[UInt] = {
65    VecInit(Seq(src1(0), src1(5), src2(2), src2(7),
66                src1(4), src2(1), src2(6), src1(3)))
67  }
68}
69
70// AES inverse shift rows
71object InverseShiftRows {
72  def apply(src1: Seq[UInt], src2: Seq[UInt]): Seq[UInt] = {
73    VecInit(Seq(src1(0), src2(5), src2(2), src1(7),
74                src1(4), src1(1), src2(6), src2(3)))
75  }
76}
77
78// AES encode sbox top
79object SboxAesTop {
80  def apply(i: UInt): Seq[Bool] = {
81    val t = Wire(Vec(6, Bool()))
82    val o = Wire(Vec(21, Bool()))
83    t( 0) := i( 3) ^ i( 1)
84    t( 1) := i( 6) ^ i( 5)
85    t( 2) := i( 6) ^ i( 2)
86    t( 3) := i( 5) ^ i( 2)
87    t( 4) := i( 4) ^ i( 0)
88    t( 5) := i( 1) ^ i( 0)
89    o( 0) := i( 0)
90    o( 1) := i( 7) ^ i( 4)
91    o( 2) := i( 7) ^ i( 2)
92    o( 3) := i( 7) ^ i( 1)
93    o( 4) := i( 4) ^ i( 2)
94    o( 5) := o( 1) ^ t( 0)
95    o( 6) := i( 0) ^ o( 5)
96    o( 7) := i( 0) ^ t( 1)
97    o( 8) := o( 5) ^ t( 1)
98    o( 9) := o( 3) ^ o( 4)
99    o(10) := o( 5) ^ t( 2)
100    o(11) := t( 0) ^ t( 2)
101    o(12) := t( 0) ^ t( 3)
102    o(13) := o( 7) ^ o(12)
103    o(14) := t( 1) ^ t( 4)
104    o(15) := o( 1) ^ o(14)
105    o(16) := t( 1) ^ t( 5)
106    o(17) := o( 2) ^ o(16)
107    o(18) := o( 2) ^ o( 8)
108    o(19) := o(15) ^ o(13)
109    o(20) := o( 1) ^ t( 3)
110    o
111  }
112}
113
114// AES decode sbox top
115object SboxIaesTop {
116  def apply(i: UInt): Seq[Bool] = {
117    val t = Wire(Vec(5, Bool()))
118    val o = Wire(Vec(21, Bool()))
119    t( 0) := i( 1) ^  i( 0)
120    t( 1) := i( 6) ^  i( 1)
121    t( 2) := i( 5) ^ ~i( 2)
122    t( 3) := i( 2) ^ ~i( 1)
123    t( 4) := i( 5) ^ ~i( 3)
124    o( 0) := i( 7) ^  t( 2)
125    o( 1) := i( 4) ^  i( 3)
126    o( 2) := i( 7) ^ ~i( 6)
127    o( 3) := o( 1) ^  t( 0)
128    o( 4) := i( 3) ^  o( 6)
129    o( 5) := o(16) ^  t( 2)
130    o( 6) := i( 6) ^ ~o(17)
131    o( 7) := i( 0) ^ ~o( 1)
132    o( 8) := o( 2) ^  o(18)
133    o( 9) := o( 2) ^  t( 0)
134    o(10) := o( 8) ^  t( 3)
135    o(11) := o( 8) ^  o(20)
136    o(12) := t( 1) ^  t( 4)
137    o(13) := i( 5) ^ ~o(14)
138    o(14) := o(16) ^  t( 0)
139    o(15) := o(18) ^  t( 1)
140    o(16) := i( 6) ^ ~i( 4)
141    o(17) := i( 7) ^  i( 4)
142    o(18) := i( 3) ^ ~i( 0)
143    o(19) := i( 5) ^ ~o( 1)
144    o(20) := o( 1) ^  t( 3)
145    o
146  }
147}
148
149// SM4 encode/decode sbox top
150object SboxSm4Top {
151  def apply(i: UInt): Seq[Bool] = {
152    val t = Wire(Vec(7, Bool()))
153    val o = Wire(Vec(21, Bool()))
154    t( 0) := i(3) ^  i( 4)
155    t( 1) := i(2) ^  i( 7)
156    t( 2) := i(7) ^  o(18)
157    t( 3) := i(1) ^  t( 1)
158    t( 4) := i(6) ^  i( 7)
159    t( 5) := i(0) ^  o(18)
160    t( 6) := i(3) ^  i( 6)
161    o( 0) := i(5) ^ ~o(10)
162    o( 1) := t(0) ^  t( 3)
163    o( 2) := i(0) ^  t( 0)
164    o( 3) := i(3) ^  o( 4)
165    o( 4) := i(0) ^  t( 3)
166    o( 5) := i(5) ^  t( 5)
167    o( 6) := i(0) ^ ~i( 1)
168    o( 7) := t(0) ^ ~o(10)
169    o( 8) := t(0) ^  t( 5)
170    o( 9) := i(3)
171    o(10) := i(1) ^  o(18)
172    o(11) := t(0) ^  t( 4)
173    o(12) := i(5) ^  t( 4)
174    o(13) := i(5) ^ ~o( 1)
175    o(14) := i(4) ^ ~t( 2)
176    o(15) := i(1) ^ ~t( 6)
177    o(16) := i(0) ^ ~t( 2)
178    o(17) := t(0) ^ ~t( 2)
179    o(18) := i(2) ^  i( 6)
180    o(19) := i(5) ^ ~o(14)
181    o(20) := i(0) ^  t( 1)
182    o
183  }
184}
185
186// Sbox middle part for AES, AES^-1, SM4
187object SboxInv {
188  def apply(i: Seq[Bool]): Seq[Bool] = {
189    val t = Wire(Vec(46, Bool()))
190    val o = Wire(Vec(18, Bool()))
191    t( 0) := i( 3) ^ i(12)
192    t( 1) := i( 9) & i( 5)
193    t( 2) := i(17) & i( 6)
194    t( 3) := i(10) ^ t( 1)
195    t( 4) := i(14) & i( 0)
196    t( 5) := t( 4) ^ t( 1)
197    t( 6) := i( 3) & i(12)
198    t( 7) := i(16) & i( 7)
199    t( 8) := t( 0) ^ t( 6)
200    t( 9) := i(15) & i(13)
201    t(10) := t( 9) ^ t( 6)
202    t(11) := i( 1) & i(11)
203    t(12) := i( 4) & i(20)
204    t(13) := t(12) ^ t(11)
205    t(14) := i( 2) & i( 8)
206    t(15) := t(14) ^ t(11)
207    t(16) := t( 3) ^ t( 2)
208    t(17) := t( 5) ^ i(18)
209    t(18) := t( 8) ^ t( 7)
210    t(19) := t(10) ^ t(15)
211    t(20) := t(16) ^ t(13)
212    t(21) := t(17) ^ t(15)
213    t(22) := t(18) ^ t(13)
214    t(23) := t(19) ^ i(19)
215    t(24) := t(22) ^ t(23)
216    t(25) := t(22) & t(20)
217    t(26) := t(21) ^ t(25)
218    t(27) := t(20) ^ t(21)
219    t(28) := t(23) ^ t(25)
220    t(29) := t(28) & t(27)
221    t(30) := t(26) & t(24)
222    t(31) := t(20) & t(23)
223    t(32) := t(27) & t(31)
224    t(33) := t(27) ^ t(25)
225    t(34) := t(21) & t(22)
226    t(35) := t(24) & t(34)
227    t(36) := t(24) ^ t(25)
228    t(37) := t(21) ^ t(29)
229    t(38) := t(32) ^ t(33)
230    t(39) := t(23) ^ t(30)
231    t(40) := t(35) ^ t(36)
232    t(41) := t(38) ^ t(40)
233    t(42) := t(37) ^ t(39)
234    t(43) := t(37) ^ t(38)
235    t(44) := t(39) ^ t(40)
236    t(45) := t(42) ^ t(41)
237    o( 0) := t(38) & i( 7)
238    o( 1) := t(37) & i(13)
239    o( 2) := t(42) & i(11)
240    o( 3) := t(45) & i(20)
241    o( 4) := t(41) & i( 8)
242    o( 5) := t(44) & i( 9)
243    o( 6) := t(40) & i(17)
244    o( 7) := t(39) & i(14)
245    o( 8) := t(43) & i( 3)
246    o( 9) := t(38) & i(16)
247    o(10) := t(37) & i(15)
248    o(11) := t(42) & i( 1)
249    o(12) := t(45) & i( 4)
250    o(13) := t(41) & i( 2)
251    o(14) := t(44) & i( 5)
252    o(15) := t(40) & i( 6)
253    o(16) := t(39) & i( 0)
254    o(17) := t(43) & i(12)
255    o
256  }
257}
258
259// AES encode sbox out
260object SboxAesOut {
261  def apply(i: Seq[Bool]): UInt = {
262    val t = Wire(Vec(30, Bool()))
263    val o = Wire(Vec(8, Bool()))
264    t( 0) := i(11) ^  i(12)
265    t( 1) := i( 0) ^  i( 6)
266    t( 2) := i(14) ^  i(16)
267    t( 3) := i(15) ^  i( 5)
268    t( 4) := i( 4) ^  i( 8)
269    t( 5) := i(17) ^  i(11)
270    t( 6) := i(12) ^  t( 5)
271    t( 7) := i(14) ^  t( 3)
272    t( 8) := i( 1) ^  i( 9)
273    t( 9) := i( 2) ^  i( 3)
274    t(10) := i( 3) ^  t( 4)
275    t(11) := i(10) ^  t( 2)
276    t(12) := i(16) ^  i( 1)
277    t(13) := i( 0) ^  t( 0)
278    t(14) := i( 2) ^  i(11)
279    t(15) := i( 5) ^  t( 1)
280    t(16) := i( 6) ^  t( 0)
281    t(17) := i( 7) ^  t( 1)
282    t(18) := i( 8) ^  t( 8)
283    t(19) := i(13) ^  t( 4)
284    t(20) := t( 0) ^  t( 1)
285    t(21) := t( 1) ^  t( 7)
286    t(22) := t( 3) ^  t(12)
287    t(23) := t(18) ^  t( 2)
288    t(24) := t(15) ^  t( 9)
289    t(25) := t( 6) ^  t(10)
290    t(26) := t( 7) ^  t( 9)
291    t(27) := t( 8) ^  t(10)
292    t(28) := t(11) ^  t(14)
293    t(29) := t(11) ^  t(17)
294    o( 0) := t( 6) ^ ~t(23)
295    o( 1) := t(13) ^ ~t(27)
296    o( 2) := t(25) ^  t(29)
297    o( 3) := t(20) ^  t(22)
298    o( 4) := t( 6) ^  t(21)
299    o( 5) := t(19) ^ ~t(28)
300    o( 6) := t(16) ^ ~t(26)
301    o( 7) := t( 6) ^  t(24)
302    o.asUInt
303  }
304}
305
306// AES decode sbox out
307object SboxIaesOut {
308  def apply(i: Seq[Bool]): UInt = {
309    val t = Wire(Vec(30, Bool()))
310    val o = Wire(Vec(8, Bool()))
311    t( 0) := i( 2) ^ i(11)
312    t( 1) := i( 8) ^ i( 9)
313    t( 2) := i( 4) ^ i(12)
314    t( 3) := i(15) ^ i( 0)
315    t( 4) := i(16) ^ i( 6)
316    t( 5) := i(14) ^ i( 1)
317    t( 6) := i(17) ^ i(10)
318    t( 7) := t( 0) ^ t( 1)
319    t( 8) := i( 0) ^ i( 3)
320    t( 9) := i( 5) ^ i(13)
321    t(10) := i( 7) ^ t( 4)
322    t(11) := t( 0) ^ t( 3)
323    t(12) := i(14) ^ i(16)
324    t(13) := i(17) ^ i( 1)
325    t(14) := i(17) ^ i(12)
326    t(15) := i( 4) ^ i( 9)
327    t(16) := i( 7) ^ i(11)
328    t(17) := i( 8) ^ t( 2)
329    t(18) := i(13) ^ t( 5)
330    t(19) := t( 2) ^ t( 3)
331    t(20) := t( 4) ^ t( 6)
332    t(21) := 0.U
333    t(22) := t( 2) ^ t( 7)
334    t(23) := t( 7) ^ t( 8)
335    t(24) := t( 5) ^ t( 7)
336    t(25) := t( 6) ^ t(10)
337    t(26) := t( 9) ^ t(11)
338    t(27) := t(10) ^ t(18)
339    t(28) := t(11) ^ t(25)
340    t(29) := t(15) ^ t(20)
341    o( 0) := t( 9) ^ t(16)
342    o( 1) := t(14) ^ t(23)
343    o( 2) := t(19) ^ t(24)
344    o( 3) := t(23) ^ t(27)
345    o( 4) := t(12) ^ t(22)
346    o( 5) := t(17) ^ t(28)
347    o( 6) := t(26) ^ t(29)
348    o( 7) := t(13) ^ t(22)
349    o.asUInt
350  }
351}
352
353// SM4 encode/decode sbox out
354object SboxSm4Out {
355  def apply(i: Seq[Bool]): UInt = {
356    val t = Wire(Vec(30, Bool()))
357    val o = Wire(Vec(8, Bool()))
358    t( 0) := i( 4) ^  i( 7)
359    t( 1) := i(13) ^  i(15)
360    t( 2) := i( 2) ^  i(16)
361    t( 3) := i( 6) ^  t( 0)
362    t( 4) := i(12) ^  t( 1)
363    t( 5) := i( 9) ^  i(10)
364    t( 6) := i(11) ^  t( 2)
365    t( 7) := i( 1) ^  t( 4)
366    t( 8) := i( 0) ^  i(17)
367    t( 9) := i( 3) ^  i(17)
368    t(10) := i( 8) ^  t( 3)
369    t(11) := t( 2) ^  t( 5)
370    t(12) := i(14) ^  t( 6)
371    t(13) := t( 7) ^  t( 9)
372    t(14) := i( 0) ^  i( 6)
373    t(15) := i( 7) ^  i(16)
374    t(16) := i( 5) ^  i(13)
375    t(17) := i( 3) ^  i(15)
376    t(18) := i(10) ^  i(12)
377    t(19) := i( 9) ^  t( 1)
378    t(20) := i( 4) ^  t( 4)
379    t(21) := i(14) ^  t( 3)
380    t(22) := i(16) ^  t( 5)
381    t(23) := t( 7) ^  t(14)
382    t(24) := t( 8) ^  t(11)
383    t(25) := t( 0) ^  t(12)
384    t(26) := t(17) ^  t( 3)
385    t(27) := t(18) ^  t(10)
386    t(28) := t(19) ^  t( 6)
387    t(29) := t( 8) ^  t(10)
388    o( 0) := t(11) ^ ~t(13)
389    o( 1) := t(15) ^ ~t(23)
390    o( 2) := t(20) ^  t(24)
391    o( 3) := t(16) ^  t(25)
392    o( 4) := t(26) ^ ~t(22)
393    o( 5) := t(21) ^  t(13)
394    o( 6) := t(27) ^ ~t(12)
395    o( 7) := t(28) ^ ~t(29)
396    o.asUInt
397  }
398}
399
400object SboxAes {
401  def apply(byte: UInt): UInt = {
402    SboxAesOut(SboxInv(SboxAesTop(byte)))
403  }
404}
405
406object SboxIaes {
407  def apply(byte: UInt): UInt = {
408    SboxIaesOut(SboxInv(SboxIaesTop(byte)))
409  }
410}
411
412object SboxSm4 {
413  def apply(byte: UInt): UInt = {
414    SboxSm4Out(SboxInv(SboxSm4Top(byte)))
415  }
416}
417
418// Mix Column
419object XtN {
420  def Xt2(byte: UInt): UInt = ((byte << 1) ^ Mux(byte(7), "h1b".U, 0.U))(7,0)
421
422  def apply(byte: UInt, t: UInt): UInt = {
423    val byte1 = Xt2(byte)
424    val byte2 = Xt2(byte1)
425    val byte3 = Xt2(byte2)
426    val result = Mux(t(0), byte, 0.U) ^ Mux(t(1), byte1, 0.U) ^ Mux(t(2), byte2, 0.U) ^ Mux(t(3), byte3, 0.U)
427    result(7,0)
428  }
429}
430
431object ByteEnc {
432  def apply(bytes: Seq[UInt]): UInt = {
433    XtN(bytes(0), "h2".U) ^ XtN(bytes(1), "h3".U) ^ bytes(2) ^ bytes(3)
434  }
435}
436
437object ByteDec {
438  def apply(bytes: Seq[UInt]): UInt = {
439    XtN(bytes(0), "he".U) ^ XtN(bytes(1), "hb".U) ^ XtN(bytes(2), "hd".U) ^ XtN(bytes(3), "h9".U)
440  }
441}
442
443object MixFwd {
444  def apply(bytes: Seq[UInt]): UInt = {
445    Cat(ByteEnc(Seq(bytes(3), bytes(0), bytes(1), bytes(2))),
446        ByteEnc(Seq(bytes(2), bytes(3), bytes(0), bytes(1))),
447        ByteEnc(Seq(bytes(1), bytes(2), bytes(3), bytes(0))),
448        ByteEnc(Seq(bytes(0), bytes(1), bytes(2), bytes(3))))
449  }
450}
451
452object MixInv {
453  def apply(bytes: Seq[UInt]): UInt = {
454    Cat(ByteDec(Seq(bytes(3), bytes(0), bytes(1), bytes(2))),
455        ByteDec(Seq(bytes(2), bytes(3), bytes(0), bytes(1))),
456        ByteDec(Seq(bytes(1), bytes(2), bytes(3), bytes(0))),
457        ByteDec(Seq(bytes(0), bytes(1), bytes(2), bytes(3))))
458  }
459}
460