diff options
Diffstat (limited to 'VexRiscv/src/main/scala/vexriscv/plugin/DecoderSimplePlugin.scala')
-rw-r--r-- | VexRiscv/src/main/scala/vexriscv/plugin/DecoderSimplePlugin.scala | 402 |
1 files changed, 402 insertions, 0 deletions
diff --git a/VexRiscv/src/main/scala/vexriscv/plugin/DecoderSimplePlugin.scala b/VexRiscv/src/main/scala/vexriscv/plugin/DecoderSimplePlugin.scala new file mode 100644 index 0000000..a525b77 --- /dev/null +++ b/VexRiscv/src/main/scala/vexriscv/plugin/DecoderSimplePlugin.scala @@ -0,0 +1,402 @@ +package vexriscv.plugin + +import vexriscv._ +import spinal.core._ +import spinal.core.internals.Literal +import spinal.lib._ +import vexriscv.demo.GenFull + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + + +case class Masked(value : BigInt,care : BigInt){ + assert((value & ~care) == 0) + var isPrime = true + + def < (that: Masked) = value < that.value || value == that.value && ~care < ~that.care + + def intersects(x: Masked) = ((value ^ x.value) & care & x.care) == 0 + + def covers(x: Masked) = ((value ^ x.value) & care | (~x.care) & care) == 0 + + def setPrime(value : Boolean) = { + isPrime = value + this + } + + def mergeOneBitDifSmaller(x: Masked) = { + val bit = value - x.value + val ret = new Masked(value &~ bit, care & ~bit) + // ret.isPrime = isPrime || x.isPrime + isPrime = false + x.isPrime = false + ret + } + def isSimilarOneBitDifSmaller(x: Masked) = { + val diff = value - x.value + care == x.care && value > x.value && (diff & diff - 1) == 0 + } + + + def === (hard : Bits) : Bool = (hard & care) === (value & care) + + def toString(bitCount : Int) = (0 until bitCount).map(i => if(care.testBit(i)) (if(value.testBit(i)) "1" else "0") else "-").reverseIterator.reduce(_+_) +} + +class DecoderSimplePlugin(catchIllegalInstruction : Boolean = false, + throwIllegalInstruction : Boolean = false, + assertIllegalInstruction : Boolean = false, + forceLegalInstructionComputation : Boolean = false, + decoderIsolationBench : Boolean = false, + stupidDecoder : Boolean = false) extends Plugin[VexRiscv] with DecoderService { + override def add(encoding: Seq[(MaskedLiteral, Seq[(Stageable[_ <: BaseType], Any)])]): Unit = encoding.foreach(e => this.add(e._1,e._2)) + override def add(key: MaskedLiteral, values: Seq[(Stageable[_ <: BaseType], Any)]): Unit = { + val instructionModel = encodings.getOrElseUpdate(key,ArrayBuffer[(Stageable[_ <: BaseType], BaseType)]()) + values.map{case (a,b) => { + assert(!instructionModel.contains(a), s"Over specification of $a") + val value = b match { + case e: SpinalEnumElement[_] => e() + case e: BaseType => e + } + instructionModel += (a->value) + }} + } + + override def addDefault(key: Stageable[_ <: BaseType], value: Any): Unit = { + assert(!defaults.contains(key)) + defaults(key) = value match{ + case e : SpinalEnumElement[_] => e() + case e : BaseType => e + } + } + + def forceIllegal() : Unit = if(catchIllegalInstruction) pipeline.decode.input(pipeline.config.LEGAL_INSTRUCTION) := False + + val defaults = mutable.LinkedHashMap[Stageable[_ <: BaseType], BaseType]() + val encodings = mutable.LinkedHashMap[MaskedLiteral,ArrayBuffer[(Stageable[_ <: BaseType], BaseType)]]() + var decodeExceptionPort : Flow[ExceptionCause] = null + + + override def setup(pipeline: VexRiscv): Unit = { + if(!catchIllegalInstruction) { + SpinalWarning("This VexRiscv configuration is set without illegal instruction catch support. Some software may rely on it (ex: Rust)") + } + if(catchIllegalInstruction) { + val exceptionService = pipeline.plugins.filter(_.isInstanceOf[ExceptionService]).head.asInstanceOf[ExceptionService] + decodeExceptionPort = exceptionService.newExceptionPort(pipeline.decode).setName("decodeExceptionPort") + } + } + + val detectLegalInstructions = catchIllegalInstruction || throwIllegalInstruction || forceLegalInstructionComputation || assertIllegalInstruction + + object ASSERT_ERROR extends Stageable(Bool) + + override def build(pipeline: VexRiscv): Unit = { + import pipeline.config._ + import pipeline.decode._ + + val stageables = (encodings.flatMap(_._2.map(_._1)) ++ defaults.map(_._1)).toList.distinct + + + if(stupidDecoder){ + if (detectLegalInstructions) insert(LEGAL_INSTRUCTION) := False + for(stageable <- stageables){ + if(defaults.contains(stageable)){ + insert(stageable).assignFrom(defaults(stageable)) + } else { + insert(stageable).assignDontCare() + } + } + for((key, tasks) <- encodings){ + when(input(INSTRUCTION) === key){ + if (detectLegalInstructions) insert(LEGAL_INSTRUCTION) := True + for((stageable, value) <- tasks){ + insert(stageable).assignFrom(value) + } + } + } + } else { + var offset = 0 + var defaultValue, defaultCare = BigInt(0) + val offsetOf = mutable.LinkedHashMap[Stageable[_ <: BaseType], Int]() + + //Build defaults value and field offset map + stageables.foreach(e => { + defaults.get(e) match { + case Some(value) => { + value.head.source match { + case literal: EnumLiteral[_] => literal.fixEncoding(e.dataType.asInstanceOf[SpinalEnumCraft[_]].getEncoding) + case _ => + } + defaultValue += value.head.source.asInstanceOf[Literal].getValue << offset + defaultCare += ((BigInt(1) << e.dataType.getBitsWidth) - 1) << offset + + } + case _ => + } + offsetOf(e) = offset + offset += e.dataType.getBitsWidth + }) + + //Build spec + val spec = encodings.map { case (key, values) => + var decodedValue = defaultValue + var decodedCare = defaultCare + for ((e, literal) <- values) { + literal.head.source match { + case literal: EnumLiteral[_] => literal.fixEncoding(e.dataType.asInstanceOf[SpinalEnumCraft[_]].getEncoding) + case _ => + } + val offset = offsetOf(e) + decodedValue |= literal.head.source.asInstanceOf[Literal].getValue << offset + decodedCare |= ((BigInt(1) << e.dataType.getBitsWidth) - 1) << offset + } + (Masked(key.value, key.careAbout), Masked(decodedValue, decodedCare)) + } + + + // logic implementation + val decodedBits = Bits(stageables.foldLeft(0)(_ + _.dataType.getBitsWidth) bits) + decodedBits := Symplify(input(INSTRUCTION), spec, decodedBits.getWidth) + if (detectLegalInstructions) insert(LEGAL_INSTRUCTION) := Symplify.logicOf(input(INSTRUCTION), SymplifyBit.getPrimeImplicantsByTrueAndDontCare(spec.unzip._1.toSeq, Nil, 32)) + if (throwIllegalInstruction) { + input(LEGAL_INSTRUCTION) //Fill the request for later (prePopTask) + Component.current.addPrePopTask(() => arbitration.isValid clearWhen(!input(LEGAL_INSTRUCTION))) + } + if(assertIllegalInstruction){ + val reg = RegInit(False) setWhen(arbitration.isValid) clearWhen(arbitration.isRemoved || !arbitration.isStuck) + insert(ASSERT_ERROR) := arbitration.isValid || reg + } + + if(decoderIsolationBench){ + KeepAttribute(RegNext(KeepAttribute(RegNext(decodedBits.removeAssignments().asInput())))) + out(Bits(32 bits)).setName("instruction") := KeepAttribute(RegNext(KeepAttribute(RegNext(input(INSTRUCTION))))) + } + + //Unpack decodedBits and insert fields in the pipeline + offset = 0 + stageables.foreach(e => { + insert(e).assignFromBits(decodedBits(offset, e.dataType.getBitsWidth bits)) + // insert(e).assignFromBits(RegNext(decodedBits(offset, e.dataType.getBitsWidth bits))) + offset += e.dataType.getBitsWidth + }) + } + + if(catchIllegalInstruction){ + decodeExceptionPort.valid := arbitration.isValid && !input(LEGAL_INSTRUCTION) // ?? HalitIt to alow decoder stage to wait valid data from 2 stages cache cache ?? + decodeExceptionPort.code := 2 + decodeExceptionPort.badAddr := input(INSTRUCTION).asUInt + } + if(assertIllegalInstruction){ + pipeline.stages.tail.foreach(s => s.output(ASSERT_ERROR) clearWhen(s.arbitration.isRemoved)) + assert(!pipeline.stages.last.output(ASSERT_ERROR)) + } + } + + def bench(toplevel : VexRiscv): Unit ={ + toplevel.rework{ + import toplevel.config._ + toplevel.getAllIo.toList.foreach{io => + if(io.isInput) { io.assignDontCare()} + io.setAsDirectionLess() + } + toplevel.decode.input(INSTRUCTION).removeAssignments() + toplevel.decode.input(INSTRUCTION) := Delay((in Bits(32 bits)).setName("instruction"),2) + val stageables = encodings.flatMap(_._2.map(_._1)).toSet + stageables.foreach(e => out(RegNext(RegNext(toplevel.decode.insert(e)).setName(e.getName())))) + if(catchIllegalInstruction) out(RegNext(RegNext(toplevel.decode.insert(LEGAL_INSTRUCTION)).setName(LEGAL_INSTRUCTION.getName()))) + // toplevel.getAdditionalNodesRoot.clear() + } + } +} + +object DecodingBench extends App{ + SpinalVerilog{ + val top = GenFull.cpu() + top.service(classOf[DecoderSimplePlugin]).bench(top) + top + } +} + + +object Symplify{ + val cache = mutable.LinkedHashMap[Bits,mutable.LinkedHashMap[Masked,Bool]]() + def getCache(addr : Bits) = cache.getOrElseUpdate(addr,mutable.LinkedHashMap[Masked,Bool]()) + + //Generate terms logic for the given input + def logicOf(input : Bits,terms : Seq[Masked]) = terms.map(t => getCache(input).getOrElseUpdate(t,t === input)).asBits.orR + + //Decode 'input' b using an mapping[key, decoding] specification + def apply(input: Bits, mapping: Iterable[(Masked, Masked)],resultWidth : Int) : Bits = { + val addrWidth = widthOf(input) + (for(bitId <- 0 until resultWidth) yield{ + val trueTerm = mapping.filter { case (k,t) => (t.care.testBit(bitId) && t.value.testBit(bitId))}.map(_._1) + val falseTerm = mapping.filter { case (k,t) => (t.care.testBit(bitId) && !t.value.testBit(bitId))}.map(_._1) + val symplifiedTerms = SymplifyBit.getPrimeImplicantsByTrueAndFalse(trueTerm.toSeq, falseTerm.toSeq, addrWidth) + logicOf(input, symplifiedTerms) + }).asBits + } +} + +object SymplifyBit{ + + //Return a new term with only one bit difference with 'term' and not included in falseTerms. above => 0 to 1 dif, else 1 to 0 diff + def genImplicitDontCare(falseTerms: Seq[Masked], term: Masked, bits: Int, above: Boolean): Masked = { + for (i <- 0 until bits; if term.care.testBit(i)) { + var t: Masked = null + if(above) { + if (!term.value.testBit(i)) + t = Masked(term.value.setBit(i), term.care) + } else { + if (term.value.testBit(i)) + t = Masked(term.value.clearBit(i), term.care) + } + if (t != null && !falseTerms.exists(_.intersects(t))) { + t.isPrime = false + return t + } + } + null + } + + //Return primes implicants for the trueTerms, falseTerms spec. Default value is don't care + def getPrimeImplicantsByTrueAndFalse(trueTerms: Seq[Masked], falseTerms: Seq[Masked], inputWidth : Int): Seq[Masked] = { + val primes = mutable.LinkedHashSet[Masked]() + trueTerms.foreach(_.isPrime = true) + falseTerms.foreach(_.isPrime = true) + val trueTermByCareCount = (inputWidth to 0 by -1).map(b => trueTerms.filter(b == _.care.bitCount)) + //table[Vector[HashSet[Masked]]](careCount)(bitSetCount) + val table = trueTermByCareCount.map(c => (0 to inputWidth).map(b => collection.mutable.Set(c.filter(b == _.value.bitCount): _*))) + for (i <- 0 to inputWidth) { + //Expends explicit terms + for (j <- 0 until inputWidth - i){ + for(term <- table(i)(j)){ + table(i+1)(j) ++= table(i)(j+1).withFilter(_.isSimilarOneBitDifSmaller(term)).map(_.mergeOneBitDifSmaller(term)) + } + } + //Expends implicit don't care terms + for (j <- 0 until inputWidth-i) { + for (prime <- table(i)(j).withFilter(_.isPrime)) { + val dc = genImplicitDontCare(falseTerms, prime, inputWidth, true) + if (dc != null) + table(i+1)(j) += dc mergeOneBitDifSmaller prime + } + for (prime <- table(i)(j+1).withFilter(_.isPrime)) { + val dc = genImplicitDontCare(falseTerms, prime, inputWidth, false) + if (dc != null) + table(i+1)(j) += prime mergeOneBitDifSmaller dc + } + } + for (r <- table(i)) + for (p <- r; if p.isPrime) + primes += p + } + + def optimise() { + val duplicateds = primes.filter(prime => verifyTrueFalse(primes.filterNot(_ == prime), trueTerms, falseTerms)) + if(duplicateds.nonEmpty) { + primes -= duplicateds.maxBy(_.care.bitCount) + optimise() + } + } + + optimise() + + verifyTrueFalse(primes, trueTerms, falseTerms) + var duplication = 0 + for(prime <- primes){ + if(verifyTrueFalse(primes.filterNot(_ == prime), trueTerms, falseTerms)){ + duplication += 1 + } + } + if(duplication != 0){ + PendingError(s"Duplicated primes : $duplication") + } + primes.toSeq + } + + //Verify that the 'terms' doesn't violate the trueTerms ++ falseTerms spec + def verifyTrueFalse(terms : Iterable[Masked], trueTerms : Seq[Masked], falseTerms : Seq[Masked]): Boolean ={ + return (trueTerms.forall(trueTerm => terms.exists(_ covers trueTerm))) && (falseTerms.forall(falseTerm => !terms.exists(_ covers falseTerm))) + } + + def checkTrue(terms : Iterable[Masked], trueTerms : Seq[Masked]): Boolean ={ + return trueTerms.forall(trueTerm => terms.exists(_ covers trueTerm)) + } + + + def getPrimeImplicantsByTrue(trueTerms: Seq[Masked], inputWidth : Int) : Seq[Masked] = getPrimeImplicantsByTrueAndDontCare(trueTerms, Nil, inputWidth) + + // Return primes implicants for the trueTerms, default value is False. + // You can insert don't care values by adding non-prime implicants in the trueTerms + // Will simplify the trueTerms from the most constrained ones to the least constrained ones + def getPrimeImplicantsByTrueAndDontCare(trueTerms: Seq[Masked],dontCareTerms: Seq[Masked], inputWidth : Int): Seq[Masked] = { + val primes = mutable.LinkedHashSet[Masked]() + trueTerms.foreach(_.isPrime = true) + dontCareTerms.foreach(_.isPrime = false) + val termsByCareCount = (inputWidth to 0 by -1).map(b => (trueTerms ++ dontCareTerms).filter(b == _.care.bitCount)) + //table[Vector[HashSet[Masked]]](careCount)(bitSetCount) + val table = termsByCareCount.map(c => (0 to inputWidth).map(b => collection.mutable.Set(c.filter(m => b == m.value.bitCount): _*))) + for (i <- 0 to inputWidth) { + for (j <- 0 until inputWidth - i){ + for(term <- table(i)(j)){ + table(i+1)(j) ++= table(i)(j+1).withFilter(_.isSimilarOneBitDifSmaller(term)).map(_.mergeOneBitDifSmaller(term)) + } + } + for (r <- table(i)) + for (p <- r; if p.isPrime) + primes += p + } + + + def optimise() { + val duplicateds = primes.filter(prime => checkTrue(primes.filterNot(_ == prime), trueTerms)) + if(duplicateds.nonEmpty) { + primes -= duplicateds.maxBy(_.care.bitCount) + optimise() + } + } + + optimise() + + + var duplication = 0 + for(prime <- primes){ + if(checkTrue(primes.filterNot(_ == prime), trueTerms)){ + duplication += 1 + } + } + if(duplication != 0){ + PendingError(s"Duplicated primes : $duplication") + } + primes.toSeq + } + + def main(args: Array[String]) { + { + // val default = Masked(0, 0xF) + // val primeImplicants = List(4, 8, 10, 11, 12, 15).map(v => Masked(v, 0xF)) + // val dcImplicants = List(9, 14).map(v => Masked(v, 0xF).setPrime(false)) + // val reducedPrimeImplicants = getPrimeImplicantsByTrueAndDontCare(primeImplicants, dcImplicants, 4) + // println("UUT") + // println(reducedPrimeImplicants.map(_.toString(4)).mkString("\n")) + // println("REF") + // println("-100\n10--\n1--0\n1-1-") + } + + { + val primeImplicants = List(0).map(v => Masked(v, 0xF)) + val dcImplicants = (1 to 15).map(v => Masked(v, 0xF)) + val reducedPrimeImplicants = getPrimeImplicantsByTrueAndDontCare(primeImplicants, dcImplicants, 4) + println("UUT") + println(reducedPrimeImplicants.map(_.toString(4)).mkString("\n")) + } + { + val trueTerms = List(0, 15).map(v => Masked(v, 0xF)) + val falseTerms = List(3).map(v => Masked(v, 0xF)) + val primes = getPrimeImplicantsByTrueAndFalse(trueTerms, falseTerms, 4) + println(primes.map(_.toString(4)).mkString("\n")) + } + } +}
\ No newline at end of file |