Commit df6b33e5 authored by Yuxiao Mao's avatar Yuxiao Mao
Browse files

Merge branch 'dev': RegmapUtil move some code to class, DetectPrime new...

Merge branch 'dev': RegmapUtil move some code to class, DetectPrime new patterns, DetectLibAccess modify regmap order, add DetectRop based on jump
parents 81405e79 dc4148ef
......@@ -44,4 +44,12 @@ class AnalyzeInst()(implicit mp: MatanaParams) extends MemoryOpConstants {
// Note: LW load 32-bit word, LD load 64-bit (rv64 only). Idea from RocketCore/ex_reg_mem_size
val isLoadWord = Wire(Bool())
isLoadWord := in.valid && ctrlsig.mem && ((ctrlsig.mem_cmd === M_XRD) && (in.data(14, 13) === 1.U))
// Jalr: Indirect jump based on register, including Ret (jalr ra)
val isJalr = Wire(Bool())
isJalr := in.valid && ctrlsig.jalr
// Jump: any jump including br, jal, jalr
val isJump = Wire(Bool())
isJump := in.valid && (ctrlsig.branch || ctrlsig.jal || ctrlsig.jalr)
}
......@@ -16,10 +16,11 @@ case class MatanaParams(
XLen: Int = 32,
haveCFlush: Boolean = false,
pidWidth: Int = 32,
dExample: Option[DetectExampleParams] = None,
dExample: Option[DetectExampleParams] = None, //Some(DetectExampleParams()),
dPatternTimer: Option[DetectPatternTimerParams] = Some(DetectPatternTimerParams()),
dPatternPrime: Option[DetectPatternPrimeParams] = Some(DetectPatternPrimeParams()),
dPatternLibAccess: Option[DetectPatternLibAccessParams] = Some(DetectPatternLibAccessParams()),
dPatternRop: Option[DetectPatternRopParams] = Some(DetectPatternRopParams()),
counterWidth: Int = 32
) {
require(isPow2(clockDiv))
......
......@@ -26,7 +26,7 @@ class DetectExampleInternal(params: DetectExampleParams)(implicit mp: MatanaPara
// Logic that detect the attack
val atk_example = Wire(Bool())
atk_example := in.some_atk
atk_example := in.some_atk && in.isMonitoring
// The regmap for bus
override def regmap(offset: Int) =
......
......@@ -43,11 +43,11 @@ class DetectPatternLibAccessInternal(params: DetectPatternLibAccessParams)(impli
in.pc_data < lib_iaddr_min) // pc outside of lib
override def regmap(offset: Int) =
RegmapUtil.countEventThreshAlarm(atk_lib_access, in.resetCounters, offset, "AtkLibAccess") ++
RegmapUtil.controlByBus(lib_daddr_min, mp.addrWidth, offset + 0x10, "LibDAddrMin") ++
RegmapUtil.controlByBus(lib_daddr_max, mp.addrWidth, offset + 0x18, "LibDAddrMax") ++
RegmapUtil.controlByBus(lib_iaddr_min, mp.addrWidth, offset + 0x20, "LibIAddrMin") ++
RegmapUtil.controlByBus(lib_iaddr_max, mp.addrWidth, offset + 0x28, "LibIAddrMax")
RegmapUtil.controlByBus(lib_daddr_min, mp.addrWidth, offset + 0x00, "LibDAddrMin") ++
RegmapUtil.controlByBus(lib_daddr_max, mp.addrWidth, offset + 0x08, "LibDAddrMax") ++
RegmapUtil.controlByBus(lib_iaddr_min, mp.addrWidth, offset + 0x10, "LibIAddrMin") ++
RegmapUtil.controlByBus(lib_iaddr_max, mp.addrWidth, offset + 0x18, "LibIAddrMax") ++
RegmapUtil.countEventThreshAlarm(atk_lib_access, in.resetCounters, offset + 0x20, "AtkLibAccess")
}
object DetectPatternLibAccess {
......
......@@ -42,16 +42,24 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
// Attack pattern
// Description: Multiples access to DCache, each time in different pages
// 1) diff page c+1, same page c-2 (thresh 8)
// 2) diff page c+1, same page c-2 (thresh 4)
// 3) diff page same set c+2, diff page diff set c+1, same page same set c-2, same page diff set c-1 (thresh 16)
// 4) diff page same set c+2, diff page diff set c+1, same page same set c-2, same page diff set c-1 (thresh 8)
// Description: Multiples access to DCache, each time in same set with differents tag
// 1) diff tag c+1, same tag c-1 (thresh 8)
// 2) diff tag c+1, same tag c-1 (thresh 4)
// 3) diff tag same set c+2, diff tag diff set c+1, same tag same set c-2, same tag diff set c-1 (thresh 16)
// 4) diff tag same set c+2, diff tag diff set c+1, same tag same set c-2, same tag diff set c-1 (thresh 8)
// 5) diff tag c+1, same tag c-1 (thresh 4, if thresh then 0)
// 6) diff tag c+1, same tag c-2 (thresh 4)
// 7) diff tag same set c+2, diff tag diff set c+1, same tag same set c-2, same tag diff set c-1 (thresh 8, if thresh then 0)
// 8) diff tag same set c+2, diff tag diff set c+1, same tag same set c-4, same tag diff set c-2 (thresh 8)
val countThresh1 = 8
val countThresh2 = 4
val countThresh3 = 16
val countThresh4 = 8
val countThresh5 = 4
val countThresh6 = 4
val countThresh7 = 8
val countThresh8 = 8
val dmemtag = Wire(UInt((mp.addrWidth - params.tagAddrLSB).W))
dmemtag := in.dmemaddr_data(mp.addrWidth - 1, params.tagAddrLSB)
......@@ -61,25 +69,35 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
val saved_dmemtag = RegInit(VecInit(Seq.fill(params.savedDMemTagNum)(0.U((mp.addrWidth - params.tagAddrLSB).W)))).suggestName("saved_dmemtag")
val saved_dmemset = RegInit(VecInit(Seq.fill(params.savedDMemSetNum)(0.U((params.tagAddrLSB - params.setAddrLSB).W)))).suggestName("saved_dmemset")
val count1 = RegInit(0.U(log2Ceil(countThresh1 + 1).W)).suggestName("dprimed_count1")
val count2 = RegInit(0.U(log2Ceil(countThresh2 + 1).W)).suggestName("dprimed_count2")
val count3 = RegInit(0.U(log2Ceil(countThresh3 + 1).W)).suggestName("dprimed_count3")
val count4 = RegInit(0.U(log2Ceil(countThresh4 + 1).W)).suggestName("dprimed_count4")
val count1 = RegInit(0.U(log2Ceil(countThresh1 + 1).W)).suggestName("dpprime_count1")
val count2 = RegInit(0.U(log2Ceil(countThresh2 + 1).W)).suggestName("dpprime_count2")
val count3 = RegInit(0.U(log2Ceil(countThresh3 + 1).W)).suggestName("dpprime_count3")
val count4 = RegInit(0.U(log2Ceil(countThresh4 + 1).W)).suggestName("dpprime_count4")
val count5 = RegInit(0.U(log2Ceil(countThresh5 + 1).W)).suggestName("dpprime_count5")
val count6 = RegInit(0.U(log2Ceil(countThresh6 + 1).W)).suggestName("dpprime_count6")
val count7 = RegInit(0.U(log2Ceil(countThresh7 + 1).W)).suggestName("dpprime_count7")
val count8 = RegInit(0.U(log2Ceil(countThresh8 + 1).W)).suggestName("dpprime_count8")
// Update saved_dmem* array in FIFO way
when (in.dmemaddr_valid) {
when (saved_dmemtag.map(_ =/= dmemtag).reduce(_&&_)) { // new page
when (saved_dmemtag.map(_ =/= dmemtag).reduce(_&&_)) { // new tag
when (saved_dmemset.map(_ =/= dmemset).reduce(_&&_)) { // new set
// Update saved_dmemset, newer at 0
saved_dmemset.foldLeft(dmemset) { (data, last) =>
last := data
last
}
count4 := Mux(count4 <= countThresh4.U - 1.U, count4 + 1.U, countThresh4.U)
count3 := Mux(count3 <= countThresh3.U - 1.U, count3 + 1.U, countThresh3.U)
count4 := Mux(count4 <= countThresh4.U - 1.U, count4 + 1.U, countThresh4.U)
count7 := Mux(count7 <= countThresh7.U - 1.U, count7 + 1.U, 1.U)
count8 := Mux(count8 <= countThresh8.U - 1.U, count8 + 1.U, countThresh8.U)
}.otherwise { // old set
count3 := Mux(count3 <= countThresh3.U - 2.U, count3 + 2.U, countThresh3.U)
count4 := Mux(count4 <= countThresh4.U - 2.U, count4 + 2.U, countThresh4.U)
count3 := Mux(count3 <= countThresh3.U - 1.U, count3 + 1.U, countThresh3.U)
count7 := Mux(count7 =/= countThresh7.U,
Mux(count7 <= countThresh7.U - 2.U, count7 + 2.U, countThresh7.U),
2.U)
count8 := Mux(count8 <= countThresh8.U - 2.U, count8 + 2.U, countThresh8.U)
}
// Update saved_dmemtag, newer at 0
saved_dmemtag.foldLeft(dmemtag) { (data, last) =>
......@@ -88,21 +106,35 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
}
count1 := Mux(count1 <= countThresh1.U - 1.U, count1 + 1.U, countThresh1.U)
count2 := Mux(count2 <= countThresh2.U - 1.U, count2 + 1.U, countThresh2.U)
}.otherwise { // old page
count5 := Mux(count5 <= countThresh5.U - 1.U, count5 + 1.U, 1.U)
count6 := Mux(count6 <= countThresh6.U - 1.U, count6 + 1.U, countThresh6.U)
}.otherwise { // old tag
when (saved_dmemset.map(_ =/= dmemtag).reduce(_&&_)) { // new set
// Update saved_dmemset, newer at 0
saved_dmemset.foldLeft(dmemset) { (data, last) =>
last := data
last
}
count4 := Mux(count4 >= 1.U, count4 - 1.U, 0.U)
count3 := Mux(count3 >= 1.U, count3 - 1.U, 0.U)
count4 := Mux(count4 >= 1.U, count4 - 1.U, 0.U)
count7 := Mux(count7 =/= countThresh7.U,
Mux(count7 >= 1.U, count7 - 1.U, 0.U),
0.U)
count8 := Mux(count8 >= 2.U, count8 - 2.U, 0.U)
}.otherwise { // old set
count4 := Mux(count4 >= 2.U, count4 - 2.U, 0.U)
count3 := Mux(count3 >= 2.U, count3 - 2.U, 0.U)
count4 := Mux(count4 >= 2.U, count4 - 2.U, 0.U)
count7 := Mux(count7 =/= countThresh7.U,
Mux(count7 >= 2.U, count7 - 2.U, 0.U),
0.U)
count8 := Mux(count8 >= 4.U, count8 - 4.U, 0.U)
}
count1 := 0.U
count1 := Mux(count1 >= 1.U, count1 - 1.U, 0.U)
count2 := Mux(count2 >= 1.U, count2 - 1.U, 0.U)
count5 := Mux(count5 =/= countThresh5.U,
Mux(count5 >= 1.U, count5 - 1.U, 0.U),
0.U)
count6 := Mux(count6 >= 2.U, count6 - 2.U, 0.U)
}
}
......@@ -115,23 +147,31 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
atk_prime_3 := (count3 === countThresh3.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_4 = Wire(Bool())
atk_prime_4 := (count4 === countThresh4.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_5 = Wire(Bool())
atk_prime_5 := (count5 === countThresh5.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_6 = Wire(Bool())
atk_prime_6 := (count6 === countThresh6.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_7 = Wire(Bool())
atk_prime_7 := (count7 === countThresh7.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_8 = Wire(Bool())
atk_prime_8 := (count8 === countThresh8.U) && in.isMonitoring && in.dmemaddr_valid
// Attack Pattern
// Description: all valid pack contains an word size load
// 1) valid load c+1, valid not load c-1 (thresh 4)
// 2) valid load c+1, valid not load c-1 (thresh 8)
// 3) valid load c+1, valid not load c-1 (thresh 12)
// 1) valid load c+1, valid not load c-1 (thresh 32)
// 2) valid load c+1, valid not load c-2 (thresh 8) (maybe only apply to high clkDiv, as for -O0 there is store with load. we suppose if store to the same addr, the next load can be conclude in the same slow cycle)
// 3) valid load c+1, valid not load c-2 (thresh 16)
// 4) valid load c+1, valid not load c-1 (thresh 16)
val countInstThresh1 = 4
val countInstThresh1 = 32
val countInstThresh2 = 8
val countInstThresh3 = 12
val countInstThresh3 = 16
val countInstThresh4 = 16
val countinst1 = RegInit(0.U(log2Ceil(countInstThresh1 + 1).W)).suggestName("dprimed_countinst1")
val countinst2 = RegInit(0.U(log2Ceil(countInstThresh2 + 1).W)).suggestName("dprimed_countinst2")
val countinst3 = RegInit(0.U(log2Ceil(countInstThresh3 + 1).W)).suggestName("dprimed_countinst3")
val countinst4 = RegInit(0.U(log2Ceil(countInstThresh4 + 1).W)).suggestName("dprimed_countinst4")
val countinst1 = RegInit(0.U(log2Ceil(countInstThresh1 + 1).W)).suggestName("dpprime_countinst1")
val countinst2 = RegInit(0.U(log2Ceil(countInstThresh2 + 1).W)).suggestName("dpprime_countinst2")
val countinst3 = RegInit(0.U(log2Ceil(countInstThresh3 + 1).W)).suggestName("dpprime_countinst3")
val countinst4 = RegInit(0.U(log2Ceil(countInstThresh4 + 1).W)).suggestName("dpprime_countinst4")
when (in.pack_has_valid) {
when (in.pack_has_loadword) {
......@@ -141,8 +181,8 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
countinst4 := Mux(countinst4 <= countInstThresh4.U - 1.U, countinst4 + 1.U, countInstThresh4.U)
}.otherwise {
countinst1 := Mux(countinst1 >= 1.U, countinst1 - 1.U, 0.U)
countinst2 := Mux(countinst2 >= 1.U, countinst2 - 1.U, 0.U)
countinst3 := Mux(countinst3 >= 1.U, countinst3 - 1.U, 0.U)
countinst2 := Mux(countinst2 >= 2.U, countinst2 - 2.U, 0.U)
countinst3 := Mux(countinst3 >= 2.U, countinst3 - 2.U, 0.U)
countinst4 := Mux(countinst4 >= 1.U, countinst4 - 1.U, 0.U)
}
}
......@@ -158,43 +198,116 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
atk_primeinst_4 := (countinst4 === countInstThresh4.U) && in.isMonitoring && in.pack_has_valid
// Mixte pattern inst with atk_prime_4
// Attack Pattern
// Description: Mix pattern
// 1) count4, but decrement the counter by 1 if count_pack_valid reach thresh (0x800)
// => TODO: perhaps adjust using dmemaddr / slow cycle count
// 2) count4+countinst4: if 2 atk are recognized in 4 slow cycles
// 3) count4+countinst4: (similar to mix2 but use 1 counter for both count4 and countinst4)
// count4 max && countinst4 max: c+8
// count4 max || countinst4 max: c+4
// else : c-1 (as c4/ci4 no reset 0, this means no attack sign)
// (max 8, thresh 5)
// 4) count4+countinst4: (similar to mix3 but priorize count4)
// count4 max && countinst4 max: c+12
// count4 max : c+8
// countinst4 max: c+4
// else : c-1
// (max 12, thresh 9)
// 5) count4+countinst4: (only take when c4/ci4=max or c4/ci4=0, if thresh then 0 to prevent cmix remains max) TODO?
val countPackValidThresh = 0x800L // 2048
val count_pack_valid = RegInit(0.U(mp.counterWidth.W)).suggestName("dpprime_count_pack_valid")
val count_pack_valid_dec = RegInit(false.B).suggestName("dpprime_count_pack_valid_dec")
when (in.resetCounters) {
count_pack_valid := 0.U
count_pack_valid_dec := false.B
}.otherwise {
when (count_pack_valid =/= countPackValidThresh.U) { // count < thresh
count_pack_valid := count_pack_valid + in.pack_has_valid
count_pack_valid_dec := false.B
}.otherwise { // count reach threshold
count_pack_valid := 0.U
count_pack_valid_dec := true.B
}
}
val atk_primemix_1 = Wire(Bool())
atk_primemix_1 := atk_primeinst_1 && atk_prime_4
atk_primemix_1 := Mux(count_pack_valid_dec, !atk_prime_4, atk_prime_4) && in.isMonitoring // if dec=true, dec1 if atk0, dec0 if atk1; if dec=false, inc1 if atk1, inc0 if atk0.
val countMix2Thresh = 4
val countmix2_1 = RegInit(0.U(log2Ceil(countMix2Thresh + 1).W)).suggestName("dpprime_countmix2_1")
val countmix2_2 = RegInit(0.U(log2Ceil(countMix2Thresh + 1).W)).suggestName("dpprime_countmix2_2")
when (in.dmemaddr_valid) { // modify counter only when valid
countmix2_1 := Mux(count4 === countThresh4.U,
countMix2Thresh.U, // if count4 max, set countmix2_1 to max
Mux(countmix2_1 <= 1.U, 0.U, countmix2_1 - 1.U)) // else decrement
}
when (in.pack_has_valid) { // modify counter only when valid
countmix2_2 := Mux(countinst4 === countInstThresh4.U,
countMix2Thresh.U, // if countinst4 max, set countmix2_2 to max
Mux(countmix2_2 <= 1.U, 0.U, countmix2_2 - 1.U)) // else decrement
}
val atk_primemix_2 = Wire(Bool())
atk_primemix_2 := atk_primeinst_2 && atk_prime_4
atk_primemix_2 := countmix2_1.orR && countmix2_2.orR && in.isMonitoring
val countMix3Max = 8
val countMix3Thresh = 5
val countMix4Max = 12
val countMix4Thresh = 9
val countmix3 = RegInit(0.U(log2Ceil(countMix3Max + 1).W)).suggestName("dpprime_countmix3")
val countmix4 = RegInit(0.U(log2Ceil(countMix4Max + 1).W)).suggestName("dpprime_countmix4")
val countmix3state = Wire(UInt(2.W))
countmix3state := Cat((count4 === countThresh4.U) && in.dmemaddr_valid,
(countinst4 === countInstThresh4.U) && in.pack_has_valid)
when (countmix3state.andR) {
countmix3 := Mux(countmix3 <= countMix3Max.U - 8.U, countmix3 + 8.U, countMix3Max.U)
countmix4 := Mux(countmix4 <= countMix4Max.U - 12.U, countmix4 + 12.U, countMix4Max.U)
}.elsewhen (countmix3state(0) === true.B) {
countmix3 := Mux(countmix3 <= countMix3Max.U - 4.U, countmix3 + 4.U, countMix3Max.U)
countmix4 := Mux(countmix4 <= countMix4Max.U - 8.U, countmix4 + 8.U, countMix4Max.U)
}.elsewhen (countmix3state(1) === true.B) {
countmix3 := Mux(countmix3 <= countMix3Max.U - 4.U, countmix3 + 4.U, countMix3Max.U)
countmix4 := Mux(countmix4 <= countMix4Max.U - 4.U, countmix4 + 4.U, countMix4Max.U)
}.elsewhen (in.dmemaddr_valid || in.pack_has_valid) { // has valid input, should decrement
countmix3 := Mux(countmix3 >= 1.U, countmix3 - 1.U, 0.U)
countmix4 := Mux(countmix4 >= 1.U, countmix4 - 1.U, 0.U)
}.otherwise { // no valid package, remains old value
countmix3 := countmix3
countmix4 := countmix4
}
val atk_primemix_3 = Wire(Bool())
atk_primemix_3 := atk_primeinst_3 && atk_prime_4
atk_primemix_3 := (countmix3 >= countMix3Thresh.U) && in.isMonitoring
val atk_primemix_4 = Wire(Bool())
atk_primemix_4 := atk_primeinst_4 && atk_prime_4
override def regmap(offset: Int) = {
// Additional debug information
val zeros_line = RegInit(0.U(params.setAddrLSB.W)).suggestName("zeros_line")
val regmapDebug = Seq(
offset -> Seq(
RegField.r(mp.addrWidth, Cat(saved_dmemtag(0), Cat(saved_dmemset(0), zeros_line)),
RegFieldDesc("saved_dmemaddr",
"Last valid input DMem addr.", reset=Some(0))))
)
// Final regmap
regmapDebug ++
RegmapUtil.countEventThreshAlarm(atk_prime_1, in.resetCounters, offset + 0x10, "AtkPrime1") ++
RegmapUtil.countEventThreshAlarm(atk_prime_2, in.resetCounters, offset + 0x20, "AtkPrime2") ++
RegmapUtil.countEventThreshAlarm(atk_prime_3, in.resetCounters, offset + 0x30, "AtkPrime3") ++
RegmapUtil.countEventThreshAlarm(atk_prime_4, in.resetCounters, offset + 0x40, "AtkPrime4") ++
RegmapUtil.countEventThreshAlarm(atk_primeinst_1, in.resetCounters, offset + 0x50, "AtkPrimeInst1") ++
RegmapUtil.countEventThreshAlarm(atk_primeinst_2, in.resetCounters, offset + 0x60, "AtkPrimeInst2") ++
RegmapUtil.countEventThreshAlarm(atk_primeinst_3, in.resetCounters, offset + 0x70, "AtkPrimeInst3") ++
RegmapUtil.countEventThreshAlarm(atk_primeinst_4, in.resetCounters, offset + 0x80, "AtkPrimeInst4") ++
RegmapUtil.countEventThreshAlarm(atk_primemix_1, in.resetCounters, offset + 0x90, "AtkPrimeMix1") ++
RegmapUtil.countEventThreshAlarm(atk_primemix_2, in.resetCounters, offset + 0xA0, "AtkPrimeMix2") ++
RegmapUtil.countEventThreshAlarm(atk_primemix_3, in.resetCounters, offset + 0xB0, "AtkPrimeMix3") ++
RegmapUtil.countEventThreshAlarm(atk_primemix_4, in.resetCounters, offset + 0xC0, "AtkPrimeMix4")
}
atk_primemix_4 := (countmix4 >= countMix4Thresh.U) && in.isMonitoring
// Debug Signals
val zeros_line = RegInit(0.U(params.setAddrLSB.W)).suggestName("zeros_line")
val saved_addr0 = RegInit(0.U(mp.addrWidth.W)).suggestName("saved_addr0")
saved_addr0 := Cat(saved_dmemtag(0), Cat(saved_dmemset(0), zeros_line))
override def regmap(offset: Int) =
RegmapUtil.countEvent(atk_prime_1, in.resetCounters, offset + 0x0, "AtkPrime1") ++
RegmapUtil.countEvent(atk_prime_2, in.resetCounters, offset + 0x4, "AtkPrime2") ++
RegmapUtil.countEvent(atk_prime_3, in.resetCounters, offset + 0x8, "AtkPrime3") ++
RegmapUtil.countEvent(atk_prime_4, in.resetCounters, offset + 0xC, "AtkPrime4") ++
RegmapUtil.countEvent(atk_prime_5, in.resetCounters, offset + 0x10, "AtkPrime5") ++
RegmapUtil.countEvent(atk_prime_6, in.resetCounters, offset + 0x14, "AtkPrime6") ++
RegmapUtil.countEvent(atk_prime_7, in.resetCounters, offset + 0x18, "AtkPrime7") ++
RegmapUtil.countEvent(atk_prime_8, in.resetCounters, offset + 0x1C, "AtkPrime8") ++
RegmapUtil.countEvent(atk_primeinst_1, in.resetCounters, offset + 0x30, "AtkPrimeInst1") ++
RegmapUtil.countEvent(atk_primeinst_2, in.resetCounters, offset + 0x34, "AtkPrimeInst2") ++
RegmapUtil.countEvent(atk_primeinst_3, in.resetCounters, offset + 0x38, "AtkPrimeInst3") ++
RegmapUtil.countEvent(atk_primeinst_4, in.resetCounters, offset + 0x3C, "AtkPrimeInst4") ++
RegmapUtil.countEventMax(atk_primemix_1, in.resetCounters, count_pack_valid_dec, offset + 0x50, "AtkPrimeMix1") ++
RegmapUtil.countEvent(atk_primemix_2, in.resetCounters, offset + 0x64, "AtkPrimeMix2") ++
RegmapUtil.countEvent(atk_primemix_3, in.resetCounters, offset + 0x68, "AtkPrimeMix3") ++
RegmapUtil.countEvent(atk_primemix_4, in.resetCounters, offset + 0x6C, "AtkPrimeMix4")
}
object DetectPatternPrime {
......
package matana
import chisel3._
import chisel3.util.Cat
import freechips.rocketchip.regmapper.RegField
case class DetectPatternRopParams(
withJalr: Boolean = true,
withMispredict: Boolean = true
)
class DetectPatternRopIn()(implicit val mp: MatanaParams) extends Bundle {
val resetCounters = Bool()
val isMonitoring = Bool()
val pack_has_valid = Bool()
val pack_has_jalr = Bool()
val pack_has_jump = Bool()
//val pack_has_mispredict = Bool() //TODO
}
class DetectPatternRop()(implicit mp: MatanaParams) {
val in = Wire(new DetectPatternRopIn())
def regmap(offset: Int): Seq[(Int, Seq[RegField])] = Nil
}
class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: MatanaParams) extends DetectPatternRop {
// Attack Pattern
// Description: Jalr chained short instruction sequence
// 1) jalr + jalr = c+1, other + jalr = c-2 (window size 1)(thresh proably 10 as in mispredict1)
// 2) jalr + jalr = c+1, other + jalr = c-2 (window size 2)
// 3) jalr + jalr = c+1, other jump = c-2 (window size 1)
// 4) jalr + jalr = c+1, other jump = c-2 (window size 2)
// Note:
// 1) clockDiv will impact the pattern on window size, smaller clockDiv is perhaps more accurate.
// 2) ROP gadget destination is likely not in cache, so jalr will take longer and will probably be visible in different clockDiv pack. But this property may increase false positive as legal chain will now seems to be shorter.
// 3) 1 pattern is a complet attack, threshold, if exists should be set to 1
val npackSizeMax = 2
val npack_jalr = RegInit(0.U(npackSizeMax.W)).suggestName("npack_jalr")
when (in.pack_has_valid) {
npack_jalr := Cat(npack_jalr(npackSizeMax - 2, 0), in.pack_has_jalr)
}
val countjalr1 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr1")
val countjalr2 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr2")
val countjalr3 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr3")
val countjalr4 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr4")
when (in.resetCounters) {
countjalr1 := 0.U
countjalr2 := 0.U
countjalr3 := 0.U
countjalr4 := 0.U
}.otherwise {
when (in.isMonitoring && in.pack_has_jalr) { // evaluate counter1,2 when jalr (include pack_has_valid)
countjalr1 := Mux(npack_jalr(0),
countjalr1 + 1.U,
Mux(countjalr1 >= 2.U, countjalr1 - 2.U, 0.U))
countjalr2 := Mux(npack_jalr(1,0).orR,
countjalr2 + 1.U,
Mux(countjalr2 >= 2.U, countjalr2 - 2.U, 0.U))
}
when (in.isMonitoring && in.pack_has_jump) { // evaluate counter3,4 when jump (include pack_has_valid)
countjalr3 := Mux(in.pack_has_jalr && npack_jalr(0),
countjalr3 + 1.U,
Mux(countjalr3 >= 2.U, countjalr3 - 2.U, 0.U))
countjalr4 := Mux(in.pack_has_jalr && npack_jalr(1,0).orR,
countjalr4 + 1.U,
Mux(countjalr4 >= 2.U, countjalr4 - 2.U, 0.U))
}
}
// Attack Pattern
// Original paper: PMU-extended Hardware ROP Attack Detection, Li et al., 2018
// Description: Mispredicted chain
// 0) Original (clockDiv=1): length <= 6 c+1, length > 6 c-2, thresh 10
// 1) mispredict and last 1 valid is mispredict c+1
// mispredict and last 1 valid is not mispredict c-2
// (thresh 10)
// Require: BTB //TODO verify
val atk_rop_mispredict = Wire(Bool())
atk_rop_mispredict := false.B && in.isMonitoring // TODO
override def regmap(offset: Int) =
RegmapUtil.readValueMax(countjalr1, in.resetCounters, mp.counterWidth, offset + 0x0, "CountRopJalr1") ++
RegmapUtil.readValueMax(countjalr2, in.resetCounters, mp.counterWidth, offset + 0x8, "CountRopJalr2") ++
RegmapUtil.readValueMax(countjalr3, in.resetCounters, mp.counterWidth, offset + 0x10, "CountRopJalr3") ++
RegmapUtil.readValueMax(countjalr4, in.resetCounters, mp.counterWidth, offset + 0x18, "CountRopJalr4") ++
Nil
}
object DetectPatternRop {
def apply()(implicit mp: MatanaParams) : DetectPatternRop = {
mp.dPatternRop match {
case Some(params) => new DetectPatternRopInternal(params)
case None => new DetectPatternRop()
}
}
}
......@@ -93,7 +93,9 @@ class MatanaSlowDetectionImp(outer: MatanaSlowDetection, params: MatanaParams)
aInst.in.valid := valid
// Analysis result of each instruction in the current pack
Seq(valid, aInst.isTimer, aInst.isFlushL1D, aInst.isLoadWord)
Seq(valid,
aInst.isTimer, aInst.isFlushL1D, aInst.isLoadWord, aInst.isJalr,
aInst.isJump)
}
// If at least one instruction has the instruction pattern in pack
......@@ -101,11 +103,15 @@ class MatanaSlowDetectionImp(outer: MatanaSlowDetection, params: MatanaParams)
val pack_has_timer = RegInit(false.B)
val pack_has_flush = RegInit(false.B)
val pack_has_loadword = RegInit(false.B)
val pack_has_jalr = RegInit(false.B)
val pack_has_jump = RegInit(false.B)
// when isMonitoring, return actual value
pack_has_valid := Mux(isMonitoring, pack_analysis.map(_(0)).reduce(_||_), false.B)
pack_has_timer := Mux(isMonitoring, pack_analysis.map(_(1)).reduce(_||_), false.B)
pack_has_flush := Mux(isMonitoring, pack_analysis.map(_(2)).reduce(_||_), false.B)
pack_has_loadword := Mux(isMonitoring, pack_analysis.map(_(3)).reduce(_||_), false.B)
pack_has_jalr := Mux(isMonitoring, pack_analysis.map(_(4)).reduce(_||_), false.B)
pack_has_jump := Mux(isMonitoring, pack_analysis.map(_(5)).reduce(_||_), false.B)
//--------------------------------------------------------------
......@@ -127,6 +133,13 @@ class MatanaSlowDetectionImp(outer: MatanaSlowDetection, params: MatanaParams)
detectPatternPrime.in.pack_has_valid := pack_has_valid
detectPatternPrime.in.pack_has_loadword := pack_has_loadword
val detectPatternRop = DetectPatternRop()
detectPatternRop.in.resetCounters := resetCounters
detectPatternRop.in.isMonitoring := isMonitoring
detectPatternRop.in.pack_has_valid := pack_has_valid
detectPatternRop.in.pack_has_jalr := pack_has_jalr
detectPatternRop.in.pack_has_jump := pack_has_jump
// Test PC
val lastPC = RegInit(0.U(mp.addrWidth.W))
lastPC := Mux(pack0.one_pc_valid, pack0.one_pc_data, lastPC)
......@@ -173,14 +186,19 @@ class MatanaSlowDetectionImp(outer: MatanaSlowDetection, params: MatanaParams)
val regmapFinal =
regmap0 ++
// Debug for control signal
RegmapUtil.countEvent(reg_pid_change, resetCounters, 0x20, "PidChange", "PidChange") ++
RegmapUtil.countEvent(reg_pid_change_monitored, resetCounters, 0x24, "PidChangeMonitored", "PidChangeMonitored") ++
RegmapUtil.countEvent(reg_pid_change, resetCounters, 0x20, "PidChange") ++
RegmapUtil.countEvent(reg_pid_change_monitored, resetCounters, 0x24, "PidChangeMonitored") ++
// Internal State
RegmapUtil.countPackHas(pack_has_valid, resetCounters, 0x100, "Valid") ++
RegmapUtil.countPackHas(pack_has_timer, resetCounters, 0x104, "Timer") ++
RegmapUtil.countPackHas(pack_has_flush, resetCounters, 0x108, "FlushL1D") ++
RegmapUtil.countEvent(pack_has_valid, resetCounters, 0x100, "PackValid") ++
RegmapUtil.countEvent(pack_has_timer, resetCounters, 0x104, "PackTimer") ++
RegmapUtil.countEvent(pack_has_flush, resetCounters, 0x108, "PackFlushL1D") ++
RegmapUtil.countEvent(pack0.one_dmemaddr_valid, resetCounters, 0x10C, "DmemAddrValid") ++
RegmapUtil.countEvent(isMonitoring, resetCounters, 0x110, "SlowCycle") ++
RegmapUtil.countEvent(pack_has_jalr, resetCounters, 0x114, "PackJalr") ++
RegmapUtil.countEvent(pack_has_jump, resetCounters, 0x118, "PackJump") ++
detectPatternTimer.regmap(0x200) ++
detectPatternPrime.regmap(0x300) ++
detectPatternRop.regmap(0x400) ++
// Example of empty list
Nil
......
......@@ -3,65 +3,142 @@ package matana
import chisel3._
import freechips.rocketchip.regmapper.{RegField, RegFieldDesc}
object RegmapUtil {
class RegmapUtilCountEventIn()(implicit val mp: MatanaParams) extends Bundle {
val event = Bool()
val reset = Bool()
val decrement = Bool()
}
def countEvent(in: Bool, reset: Bool, regmapOffset: Int, regName: String, regDesc: String)(implicit mp: MatanaParams) : Seq[(Int, Seq[RegField])] = {
class RegmapUtilCountEvent(regName: String)(implicit mp: MatanaParams) {
val in = Wire(new RegmapUtilCountEventIn())
val countEvent = RegInit(0.U(mp.counterWidth.W)).suggestName("count" + regName)
countEvent := Mux(reset, 0.U, countEvent + in)
val countEvent = RegInit(0.U(mp.counterWidth.W)).suggestName("count" + regName)
countEvent := Mux(in.reset,
0.U,
Mux(in.decrement && in.event,
Mux(countEvent >= 1.U, countEvent - in.event, 0.U), // decrement do not below 0
countEvent + in.event))