Commit 8b09fe7c authored by Yuxiao Mao's avatar Yuxiao Mao
Browse files

DetectPrime/Rop: rename patterns so that they are in order

parent 16c1e504
...@@ -43,11 +43,11 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp: ...@@ -43,11 +43,11 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
// Attack pattern // Attack pattern
// Description: Multiples access to DCache, each time in same set with differents tag // Description: Multiples access to DCache, each time in same set with differents tag
// 4) diff tag same set c+2, diff tag diff set c+1, same tag same set c-2, same tag diff set c-1 (thresh 8) // 1) diff tag same set c+2, diff tag diff set c+1, same tag same set c-2, same tag diff set c-1 (thresh 8)
// 8) diff tag same set c+2, diff tag diff set c+1, same tag same set c-4, same tag diff set c-2 (thresh 8) // 2) diff tag same set c+2, diff tag diff set c+1, same tag same set c-4, same tag diff set c-2 (thresh 8)
val countThresh4 = 8 val countThresh1 = 8
val countThresh8 = 8 val countThresh2 = 8
val dmemtag = Wire(UInt((mp.addrWidth - params.tagAddrLSB).W)) val dmemtag = Wire(UInt((mp.addrWidth - params.tagAddrLSB).W))
dmemtag := in.dmemaddr_data(mp.addrWidth - 1, params.tagAddrLSB) dmemtag := in.dmemaddr_data(mp.addrWidth - 1, params.tagAddrLSB)
...@@ -57,8 +57,8 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp: ...@@ -57,8 +57,8 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
val saved_dmemtag = RegInit(VecInit(Seq.fill(params.savedDMemTagNum)(0.U((mp.addrWidth - params.tagAddrLSB).W)))).suggestName("saved_dmemtag") val saved_dmemtag = RegInit(VecInit(Seq.fill(params.savedDMemTagNum)(0.U((mp.addrWidth - params.tagAddrLSB).W)))).suggestName("saved_dmemtag")
val saved_dmemset = RegInit(VecInit(Seq.fill(params.savedDMemSetNum)(0.U((params.tagAddrLSB - params.setAddrLSB).W)))).suggestName("saved_dmemset") val saved_dmemset = RegInit(VecInit(Seq.fill(params.savedDMemSetNum)(0.U((params.tagAddrLSB - params.setAddrLSB).W)))).suggestName("saved_dmemset")
val count4 = RegInit(0.U(log2Ceil(countThresh4 + 1).W)).suggestName("dpprime_count4") val count1 = RegInit(0.U(log2Ceil(countThresh1 + 1).W)).suggestName("dpprime_count1")
val count8 = RegInit(0.U(log2Ceil(countThresh8 + 1).W)).suggestName("dpprime_count8") val count2 = RegInit(0.U(log2Ceil(countThresh2 + 1).W)).suggestName("dpprime_count2")
// Update saved_dmem* array in FIFO way // Update saved_dmem* array in FIFO way
when (in.dmemaddr_valid && in.isMonitoring) { when (in.dmemaddr_valid && in.isMonitoring) {
...@@ -69,11 +69,11 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp: ...@@ -69,11 +69,11 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
last := data last := data
last last
} }
count4 := Mux(count4 <= countThresh4.U - 1.U, count4 + 1.U, countThresh4.U) count1 := Mux(count1 <= countThresh1.U - 1.U, count1 + 1.U, countThresh1.U)
count8 := Mux(count8 <= countThresh8.U - 1.U, count8 + 1.U, countThresh8.U) count2 := Mux(count2 <= countThresh2.U - 1.U, count2 + 1.U, countThresh2.U)
}.otherwise { // old set }.otherwise { // old set
count4 := Mux(count4 <= countThresh4.U - 2.U, count4 + 2.U, countThresh4.U) count1 := Mux(count1 <= countThresh1.U - 2.U, count1 + 2.U, countThresh1.U)
count8 := Mux(count8 <= countThresh8.U - 2.U, count8 + 2.U, countThresh8.U) count2 := Mux(count2 <= countThresh2.U - 2.U, count2 + 2.U, countThresh2.U)
} }
// Update saved_dmemtag, newer at 0 // Update saved_dmemtag, newer at 0
saved_dmemtag.foldLeft(dmemtag) { (data, last) => saved_dmemtag.foldLeft(dmemtag) { (data, last) =>
...@@ -87,45 +87,45 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp: ...@@ -87,45 +87,45 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
last := data last := data
last last
} }
count4 := Mux(count4 >= 1.U, count4 - 1.U, 0.U) count1 := Mux(count1 >= 1.U, count1 - 1.U, 0.U)
count8 := Mux(count8 >= 2.U, count8 - 2.U, 0.U) count2 := Mux(count2 >= 2.U, count2 - 2.U, 0.U)
}.otherwise { // old set }.otherwise { // old set
count4 := Mux(count4 >= 2.U, count4 - 2.U, 0.U) count1 := Mux(count1 >= 2.U, count1 - 2.U, 0.U)
count8 := Mux(count8 >= 4.U, count8 - 4.U, 0.U) count2 := Mux(count2 >= 4.U, count2 - 4.U, 0.U)
} }
} }
} }
// only count atk during monitoring, and when an dmemaddr is present (prevent counting during wait) // only count atk during monitoring, and when an dmemaddr is present (prevent counting during wait)
val atk_prime_4 = Wire(Bool()) val atk_prime_1 = Wire(Bool())
atk_prime_4 := (count4 === countThresh4.U) && in.isMonitoring && in.dmemaddr_valid atk_prime_1 := (count1 === countThresh1.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_8 = Wire(Bool()) val atk_prime_2 = Wire(Bool())
atk_prime_8 := (count8 === countThresh8.U) && in.isMonitoring && in.dmemaddr_valid atk_prime_2 := (count2 === countThresh2.U) && in.isMonitoring && in.dmemaddr_valid
// Attack Pattern // Attack Pattern
// Description: all valid pack contains an word size load. For -O0 there is store with load, we suppose if store to the same addr, the next load can be conclude in the same slow cycle. // Description: all valid pack contains an word size load. For -O0 there is store with load, we suppose if store to the same addr, the next load can be conclude in the same slow cycle.
// 3) valid load c+1, valid not load c-2 (thresh 16) // 1) valid load c+1, valid not load c-2 (thresh 16)
val countInstThresh3 = 16 val countInstThresh1 = 16
val countinst3 = RegInit(0.U(log2Ceil(countInstThresh3 + 1).W)).suggestName("dpprime_countinst3") val countinst1 = RegInit(0.U(log2Ceil(countInstThresh1 + 1).W)).suggestName("dpprime_countinst1")
when (in.pack_has_valid && in.isMonitoring) { when (in.pack_has_valid && in.isMonitoring) {
when (in.pack_has_loadword) { when (in.pack_has_loadword) {
countinst3 := Mux(countinst3 <= countInstThresh3.U - 1.U, countinst3 + 1.U, countInstThresh3.U) countinst1 := Mux(countinst1 <= countInstThresh1.U - 1.U, countinst1 + 1.U, countInstThresh1.U)
}.otherwise { }.otherwise {
countinst3 := Mux(countinst3 >= 2.U, countinst3 - 2.U, 0.U) countinst1 := Mux(countinst1 >= 2.U, countinst1 - 2.U, 0.U)
} }
} }
// only count atk during monitoring, and when an valid inst is present (prevent counting during wait) // only count atk during monitoring, and when an valid inst is present (prevent counting during wait)
val atk_primeinst_3 = Wire(Bool()) val atk_primeinst_1 = Wire(Bool())
atk_primeinst_3 := (countinst3 === countInstThresh3.U) && in.isMonitoring && in.pack_has_valid atk_primeinst_1 := (countinst1 === countInstThresh1.U) && in.isMonitoring && in.pack_has_valid
// Attack Pattern // Attack Pattern
// Description: Mix pattern // Description: Mix pattern
// 1) count4, but decrement the counter by 1 if count_pack_valid reach thresh (0x800=2048) // 1) count1, but decrement the counter by 1 if count_pack_valid reach thresh (0x800=2048)
// => TODO: perhaps adjust using dmemaddr / slow cycle count // => TODO: perhaps adjust using dmemaddr / slow cycle count
val countPackValidThresh = 0x800L //2048 val countPackValidThresh = 0x800L //2048
...@@ -145,7 +145,7 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp: ...@@ -145,7 +145,7 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
} }
val atk_primemix_1 = Wire(Bool()) val atk_primemix_1 = Wire(Bool())
atk_primemix_1 := Mux(count_pack_valid_dec, !atk_prime_4, atk_prime_4) && in.isMonitoring // if dec=true, dec1 if atk0, dec0 if atk1; if dec=false, inc1 if atk1, inc0 if atk0. atk_primemix_1 := Mux(count_pack_valid_dec, !atk_prime_1, atk_prime_1) && in.isMonitoring // if dec=true, dec1 if atk0, dec0 if atk1; if dec=false, inc1 if atk1, inc0 if atk0.
// Debug Signals // Debug Signals
...@@ -155,10 +155,10 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp: ...@@ -155,10 +155,10 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
override def regmap(offset: Int) = override def regmap(offset: Int) =
RegmapUtil.countEvent(atk_prime_4, in.resetCounters, offset + 0xC, "AtkPrime4") ++ RegmapUtil.countEvent(atk_prime_1, in.resetCounters, offset + 0x0, "AtkPrime1") ++
RegmapUtil.countEvent(atk_prime_8, in.resetCounters, offset + 0x1C, "AtkPrime8") ++ RegmapUtil.countEvent(atk_prime_2, in.resetCounters, offset + 0x4, "AtkPrime2") ++
RegmapUtil.countEvent(atk_primeinst_3, in.resetCounters, offset + 0x38, "AtkPrimeInst3") ++ RegmapUtil.countEvent(atk_primeinst_1, in.resetCounters, offset + 0x10, "AtkPrimeInst1") ++
RegmapUtil.countEventMax(atk_primemix_1, in.resetCounters, count_pack_valid_dec, offset + 0x50, "AtkPrimeMix1") ++ RegmapUtil.countEventMax(atk_primemix_1, in.resetCounters, count_pack_valid_dec, offset + 0x20, "AtkPrimeMix1") ++
Nil Nil
} }
......
...@@ -27,7 +27,7 @@ class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: Mata ...@@ -27,7 +27,7 @@ class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: Mata
// Attack Pattern // Attack Pattern
// Description: Jalr chained short instruction sequence. // Description: Jalr chained short instruction sequence.
// 3) jalr + jalr = c+1, other jump = c-2 (window size 1) (thresh proably 10 as in mispredict1) (window size 2 has high false positive rate) // 1) jalr + jalr = c+1, other jump = c-2 (window size 1) (thresh proably 10 as in mispredict1) (window size 2 has high false positive rate)
// Note: // Note:
// 1) clockDiv will impact the pattern on window size, smaller clockDiv is perhaps more accurate. // 1) clockDiv will impact the pattern on window size, smaller clockDiv is perhaps more accurate.
// 2) ROP gadget destination is likely not in cache, so jalr will take longer and will probably be visible in different clockDiv pack. But this property may increase false positive as legal chain will now seems to be shorter. // 2) ROP gadget destination is likely not in cache, so jalr will take longer and will probably be visible in different clockDiv pack. But this property may increase false positive as legal chain will now seems to be shorter.
...@@ -39,15 +39,15 @@ class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: Mata ...@@ -39,15 +39,15 @@ class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: Mata
npack_jalr := Cat(npack_jalr(npackSizeMax - 2, 0), in.pack_has_jalr) npack_jalr := Cat(npack_jalr(npackSizeMax - 2, 0), in.pack_has_jalr)
} }
val countjalr3 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr3") val countjalr1 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr1")
when (in.resetCounters) { when (in.resetCounters) {
countjalr3 := 0.U countjalr1 := 0.U
}.otherwise { }.otherwise {
when (in.isMonitoring && in.pack_has_jump) { // evaluate counter3 when jump (include pack_has_valid) when (in.isMonitoring && in.pack_has_jump) { // evaluate counter3 when jump (include pack_has_valid)
countjalr3 := Mux(in.pack_has_jalr && npack_jalr(0), countjalr1 := Mux(in.pack_has_jalr && npack_jalr(0),
countjalr3 + 1.U, countjalr1 + 1.U,
Mux(countjalr3 >= 2.U, countjalr3 - 2.U, 0.U)) Mux(countjalr1 >= 2.U, countjalr1 - 2.U, 0.U))
} }
} }
...@@ -64,7 +64,7 @@ class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: Mata ...@@ -64,7 +64,7 @@ class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: Mata
atk_rop_mispredict := false.B && in.isMonitoring // TODO atk_rop_mispredict := false.B && in.isMonitoring // TODO
override def regmap(offset: Int) = override def regmap(offset: Int) =
RegmapUtil.readValueMax(countjalr3, in.resetCounters, mp.counterWidth, offset + 0x10, "CountRopJalr3") ++ RegmapUtil.readValueMax(countjalr1, in.resetCounters, mp.counterWidth, offset, "CountRopJalr1") ++
Nil Nil
} }
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment