aboutsummaryrefslogtreecommitdiff
path: root/VexRiscv/src/main/scala/vexriscv/ip/fpu/FpuSqrt.scala
blob: 0f809059f788e9d3e1d73eacff3ca614b9698489 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package vexriscv.ip.fpu

import spinal.core._
import spinal.lib._
import spinal.lib.sim.{StreamDriver, StreamMonitor, StreamReadyRandomizer}

import scala.collection.mutable
import scala.util.Random

case class FpuSqrtCmd(mantissaWidth : Int) extends Bundle{
  val a = UInt(mantissaWidth+2 bits)
}

case class FpuSqrtRsp(mantissaWidth : Int) extends Bundle{
  val result = UInt(mantissaWidth+1 bits)
  val remain = UInt(mantissaWidth+5 bits)
}

case class FpuSqrt(val mantissaWidth : Int) extends Component {
  val io = new Bundle{
    val input = slave Stream(FpuSqrtCmd(mantissaWidth))
    val output = master Stream(FpuSqrtRsp(mantissaWidth))
  }

  val iterations = mantissaWidth+2
  val counter = Reg(UInt(log2Up(iterations ) bits))
  val busy = RegInit(False) clearWhen(io.output.fire)
  val done = RegInit(False) setWhen(busy && counter === iterations-1) clearWhen(io.output.fire)

  val a = Reg(UInt(mantissaWidth+5 bits))
  val x = Reg(UInt(mantissaWidth bits))
  val q = Reg(UInt(mantissaWidth+1 bits))
  val t = a-(q @@ U"01")


  io.output.valid := done
  io.output.result := (q << 0).resized
  io.output.remain := a
  io.input.ready := !busy

  when(!done){
    counter := counter + 1
    val sel = CombInit(a)
    when(!t.msb){
      sel := t.resized
    }
    q := (q @@ !t.msb).resized
    a := (sel @@ x(widthOf(x)-2,2 bits)).resized
    x := x |<< 2
  }

  when(!busy){
    q := 0
    a := io.input.a(widthOf(io.input.a)-2,2 bits).resized
    x := (io.input.a).resized
    counter := 0
    when(io.input.valid){
      busy := True
    }
  }
}


object FpuSqrtTester extends App{
  import spinal.core.sim._

  for(w <- List(16)) {
    val config = SimConfig
    config.withFstWave
    config.compile(new FpuSqrt(w)).doSim(seed=2){dut =>
      dut.clockDomain.forkStimulus(10)


      val (cmdDriver, cmdQueue) = StreamDriver.queue(dut.io.input, dut.clockDomain)
      val rspQueue = mutable.Queue[FpuSqrtRsp => Unit]()
      StreamMonitor(dut.io.output, dut.clockDomain)( rspQueue.dequeue()(_))
      StreamReadyRandomizer(dut.io.output, dut.clockDomain)

      def test(a : Int): Unit ={
        cmdQueue +={p =>
          p.a #= a
        }
        rspQueue += {p =>
//          val x = (a * (1l << dut.mantissaWidth)).toLong
//          val result = Math.sqrt(x).toLong/(1 << dut.mantissaWidth/2)
//          val remain = a-x*x
          val x = a.toDouble / (1 << dut.mantissaWidth)
          val result = (Math.sqrt(x)*(1 << dut.mantissaWidth+1)).toLong
          val filtred = result  % (1 << dut.mantissaWidth+1)
//          val remain = (a-(result*result)).toLong
          assert(p.result.toLong == filtred, f"$a%x=${p.result.toLong}%x instead of $filtred%x")
//          assert(p.remain.toLong == remain, f"$a%x=${p.remain.toLong}%x instead of $remain%x")
        }
      }

      val s = dut.mantissaWidth-16
      val f = (1 << dut.mantissaWidth)-1
//      test(121)
      test(0x20000)
      test(0x18000)
//      test(0,0)
//      test(0,f)
//      test(f,0)
//      test(f,f)

      for(i <- 0 until 10000){
        test(Random.nextInt(3 << dut.mantissaWidth) + (1 << dut.mantissaWidth))
      }

      waitUntil(rspQueue.isEmpty)

      dut.clockDomain.waitSampling(100)

    }
  }
}