diff options
Diffstat (limited to 'VexRiscv/src/main/scala/vexriscv/plugin/BranchPlugin.scala')
-rw-r--r-- | VexRiscv/src/main/scala/vexriscv/plugin/BranchPlugin.scala | 386 |
1 files changed, 386 insertions, 0 deletions
diff --git a/VexRiscv/src/main/scala/vexriscv/plugin/BranchPlugin.scala b/VexRiscv/src/main/scala/vexriscv/plugin/BranchPlugin.scala new file mode 100644 index 0000000..24d42fa --- /dev/null +++ b/VexRiscv/src/main/scala/vexriscv/plugin/BranchPlugin.scala @@ -0,0 +1,386 @@ +package vexriscv.plugin + +import vexriscv.Riscv._ +import vexriscv._ +import spinal.core._ +import spinal.lib._ + +trait BranchPrediction +object NONE extends BranchPrediction +object STATIC extends BranchPrediction +object DYNAMIC extends BranchPrediction +object DYNAMIC_TARGET extends BranchPrediction + +object BranchCtrlEnum extends SpinalEnum(binarySequential){ + val INC,B,JAL,JALR = newElement() +} +object BRANCH_CTRL extends Stageable(BranchCtrlEnum()) + + +case class DecodePredictionCmd() extends Bundle { + val hadBranch = Bool +} +case class DecodePredictionRsp(stage : Stage) extends Bundle { + val wasWrong = Bool +} +case class DecodePredictionBus(stage : Stage) extends Bundle { + val cmd = DecodePredictionCmd() + val rsp = DecodePredictionRsp(stage) +} + +case class FetchPredictionCmd() extends Bundle{ + val hadBranch = Bool + val targetPc = UInt(32 bits) +} +case class FetchPredictionRsp() extends Bundle{ + val wasRight = Bool + val finalPc = UInt(32 bits) + val sourceLastWord = UInt(32 bits) +} +case class FetchPredictionBus(stage : Stage) extends Bundle { + val cmd = FetchPredictionCmd() + val rsp = FetchPredictionRsp() +} + + +trait PredictionInterface{ + def askFetchPrediction() : FetchPredictionBus + def askDecodePrediction() : DecodePredictionBus + def inDebugNoFetch() : Unit +} + + + +class BranchPlugin(earlyBranch : Boolean, + catchAddressMisaligned : Boolean = false, + fenceiGenAsAJump : Boolean = false, + fenceiGenAsANop : Boolean = false, + decodeBranchSrc2 : Boolean = false) extends Plugin[VexRiscv] with PredictionInterface{ + + + def catchAddressMisalignedForReal = catchAddressMisaligned && !pipeline.config.withRvc + lazy val branchStage = if(earlyBranch) pipeline.execute else pipeline.memory + + object BRANCH_CALC extends Stageable(UInt(32 bits)) + object BRANCH_DO extends Stageable(Bool) + object BRANCH_COND_RESULT extends Stageable(Bool) + object IS_FENCEI extends Stageable(Bool) + + var jumpInterface : Flow[UInt] = null + var predictionExceptionPort : Flow[ExceptionCause] = null + var branchExceptionPort : Flow[ExceptionCause] = null + var inDebugNoFetchFlag : Bool = null + + + var decodePrediction : DecodePredictionBus = null + var fetchPrediction : FetchPredictionBus = null + + + override def askFetchPrediction() = { + fetchPrediction = FetchPredictionBus(branchStage) + fetchPrediction + } + + override def askDecodePrediction() = { + decodePrediction = DecodePredictionBus(branchStage) + decodePrediction + } + + + override def inDebugNoFetch(): Unit = inDebugNoFetchFlag := True + + def hasHazardOnBranch = if(earlyBranch) pipeline.service(classOf[HazardService]).hazardOnExecuteRS else False + + override def setup(pipeline: VexRiscv): Unit = { + import Riscv._ + import pipeline.config._ + import IntAluPlugin._ + + assert(earlyBranch || withMemoryStage, "earlyBranch must be true when memory stage is disabled!") + + val bActions = List[(Stageable[_ <: BaseType],Any)]( + SRC1_CTRL -> Src1CtrlEnum.RS, + SRC2_CTRL -> Src2CtrlEnum.RS, + SRC_USE_SUB_LESS -> True, + RS1_USE -> True, + RS2_USE -> True, + HAS_SIDE_EFFECT -> True + ) + + val jActions = List[(Stageable[_ <: BaseType],Any)]( + SRC1_CTRL -> Src1CtrlEnum.PC_INCREMENT, + SRC2_CTRL -> Src2CtrlEnum.PC, + SRC_USE_SUB_LESS -> False, + REGFILE_WRITE_VALID -> True, + HAS_SIDE_EFFECT -> True + ) + + val decoderService = pipeline.service(classOf[DecoderService]) + + + decoderService.addDefault(BRANCH_CTRL, BranchCtrlEnum.INC) + decoderService.add(List( + JAL(true) -> (jActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.JAL, ALU_CTRL -> AluCtrlEnum.ADD_SUB)), + JALR -> (jActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.JALR, ALU_CTRL -> AluCtrlEnum.ADD_SUB, RS1_USE -> True)), + BEQ(true) -> (bActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.B)), + BNE(true) -> (bActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.B)), + BLT(true) -> (bActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.B, SRC_LESS_UNSIGNED -> False)), + BGE(true) -> (bActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.B, SRC_LESS_UNSIGNED -> False)), + BLTU(true) -> (bActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.B, SRC_LESS_UNSIGNED -> True)), + BGEU(true) -> (bActions ++ List(BRANCH_CTRL -> BranchCtrlEnum.B, SRC_LESS_UNSIGNED -> True)) + )) + + if(fenceiGenAsAJump) { + decoderService.addDefault(IS_FENCEI, False) + decoderService.add(List( + FENCEI -> (List(IS_FENCEI -> True,HAS_SIDE_EFFECT -> True, BRANCH_CTRL -> BranchCtrlEnum.JAL)) + )) + } + + if(fenceiGenAsANop){ + decoderService.add(List(FENCEI -> List())) + } + + val pcManagerService = pipeline.service(classOf[JumpService]) + + //Priority -1, as DYNAMIC_TARGET misspredicted on non branch instruction should lose against other instructions + //legitim branches, as MRET for instance + jumpInterface = pcManagerService.createJumpInterface(branchStage, priority = -10) + + + if (catchAddressMisalignedForReal) { + val exceptionService = pipeline.service(classOf[ExceptionService]) + branchExceptionPort = exceptionService.newExceptionPort(branchStage) + } + inDebugNoFetchFlag = False.setCompositeName(this, "inDebugNoFetchFlag") + } + + override def build(pipeline: VexRiscv): Unit = { + (fetchPrediction,decodePrediction) match { + case (null, null) => buildWithoutPrediction(pipeline) + case (_ , null) => buildFetchPrediction(pipeline) + case (null, _) => buildDecodePrediction(pipeline) + } + if(fenceiGenAsAJump) { + import pipeline._ + import pipeline.config._ + when(decode.input(IS_FENCEI)) { + decode.output(INSTRUCTION)(12) := False + decode.output(INSTRUCTION)(22) := True + } + execute.arbitration.haltByOther setWhen(execute.arbitration.isValid && execute.input(IS_FENCEI) && stagesFromExecute.tail.map(_.arbitration.isValid).asBits.orR) + } + } + + def buildWithoutPrediction(pipeline: VexRiscv): Unit = { + import pipeline._ + import pipeline.config._ + + //Do branch calculations (conditions + target PC) + execute plug new Area { + import execute._ + + val less = input(SRC_LESS) + val eq = input(SRC1) === input(SRC2) + + insert(BRANCH_DO) := input(BRANCH_CTRL).mux( + BranchCtrlEnum.INC -> False, + BranchCtrlEnum.JAL -> True, + BranchCtrlEnum.JALR -> True, + BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux( + B"000" -> eq , + B"001" -> !eq , + M"1-1" -> !less, + default -> less + ) + ) + + val imm = IMM(input(INSTRUCTION)) + val branch_src1 = (input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? input(RS1).asUInt | input(PC) + val branch_src2 = input(BRANCH_CTRL).mux( + BranchCtrlEnum.JAL -> imm.j_sext, + BranchCtrlEnum.JALR -> imm.i_sext, + default -> imm.b_sext + ).asUInt + + val branchAdder = branch_src1 + branch_src2 + insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ U"0" + } + + //Apply branchs (JAL,JALR, Bxx) + branchStage plug new Area { + import branchStage._ + jumpInterface.valid := arbitration.isValid && input(BRANCH_DO) && !hasHazardOnBranch + jumpInterface.payload := input(BRANCH_CALC) + arbitration.flushNext setWhen(jumpInterface.valid) + + if(catchAddressMisalignedForReal) { + branchExceptionPort.valid := arbitration.isValid && input(BRANCH_DO) && jumpInterface.payload(1) + branchExceptionPort.code := 0 + branchExceptionPort.badAddr := jumpInterface.payload + + if(branchStage == execute) branchExceptionPort.valid clearWhen(service(classOf[HazardService]).hazardOnExecuteRS) + } + } + } + + + def buildDecodePrediction(pipeline: VexRiscv): Unit = { + object PREDICTION_HAD_BRANCHED extends Stageable(Bool) + + import pipeline._ + import pipeline.config._ + + + decode plug new Area { + import decode._ + insert(PREDICTION_HAD_BRANCHED) := (if(fenceiGenAsAJump) decodePrediction.cmd.hadBranch && !decode.input(IS_FENCEI) else decodePrediction.cmd.hadBranch) + } + + //Do real branch calculation + execute plug new Area { + import execute._ + + val less = input(SRC_LESS) + val eq = input(SRC1) === input(SRC2) + + insert(BRANCH_COND_RESULT) := input(BRANCH_CTRL).mux( + BranchCtrlEnum.INC -> False, + BranchCtrlEnum.JAL -> True, + BranchCtrlEnum.JALR -> True, + BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux( + B"000" -> eq , + B"001" -> !eq , + M"1-1" -> !less, + default -> less + ) + ) + + val imm = IMM(input(INSTRUCTION)) + val missAlignedTarget = if(pipeline.config.withRvc) False else (input(BRANCH_COND_RESULT) && input(BRANCH_CTRL).mux( + BranchCtrlEnum.JALR -> (imm.i_sext(1) ^ input(RS1)(1)), + BranchCtrlEnum.JAL -> imm.j_sext(1), + default -> imm.b_sext(1) + )) + + insert(BRANCH_DO) := input(PREDICTION_HAD_BRANCHED) =/= input(BRANCH_COND_RESULT) || missAlignedTarget + + //Calculation of the branch target / correction + val branch_src1,branch_src2 = UInt(32 bits) + switch(input(BRANCH_CTRL)){ + is(BranchCtrlEnum.JALR){ + branch_src1 := input(RS1).asUInt + branch_src2 := imm.i_sext.asUInt + } + default{ + branch_src1 := input(PC) + branch_src2 := ((input(BRANCH_CTRL) === BranchCtrlEnum.JAL) ? imm.j_sext | imm.b_sext).asUInt + when(input(PREDICTION_HAD_BRANCHED)){ //Assume the predictor never predict missaligned stuff, this avoid the need to know if the instruction should branch or not + branch_src2 := (if(pipeline.config.withRvc) Mux(input(IS_RVC), B(2), B(4)) else B(4)).asUInt.resized + } + } + } + val branchAdder = branch_src1 + branch_src2 + insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ U"0" + } + + + // branch JALR or JAL/Bxx prediction miss corrections + val branchStage = if(earlyBranch) execute else memory + branchStage plug new Area { + import branchStage._ + jumpInterface.valid := arbitration.isValid && input(BRANCH_DO) && !hasHazardOnBranch + jumpInterface.payload := input(BRANCH_CALC) + arbitration.flushNext setWhen(jumpInterface.valid) + + if(catchAddressMisalignedForReal) { + val unalignedJump = input(BRANCH_DO) && input(BRANCH_CALC)(1) + branchExceptionPort.valid := arbitration.isValid && unalignedJump + branchExceptionPort.code := 0 + branchExceptionPort.badAddr := input(BRANCH_CALC) //pipeline.stages(pipeline.indexOf(branchStage)-1).input + + if(branchStage == execute) branchExceptionPort.valid clearWhen(service(classOf[HazardService]).hazardOnExecuteRS) + } + } + + decodePrediction.rsp.wasWrong := jumpInterface.valid + } + + + + + + def buildFetchPrediction(pipeline: VexRiscv): Unit = { + import pipeline._ + import pipeline.config._ + + + //Do branch calculations (conditions + target PC) + object NEXT_PC extends Stageable(UInt(32 bits)) + object TARGET_MISSMATCH extends Stageable(Bool) + object BRANCH_SRC2 extends Stageable(UInt(32 bits)) + val branchSrc2Stage = if(decodeBranchSrc2) decode else execute + execute plug new Area { + import execute._ + + val less = input(SRC_LESS) + val eq = input(SRC1) === input(SRC2) + + insert(BRANCH_DO) := input(BRANCH_CTRL).mux( + BranchCtrlEnum.INC -> False, + BranchCtrlEnum.JAL -> True, + BranchCtrlEnum.JALR -> True, + BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux( + B"000" -> eq , + B"001" -> !eq , + M"1-1" -> !less, + default -> less + ) + ) + + val branch_src1 = (input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? input(RS1).asUInt | input(PC) + + val imm = IMM(branchSrc2Stage.input(INSTRUCTION)) + branchSrc2Stage.insert(BRANCH_SRC2) := branchSrc2Stage.input(BRANCH_CTRL).mux( + BranchCtrlEnum.JAL -> imm.j_sext, + BranchCtrlEnum.JALR -> imm.i_sext, + default -> imm.b_sext + ).asUInt + + val branchAdder = branch_src1 + input(BRANCH_SRC2) + insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ U"0" + insert(NEXT_PC) := input(PC) + (if(pipeline.config.withRvc) ((input(IS_RVC)) ? U(2) | U(4)) else 4) + insert(TARGET_MISSMATCH) := decode.input(PC) =/= input(BRANCH_CALC) + } + + //Apply branchs (JAL,JALR, Bxx) + val branchStage = if(earlyBranch) execute else memory + branchStage plug new Area { + import branchStage._ + + val predictionMissmatch = fetchPrediction.cmd.hadBranch =/= input(BRANCH_DO) || (input(BRANCH_DO) && input(TARGET_MISSMATCH)) + when(inDebugNoFetchFlag) { predictionMissmatch := input(BRANCH_DO)} + fetchPrediction.rsp.wasRight := ! predictionMissmatch + fetchPrediction.rsp.finalPc := input(BRANCH_CALC) + fetchPrediction.rsp.sourceLastWord := { + if(pipeline.config.withRvc) + ((!input(IS_RVC) && input(PC)(1)) ? input(NEXT_PC) | input(PC)) + else + input(PC) + } + + jumpInterface.valid := arbitration.isValid && predictionMissmatch && !hasHazardOnBranch + jumpInterface.payload := (input(BRANCH_DO) ? input(BRANCH_CALC) | input(NEXT_PC)) + arbitration.flushNext setWhen(jumpInterface.valid) + + + if(catchAddressMisalignedForReal) { + branchExceptionPort.valid := arbitration.isValid && input(BRANCH_DO) && input(BRANCH_CALC)(1) + branchExceptionPort.code := 0 + branchExceptionPort.badAddr := input(BRANCH_CALC) + + if(branchStage == execute) branchExceptionPort.valid clearWhen(service(classOf[HazardService]).hazardOnExecuteRS) + } + } + } +} |