xref: /XiangShan/src/main/scala/xiangshan/transforms/PrintControl.scala (revision e3da8bad334fc71ba0d72f0607e2e93245ddaece)
1/***************************************************************************************
2* Copyright (c) 2024 Beijing Institute of Open Source Chip (BOSC)
3* Copyright (c) 2020-2024 Institute of Computing Technology, Chinese Academy of Sciences
4* Copyright (c) 2020-2021 Peng Cheng Laboratory
5*
6* XiangShan is licensed under Mulan PSL v2.
7* You can use this software according to the terms and conditions of the Mulan PSL v2.
8* You may obtain a copy of Mulan PSL v2 at:
9*          http://license.coscl.org.cn/MulanPSL2
10*
11* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
12* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
13* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
14*
15* See the Mulan PSL v2 for more details.
16***************************************************************************************/
17
18package xiangshan.transforms
19
20case class DisablePrintfAnnotation(m: String) extends firrtl.annotations.NoTargetAnnotation
21object DisablePrintfAnnotation extends firrtl.options.HasShellOptions{
22
23  val options = Seq(
24    new firrtl.options.ShellOption[String](
25      longOption = "disable-module-print",
26      toAnnotationSeq = s => Seq(DisablePrintfAnnotation(s)),
27      helpText =
28        "The verilog 'printf' in the <module> and it's submodules will be removed\n",
29      shortOption = Some("dm"),
30      helpValueName = Some("<module>")
31    )
32  )
33
34}
35
36case class EnablePrintfAnnotation(m: String) extends firrtl.annotations.NoTargetAnnotation
37object EnablePrintfAnnotation extends firrtl.options.HasShellOptions {
38  val options = Seq(
39    new firrtl.options.ShellOption[String](
40      longOption = "enable-module-print",
41      toAnnotationSeq = s => Seq(EnablePrintfAnnotation(s)),
42      helpText =
43        "The verilog 'printf' except the <module> and it's submodules will be removed\n",
44      shortOption = Some("em"),
45      helpValueName = Some("<module>")
46    )
47  )
48
49}
50
51case class DisableAllPrintAnnotation() extends firrtl.annotations.NoTargetAnnotation
52object DisableAllPrintAnnotation extends firrtl.options.HasShellOptions {
53  val options = Seq(
54    new firrtl.options.ShellOption[Unit](
55      longOption = "disable-all",
56      toAnnotationSeq = _ => Seq(DisableAllPrintAnnotation()),
57      helpText =
58        "All the verilog 'printf' will be removed\n",
59      shortOption = Some("dall")
60    )
61  )
62}
63
64case class RemoveAssertAnnotation() extends firrtl.annotations.NoTargetAnnotation
65object RemoveAssertAnnotation extends firrtl.options.HasShellOptions{
66  val options = Seq(
67    new firrtl.options.ShellOption[Unit](
68      longOption = "remove-assert",
69      toAnnotationSeq = _ => Seq(RemoveAssertAnnotation()),
70      helpText = "All the 'assert' will be removed\n",
71      shortOption = None
72    )
73  )
74}
75
76import scala.collection.mutable
77
78class PrintControl extends firrtl.options.Phase {
79
80  override def invalidates(a: firrtl.options.Phase) = false
81
82  override def transform(annotations: firrtl.AnnotationSeq): firrtl.AnnotationSeq = {
83
84    import xiangshan.transforms.Helpers._
85
86    val disableList = annotations.collect {
87      case DisablePrintfAnnotation(m) => m
88    }
89    val enableList = annotations.collect {
90      case EnablePrintfAnnotation(m) => m
91    }
92    val disableAll = annotations.collectFirst {
93      case DisableAllPrintAnnotation() => true
94    }.nonEmpty
95    val removeAssert = annotations.collectFirst{
96      case RemoveAssertAnnotation() => true
97    }.nonEmpty
98
99    assert(!(enableList.nonEmpty && (disableAll || disableList.nonEmpty)))
100
101    val (Seq(circuitAnno: firrtl.stage.FirrtlCircuitAnnotation), otherAnnos) = annotations.partition {
102      case _: firrtl.stage.FirrtlCircuitAnnotation => true
103      case _ => false
104    }
105    val c = circuitAnno.circuit
106
107    val top = c.main
108    val queue = new mutable.Queue[String]()
109    val ancestors = new mutable.HashMap[String, mutable.LinkedHashSet[String]]()
110
111    queue += top
112    ancestors(top) = mutable.LinkedHashSet.empty
113
114    while (queue.nonEmpty) {
115      val curr = queue.dequeue()
116      c.modules.find(m => m.name==curr).foreach(m => {
117        def viewStmt(s: firrtl.ir.Statement): firrtl.ir.Statement = s match {
118          case firrtl.ir.DefInstance(_, _, module, _) =>
119            ancestors(module) = ancestors(curr).union(Set(m.name))
120            queue += module
121            s
122          case other =>
123            other.mapStmt(viewStmt)
124        }
125        m.foreachStmt(viewStmt)
126      })
127    }
128
129    def onModule(m: firrtl.ir.DefModule): firrtl.ir.DefModule = m match {
130      case _: firrtl.ir.ExtModule => m
131      case _: firrtl.ir.Module =>
132        def inRange(seq: Seq[String]): Boolean = {
133          seq.nonEmpty && (seq.contains(m.name) || seq.map(elm => {
134            ancestors(m.name).contains(elm)
135          }).reduce(_||_))
136        }
137        val enable = enableList.isEmpty || inRange(enableList)
138        val disable = disableAll || inRange(disableList) || !enable
139        def onStmt(s: firrtl.ir.Statement): firrtl.ir.Statement = s match {
140          case _: firrtl.ir.Print if disable => firrtl.ir.EmptyStmt
141          case _: firrtl.ir.Stop if removeAssert => firrtl.ir.EmptyStmt
142          case _: firrtl.ir.Verification if removeAssert => firrtl.ir.EmptyStmt
143          case other => other.mapStmt(onStmt)
144        }
145        m.mapStmt(onStmt)
146    }
147
148    firrtl.stage.FirrtlCircuitAnnotation(c.mapModule(onModule)) +: otherAnnos
149  }
150}
151