Commit 16c1e504 authored by Yuxiao Mao's avatar Yuxiao Mao
Browse files

Detect: keep only the best Prime/ROP detect pattern

parent 784261e7
......@@ -43,22 +43,10 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
// Attack pattern
// Description: Multiples access to DCache, each time in same set with differents tag
// 1) diff tag c+1, same tag c-1 (thresh 8)
// 2) diff tag c+1, same tag c-1 (thresh 4)
// 3) diff tag same set c+2, diff tag diff set c+1, same tag same set c-2, same tag diff set c-1 (thresh 16)
// 4) diff tag same set c+2, diff tag diff set c+1, same tag same set c-2, same tag diff set c-1 (thresh 8)
// 5) diff tag c+1, same tag c-1 (thresh 4, if thresh then 0)
// 6) diff tag c+1, same tag c-2 (thresh 4)
// 7) diff tag same set c+2, diff tag diff set c+1, same tag same set c-2, same tag diff set c-1 (thresh 8, if thresh then 0)
// 8) diff tag same set c+2, diff tag diff set c+1, same tag same set c-4, same tag diff set c-2 (thresh 8)
val countThresh1 = 8
val countThresh2 = 4
val countThresh3 = 16
val countThresh4 = 8
val countThresh5 = 4
val countThresh6 = 4
val countThresh7 = 8
val countThresh8 = 8
val dmemtag = Wire(UInt((mp.addrWidth - params.tagAddrLSB).W))
......@@ -69,13 +57,7 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
val saved_dmemtag = RegInit(VecInit(Seq.fill(params.savedDMemTagNum)(0.U((mp.addrWidth - params.tagAddrLSB).W)))).suggestName("saved_dmemtag")
val saved_dmemset = RegInit(VecInit(Seq.fill(params.savedDMemSetNum)(0.U((params.tagAddrLSB - params.setAddrLSB).W)))).suggestName("saved_dmemset")
val count1 = RegInit(0.U(log2Ceil(countThresh1 + 1).W)).suggestName("dpprime_count1")
val count2 = RegInit(0.U(log2Ceil(countThresh2 + 1).W)).suggestName("dpprime_count2")
val count3 = RegInit(0.U(log2Ceil(countThresh3 + 1).W)).suggestName("dpprime_count3")
val count4 = RegInit(0.U(log2Ceil(countThresh4 + 1).W)).suggestName("dpprime_count4")
val count5 = RegInit(0.U(log2Ceil(countThresh5 + 1).W)).suggestName("dpprime_count5")
val count6 = RegInit(0.U(log2Ceil(countThresh6 + 1).W)).suggestName("dpprime_count6")
val count7 = RegInit(0.U(log2Ceil(countThresh7 + 1).W)).suggestName("dpprime_count7")
val count8 = RegInit(0.U(log2Ceil(countThresh8 + 1).W)).suggestName("dpprime_count8")
// Update saved_dmem* array in FIFO way
......@@ -87,16 +69,10 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
last := data
last
}
count3 := Mux(count3 <= countThresh3.U - 1.U, count3 + 1.U, countThresh3.U)
count4 := Mux(count4 <= countThresh4.U - 1.U, count4 + 1.U, countThresh4.U)
count7 := Mux(count7 <= countThresh7.U - 1.U, count7 + 1.U, 1.U)
count8 := Mux(count8 <= countThresh8.U - 1.U, count8 + 1.U, countThresh8.U)
}.otherwise { // old set
count3 := Mux(count3 <= countThresh3.U - 2.U, count3 + 2.U, countThresh3.U)
count4 := Mux(count4 <= countThresh4.U - 2.U, count4 + 2.U, countThresh4.U)
count7 := Mux(count7 =/= countThresh7.U,
Mux(count7 <= countThresh7.U - 2.U, count7 + 2.U, countThresh7.U),
2.U)
count8 := Mux(count8 <= countThresh8.U - 2.U, count8 + 2.U, countThresh8.U)
}
// Update saved_dmemtag, newer at 0
......@@ -104,10 +80,6 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
last := data
last
}
count1 := Mux(count1 <= countThresh1.U - 1.U, count1 + 1.U, countThresh1.U)
count2 := Mux(count2 <= countThresh2.U - 1.U, count2 + 1.U, countThresh2.U)
count5 := Mux(count5 <= countThresh5.U - 1.U, count5 + 1.U, 1.U)
count6 := Mux(count6 <= countThresh6.U - 1.U, count6 + 1.U, countThresh6.U)
}.otherwise { // old tag
when (saved_dmemset.map(_ =/= dmemtag).reduce(_&&_)) { // new set
// Update saved_dmemset, newer at 0
......@@ -115,108 +87,48 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
last := data
last
}
count3 := Mux(count3 >= 1.U, count3 - 1.U, 0.U)
count4 := Mux(count4 >= 1.U, count4 - 1.U, 0.U)
count7 := Mux(count7 =/= countThresh7.U,
Mux(count7 >= 1.U, count7 - 1.U, 0.U),
0.U)
count8 := Mux(count8 >= 2.U, count8 - 2.U, 0.U)
}.otherwise { // old set
count3 := Mux(count3 >= 2.U, count3 - 2.U, 0.U)
count4 := Mux(count4 >= 2.U, count4 - 2.U, 0.U)
count7 := Mux(count7 =/= countThresh7.U,
Mux(count7 >= 2.U, count7 - 2.U, 0.U),
0.U)
count8 := Mux(count8 >= 4.U, count8 - 4.U, 0.U)
}
count1 := Mux(count1 >= 1.U, count1 - 1.U, 0.U)
count2 := Mux(count2 >= 1.U, count2 - 1.U, 0.U)
count5 := Mux(count5 =/= countThresh5.U,
Mux(count5 >= 1.U, count5 - 1.U, 0.U),
0.U)
count6 := Mux(count6 >= 2.U, count6 - 2.U, 0.U)
}
}
// only count atk during monitoring, and when an dmemaddr is present (prevent counting during wait)
val atk_prime_1 = Wire(Bool())
atk_prime_1 := (count1 === countThresh1.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_2 = Wire(Bool())
atk_prime_2 := (count2 === countThresh2.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_3 = Wire(Bool())
atk_prime_3 := (count3 === countThresh3.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_4 = Wire(Bool())
atk_prime_4 := (count4 === countThresh4.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_5 = Wire(Bool())
atk_prime_5 := (count5 === countThresh5.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_6 = Wire(Bool())
atk_prime_6 := (count6 === countThresh6.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_7 = Wire(Bool())
atk_prime_7 := (count7 === countThresh7.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_8 = Wire(Bool())
atk_prime_8 := (count8 === countThresh8.U) && in.isMonitoring && in.dmemaddr_valid
// Attack Pattern
// Description: all valid pack contains an word size load
// 1) valid load c+1, valid not load c-1 (thresh 32)
// 2) valid load c+1, valid not load c-2 (thresh 32) (maybe only apply to high clkDiv, as for -O0 there is store with load. we suppose if store to the same addr, the next load can be conclude in the same slow cycle)
// Description: all valid pack contains an word size load. For -O0 there is store with load, we suppose if store to the same addr, the next load can be conclude in the same slow cycle.
// 3) valid load c+1, valid not load c-2 (thresh 16)
// 4) valid load c+1, valid not load c-4 (thresh 16)
val countInstThresh1 = 32
val countInstThresh2 = 32
val countInstThresh3 = 16
val countInstThresh4 = 16
val countinst1 = RegInit(0.U(log2Ceil(countInstThresh1 + 1).W)).suggestName("dpprime_countinst1")
val countinst2 = RegInit(0.U(log2Ceil(countInstThresh2 + 1).W)).suggestName("dpprime_countinst2")
val countinst3 = RegInit(0.U(log2Ceil(countInstThresh3 + 1).W)).suggestName("dpprime_countinst3")
val countinst4 = RegInit(0.U(log2Ceil(countInstThresh4 + 1).W)).suggestName("dpprime_countinst4")
when (in.pack_has_valid && in.isMonitoring) {
when (in.pack_has_loadword) {
countinst1 := Mux(countinst1 <= countInstThresh1.U - 1.U, countinst1 + 1.U, countInstThresh1.U)
countinst2 := Mux(countinst2 <= countInstThresh2.U - 1.U, countinst2 + 1.U, countInstThresh2.U)
countinst3 := Mux(countinst3 <= countInstThresh3.U - 1.U, countinst3 + 1.U, countInstThresh3.U)
countinst4 := Mux(countinst4 <= countInstThresh4.U - 1.U, countinst4 + 1.U, countInstThresh4.U)
}.otherwise {
countinst1 := Mux(countinst1 >= 1.U, countinst1 - 1.U, 0.U)
countinst2 := Mux(countinst2 >= 2.U, countinst2 - 2.U, 0.U)
countinst3 := Mux(countinst3 >= 2.U, countinst3 - 2.U, 0.U)
countinst4 := Mux(countinst4 >= 4.U, countinst4 - 4.U, 0.U)
}
}
// only count atk during monitoring, and when an valid inst is present (prevent counting during wait)
val atk_primeinst_1 = Wire(Bool())
atk_primeinst_1 := (countinst1 === countInstThresh1.U) && in.isMonitoring && in.pack_has_valid
val atk_primeinst_2 = Wire(Bool())
atk_primeinst_2 := (countinst2 === countInstThresh2.U) && in.isMonitoring && in.pack_has_valid
val atk_primeinst_3 = Wire(Bool())
atk_primeinst_3 := (countinst3 === countInstThresh3.U) && in.isMonitoring && in.pack_has_valid
val atk_primeinst_4 = Wire(Bool())
atk_primeinst_4 := (countinst4 === countInstThresh4.U) && in.isMonitoring && in.pack_has_valid
// Attack Pattern
// Description: Mix pattern
// 1) count4, but decrement the counter by 1 if count_pack_valid reach thresh (0x200=512)
// 1) count4, but decrement the counter by 1 if count_pack_valid reach thresh (0x800=2048)
// => TODO: perhaps adjust using dmemaddr / slow cycle count
// 2) count4+countinst4: if 2 atk are recognized in 8 slow cycles
// 3) count4+countinst4: (similar to mix2 but use 1 counter for both count4 and countinst4)
// count4 max && countinst4 max: c+8
// count4 max || countinst4 max: c+4
// else : c-1 (as c4/ci4 no reset 0, this means no attack sign)
// (max 8, thresh 5)
// 4) count4+countinst4: (similar to mix3 but priorize count4)
// count4 max && countinst4 max: c+12
// count4 max : c+8
// countinst4 max: c+4
// else : c-1
// (max 12, thresh 9)
// 5) count4+countinst4: (only take when c4/ci4=max or c4/ci4=0, if thresh then 0 to prevent cmix remains max) TODO?
val countPackValidThresh = 0x200L
val countPackValidThresh = 0x800L //2048
val count_pack_valid = RegInit(0.U(mp.counterWidth.W)).suggestName("dpprime_count_pack_valid")
val count_pack_valid_dec = RegInit(false.B).suggestName("dpprime_count_pack_valid_dec")
when (in.resetCounters) {
......@@ -235,57 +147,6 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
val atk_primemix_1 = Wire(Bool())
atk_primemix_1 := Mux(count_pack_valid_dec, !atk_prime_4, atk_prime_4) && in.isMonitoring // if dec=true, dec1 if atk0, dec0 if atk1; if dec=false, inc1 if atk1, inc0 if atk0.
val countMix2Thresh = 8
val countmix2_1 = RegInit(0.U(log2Ceil(countMix2Thresh + 1).W)).suggestName("dpprime_countmix2_1")
val countmix2_2 = RegInit(0.U(log2Ceil(countMix2Thresh + 1).W)).suggestName("dpprime_countmix2_2")
when (in.dmemaddr_valid && in.isMonitoring) { // modify counter only when valid
countmix2_1 := Mux(count4 === countThresh4.U,
countMix2Thresh.U, // if count4 max, set countmix2_1 to max
Mux(countmix2_1 <= 1.U, 0.U, countmix2_1 - 1.U)) // else decrement
}
when (in.pack_has_valid && in.isMonitoring) { // modify counter only when valid
countmix2_2 := Mux(countinst4 === countInstThresh4.U,
countMix2Thresh.U, // if countinst4 max, set countmix2_2 to max
Mux(countmix2_2 <= 1.U, 0.U, countmix2_2 - 1.U)) // else decrement
}
val atk_primemix_2 = Wire(Bool())
atk_primemix_2 := countmix2_1.orR && countmix2_2.orR && in.isMonitoring
val countMix3Max = 8
val countMix3Thresh = 5
val countMix4Max = 12
val countMix4Thresh = 9
val countmix3 = RegInit(0.U(log2Ceil(countMix3Max + 1).W)).suggestName("dpprime_countmix3")
val countmix4 = RegInit(0.U(log2Ceil(countMix4Max + 1).W)).suggestName("dpprime_countmix4")
val countmix3state = Wire(UInt(2.W))
countmix3state := Cat((count4 === countThresh4.U) && in.dmemaddr_valid,
(countinst4 === countInstThresh4.U) && in.pack_has_valid)
when (in.resetCounters) {
countmix3 := 0.U
countmix4 := 0.U
}.elsewhen (in.isMonitoring) {
when (countmix3state.andR) {
countmix3 := Mux(countmix3 <= countMix3Max.U - 8.U, countmix3 + 8.U, countMix3Max.U)
countmix4 := Mux(countmix4 <= countMix4Max.U - 12.U, countmix4 + 12.U, countMix4Max.U)
}.elsewhen (countmix3state(0) === true.B) {
countmix3 := Mux(countmix3 <= countMix3Max.U - 4.U, countmix3 + 4.U, countMix3Max.U)
countmix4 := Mux(countmix4 <= countMix4Max.U - 8.U, countmix4 + 8.U, countMix4Max.U)
}.elsewhen (countmix3state(1) === true.B) {
countmix3 := Mux(countmix3 <= countMix3Max.U - 4.U, countmix3 + 4.U, countMix3Max.U)
countmix4 := Mux(countmix4 <= countMix4Max.U - 4.U, countmix4 + 4.U, countMix4Max.U)
}.elsewhen (in.dmemaddr_valid || in.pack_has_valid) { // has valid input, should decrement
countmix3 := Mux(countmix3 >= 1.U, countmix3 - 1.U, 0.U)
countmix4 := Mux(countmix4 >= 1.U, countmix4 - 1.U, 0.U)
} // otherwise no valid package, remains old value
}
val atk_primemix_3 = Wire(Bool())
atk_primemix_3 := (countmix3 >= countMix3Thresh.U) && in.isMonitoring
val atk_primemix_4 = Wire(Bool())
atk_primemix_4 := (countmix4 >= countMix4Thresh.U) && in.isMonitoring
// Debug Signals
val zeros_line = RegInit(0.U(params.setAddrLSB.W)).suggestName("zeros_line")
......@@ -294,22 +155,11 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
override def regmap(offset: Int) =
RegmapUtil.countEvent(atk_prime_1, in.resetCounters, offset + 0x0, "AtkPrime1") ++
RegmapUtil.countEvent(atk_prime_2, in.resetCounters, offset + 0x4, "AtkPrime2") ++
RegmapUtil.countEvent(atk_prime_3, in.resetCounters, offset + 0x8, "AtkPrime3") ++
RegmapUtil.countEvent(atk_prime_4, in.resetCounters, offset + 0xC, "AtkPrime4") ++
RegmapUtil.countEvent(atk_prime_5, in.resetCounters, offset + 0x10, "AtkPrime5") ++
RegmapUtil.countEvent(atk_prime_6, in.resetCounters, offset + 0x14, "AtkPrime6") ++
RegmapUtil.countEvent(atk_prime_7, in.resetCounters, offset + 0x18, "AtkPrime7") ++
RegmapUtil.countEvent(atk_prime_8, in.resetCounters, offset + 0x1C, "AtkPrime8") ++
RegmapUtil.countEvent(atk_primeinst_1, in.resetCounters, offset + 0x30, "AtkPrimeInst1") ++
RegmapUtil.countEvent(atk_primeinst_2, in.resetCounters, offset + 0x34, "AtkPrimeInst2") ++
RegmapUtil.countEvent(atk_primeinst_3, in.resetCounters, offset + 0x38, "AtkPrimeInst3") ++
RegmapUtil.countEvent(atk_primeinst_4, in.resetCounters, offset + 0x3C, "AtkPrimeInst4") ++
RegmapUtil.countEventMax(atk_primemix_1, in.resetCounters, count_pack_valid_dec, offset + 0x50, "AtkPrimeMix1") ++
RegmapUtil.countEvent(atk_primemix_2, in.resetCounters, offset + 0x64, "AtkPrimeMix2") ++
RegmapUtil.countEvent(atk_primemix_3, in.resetCounters, offset + 0x68, "AtkPrimeMix3") ++
RegmapUtil.countEvent(atk_primemix_4, in.resetCounters, offset + 0x6C, "AtkPrimeMix4")
Nil
}
object DetectPatternPrime {
......
......@@ -26,15 +26,12 @@ class DetectPatternRop()(implicit mp: MatanaParams) {
class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: MatanaParams) extends DetectPatternRop {
// Attack Pattern
// Description: Jalr chained short instruction sequence
// 1) jalr + jalr = c+1, other + jalr = c-2 (window size 1)(thresh proably 10 as in mispredict1)
// 2) jalr + jalr = c+1, other + jalr = c-2 (window size 2)
// 3) jalr + jalr = c+1, other jump = c-2 (window size 1)
// 4) jalr + jalr = c+1, other jump = c-2 (window size 2)
// Description: Jalr chained short instruction sequence.
// 3) jalr + jalr = c+1, other jump = c-2 (window size 1) (thresh proably 10 as in mispredict1) (window size 2 has high false positive rate)
// Note:
// 1) clockDiv will impact the pattern on window size, smaller clockDiv is perhaps more accurate.
// 2) ROP gadget destination is likely not in cache, so jalr will take longer and will probably be visible in different clockDiv pack. But this property may increase false positive as legal chain will now seems to be shorter.
// 3) 1 pattern is a complet attack, threshold, if exists should be set to 1
val npackSizeMax = 2
val npack_jalr = RegInit(0.U(npackSizeMax.W)).suggestName("npack_jalr")
......@@ -42,32 +39,15 @@ class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: Mata
npack_jalr := Cat(npack_jalr(npackSizeMax - 2, 0), in.pack_has_jalr)
}
val countjalr1 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr1")
val countjalr2 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr2")
val countjalr3 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr3")
val countjalr4 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr4")
when (in.resetCounters) {
countjalr1 := 0.U
countjalr2 := 0.U
countjalr3 := 0.U
countjalr4 := 0.U
}.otherwise {
when (in.isMonitoring && in.pack_has_jalr) { // evaluate counter1,2 when jalr (include pack_has_valid)
countjalr1 := Mux(npack_jalr(0),
countjalr1 + 1.U,
Mux(countjalr1 >= 2.U, countjalr1 - 2.U, 0.U))
countjalr2 := Mux(npack_jalr(1,0).orR,
countjalr2 + 1.U,
Mux(countjalr2 >= 2.U, countjalr2 - 2.U, 0.U))
}
when (in.isMonitoring && in.pack_has_jump) { // evaluate counter3,4 when jump (include pack_has_valid)
when (in.isMonitoring && in.pack_has_jump) { // evaluate counter3 when jump (include pack_has_valid)
countjalr3 := Mux(in.pack_has_jalr && npack_jalr(0),
countjalr3 + 1.U,
Mux(countjalr3 >= 2.U, countjalr3 - 2.U, 0.U))
countjalr4 := Mux(in.pack_has_jalr && npack_jalr(1,0).orR,
countjalr4 + 1.U,
Mux(countjalr4 >= 2.U, countjalr4 - 2.U, 0.U))
}
}
......@@ -84,10 +64,7 @@ class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: Mata
atk_rop_mispredict := false.B && in.isMonitoring // TODO
override def regmap(offset: Int) =
RegmapUtil.readValueMax(countjalr1, in.resetCounters, mp.counterWidth, offset + 0x0, "CountRopJalr1") ++
RegmapUtil.readValueMax(countjalr2, in.resetCounters, mp.counterWidth, offset + 0x8, "CountRopJalr2") ++
RegmapUtil.readValueMax(countjalr3, in.resetCounters, mp.counterWidth, offset + 0x10, "CountRopJalr3") ++
RegmapUtil.readValueMax(countjalr4, in.resetCounters, mp.counterWidth, offset + 0x18, "CountRopJalr4") ++
Nil
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment