xref: /XiangShan/src/main/scala/xiangshan/backend/fu/Multiplier.scala (revision c49ebec88f6e402aefec681225e3537e2c511430)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15*
16*
17* Acknowledgement
18*
19* This implementation is inspired by several key papers:
20* [1] Andrew D. Booth. "[A signed binary multiplication technique.](https://doi.org/10.1093/qjmam/4.2.236)" The
21* Quarterly Journal of Mechanics and Applied Mathematics 4.2: 236-240. 1951.
22* [2] Christopher. S. Wallace. "[A suggestion for a fast multiplier.](https://doi.org/10.1109/PGEC.1964.263830)" IEEE
23* Transactions on Electronic Computers 1: 14-17. 1964.
24***************************************************************************************/
25
26package xiangshan.backend.fu
27
28import org.chipsalliance.cde.config.Parameters
29import chisel3._
30import chisel3.util._
31import utility._
32import utils._
33import xiangshan._
34import xiangshan.backend.fu.util.{C22, C32, C53}
35
36class MulDivCtrl extends Bundle{
37  val sign = Bool()
38  val isW = Bool()
39  val isHi = Bool() // return hi bits of result ?
40}
41
42class ArrayMulDataModule(len: Int) extends Module {
43  val io = IO(new Bundle() {
44    val a, b = Input(UInt(len.W))
45    val regEnables = Input(Vec(2, Bool()))
46    val result = Output(UInt((2 * len).W))
47  })
48  val (a, b) = (io.a, io.b)
49
50  val b_sext, bx2, neg_b, neg_bx2 = Wire(UInt((len+1).W))
51  b_sext := SignExt(b, len+1)
52  bx2 := b_sext << 1
53  neg_b := (~b_sext).asUInt
54  neg_bx2 := neg_b << 1
55
56  val columns: Array[Seq[Bool]] = Array.fill(2*len)(Seq())
57
58  var last_x = WireInit(0.U(3.W))
59  for(i <- Range(0, len, 2)){
60    val x = if(i==0) Cat(a(1,0), 0.U(1.W)) else if(i+1==len) SignExt(a(i, i-1), 3) else a(i+1, i-1)
61    val pp_temp = MuxLookup(x, 0.U)(Seq(
62      1.U -> b_sext,
63      2.U -> b_sext,
64      3.U -> bx2,
65      4.U -> neg_bx2,
66      5.U -> neg_b,
67      6.U -> neg_b
68    ))
69    val s = pp_temp(len)
70    val t = MuxLookup(last_x, 0.U(2.W))(Seq(
71      4.U -> 2.U(2.W),
72      5.U -> 1.U(2.W),
73      6.U -> 1.U(2.W)
74    ))
75    last_x = x
76    val (pp, weight) = i match {
77      case 0 =>
78        (Cat(~s, s, s, pp_temp), 0)
79      case n if (n==len-1) || (n==len-2) =>
80        (Cat(~s, pp_temp, t), i-2)
81      case _ =>
82        (Cat(1.U(1.W), ~s, pp_temp, t), i-2)
83    }
84    for(j <- columns.indices){
85      if(j >= weight && j < (weight + pp.getWidth)){
86        columns(j) = columns(j) :+ pp(j-weight)
87      }
88    }
89  }
90
91  def addOneColumn(col: Seq[Bool], cin: Seq[Bool]): (Seq[Bool], Seq[Bool], Seq[Bool]) = {
92    var sum = Seq[Bool]()
93    var cout1 = Seq[Bool]()
94    var cout2 = Seq[Bool]()
95    col.size match {
96      case 1 =>  // do nothing
97        sum = col ++ cin
98      case 2 =>
99        val c22 = Module(new C22)
100        c22.io.in := col
101        sum = c22.io.out(0).asBool +: cin
102        cout2 = Seq(c22.io.out(1).asBool)
103      case 3 =>
104        val c32 = Module(new C32)
105        c32.io.in := col
106        sum = c32.io.out(0).asBool +: cin
107        cout2 = Seq(c32.io.out(1).asBool)
108      case 4 =>
109        val c53 = Module(new C53)
110        for((x, y) <- c53.io.in.take(4) zip col){
111          x := y
112        }
113        c53.io.in.last := (if(cin.nonEmpty) cin.head else 0.U)
114        sum = Seq(c53.io.out(0).asBool) ++ (if(cin.nonEmpty) cin.drop(1) else Nil)
115        cout1 = Seq(c53.io.out(1).asBool)
116        cout2 = Seq(c53.io.out(2).asBool)
117      case n =>
118        val cin_1 = if(cin.nonEmpty) Seq(cin.head) else Nil
119        val cin_2 = if(cin.nonEmpty) cin.drop(1) else Nil
120        val (s_1, c_1_1, c_1_2) = addOneColumn(col take 4, cin_1)
121        val (s_2, c_2_1, c_2_2) = addOneColumn(col drop 4, cin_2)
122        sum = s_1 ++ s_2
123        cout1 = c_1_1 ++ c_2_1
124        cout2 = c_1_2 ++ c_2_2
125    }
126    (sum, cout1, cout2)
127  }
128
129  def max(in: Iterable[Int]): Int = in.reduce((a, b) => if(a>b) a else b)
130  def addAll(cols: Seq[Seq[Bool]], depth: Int): (UInt, UInt) = {
131    if(max(cols.map(_.size)) <= 2){
132      val sum = Cat(cols.map(_(0)).reverse)
133      var k = 0
134      while(cols(k).size == 1) k = k+1
135      val carry = Cat(cols.drop(k).map(_(1)).reverse)
136      (sum, Cat(carry, 0.U(k.W)))
137    } else {
138      val columns_next = Array.fill(2*len)(Seq[Bool]())
139      var cout1, cout2 = Seq[Bool]()
140      for( i <- cols.indices){
141        val (s, c1, c2) = addOneColumn(cols(i), cout1)
142        columns_next(i) = s ++ cout2
143        cout1 = c1
144        cout2 = c2
145      }
146
147      val needReg = depth == 4
148      val toNextLayer = if(needReg)
149        columns_next.map(_.map(x => RegEnable(x, io.regEnables(1))))
150      else
151        columns_next
152
153      addAll(toNextLayer.toSeq, depth+1)
154    }
155  }
156
157  val columns_reg = columns.map(col => col.map(b => RegEnable(b, io.regEnables(0))))
158  val (sum, carry) = addAll(cols = columns_reg.toSeq, depth = 0)
159
160  io.result := sum + carry
161}
162