Commit 664cf4bb authored by Yuxiao Mao's avatar Yuxiao Mao
Browse files

src: separate RehadTop and RehadSlowDetection into different files. gen Verilog ok

parent 93f48c00
package rehad
import chisel3._
import chisel3.util.Cat
import chisel3.experimental.chiselName
import freechips.rocketchip.config.Parameters
import freechips.rocketchip.diplomacy.{LazyModule, LazyModuleImp, SimpleDevice, AddressSet, Resource, RationalCrossing}
import freechips.rocketchip.interrupts.{IntSourceNode, IntSourcePortSimple, IntOutwardCrossingHelper}
import freechips.rocketchip.diplomacy.{LazyModule, LazyModuleImp, RationalCrossing}
import freechips.rocketchip.interrupts.IntOutwardCrossingHelper
import freechips.rocketchip.prci.ClockSinkDomain
import freechips.rocketchip.regmapper.{RegField, RegFieldDesc, RegFieldRdAction}
import freechips.rocketchip.rocket.{IntCtrlSigs, CSR, CSRs}
import freechips.rocketchip.rocket.constants.MemoryOpConstants
import freechips.rocketchip.subsystem.BaseSubsystem
import freechips.rocketchip.tile.RehadCoreIO
import freechips.rocketchip.tilelink.{TLRegisterNode,TLInwardCrossingHelper}
import freechips.rocketchip.tilelink.TLInwardCrossingHelper
import freechips.rocketchip.util.SlowToFast
class RehadSlowDetectionImp(outer: RehadSlowDetection, params: RehadParams)
extends LazyModuleImp(outer)
with MemoryOpConstants
{
val io = IO(new Bundle {
val pack = Vec(params.clockDiv, Flipped(new RehadCoreIO))
})
//--------------------------------------------------------------
// Instruction analysis
//--------------------------------------------------------------
// Portion of RocketCore/decode_table
val decode_table = RehadDecode.table()
// Analyze each instruction of pack using decode_table
val pack_analysis: Seq[Seq[Bool]] = io.pack.map { core =>
val data = core.inst_data
val valid = core.inst_valid
val ctrlsig = Wire(new IntCtrlSigs()).decode(data, decode_table)
// Valid: Any of the clockDiv data is valid
val inst_is_valid = valid
// Timer: CSR get timer
// Use: Set mcounteren, scounteren to access CSR counters in User Mode.
// Note: Rocket-chip does not support rdtime and rdtimeh.
val inst_is_timer = (
valid && ctrlsig.csr =/= CSR.N && (
data(31,20) === CSRs.time.U || data(31,20) === CSRs.cycle.U || data(31,20) === CSRs.instret.U ||
data(31,20) === CSRs.timeh.U || data(31,20) === CSRs.cycleh.U || data(31,20) === CSRs.instreth.U ||
data(31,20) === CSRs.mcycle.U || data(31,20) === CSRs.minstret.U ||
data(31,20) === CSRs.mcycleh.U || data(31,20) === CSRs.minstreth.U))
// Flush L1 DCache (with core config hasCFlush)
// Use: Must call Flush in Machine Mode, as they are consider as CSR of Machine level.
// cf rocket-chip.rocket.Instructions.CFLUSH_D_L1, CDISCARD_D_L1.
// Note: Flush will be valid for 1+1 cycles (distance 5) in wb stage (2+2 cycles in mem stage).
val inst_is_flush = valid && ctrlsig.mem && ctrlsig.mem_cmd === M_FLUSH_ALL
// Flush L1 ICache: io.imem.flush_icache := wb_reg_valid && wb_ctrl.fence_i && !io.dmem.s2_nack
// Flush L2 Sifive Inclusive cache: 0x201_0240 "Flush32" flush physical address data << 4
// Analysis result of each instruction in the current pack
Seq(inst_is_valid, inst_is_timer, inst_is_flush)
}
//--------------------------------------------------------------
// Pack analysis
//--------------------------------------------------------------
// If at least one instruction has the instruction pattern in pack
val pack_has_valid = RegInit(false.B)
val pack_has_timer = RegInit(false.B)
val pack_has_flush = RegInit(false.B)
pack_has_valid := pack_analysis.map(_(0)).reduce(_||_)
pack_has_timer := pack_analysis.map(_(1)).reduce(_||_)
pack_has_flush := pack_analysis.map(_(2)).reduce(_||_)
//--------------------------------------------------------------
// NPack analysis
//--------------------------------------------------------------
// Registed (valid) pack analysis result for multiple slow cycles
val npackSize: Int = 16
val npack_timer = RegInit(0.U(npackSize.W))
val npack_flush = RegInit(0.U(npackSize.W))
when (pack_has_valid) {
npack_timer := Cat(npack_timer(npackSize-1, 1), pack_has_timer)
npack_flush := Cat(npack_flush(npackSize-1, 1), pack_has_flush)
}
// Detect cache side-channel attack
val atk_timer_timer = Wire(Bool())
val atk_timer_flush = Wire(Bool())
atk_timer_timer := npack_timer.orR && pack_has_timer
atk_timer_flush := npack_timer.orR && pack_has_flush
//--------------------------------------------------------------
// Memory-Mapped registers: config, counters, threshold ...
//--------------------------------------------------------------
val counterSize = params.counterWidth
val regTest = RegInit(0.U(counterSize.W))
val intEnable = RegInit(false.B)
// Pack event counter
val countPackValid = RegInit(0.U(counterSize.W))
val countPackTimer = RegInit(0.U(counterSize.W))
val countPackFlush = RegInit(0.U(counterSize.W))
countPackValid := countPackValid + pack_has_valid
countPackTimer := countPackTimer + pack_has_timer
countPackFlush := countPackFlush + pack_has_flush
// Attack event counter
val countAtkTimerTimer = RegInit(0.U(counterSize.W))
val countAtkTimerFlush = RegInit(0.U(counterSize.W))
countAtkTimerTimer := countAtkTimerTimer + atk_timer_timer
countAtkTimerFlush := countAtkTimerFlush + atk_timer_flush
// Attack detect threshold
val threshAtkTimerTimer = RegInit(0.U(counterSize.W))
// Attack alarm when counter >= threshold (reset by TL)
val alarmAtkTimerTimer = RegInit(false.B)
when ((threshAtkTimerTimer =/= 0.U) && (countAtkTimerTimer >= threshAtkTimerTimer)) {
alarmAtkTimerTimer := true.B
}
//--------------------------------------------------------------
// TileLink and Interrupts connections
//--------------------------------------------------------------
// TileLink MMIO RegMap
def regmap(mapping: RegField.Map*) = outer.node.regmap(mapping:_*)
regmap(
0x0 -> Seq(
RegField(counterSize, regTest, RegFieldDesc("regTest", "A simple test register."))),
0x4 -> Seq(
RegField(1, intEnable, RegFieldDesc("intEnable", "Interrupt enable.", reset=Some(0)))),
0x100 -> Seq(
RegField.r(counterSize, countPackValid,
RegFieldDesc("countPackValid", "Count if any instructions in current pack is valid."))),
0x104 -> Seq(
RegField.r(counterSize, countPackTimer,
RegFieldDesc("countPackTimer", "Count if any instructions in current pack is a Timer instruction."))),
0x108 -> Seq(
RegField.r(counterSize, countPackFlush,
RegFieldDesc("countPackFlush", "Count if any instructions in current pack is a L1 Cache Flush instruction."))),
0x200 -> Seq(
RegField.r(counterSize, countAtkTimerTimer,
RegFieldDesc("countAtkTimerTimer", "Count if any timer+timer attack pattern appear."))),
0x204 -> Seq(
RegField.r(counterSize, countAtkTimerFlush,
RegFieldDesc("countAtkTimerFlush", "Count if any timer+flush attack pattern appear."))),
0x300 -> Seq(
RegField(counterSize, threshAtkTimerTimer,
RegFieldDesc("threshAtkTimerTimer", "Threshold for countAtkTimerTimer. Must be > 0 to enable."))),
0x400 -> Seq(
RegField.r(1, alarmAtkTimerTimer,
RegFieldDesc("alarmAtkTimerTimer", "Alarm for countAtkTimerTimer reach threshold. Write to clear.",
rdAction=Some(RegFieldRdAction.CLEAR)))),
)
// RehadTop, connect cores to RehadSlowDetection.
// Contains all cross clock domain connection (clock, TLbus, interrupts, direct sequential to parallel link) between the processor and RehadSlowDetection module.
// Interrupts
val interrupts = if (outer.intnode.out.isEmpty) Vec(0, Bool()) else outer.intnode.out(0)._1
interrupts(0) := intEnable && (regTest >= 1024.U)
}
class RehadSlowDetection(params: RehadParams, beatBytes: Int)(implicit p: Parameters)
class RehadTopModule(params: RehadParams, beatBytes: Int)(implicit p: Parameters)
extends LazyModule
{
lazy val module = new RehadSlowDetectionImp(this, params)
lazy val module = new RehadTopModuleImp(this, params)
// Detection Module
val rehad_slow_detection = LazyModule(new RehadSlowDetection(params, beatBytes))
// Clock divider and slow clock domain
val clockdiv = LazyModule(new RehadClockDivider(params.clockDiv))
val rehad_slow_domain = LazyModule(new ClockSinkDomain(take = None))
rehad_slow_domain.clockNode := clockdiv.node
// TileLink and Interrupts node
val device = new SimpleDevice("rehad", Seq("laas,rehad"))
val node = TLRegisterNode(
address = Seq(AddressSet(params.address, 0xfff)),
device = device,
beatBytes = beatBytes)
val intnode = IntSourceNode(
IntSourcePortSimple(
num = 1,
resources = Seq(Resource(device, "int"))))
// TileLink and Interrupt crossing to slow clock domain
val node = TLInwardCrossingHelper("rehadXnode", rehad_slow_domain, rehad_slow_detection.node)(RationalCrossing(SlowToFast))
val intnode = IntOutwardCrossingHelper("rehadXint", rehad_slow_domain, rehad_slow_detection.intnode)(RationalCrossing(SlowToFast))
}
@chiselName
......@@ -210,23 +57,8 @@ class RehadTopModuleImp(outer: RehadTopModule, params: RehadParams)
}
}
class RehadTopModule(params: RehadParams, beatBytes: Int)(implicit p: Parameters)
extends LazyModule
{
lazy val module = new RehadTopModuleImp(this, params)
// Detection Module
val rehad_slow_detection = LazyModule(new RehadSlowDetection(params, beatBytes))
// Clock divider and slow clock domain
val clockdiv = LazyModule(new RehadClockDivider(params.clockDiv))
val rehad_slow_domain = LazyModule(new ClockSinkDomain(take = None))
rehad_slow_domain.clockNode := clockdiv.node
// TileLink and Interrupt crossing to slow clock domain
val node = TLInwardCrossingHelper("rehadXnode", rehad_slow_domain, rehad_slow_detection.node)(RationalCrossing(SlowToFast))
val intnode = IntOutwardCrossingHelper("rehadXint", rehad_slow_domain, rehad_slow_detection.intnode)(RationalCrossing(SlowToFast))
}
// Use these trait in a BaseSubsystem to include a Rehad module.
trait CanHavePeripheryRehad { this: BaseSubsystem =>
private val portName = "rehad"
......
package rehad
import chisel3._
import chisel3.util.Cat
import freechips.rocketchip.config.Parameters
import freechips.rocketchip.diplomacy.{LazyModule, LazyModuleImp, SimpleDevice, AddressSet, Resource}
import freechips.rocketchip.interrupts.{IntSourceNode, IntSourcePortSimple}
import freechips.rocketchip.regmapper.{RegField, RegFieldDesc, RegFieldRdAction}
import freechips.rocketchip.rocket.{IntCtrlSigs, CSR, CSRs}
import freechips.rocketchip.rocket.constants.MemoryOpConstants
import freechips.rocketchip.tile.RehadCoreIO
import freechips.rocketchip.tilelink.TLRegisterNode
// RehadSlowDetection, work with a single slow clock.
class RehadSlowDetection(params: RehadParams, beatBytes: Int)(implicit p: Parameters)
extends LazyModule
{
lazy val module = new RehadSlowDetectionImp(this, params)
// TileLink and Interrupts node
val device = new SimpleDevice("rehad", Seq("laas,rehad"))
val node = TLRegisterNode(
address = Seq(AddressSet(params.address, 0xfff)),
device = device,
beatBytes = beatBytes)
val intnode = IntSourceNode(
IntSourcePortSimple(
num = 1,
resources = Seq(Resource(device, "int"))))
}
class RehadSlowDetectionImp(outer: RehadSlowDetection, params: RehadParams)
extends LazyModuleImp(outer)
with MemoryOpConstants
{
val io = IO(new Bundle {
val pack = Vec(params.clockDiv, Flipped(new RehadCoreIO))
})
//--------------------------------------------------------------
// Instruction analysis
//--------------------------------------------------------------
// Portion of RocketCore/decode_table
val decode_table = RehadDecode.table()
// Analyze each instruction of pack using decode_table
val pack_analysis: Seq[Seq[Bool]] = io.pack.map { core =>
val data = core.inst_data
val valid = core.inst_valid
val ctrlsig = Wire(new IntCtrlSigs()).decode(data, decode_table)
// Valid: Any of the clockDiv data is valid
val inst_is_valid = valid
// Timer: CSR get timer
// Use: Set mcounteren, scounteren to access CSR counters in User Mode.
// Note: Rocket-chip does not support rdtime and rdtimeh.
val inst_is_timer = (
valid && ctrlsig.csr =/= CSR.N && (
data(31,20) === CSRs.time.U || data(31,20) === CSRs.cycle.U || data(31,20) === CSRs.instret.U ||
data(31,20) === CSRs.timeh.U || data(31,20) === CSRs.cycleh.U || data(31,20) === CSRs.instreth.U ||
data(31,20) === CSRs.mcycle.U || data(31,20) === CSRs.minstret.U ||
data(31,20) === CSRs.mcycleh.U || data(31,20) === CSRs.minstreth.U))
// Flush L1 DCache (with core config hasCFlush)
// Use: Must call Flush in Machine Mode, as they are consider as CSR of Machine level.
// cf rocket-chip.rocket.Instructions.CFLUSH_D_L1, CDISCARD_D_L1.
// Note: Flush will be valid for 1+1 cycles (distance 5) in wb stage (2+2 cycles in mem stage).
val inst_is_flush = valid && ctrlsig.mem && ctrlsig.mem_cmd === M_FLUSH_ALL
// Flush L1 ICache: io.imem.flush_icache := wb_reg_valid && wb_ctrl.fence_i && !io.dmem.s2_nack
// Flush L2 Sifive Inclusive cache: 0x201_0240 "Flush32" flush physical address data << 4
// Analysis result of each instruction in the current pack
Seq(inst_is_valid, inst_is_timer, inst_is_flush)
}
//--------------------------------------------------------------
// Pack analysis
//--------------------------------------------------------------
// If at least one instruction has the instruction pattern in pack
val pack_has_valid = RegInit(false.B)
val pack_has_timer = RegInit(false.B)
val pack_has_flush = RegInit(false.B)
pack_has_valid := pack_analysis.map(_(0)).reduce(_||_)
pack_has_timer := pack_analysis.map(_(1)).reduce(_||_)
pack_has_flush := pack_analysis.map(_(2)).reduce(_||_)
//--------------------------------------------------------------
// NPack analysis
//--------------------------------------------------------------
// Registed (valid) pack analysis result for multiple slow cycles
val npackSize: Int = 16
val npack_timer = RegInit(0.U(npackSize.W))
val npack_flush = RegInit(0.U(npackSize.W))
when (pack_has_valid) {
npack_timer := Cat(npack_timer(npackSize-1, 1), pack_has_timer)
npack_flush := Cat(npack_flush(npackSize-1, 1), pack_has_flush)
}
// Detect cache side-channel attack
val atk_timer_timer = Wire(Bool())
val atk_timer_flush = Wire(Bool())
atk_timer_timer := npack_timer.orR && pack_has_timer
atk_timer_flush := npack_timer.orR && pack_has_flush
//--------------------------------------------------------------
// Memory-Mapped registers: config, counters, threshold ...
//--------------------------------------------------------------
val counterSize = params.counterWidth
val regTest = RegInit(0.U(counterSize.W))
val intEnable = RegInit(false.B)
// Pack event counter
val countPackValid = RegInit(0.U(counterSize.W))
val countPackTimer = RegInit(0.U(counterSize.W))
val countPackFlush = RegInit(0.U(counterSize.W))
countPackValid := countPackValid + pack_has_valid
countPackTimer := countPackTimer + pack_has_timer
countPackFlush := countPackFlush + pack_has_flush
// Attack event counter
val countAtkTimerTimer = RegInit(0.U(counterSize.W))
val countAtkTimerFlush = RegInit(0.U(counterSize.W))
countAtkTimerTimer := countAtkTimerTimer + atk_timer_timer
countAtkTimerFlush := countAtkTimerFlush + atk_timer_flush
// Attack detect threshold
val threshAtkTimerTimer = RegInit(0.U(counterSize.W))
// Attack alarm when counter >= threshold (reset by TL)
val alarmAtkTimerTimer = RegInit(false.B)
when ((threshAtkTimerTimer =/= 0.U) && (countAtkTimerTimer >= threshAtkTimerTimer)) {
alarmAtkTimerTimer := true.B
}
//--------------------------------------------------------------
// TileLink and Interrupts connections
//--------------------------------------------------------------
// TileLink MMIO RegMap
def regmap(mapping: RegField.Map*) = outer.node.regmap(mapping:_*)
regmap(
0x0 -> Seq(
RegField(counterSize, regTest, RegFieldDesc("regTest", "A simple test register."))),
0x4 -> Seq(
RegField(1, intEnable, RegFieldDesc("intEnable", "Interrupt enable.", reset=Some(0)))),
0x100 -> Seq(
RegField.r(counterSize, countPackValid,
RegFieldDesc("countPackValid", "Count if any instructions in current pack is valid."))),
0x104 -> Seq(
RegField.r(counterSize, countPackTimer,
RegFieldDesc("countPackTimer", "Count if any instructions in current pack is a Timer instruction."))),
0x108 -> Seq(
RegField.r(counterSize, countPackFlush,
RegFieldDesc("countPackFlush", "Count if any instructions in current pack is a L1 Cache Flush instruction."))),
0x200 -> Seq(
RegField.r(counterSize, countAtkTimerTimer,
RegFieldDesc("countAtkTimerTimer", "Count if any timer+timer attack pattern appear."))),
0x204 -> Seq(
RegField.r(counterSize, countAtkTimerFlush,
RegFieldDesc("countAtkTimerFlush", "Count if any timer+flush attack pattern appear."))),
0x300 -> Seq(
RegField(counterSize, threshAtkTimerTimer,
RegFieldDesc("threshAtkTimerTimer", "Threshold for countAtkTimerTimer. Must be > 0 to enable."))),
0x400 -> Seq(
RegField.r(1, alarmAtkTimerTimer,
RegFieldDesc("alarmAtkTimerTimer", "Alarm for countAtkTimerTimer reach threshold. Write to clear.",
rdAction=Some(RegFieldRdAction.CLEAR)))),
)
// Interrupts
val interrupts = if (outer.intnode.out.isEmpty) Vec(0, Bool()) else outer.intnode.out(0)._1
interrupts(0) := intEnable && (regTest >= 1024.U)
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment