aboutsummaryrefslogtreecommitdiff
path: root/VexRiscv/src/main/scala/vexriscv/plugin/MulPlugin.scala
blob: 3e909a06702b5c4df7487e615a28d23fd445ad81 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
package vexriscv.plugin
import vexriscv._
import vexriscv.VexRiscv
import spinal.core._
import spinal.lib.KeepAttribute

//Input buffer generaly avoid the FPGA synthesis to duplicate reg inside the DSP cell, which could stress timings quite much.
class MulPlugin(var inputBuffer : Boolean = false,
                var outputBuffer : Boolean = false) extends Plugin[VexRiscv] with VexRiscvRegressionArg {
  object MUL_LL extends Stageable(UInt(32 bits))
  object MUL_LH extends Stageable(SInt(34 bits))
  object MUL_HL extends Stageable(SInt(34 bits))
  object MUL_HH extends Stageable(SInt(34 bits))

  object MUL_LOW extends Stageable(SInt(34+16+2 bits))

  object IS_MUL extends Stageable(Bool)

  override def getVexRiscvRegressionArgs(): Seq[String] = {
    List("MUL=yes")
  }

  override def setup(pipeline: VexRiscv): Unit = {
    import Riscv._
    import pipeline.config._


    val actions = List[(Stageable[_ <: BaseType],Any)](
//      SRC1_CTRL                -> Src1CtrlEnum.RS,
//      SRC2_CTRL                -> Src2CtrlEnum.RS,
      REGFILE_WRITE_VALID      -> True,
      BYPASSABLE_EXECUTE_STAGE -> False,
      BYPASSABLE_MEMORY_STAGE  -> False,
      RS1_USE                 -> True,
      RS2_USE                 -> True,
      IS_MUL                   -> True
    )

    val decoderService = pipeline.service(classOf[DecoderService])
    decoderService.addDefault(IS_MUL, False)
    decoderService.add(List(
      MULX  -> actions
    ))

  }

  override def build(pipeline: VexRiscv): Unit = {
    import pipeline._
    import pipeline.config._


    //Do partial multiplication, four times 16 bits * 16 bits
    execute plug new Area {
      import execute._
      val aSigned,bSigned = Bool
      val a,b = Bits(32 bit)

//      a := input(SRC1)
//      b := input(SRC2)

      val delay = (if(inputBuffer) 1 else 0) + (if(outputBuffer) 1 else 0)

      val delayLogic = (delay != 0) generate new Area{
        val counter = Reg(UInt(log2Up(delay+1) bits))
        when(arbitration.isValid && input(IS_MUL) && counter =/= delay){
          arbitration.haltItself := True
        }

        counter := counter + 1
        when(!arbitration.isStuck || arbitration.isStuckByOthers){
          counter := 0
        }
      }

      val withInputBuffer = inputBuffer generate new Area{
        val rs1 = RegNext(input(RS1))
        val rs2 = RegNext(input(RS2))
        a := rs1
        b := rs2
      }

      val noInputBuffer = (!inputBuffer) generate new Area{
        a := input(RS1)
        b := input(RS2)
      }

      switch(input(INSTRUCTION)(13 downto 12)) {
        is(B"01") {
          aSigned := True
          bSigned := True
        }
        is(B"10") {
          aSigned := True
          bSigned := False
        }
        default {
          aSigned := False
          bSigned := False
        }
      }

      val aULow = a(15 downto 0).asUInt
      val bULow = b(15 downto 0).asUInt
      val aSLow = (False ## a(15 downto 0)).asSInt
      val bSLow = (False ## b(15 downto 0)).asSInt
      val aHigh = (((aSigned && a.msb) ## a(31 downto 16))).asSInt
      val bHigh = (((bSigned && b.msb) ## b(31 downto 16))).asSInt

      val withOuputBuffer = outputBuffer generate new Area{
        val mul_ll = RegNext(aULow * bULow)
        val mul_lh = RegNext(aSLow * bHigh)
        val mul_hl = RegNext(aHigh * bSLow)
        val mul_hh = RegNext(aHigh * bHigh)

        insert(MUL_LL) := mul_ll
        insert(MUL_LH) := mul_lh
        insert(MUL_HL) := mul_hl
        insert(MUL_HH) := mul_hh
      }

      val noOutputBuffer = (!outputBuffer) generate new Area{
        insert(MUL_LL) := aULow * bULow
        insert(MUL_LH) := aSLow * bHigh
        insert(MUL_HL) := aHigh * bSLow
        insert(MUL_HH) := aHigh * bHigh
      }

      Component.current.afterElaboration{
        //Avoid synthesis tools to retime RS1 RS2 from execute stage to decode stage leading to bad timings (ex : Vivado, even if retiming is disabled)
        KeepAttribute(input(RS1))
        KeepAttribute(input(RS2))
      }
    }

    //First aggregation of partial multiplication
    memory plug new Area {
      import memory._
      insert(MUL_LOW) := S(0, MUL_HL.dataType.getWidth + 16 + 2 bit) + (False ## input(MUL_LL)).asSInt + (input(MUL_LH) << 16) + (input(MUL_HL) << 16)
    }

    //Final aggregation of partial multiplications, REGFILE_WRITE_DATA overriding
    writeBack plug new Area {
      import writeBack._
      val result = input(MUL_LOW) + (input(MUL_HH) << 32)


      when(arbitration.isValid && input(IS_MUL)){
        switch(input(INSTRUCTION)(13 downto 12)){
          is(B"00"){
            output(REGFILE_WRITE_DATA) := input(MUL_LOW)(31 downto 0).asBits
          }
          is(B"01",B"10",B"11"){
            output(REGFILE_WRITE_DATA) := result(63 downto 32).asBits
          }
        }
      }
    }
  }
}