aboutsummaryrefslogtreecommitdiff
path: root/VexRiscv/src/main/scala/vexriscv/plugin/BranchPlugin.scala
diff options
context:
space:
mode:
Diffstat (limited to 'VexRiscv/src/main/scala/vexriscv/plugin/BranchPlugin.scala')
-rw-r--r--VexRiscv/src/main/scala/vexriscv/plugin/BranchPlugin.scala386
1 files changed, 386 insertions, 0 deletions
diff --git a/VexRiscv/src/main/scala/vexriscv/plugin/BranchPlugin.scala b/VexRiscv/src/main/scala/vexriscv/plugin/BranchPlugin.scala
new file mode 100644
index 0000000..24d42fa
--- /dev/null
+++ b/VexRiscv/src/main/scala/vexriscv/plugin/BranchPlugin.scala
@@ -0,0 +1,386 @@
+package vexriscv.plugin
+
+import vexriscv.Riscv._
+import vexriscv._
+import spinal.core._
+import spinal.lib._
+
+trait BranchPrediction
+object NONE extends BranchPrediction
+object STATIC extends BranchPrediction
+object DYNAMIC extends BranchPrediction
+object DYNAMIC_TARGET extends BranchPrediction
+
+object BranchCtrlEnum extends SpinalEnum(binarySequential){
+ val INC,B,JAL,JALR = newElement()
+}
+object BRANCH_CTRL extends Stageable(BranchCtrlEnum())
+
+
+case class DecodePredictionCmd() extends Bundle {
+ val hadBranch = Bool
+}
+case class DecodePredictionRsp(stage : Stage) extends Bundle {
+ val wasWrong = Bool
+}
+case class DecodePredictionBus(stage : Stage) extends Bundle {
+ val cmd = DecodePredictionCmd()
+ val rsp = DecodePredictionRsp(stage)
+}
+
+case class FetchPredictionCmd() extends Bundle{
+ val hadBranch = Bool
+ val targetPc = UInt(32 bits)
+}
+case class FetchPredictionRsp() extends Bundle{
+ val wasRight = Bool
+ val finalPc = UInt(32 bits)
+ val sourceLastWord = UInt(32 bits)
+}
+case class FetchPredictionBus(stage : Stage) extends Bundle {
+ val cmd = FetchPredictionCmd()
+ val rsp = FetchPredictionRsp()
+}
+
+
+trait PredictionInterface{
+ def askFetchPrediction() : FetchPredictionBus
+ def askDecodePrediction() : DecodePredictionBus
+ def inDebugNoFetch() : Unit
+}
+
+
+
+class BranchPlugin(earlyBranch : Boolean,
+ catchAddressMisaligned : Boolean = false,
+ fenceiGenAsAJump : Boolean = false,
+ fenceiGenAsANop : Boolean = false,
+ decodeBranchSrc2 : Boolean = false) extends Plugin[VexRiscv] with PredictionInterface{
+
+
+ def catchAddressMisalignedForReal = catchAddressMisaligned && !pipeline.config.withRvc
+ lazy val branchStage = if(earlyBranch) pipeline.execute else pipeline.memory
+
+ object BRANCH_CALC extends Stageable(UInt(32 bits))
+ object BRANCH_DO extends Stageable(Bool)
+ object BRANCH_COND_RESULT extends Stageable(Bool)
+ object IS_FENCEI extends Stageable(Bool)
+
+ var jumpInterface : Flow[UInt] = null
+ var predictionExceptionPort : Flow[ExceptionCause] = null
+ var branchExceptionPort : Flow[ExceptionCause] = null
+ var inDebugNoFetchFlag : Bool = null
+
+
+ var decodePrediction : DecodePredictionBus = null
+ var fetchPrediction : FetchPredictionBus = null
+
+
+ override def askFetchPrediction() = {
+ fetchPrediction = FetchPredictionBus(branchStage)
+ fetchPrediction
+ }
+
+ override def askDecodePrediction() = {
+ decodePrediction = DecodePredictionBus(branchStage)
+ decodePrediction
+ }
+
+
+ override def inDebugNoFetch(): Unit = inDebugNoFetchFlag := True
+
+ def hasHazardOnBranch = if(earlyBranch) pipeline.service(classOf[HazardService]).hazardOnExecuteRS else False
+
+ override def setup(pipeline: VexRiscv): Unit = {
+ import Riscv._
+ import pipeline.config._
+ import IntAluPlugin._
+
+ assert(earlyBranch || withMemoryStage, "earlyBranch must be true when memory stage is disabled!")
+
+ val bActions = List[(Stageable[_ <: BaseType],Any)](
+ SRC1_CTRL -> Src1CtrlEnum.RS,
+ SRC2_CTRL -> Src2CtrlEnum.RS,
+ SRC_USE_SUB_LESS -> True,
+ RS1_USE -> True,
+ RS2_USE -> True,
+ HAS_SIDE_EFFECT -> True
+ )
+
+ val jActions = List[(Stageable[_ <: BaseType],Any)](
+ SRC1_CTRL -> Src1CtrlEnum.PC_INCREMENT,
+ SRC2_CTRL -> Src2CtrlEnum.PC,
+ SRC_USE_SUB_LESS -> False,
+ REGFILE_WRITE_VALID -> True,
+ HAS_SIDE_EFFECT -> True
+ )
+
+ val decoderService = pipeline.service(classOf[DecoderService])
+
+
+ decoderService.addDefault(BRANCH_CTRL, BranchCtrlEnum.INC)
+ decoderService.add(List(
+ JAL(true) -> (jActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.JAL, ALU_CTRL -> AluCtrlEnum.ADD_SUB)),
+ JALR -> (jActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.JALR, ALU_CTRL -> AluCtrlEnum.ADD_SUB, RS1_USE -> True)),
+ BEQ(true) -> (bActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.B)),
+ BNE(true) -> (bActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.B)),
+ BLT(true) -> (bActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.B, SRC_LESS_UNSIGNED -> False)),
+ BGE(true) -> (bActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.B, SRC_LESS_UNSIGNED -> False)),
+ BLTU(true) -> (bActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.B, SRC_LESS_UNSIGNED -> True)),
+ BGEU(true) -> (bActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.B, SRC_LESS_UNSIGNED -> True))
+ ))
+
+ if(fenceiGenAsAJump) {
+ decoderService.addDefault(IS_FENCEI, False)
+ decoderService.add(List(
+ FENCEI -> (List(IS_FENCEI -> True,HAS_SIDE_EFFECT -> True, BRANCH_CTRL -> BranchCtrlEnum.JAL))
+ ))
+ }
+
+ if(fenceiGenAsANop){
+ decoderService.add(List(FENCEI -> List()))
+ }
+
+ val pcManagerService = pipeline.service(classOf[JumpService])
+
+ //Priority -1, as DYNAMIC_TARGET misspredicted on non branch instruction should lose against other instructions
+ //legitim branches, as MRET for instance
+ jumpInterface = pcManagerService.createJumpInterface(branchStage, priority = -10)
+
+
+ if (catchAddressMisalignedForReal) {
+ val exceptionService = pipeline.service(classOf[ExceptionService])
+ branchExceptionPort = exceptionService.newExceptionPort(branchStage)
+ }
+ inDebugNoFetchFlag = False.setCompositeName(this, "inDebugNoFetchFlag")
+ }
+
+ override def build(pipeline: VexRiscv): Unit = {
+ (fetchPrediction,decodePrediction) match {
+ case (null, null) => buildWithoutPrediction(pipeline)
+ case (_ , null) => buildFetchPrediction(pipeline)
+ case (null, _) => buildDecodePrediction(pipeline)
+ }
+ if(fenceiGenAsAJump) {
+ import pipeline._
+ import pipeline.config._
+ when(decode.input(IS_FENCEI)) {
+ decode.output(INSTRUCTION)(12) := False
+ decode.output(INSTRUCTION)(22) := True
+ }
+ execute.arbitration.haltByOther setWhen(execute.arbitration.isValid && execute.input(IS_FENCEI) && stagesFromExecute.tail.map(_.arbitration.isValid).asBits.orR)
+ }
+ }
+
+ def buildWithoutPrediction(pipeline: VexRiscv): Unit = {
+ import pipeline._
+ import pipeline.config._
+
+ //Do branch calculations (conditions + target PC)
+ execute plug new Area {
+ import execute._
+
+ val less = input(SRC_LESS)
+ val eq = input(SRC1) === input(SRC2)
+
+ insert(BRANCH_DO) := input(BRANCH_CTRL).mux(
+ BranchCtrlEnum.INC -> False,
+ BranchCtrlEnum.JAL -> True,
+ BranchCtrlEnum.JALR -> True,
+ BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux(
+ B"000" -> eq ,
+ B"001" -> !eq ,
+ M"1-1" -> !less,
+ default -> less
+ )
+ )
+
+ val imm = IMM(input(INSTRUCTION))
+ val branch_src1 = (input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? input(RS1).asUInt | input(PC)
+ val branch_src2 = input(BRANCH_CTRL).mux(
+ BranchCtrlEnum.JAL -> imm.j_sext,
+ BranchCtrlEnum.JALR -> imm.i_sext,
+ default -> imm.b_sext
+ ).asUInt
+
+ val branchAdder = branch_src1 + branch_src2
+ insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ U"0"
+ }
+
+ //Apply branchs (JAL,JALR, Bxx)
+ branchStage plug new Area {
+ import branchStage._
+ jumpInterface.valid := arbitration.isValid && input(BRANCH_DO) && !hasHazardOnBranch
+ jumpInterface.payload := input(BRANCH_CALC)
+ arbitration.flushNext setWhen(jumpInterface.valid)
+
+ if(catchAddressMisalignedForReal) {
+ branchExceptionPort.valid := arbitration.isValid && input(BRANCH_DO) && jumpInterface.payload(1)
+ branchExceptionPort.code := 0
+ branchExceptionPort.badAddr := jumpInterface.payload
+
+ if(branchStage == execute) branchExceptionPort.valid clearWhen(service(classOf[HazardService]).hazardOnExecuteRS)
+ }
+ }
+ }
+
+
+ def buildDecodePrediction(pipeline: VexRiscv): Unit = {
+ object PREDICTION_HAD_BRANCHED extends Stageable(Bool)
+
+ import pipeline._
+ import pipeline.config._
+
+
+ decode plug new Area {
+ import decode._
+ insert(PREDICTION_HAD_BRANCHED) := (if(fenceiGenAsAJump) decodePrediction.cmd.hadBranch && !decode.input(IS_FENCEI) else decodePrediction.cmd.hadBranch)
+ }
+
+ //Do real branch calculation
+ execute plug new Area {
+ import execute._
+
+ val less = input(SRC_LESS)
+ val eq = input(SRC1) === input(SRC2)
+
+ insert(BRANCH_COND_RESULT) := input(BRANCH_CTRL).mux(
+ BranchCtrlEnum.INC -> False,
+ BranchCtrlEnum.JAL -> True,
+ BranchCtrlEnum.JALR -> True,
+ BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux(
+ B"000" -> eq ,
+ B"001" -> !eq ,
+ M"1-1" -> !less,
+ default -> less
+ )
+ )
+
+ val imm = IMM(input(INSTRUCTION))
+ val missAlignedTarget = if(pipeline.config.withRvc) False else (input(BRANCH_COND_RESULT) && input(BRANCH_CTRL).mux(
+ BranchCtrlEnum.JALR -> (imm.i_sext(1) ^ input(RS1)(1)),
+ BranchCtrlEnum.JAL -> imm.j_sext(1),
+ default -> imm.b_sext(1)
+ ))
+
+ insert(BRANCH_DO) := input(PREDICTION_HAD_BRANCHED) =/= input(BRANCH_COND_RESULT) || missAlignedTarget
+
+ //Calculation of the branch target / correction
+ val branch_src1,branch_src2 = UInt(32 bits)
+ switch(input(BRANCH_CTRL)){
+ is(BranchCtrlEnum.JALR){
+ branch_src1 := input(RS1).asUInt
+ branch_src2 := imm.i_sext.asUInt
+ }
+ default{
+ branch_src1 := input(PC)
+ branch_src2 := ((input(BRANCH_CTRL) === BranchCtrlEnum.JAL) ? imm.j_sext | imm.b_sext).asUInt
+ when(input(PREDICTION_HAD_BRANCHED)){ //Assume the predictor never predict missaligned stuff, this avoid the need to know if the instruction should branch or not
+ branch_src2 := (if(pipeline.config.withRvc) Mux(input(IS_RVC), B(2), B(4)) else B(4)).asUInt.resized
+ }
+ }
+ }
+ val branchAdder = branch_src1 + branch_src2
+ insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ U"0"
+ }
+
+
+ // branch JALR or JAL/Bxx prediction miss corrections
+ val branchStage = if(earlyBranch) execute else memory
+ branchStage plug new Area {
+ import branchStage._
+ jumpInterface.valid := arbitration.isValid && input(BRANCH_DO) && !hasHazardOnBranch
+ jumpInterface.payload := input(BRANCH_CALC)
+ arbitration.flushNext setWhen(jumpInterface.valid)
+
+ if(catchAddressMisalignedForReal) {
+ val unalignedJump = input(BRANCH_DO) && input(BRANCH_CALC)(1)
+ branchExceptionPort.valid := arbitration.isValid && unalignedJump
+ branchExceptionPort.code := 0
+ branchExceptionPort.badAddr := input(BRANCH_CALC) //pipeline.stages(pipeline.indexOf(branchStage)-1).input
+
+ if(branchStage == execute) branchExceptionPort.valid clearWhen(service(classOf[HazardService]).hazardOnExecuteRS)
+ }
+ }
+
+ decodePrediction.rsp.wasWrong := jumpInterface.valid
+ }
+
+
+
+
+
+ def buildFetchPrediction(pipeline: VexRiscv): Unit = {
+ import pipeline._
+ import pipeline.config._
+
+
+ //Do branch calculations (conditions + target PC)
+ object NEXT_PC extends Stageable(UInt(32 bits))
+ object TARGET_MISSMATCH extends Stageable(Bool)
+ object BRANCH_SRC2 extends Stageable(UInt(32 bits))
+ val branchSrc2Stage = if(decodeBranchSrc2) decode else execute
+ execute plug new Area {
+ import execute._
+
+ val less = input(SRC_LESS)
+ val eq = input(SRC1) === input(SRC2)
+
+ insert(BRANCH_DO) := input(BRANCH_CTRL).mux(
+ BranchCtrlEnum.INC -> False,
+ BranchCtrlEnum.JAL -> True,
+ BranchCtrlEnum.JALR -> True,
+ BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux(
+ B"000" -> eq ,
+ B"001" -> !eq ,
+ M"1-1" -> !less,
+ default -> less
+ )
+ )
+
+ val branch_src1 = (input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? input(RS1).asUInt | input(PC)
+
+ val imm = IMM(branchSrc2Stage.input(INSTRUCTION))
+ branchSrc2Stage.insert(BRANCH_SRC2) := branchSrc2Stage.input(BRANCH_CTRL).mux(
+ BranchCtrlEnum.JAL -> imm.j_sext,
+ BranchCtrlEnum.JALR -> imm.i_sext,
+ default -> imm.b_sext
+ ).asUInt
+
+ val branchAdder = branch_src1 + input(BRANCH_SRC2)
+ insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ U"0"
+ insert(NEXT_PC) := input(PC) + (if(pipeline.config.withRvc) ((input(IS_RVC)) ? U(2) | U(4)) else 4)
+ insert(TARGET_MISSMATCH) := decode.input(PC) =/= input(BRANCH_CALC)
+ }
+
+ //Apply branchs (JAL,JALR, Bxx)
+ val branchStage = if(earlyBranch) execute else memory
+ branchStage plug new Area {
+ import branchStage._
+
+ val predictionMissmatch = fetchPrediction.cmd.hadBranch =/= input(BRANCH_DO) || (input(BRANCH_DO) && input(TARGET_MISSMATCH))
+ when(inDebugNoFetchFlag) { predictionMissmatch := input(BRANCH_DO)}
+ fetchPrediction.rsp.wasRight := ! predictionMissmatch
+ fetchPrediction.rsp.finalPc := input(BRANCH_CALC)
+ fetchPrediction.rsp.sourceLastWord := {
+ if(pipeline.config.withRvc)
+ ((!input(IS_RVC) && input(PC)(1)) ? input(NEXT_PC) | input(PC))
+ else
+ input(PC)
+ }
+
+ jumpInterface.valid := arbitration.isValid && predictionMissmatch && !hasHazardOnBranch
+ jumpInterface.payload := (input(BRANCH_DO) ? input(BRANCH_CALC) | input(NEXT_PC))
+ arbitration.flushNext setWhen(jumpInterface.valid)
+
+
+ if(catchAddressMisalignedForReal) {
+ branchExceptionPort.valid := arbitration.isValid && input(BRANCH_DO) && input(BRANCH_CALC)(1)
+ branchExceptionPort.code := 0
+ branchExceptionPort.badAddr := input(BRANCH_CALC)
+
+ if(branchStage == execute) branchExceptionPort.valid clearWhen(service(classOf[HazardService]).hazardOnExecuteRS)
+ }
+ }
+ }
+}