aboutsummaryrefslogtreecommitdiff
path: root/VexRiscv/src/main/scala/vexriscv/plugin/MmuPlugin.scala
diff options
context:
space:
mode:
Diffstat (limited to 'VexRiscv/src/main/scala/vexriscv/plugin/MmuPlugin.scala')
-rw-r--r--VexRiscv/src/main/scala/vexriscv/plugin/MmuPlugin.scala313
1 files changed, 313 insertions, 0 deletions
diff --git a/VexRiscv/src/main/scala/vexriscv/plugin/MmuPlugin.scala b/VexRiscv/src/main/scala/vexriscv/plugin/MmuPlugin.scala
new file mode 100644
index 0000000..093f59a
--- /dev/null
+++ b/VexRiscv/src/main/scala/vexriscv/plugin/MmuPlugin.scala
@@ -0,0 +1,313 @@
+package vexriscv.plugin
+
+import vexriscv.{VexRiscv, _}
+import spinal.core._
+import spinal.lib._
+
+import scala.collection.mutable.ArrayBuffer
+
+trait DBusAccessService{
+ def newDBusAccess() : DBusAccess
+}
+
+case class DBusAccessCmd() extends Bundle {
+ val address = UInt(32 bits)
+ val size = UInt(2 bits)
+ val write = Bool
+ val data = Bits(32 bits)
+ val writeMask = Bits(4 bits)
+}
+
+case class DBusAccessRsp() extends Bundle {
+ val data = Bits(32 bits)
+ val error = Bool()
+ val redo = Bool()
+}
+
+case class DBusAccess() extends Bundle {
+ val cmd = Stream(DBusAccessCmd())
+ val rsp = Flow(DBusAccessRsp())
+}
+
+
+object MmuPort{
+ val PRIORITY_DATA = 1
+ val PRIORITY_INSTRUCTION = 0
+}
+case class MmuPort(bus : MemoryTranslatorBus, priority : Int, args : MmuPortConfig, id : Int)
+
+case class MmuPortConfig(portTlbSize : Int, latency : Int = 0, earlyRequireMmuLockup : Boolean = false, earlyCacheHits : Boolean = false)
+
+class MmuPlugin(ioRange : UInt => Bool,
+ virtualRange : UInt => Bool = address => True,
+// allowUserIo : Boolean = false,
+ enableMmuInMachineMode : Boolean = false) extends Plugin[VexRiscv] with MemoryTranslator {
+
+ var dBusAccess : DBusAccess = null
+ val portsInfo = ArrayBuffer[MmuPort]()
+
+ override def newTranslationPort(priority : Int,args : Any): MemoryTranslatorBus = {
+ val config = args.asInstanceOf[MmuPortConfig]
+ val port = MmuPort(MemoryTranslatorBus(MemoryTranslatorBusParameter(wayCount = config.portTlbSize, latency = config.latency)),priority, config, portsInfo.length)
+ portsInfo += port
+ port.bus
+ }
+
+ object IS_SFENCE_VMA2 extends Stageable(Bool)
+ override def setup(pipeline: VexRiscv): Unit = {
+ import Riscv._
+ import pipeline.config._
+ val decoderService = pipeline.service(classOf[DecoderService])
+ decoderService.addDefault(IS_SFENCE_VMA2, False)
+ decoderService.add(SFENCE_VMA, List(IS_SFENCE_VMA2 -> True))
+
+
+ dBusAccess = pipeline.service(classOf[DBusAccessService]).newDBusAccess()
+ }
+
+ override def build(pipeline: VexRiscv): Unit = {
+ import pipeline._
+ import pipeline.config._
+ import Riscv._
+ val csrService = pipeline.service(classOf[CsrInterface])
+
+ //Sorted by priority
+ val sortedPortsInfo = portsInfo.sortBy(_.priority)
+
+ case class CacheLine() extends Bundle {
+ val valid, exception, superPage = Bool
+ val virtualAddress = Vec(UInt(10 bits), UInt(10 bits))
+ val physicalAddress = Vec(UInt(10 bits), UInt(10 bits))
+ val allowRead, allowWrite, allowExecute, allowUser = Bool
+
+ def init = {
+ valid init (False)
+ this
+ }
+ }
+
+ val csr = pipeline plug new Area{
+ val status = new Area{
+ val sum, mxr, mprv = RegInit(False)
+ }
+ val satp = new Area {
+ val mode = RegInit(False)
+ val asid = Reg(Bits(9 bits))
+ val ppn = Reg(UInt(20 bits))
+ }
+
+ for(offset <- List(CSR.MSTATUS, CSR.SSTATUS)) csrService.rw(offset, 19 -> status.mxr, 18 -> status.sum, 17 -> status.mprv)
+ csrService.rw(CSR.SATP, 31 -> satp.mode, 22 -> satp.asid, 0 -> satp.ppn)
+ }
+
+ val core = pipeline plug new Area {
+ val ports = for (port <- sortedPortsInfo) yield new Area {
+ val handle = port
+ val id = port.id
+ val privilegeService = pipeline.serviceElse(classOf[PrivilegeService], PrivilegeServiceDefault())
+ val cache = Vec(Reg(CacheLine()) init, port.args.portTlbSize)
+ val dirty = RegInit(False).allowUnsetRegToAvoidLatch
+ if(port.args.earlyRequireMmuLockup){
+ dirty clearWhen(!port.bus.cmd.last.isStuck)
+ }
+
+ def toRsp[T <: Data](data : T, from : MemoryTranslatorCmd) : T = from match {
+ case _ if from == port.bus.cmd.last => data
+ case _ => {
+ val next = port.bus.cmd.dropWhile(_ != from)(1)
+ toRsp(RegNextWhen(data, !next.isStuck), next)
+ }
+ }
+ val requireMmuLockupCmd = port.bus.cmd.takeRight(if(port.args.earlyRequireMmuLockup) 2 else 1).head
+
+ val requireMmuLockupCalc = virtualRange(requireMmuLockupCmd.virtualAddress) && !requireMmuLockupCmd.bypassTranslation && csr.satp.mode
+ if(!enableMmuInMachineMode) {
+ requireMmuLockupCalc clearWhen(!csr.status.mprv && privilegeService.isMachine())
+ when(privilegeService.isMachine()) {
+ if (port.priority == MmuPort.PRIORITY_DATA) {
+ requireMmuLockupCalc clearWhen (!csr.status.mprv || pipeline(MPP) === 3)
+ } else {
+ requireMmuLockupCalc := False
+ }
+ }
+ }
+
+ val cacheHitsCmd = port.bus.cmd.takeRight(if(port.args.earlyCacheHits) 2 else 1).head
+ val cacheHitsCalc = B(cache.map(line => line.valid && line.virtualAddress(1) === cacheHitsCmd.virtualAddress(31 downto 22) && (line.superPage || line.virtualAddress(0) === cacheHitsCmd.virtualAddress(21 downto 12))))
+
+
+ val requireMmuLockup = toRsp(requireMmuLockupCalc, requireMmuLockupCmd)
+ val cacheHits = toRsp(cacheHitsCalc, cacheHitsCmd)
+
+ val cacheHit = cacheHits.asBits.orR
+ val cacheLine = MuxOH(cacheHits, cache)
+ val entryToReplace = Counter(port.args.portTlbSize)
+
+
+ when(requireMmuLockup) {
+ port.bus.rsp.physicalAddress := cacheLine.physicalAddress(1) @@ (cacheLine.superPage ? port.bus.cmd.last.virtualAddress(21 downto 12) | cacheLine.physicalAddress(0)) @@ port.bus.cmd.last.virtualAddress(11 downto 0)
+ port.bus.rsp.allowRead := cacheLine.allowRead || csr.status.mxr && cacheLine.allowExecute
+ port.bus.rsp.allowWrite := cacheLine.allowWrite
+ port.bus.rsp.allowExecute := cacheLine.allowExecute
+ port.bus.rsp.exception := !dirty && cacheHit && (cacheLine.exception || cacheLine.allowUser && privilegeService.isSupervisor() && !csr.status.sum || !cacheLine.allowUser && privilegeService.isUser())
+ port.bus.rsp.refilling := dirty || !cacheHit
+ port.bus.rsp.isPaging := True
+ } otherwise {
+ port.bus.rsp.physicalAddress := port.bus.cmd.last.virtualAddress
+ port.bus.rsp.allowRead := True
+ port.bus.rsp.allowWrite := True
+ port.bus.rsp.allowExecute := True
+ port.bus.rsp.exception := False
+ port.bus.rsp.refilling := False
+ port.bus.rsp.isPaging := False
+ }
+ port.bus.rsp.isIoAccess := ioRange(port.bus.rsp.physicalAddress)
+
+ port.bus.rsp.bypassTranslation := !requireMmuLockup
+ for(wayId <- 0 until port.args.portTlbSize){
+ port.bus.rsp.ways(wayId).sel := cacheHits(wayId)
+ port.bus.rsp.ways(wayId).physical := cache(wayId).physicalAddress(1) @@ (cache(wayId).superPage ? port.bus.cmd.last.virtualAddress(21 downto 12) | cache(wayId).physicalAddress(0)) @@ port.bus.cmd.last.virtualAddress(11 downto 0)
+ }
+
+ // Avoid keeping any invalid line in the cache after an exception.
+ // https://github.com/riscv/riscv-linux/blob/8fe28cb58bcb235034b64cbbb7550a8a43fd88be/arch/riscv/include/asm/pgtable.h#L276
+ when(service(classOf[IContextSwitching]).isContextSwitching) {
+ for (line <- cache) {
+ when(line.exception) {
+ line.valid := False
+ }
+ }
+ }
+ }
+
+ val shared = new Area {
+ val State = new SpinalEnum{
+ val IDLE, L1_CMD, L1_RSP, L0_CMD, L0_RSP = newElement()
+ }
+ val state = RegInit(State.IDLE)
+ val vpn = Reg(Vec(UInt(10 bits), UInt(10 bits)))
+ val portSortedOh = Reg(Bits(portsInfo.length bits))
+ case class PTE() extends Bundle {
+ val V, R, W ,X, U, G, A, D = Bool()
+ val RSW = Bits(2 bits)
+ val PPN0 = UInt(10 bits)
+ val PPN1 = UInt(12 bits)
+ }
+
+ val dBusRspStaged = dBusAccess.rsp.stage()
+ val dBusRsp = new Area{
+ val pte = PTE()
+ pte.assignFromBits(dBusRspStaged.data)
+ val exception = !pte.V || (!pte.R && pte.W) || dBusRspStaged.error
+ val leaf = pte.R || pte.X
+ }
+
+ val pteBuffer = RegNextWhen(dBusRsp.pte, dBusRspStaged.valid && !dBusRspStaged.redo)
+
+ dBusAccess.cmd.valid := False
+ dBusAccess.cmd.write := False
+ dBusAccess.cmd.size := 2
+ dBusAccess.cmd.address.assignDontCare()
+ dBusAccess.cmd.data.assignDontCare()
+ dBusAccess.cmd.writeMask.assignDontCare()
+
+ val refills = OHMasking.last(B(ports.map(port => port.handle.bus.cmd.last.isValid && port.requireMmuLockup && !port.dirty && !port.cacheHit)))
+ switch(state){
+ is(State.IDLE){
+ when(refills.orR){
+ portSortedOh := refills
+ state := State.L1_CMD
+ val address = MuxOH(refills, sortedPortsInfo.map(_.bus.cmd.last.virtualAddress))
+ vpn(1) := address(31 downto 22)
+ vpn(0) := address(21 downto 12)
+ }
+// for(port <- portsInfo.sortBy(_.priority)){
+// when(port.bus.cmd.isValid && port.bus.rsp.refilling){
+// vpn(1) := port.bus.cmd.virtualAddress(31 downto 22)
+// vpn(0) := port.bus.cmd.virtualAddress(21 downto 12)
+// portId := port.id
+// state := State.L1_CMD
+// }
+// }
+ }
+ is(State.L1_CMD){
+ dBusAccess.cmd.valid := True
+ dBusAccess.cmd.address := csr.satp.ppn @@ vpn(1) @@ U"00"
+ when(dBusAccess.cmd.ready){
+ state := State.L1_RSP
+ }
+ }
+ is(State.L1_RSP){
+ when(dBusRspStaged.valid){
+ state := State.L0_CMD
+ when(dBusRsp.leaf || dBusRsp.exception){
+ state := State.IDLE
+ }
+ when(dBusRspStaged.redo){
+ state := State.L1_CMD
+ }
+ }
+ }
+ is(State.L0_CMD){
+ dBusAccess.cmd.valid := True
+ dBusAccess.cmd.address := pteBuffer.PPN1(9 downto 0) @@ pteBuffer.PPN0 @@ vpn(0) @@ U"00"
+ when(dBusAccess.cmd.ready){
+ state := State.L0_RSP
+ }
+ }
+ is(State.L0_RSP){
+ when(dBusRspStaged.valid) {
+ state := State.IDLE
+ when(dBusRspStaged.redo){
+ state := State.L0_CMD
+ }
+ }
+ }
+ }
+
+ for((port, id) <- sortedPortsInfo.zipWithIndex) {
+ port.bus.busy := state =/= State.IDLE && portSortedOh(id)
+ }
+
+ when(dBusRspStaged.valid && !dBusRspStaged.redo && (dBusRsp.leaf || dBusRsp.exception)){
+ for((port, id) <- ports.zipWithIndex) {
+ when(portSortedOh(id)) {
+ port.entryToReplace.increment()
+ if(port.handle.args.earlyRequireMmuLockup) {
+ port.dirty := True
+ } //Avoid having non coherent TLB lookup
+ for ((line, lineId) <- port.cache.zipWithIndex) {
+ when(port.entryToReplace === lineId){
+ val superPage = state === State.L1_RSP
+ line.valid := True
+ line.exception := dBusRsp.exception || (superPage && dBusRsp.pte.PPN0 =/= 0)
+ line.virtualAddress := vpn
+ line.physicalAddress := Vec(dBusRsp.pte.PPN0, dBusRsp.pte.PPN1(9 downto 0))
+ line.allowRead := dBusRsp.pte.R
+ line.allowWrite := dBusRsp.pte.W
+ line.allowExecute := dBusRsp.pte.X
+ line.allowUser := dBusRsp.pte.U
+ line.superPage := state === State.L1_RSP
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ val fenceStage = execute
+
+ //Both SFENCE_VMA and SATP reschedule the next instruction in the CsrPlugin itself with one extra cycle to ensure side effect propagation.
+ fenceStage plug new Area{
+ import fenceStage._
+ when(arbitration.isValid && arbitration.isFiring && input(IS_SFENCE_VMA2)){
+ for(port <- core.ports; line <- port.cache) line.valid := False
+ }
+
+ csrService.onWrite(CSR.SATP){
+ for(port <- core.ports; line <- port.cache) line.valid := False
+ }
+ }
+ }
+}