Commit bc434750 authored by Yuxiao Mao's avatar Yuxiao Mao
Browse files

DetectRop: simple heuristic based on jump, add connections. Add isMonitoring in Example.

parent 0c21755f
...@@ -44,4 +44,12 @@ class AnalyzeInst()(implicit mp: MatanaParams) extends MemoryOpConstants { ...@@ -44,4 +44,12 @@ class AnalyzeInst()(implicit mp: MatanaParams) extends MemoryOpConstants {
// Note: LW load 32-bit word, LD load 64-bit (rv64 only). Idea from RocketCore/ex_reg_mem_size // Note: LW load 32-bit word, LD load 64-bit (rv64 only). Idea from RocketCore/ex_reg_mem_size
val isLoadWord = Wire(Bool()) val isLoadWord = Wire(Bool())
isLoadWord := in.valid && ctrlsig.mem && ((ctrlsig.mem_cmd === M_XRD) && (in.data(14, 13) === 1.U)) isLoadWord := in.valid && ctrlsig.mem && ((ctrlsig.mem_cmd === M_XRD) && (in.data(14, 13) === 1.U))
// Jalr: Indirect jump based on register, including Ret (jalr ra)
val isJalr = Wire(Bool())
isJalr := in.valid && ctrlsig.jalr
// Jump: any jump including br, jal, jalr
val isJump = Wire(Bool())
isJump := in.valid && (ctrlsig.branch || ctrlsig.jal || ctrlsig.jalr)
} }
...@@ -16,10 +16,11 @@ case class MatanaParams( ...@@ -16,10 +16,11 @@ case class MatanaParams(
XLen: Int = 32, XLen: Int = 32,
haveCFlush: Boolean = false, haveCFlush: Boolean = false,
pidWidth: Int = 32, pidWidth: Int = 32,
dExample: Option[DetectExampleParams] = None, dExample: Option[DetectExampleParams] = None, //Some(DetectExampleParams()),
dPatternTimer: Option[DetectPatternTimerParams] = Some(DetectPatternTimerParams()), dPatternTimer: Option[DetectPatternTimerParams] = Some(DetectPatternTimerParams()),
dPatternPrime: Option[DetectPatternPrimeParams] = Some(DetectPatternPrimeParams()), dPatternPrime: Option[DetectPatternPrimeParams] = Some(DetectPatternPrimeParams()),
dPatternLibAccess: Option[DetectPatternLibAccessParams] = Some(DetectPatternLibAccessParams()), dPatternLibAccess: Option[DetectPatternLibAccessParams] = Some(DetectPatternLibAccessParams()),
dPatternRop: Option[DetectPatternRopParams] = Some(DetectPatternRopParams()),
counterWidth: Int = 32 counterWidth: Int = 32
) { ) {
require(isPow2(clockDiv)) require(isPow2(clockDiv))
......
...@@ -26,7 +26,7 @@ class DetectExampleInternal(params: DetectExampleParams)(implicit mp: MatanaPara ...@@ -26,7 +26,7 @@ class DetectExampleInternal(params: DetectExampleParams)(implicit mp: MatanaPara
// Logic that detect the attack // Logic that detect the attack
val atk_example = Wire(Bool()) val atk_example = Wire(Bool())
atk_example := in.some_atk atk_example := in.some_atk && in.isMonitoring
// The regmap for bus // The regmap for bus
override def regmap(offset: Int) = override def regmap(offset: Int) =
......
package matana
import chisel3._
import chisel3.util.Cat
import freechips.rocketchip.regmapper.RegField
case class DetectPatternRopParams(
withJalr: Boolean = true,
withMispredict: Boolean = true
)
class DetectPatternRopIn()(implicit val mp: MatanaParams) extends Bundle {
val resetCounters = Bool()
val isMonitoring = Bool()
val pack_has_valid = Bool()
val pack_has_jalr = Bool()
val pack_has_jump = Bool()
//val pack_has_mispredict = Bool() //TODO
}
class DetectPatternRop()(implicit mp: MatanaParams) {
val in = Wire(new DetectPatternRopIn())
def regmap(offset: Int): Seq[(Int, Seq[RegField])] = Nil
}
class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: MatanaParams) extends DetectPatternRop {
// Attack Pattern
// Description: Jalr chained short instruction sequence
// 1) jalr + jalr = c+1, other + jalr = c-2 (window size 1)(thresh proably 10 as in mispredict1)
// 2) jalr + jalr = c+1, other + jalr = c-2 (window size 2)
// 3) jalr + jalr = c+1, other jump = c-2 (window size 1)
// 4) jalr + jalr = c+1, other jump = c-2 (window size 2)
// Note:
// 1) clockDiv will impact the pattern on window size, smaller clockDiv is perhaps more accurate.
// 2) ROP gadget destination is likely not in cache, so jalr will take longer and will probably be visible in different clockDiv pack. But this property may increase false positive as legal chain will now seems to be shorter.
// 3) 1 pattern is a complet attack, threshold, if exists should be set to 1
val npackSizeMax = 2
val npack_jalr = RegInit(0.U(npackSizeMax.W)).suggestName("npack_jalr")
when (in.pack_has_valid) {
npack_jalr := Cat(npack_jalr(npackSizeMax - 2, 0), in.pack_has_jalr)
}
val countjalr1 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr1")
val countjalr2 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr2")
val countjalr3 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr3")
val countjalr4 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr4")
when (in.resetCounters) {
countjalr1 := 0.U
countjalr2 := 0.U
countjalr3 := 0.U
countjalr4 := 0.U
}.otherwise {
when (in.isMonitoring && in.pack_has_jalr) { // evaluate counter1,2 when jalr (include pack_has_valid)
countjalr1 := Mux(npack_jalr(0),
countjalr1 + 1.U,
Mux(countjalr1 >= 2.U, countjalr1 - 2.U, 0.U))
countjalr2 := Mux(npack_jalr(1,0).orR,
countjalr2 + 1.U,
Mux(countjalr2 >= 2.U, countjalr2 - 2.U, 0.U))
}
when (in.isMonitoring && in.pack_has_jump) { // evaluate counter3,4 when jump (include pack_has_valid)
countjalr3 := Mux(in.pack_has_jalr && npack_jalr(0),
countjalr3 + 1.U,
Mux(countjalr3 >= 2.U, countjalr3 - 2.U, 0.U))
countjalr4 := Mux(in.pack_has_jalr && npack_jalr(1,0).orR,
countjalr4 + 1.U,
Mux(countjalr4 >= 2.U, countjalr4 - 2.U, 0.U))
}
}
// Attack Pattern
// Original paper: PMU-extended Hardware ROP Attack Detection, Li et al., 2018
// Description: Mispredicted chain
// 0) Original (clockDiv=1): length <= 6 c+1, length > 6 c-2, thresh 10
// 1) mispredict and last 1 valid is mispredict c+1
// mispredict and last 1 valid is not mispredict c-2
// (thresh 10)
// Require: BTB //TODO verify
val atk_rop_mispredict = Wire(Bool())
atk_rop_mispredict := false.B && in.isMonitoring // TODO
override def regmap(offset: Int) =
RegmapUtil.readValueMax(countjalr1, in.resetCounters, mp.counterWidth, offset + 0x0, "CountRopJalr1") ++
RegmapUtil.readValueMax(countjalr2, in.resetCounters, mp.counterWidth, offset + 0x8, "CountRopJalr2") ++
RegmapUtil.readValueMax(countjalr3, in.resetCounters, mp.counterWidth, offset + 0x10, "CountRopJalr3") ++
RegmapUtil.readValueMax(countjalr4, in.resetCounters, mp.counterWidth, offset + 0x18, "CountRopJalr4") ++
Nil
}
object DetectPatternRop {
def apply()(implicit mp: MatanaParams) : DetectPatternRop = {
mp.dPatternRop match {
case Some(params) => new DetectPatternRopInternal(params)
case None => new DetectPatternRop()
}
}
}
...@@ -93,7 +93,9 @@ class MatanaSlowDetectionImp(outer: MatanaSlowDetection, params: MatanaParams) ...@@ -93,7 +93,9 @@ class MatanaSlowDetectionImp(outer: MatanaSlowDetection, params: MatanaParams)
aInst.in.valid := valid aInst.in.valid := valid
// Analysis result of each instruction in the current pack // Analysis result of each instruction in the current pack
Seq(valid, aInst.isTimer, aInst.isFlushL1D, aInst.isLoadWord) Seq(valid,
aInst.isTimer, aInst.isFlushL1D, aInst.isLoadWord, aInst.isJalr,
aInst.isJump)
} }
// If at least one instruction has the instruction pattern in pack // If at least one instruction has the instruction pattern in pack
...@@ -101,11 +103,15 @@ class MatanaSlowDetectionImp(outer: MatanaSlowDetection, params: MatanaParams) ...@@ -101,11 +103,15 @@ class MatanaSlowDetectionImp(outer: MatanaSlowDetection, params: MatanaParams)
val pack_has_timer = RegInit(false.B) val pack_has_timer = RegInit(false.B)
val pack_has_flush = RegInit(false.B) val pack_has_flush = RegInit(false.B)
val pack_has_loadword = RegInit(false.B) val pack_has_loadword = RegInit(false.B)
val pack_has_jalr = RegInit(false.B)
val pack_has_jump = RegInit(false.B)
// when isMonitoring, return actual value // when isMonitoring, return actual value
pack_has_valid := Mux(isMonitoring, pack_analysis.map(_(0)).reduce(_||_), false.B) pack_has_valid := Mux(isMonitoring, pack_analysis.map(_(0)).reduce(_||_), false.B)
pack_has_timer := Mux(isMonitoring, pack_analysis.map(_(1)).reduce(_||_), false.B) pack_has_timer := Mux(isMonitoring, pack_analysis.map(_(1)).reduce(_||_), false.B)
pack_has_flush := Mux(isMonitoring, pack_analysis.map(_(2)).reduce(_||_), false.B) pack_has_flush := Mux(isMonitoring, pack_analysis.map(_(2)).reduce(_||_), false.B)
pack_has_loadword := Mux(isMonitoring, pack_analysis.map(_(3)).reduce(_||_), false.B) pack_has_loadword := Mux(isMonitoring, pack_analysis.map(_(3)).reduce(_||_), false.B)
pack_has_jalr := Mux(isMonitoring, pack_analysis.map(_(4)).reduce(_||_), false.B)
pack_has_jump := Mux(isMonitoring, pack_analysis.map(_(5)).reduce(_||_), false.B)
//-------------------------------------------------------------- //--------------------------------------------------------------
...@@ -127,6 +133,13 @@ class MatanaSlowDetectionImp(outer: MatanaSlowDetection, params: MatanaParams) ...@@ -127,6 +133,13 @@ class MatanaSlowDetectionImp(outer: MatanaSlowDetection, params: MatanaParams)
detectPatternPrime.in.pack_has_valid := pack_has_valid detectPatternPrime.in.pack_has_valid := pack_has_valid
detectPatternPrime.in.pack_has_loadword := pack_has_loadword detectPatternPrime.in.pack_has_loadword := pack_has_loadword
val detectPatternRop = DetectPatternRop()
detectPatternRop.in.resetCounters := resetCounters
detectPatternRop.in.isMonitoring := isMonitoring
detectPatternRop.in.pack_has_valid := pack_has_valid
detectPatternRop.in.pack_has_jalr := pack_has_jalr
detectPatternRop.in.pack_has_jump := pack_has_jump
// Test PC // Test PC
val lastPC = RegInit(0.U(mp.addrWidth.W)) val lastPC = RegInit(0.U(mp.addrWidth.W))
lastPC := Mux(pack0.one_pc_valid, pack0.one_pc_data, lastPC) lastPC := Mux(pack0.one_pc_valid, pack0.one_pc_data, lastPC)
...@@ -181,8 +194,11 @@ class MatanaSlowDetectionImp(outer: MatanaSlowDetection, params: MatanaParams) ...@@ -181,8 +194,11 @@ class MatanaSlowDetectionImp(outer: MatanaSlowDetection, params: MatanaParams)
RegmapUtil.countEvent(pack_has_flush, resetCounters, 0x108, "PackHasFlushL1D") ++ RegmapUtil.countEvent(pack_has_flush, resetCounters, 0x108, "PackHasFlushL1D") ++
RegmapUtil.countEvent(pack0.one_dmemaddr_valid, resetCounters, 0x10C, "DmemAddrValid") ++ RegmapUtil.countEvent(pack0.one_dmemaddr_valid, resetCounters, 0x10C, "DmemAddrValid") ++
RegmapUtil.countEvent(true.B, resetCounters, 0x110, "SlowCycle") ++ RegmapUtil.countEvent(true.B, resetCounters, 0x110, "SlowCycle") ++
RegmapUtil.countEvent(pack_has_jalr, resetCounters, 0x114, "PackHasJalr") ++
RegmapUtil.countEvent(pack_has_jump, resetCounters, 0x118, "PackHasJump") ++
detectPatternTimer.regmap(0x200) ++ detectPatternTimer.regmap(0x200) ++
detectPatternPrime.regmap(0x300) ++ detectPatternPrime.regmap(0x300) ++
detectPatternRop.regmap(0x400) ++
// Example of empty list // Example of empty list
Nil Nil
......
...@@ -102,7 +102,6 @@ object RegmapUtil { ...@@ -102,7 +102,6 @@ object RegmapUtil {
countEventThreshAlarm(in, reset, false.B, regmapOffset, regName) countEventThreshAlarm(in, reset, false.B, regmapOffset, regName)
def controlByBus(in: UInt, regWidth: Int, regmapOffset: Int, regName: String, resetVal: Option[BigInt] = Some(0))(implicit mp: MatanaParams) : Seq[(Int, Seq[RegField])] = { def controlByBus(in: UInt, regWidth: Int, regmapOffset: Int, regName: String, resetVal: Option[BigInt] = Some(0))(implicit mp: MatanaParams) : Seq[(Int, Seq[RegField])] = {
val regmap = Seq( val regmap = Seq(
regmapOffset -> Seq( regmapOffset -> Seq(
RegField(regWidth, in, RegField(regWidth, in,
...@@ -112,4 +111,33 @@ object RegmapUtil { ...@@ -112,4 +111,33 @@ object RegmapUtil {
regmap regmap
} }
def readValue(in: UInt, regWidth: Int, regmapOffset:Int, regName: String, resetVal: Option[BigInt] = Some(0))(implicit mp: MatanaParams) : Seq[(Int, Seq[RegField])] = {
val regmap = Seq(
regmapOffset -> Seq(
RegField.r(regWidth, in,
RegFieldDesc(regName, regName + " value.", reset=resetVal)))
)
regmap
}
// Useful for counter that may decrement
def readValueMax(in: UInt, reset: Bool, regWidth: Int, regmapOffset:Int, regName: String, resetVal: Option[BigInt] = Some(0))(implicit mp: MatanaParams) : Seq[(Int, Seq[RegField])] = {
require(regWidth <= 32) // as regmapOffset +0x4
val maxValue = RegInit(0.U(regWidth.W)).suggestName("max" + regName)
maxValue := Mux(reset, 0.U, Mux(in > maxValue, in, maxValue))
val regmap = Seq(
regmapOffset -> Seq(
RegField.r(regWidth, in,
RegFieldDesc(regName, regName + " value.", reset=resetVal))),
regmapOffset + 0x4 -> Seq(
RegField.r(regWidth, maxValue,
RegFieldDesc("Max" + regName, "Max of " + regName + " value.", reset=Some(0))))
)
regmap
}
} }
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment