xref: /XiangShan/src/main/scala/xiangshan/transforms/Helpers.scala (revision e3da8bad334fc71ba0d72f0607e2e93245ddaece)
1/***************************************************************************************
2* Copyright (c) 2024 Beijing Institute of Open Source Chip (BOSC)
3* Copyright (c) 2020-2024 Institute of Computing Technology, Chinese Academy of Sciences
4* Copyright (c) 2020-2021 Peng Cheng Laboratory
5*
6* XiangShan is licensed under Mulan PSL v2.
7* You can use this software according to the terms and conditions of the Mulan PSL v2.
8* You may obtain a copy of Mulan PSL v2 at:
9*          http://license.coscl.org.cn/MulanPSL2
10*
11* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
12* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
13* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
14*
15* See the Mulan PSL v2 for more details.
16***************************************************************************************/
17
18package xiangshan.transforms
19
20object Helpers {
21
22  implicit class CircuitHelper(circuit: firrtl.ir.Circuit) {
23    def mapModule(f: firrtl.ir.DefModule => firrtl.ir.DefModule): firrtl.ir.Circuit = circuit.copy(modules = circuit.modules.map(f))
24  }
25
26  implicit class DefModuleHelper(defModule: firrtl.ir.DefModule) {
27    def mapStmt(f: firrtl.ir.Statement => firrtl.ir.Statement): firrtl.ir.DefModule = defModule match {
28      case firrtl.ir.Module(info, name, ports, body) => firrtl.ir.Module(info, name, ports, f(body))
29      case firrtl.ir.DefClass(info, name, ports, body) => firrtl.ir.DefClass(info, name, ports, f(body))
30      case other: firrtl.ir.DefModule => other
31    }
32
33    def foreachStmt(f: firrtl.ir.Statement => Unit): Unit = defModule match {
34      case firrtl.ir.Module(_, _, _, body) => f(body)
35      case firrtl.ir.DefClass(_, _, _, body) => f(body)
36      case _: firrtl.ir.DefModule =>
37    }
38  }
39
40  implicit class StatementHelper(statement: firrtl.ir.Statement) {
41    def mapStmt(f: firrtl.ir.Statement => firrtl.ir.Statement): firrtl.ir.Statement = statement match {
42      case firrtl.ir.Conditionally(info, pred, conseq, alt) => firrtl.ir.Conditionally(info, pred, f(conseq), f(alt))
43      case firrtl.ir.Block(stmts) =>
44        val res = new scala.collection.mutable.ArrayBuffer[firrtl.ir.Statement]()
45        var its = stmts.iterator :: Nil
46        while (its.nonEmpty) {
47          val it = its.head
48          if (it.hasNext) {
49            it.next() match {
50              case firrtl.ir.EmptyStmt => // flatten out
51              case b: firrtl.ir.Block =>
52                its = b.stmts.iterator :: its
53              case other =>
54                res.append(f(other))
55            }
56          } else {
57            its = its.tail
58          }
59        }
60        firrtl.ir.Block(res.toSeq)
61      case other: firrtl.ir.Statement => other
62    }
63  }
64}
65