Commit 8b09fe7c authored by Yuxiao Mao's avatar Yuxiao Mao
Browse files

DetectPrime/Rop: rename patterns so that they are in order

parent 16c1e504
......@@ -43,11 +43,11 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
// Attack pattern
// Description: Multiples access to DCache, each time in same set with differents tag
// 4) diff tag same set c+2, diff tag diff set c+1, same tag same set c-2, same tag diff set c-1 (thresh 8)
// 8) diff tag same set c+2, diff tag diff set c+1, same tag same set c-4, same tag diff set c-2 (thresh 8)
// 1) diff tag same set c+2, diff tag diff set c+1, same tag same set c-2, same tag diff set c-1 (thresh 8)
// 2) diff tag same set c+2, diff tag diff set c+1, same tag same set c-4, same tag diff set c-2 (thresh 8)
val countThresh4 = 8
val countThresh8 = 8
val countThresh1 = 8
val countThresh2 = 8
val dmemtag = Wire(UInt((mp.addrWidth - params.tagAddrLSB).W))
dmemtag := in.dmemaddr_data(mp.addrWidth - 1, params.tagAddrLSB)
......@@ -57,8 +57,8 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
val saved_dmemtag = RegInit(VecInit(Seq.fill(params.savedDMemTagNum)(0.U((mp.addrWidth - params.tagAddrLSB).W)))).suggestName("saved_dmemtag")
val saved_dmemset = RegInit(VecInit(Seq.fill(params.savedDMemSetNum)(0.U((params.tagAddrLSB - params.setAddrLSB).W)))).suggestName("saved_dmemset")
val count4 = RegInit(0.U(log2Ceil(countThresh4 + 1).W)).suggestName("dpprime_count4")
val count8 = RegInit(0.U(log2Ceil(countThresh8 + 1).W)).suggestName("dpprime_count8")
val count1 = RegInit(0.U(log2Ceil(countThresh1 + 1).W)).suggestName("dpprime_count1")
val count2 = RegInit(0.U(log2Ceil(countThresh2 + 1).W)).suggestName("dpprime_count2")
// Update saved_dmem* array in FIFO way
when (in.dmemaddr_valid && in.isMonitoring) {
......@@ -69,11 +69,11 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
last := data
last
}
count4 := Mux(count4 <= countThresh4.U - 1.U, count4 + 1.U, countThresh4.U)
count8 := Mux(count8 <= countThresh8.U - 1.U, count8 + 1.U, countThresh8.U)
count1 := Mux(count1 <= countThresh1.U - 1.U, count1 + 1.U, countThresh1.U)
count2 := Mux(count2 <= countThresh2.U - 1.U, count2 + 1.U, countThresh2.U)
}.otherwise { // old set
count4 := Mux(count4 <= countThresh4.U - 2.U, count4 + 2.U, countThresh4.U)
count8 := Mux(count8 <= countThresh8.U - 2.U, count8 + 2.U, countThresh8.U)
count1 := Mux(count1 <= countThresh1.U - 2.U, count1 + 2.U, countThresh1.U)
count2 := Mux(count2 <= countThresh2.U - 2.U, count2 + 2.U, countThresh2.U)
}
// Update saved_dmemtag, newer at 0
saved_dmemtag.foldLeft(dmemtag) { (data, last) =>
......@@ -87,45 +87,45 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
last := data
last
}
count4 := Mux(count4 >= 1.U, count4 - 1.U, 0.U)
count8 := Mux(count8 >= 2.U, count8 - 2.U, 0.U)
count1 := Mux(count1 >= 1.U, count1 - 1.U, 0.U)
count2 := Mux(count2 >= 2.U, count2 - 2.U, 0.U)
}.otherwise { // old set
count4 := Mux(count4 >= 2.U, count4 - 2.U, 0.U)
count8 := Mux(count8 >= 4.U, count8 - 4.U, 0.U)
count1 := Mux(count1 >= 2.U, count1 - 2.U, 0.U)
count2 := Mux(count2 >= 4.U, count2 - 4.U, 0.U)
}
}
}
// only count atk during monitoring, and when an dmemaddr is present (prevent counting during wait)
val atk_prime_4 = Wire(Bool())
atk_prime_4 := (count4 === countThresh4.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_8 = Wire(Bool())
atk_prime_8 := (count8 === countThresh8.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_1 = Wire(Bool())
atk_prime_1 := (count1 === countThresh1.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_2 = Wire(Bool())
atk_prime_2 := (count2 === countThresh2.U) && in.isMonitoring && in.dmemaddr_valid
// Attack Pattern
// Description: all valid pack contains an word size load. For -O0 there is store with load, we suppose if store to the same addr, the next load can be conclude in the same slow cycle.
// 3) valid load c+1, valid not load c-2 (thresh 16)
// 1) valid load c+1, valid not load c-2 (thresh 16)
val countInstThresh3 = 16
val countinst3 = RegInit(0.U(log2Ceil(countInstThresh3 + 1).W)).suggestName("dpprime_countinst3")
val countInstThresh1 = 16
val countinst1 = RegInit(0.U(log2Ceil(countInstThresh1 + 1).W)).suggestName("dpprime_countinst1")
when (in.pack_has_valid && in.isMonitoring) {
when (in.pack_has_loadword) {
countinst3 := Mux(countinst3 <= countInstThresh3.U - 1.U, countinst3 + 1.U, countInstThresh3.U)
countinst1 := Mux(countinst1 <= countInstThresh1.U - 1.U, countinst1 + 1.U, countInstThresh1.U)
}.otherwise {
countinst3 := Mux(countinst3 >= 2.U, countinst3 - 2.U, 0.U)
countinst1 := Mux(countinst1 >= 2.U, countinst1 - 2.U, 0.U)
}
}
// only count atk during monitoring, and when an valid inst is present (prevent counting during wait)
val atk_primeinst_3 = Wire(Bool())
atk_primeinst_3 := (countinst3 === countInstThresh3.U) && in.isMonitoring && in.pack_has_valid
val atk_primeinst_1 = Wire(Bool())
atk_primeinst_1 := (countinst1 === countInstThresh1.U) && in.isMonitoring && in.pack_has_valid
// Attack Pattern
// Description: Mix pattern
// 1) count4, but decrement the counter by 1 if count_pack_valid reach thresh (0x800=2048)
// 1) count1, but decrement the counter by 1 if count_pack_valid reach thresh (0x800=2048)
// => TODO: perhaps adjust using dmemaddr / slow cycle count
val countPackValidThresh = 0x800L //2048
......@@ -145,7 +145,7 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
}
val atk_primemix_1 = Wire(Bool())
atk_primemix_1 := Mux(count_pack_valid_dec, !atk_prime_4, atk_prime_4) && in.isMonitoring // if dec=true, dec1 if atk0, dec0 if atk1; if dec=false, inc1 if atk1, inc0 if atk0.
atk_primemix_1 := Mux(count_pack_valid_dec, !atk_prime_1, atk_prime_1) && in.isMonitoring // if dec=true, dec1 if atk0, dec0 if atk1; if dec=false, inc1 if atk1, inc0 if atk0.
// Debug Signals
......@@ -155,10 +155,10 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
override def regmap(offset: Int) =
RegmapUtil.countEvent(atk_prime_4, in.resetCounters, offset + 0xC, "AtkPrime4") ++
RegmapUtil.countEvent(atk_prime_8, in.resetCounters, offset + 0x1C, "AtkPrime8") ++
RegmapUtil.countEvent(atk_primeinst_3, in.resetCounters, offset + 0x38, "AtkPrimeInst3") ++
RegmapUtil.countEventMax(atk_primemix_1, in.resetCounters, count_pack_valid_dec, offset + 0x50, "AtkPrimeMix1") ++
RegmapUtil.countEvent(atk_prime_1, in.resetCounters, offset + 0x0, "AtkPrime1") ++
RegmapUtil.countEvent(atk_prime_2, in.resetCounters, offset + 0x4, "AtkPrime2") ++
RegmapUtil.countEvent(atk_primeinst_1, in.resetCounters, offset + 0x10, "AtkPrimeInst1") ++
RegmapUtil.countEventMax(atk_primemix_1, in.resetCounters, count_pack_valid_dec, offset + 0x20, "AtkPrimeMix1") ++
Nil
}
......
......@@ -27,7 +27,7 @@ class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: Mata
// Attack Pattern
// Description: Jalr chained short instruction sequence.
// 3) jalr + jalr = c+1, other jump = c-2 (window size 1) (thresh proably 10 as in mispredict1) (window size 2 has high false positive rate)
// 1) jalr + jalr = c+1, other jump = c-2 (window size 1) (thresh proably 10 as in mispredict1) (window size 2 has high false positive rate)
// Note:
// 1) clockDiv will impact the pattern on window size, smaller clockDiv is perhaps more accurate.
// 2) ROP gadget destination is likely not in cache, so jalr will take longer and will probably be visible in different clockDiv pack. But this property may increase false positive as legal chain will now seems to be shorter.
......@@ -39,15 +39,15 @@ class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: Mata
npack_jalr := Cat(npack_jalr(npackSizeMax - 2, 0), in.pack_has_jalr)
}
val countjalr3 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr3")
val countjalr1 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr1")
when (in.resetCounters) {
countjalr3 := 0.U
countjalr1 := 0.U
}.otherwise {
when (in.isMonitoring && in.pack_has_jump) { // evaluate counter3 when jump (include pack_has_valid)
countjalr3 := Mux(in.pack_has_jalr && npack_jalr(0),
countjalr3 + 1.U,
Mux(countjalr3 >= 2.U, countjalr3 - 2.U, 0.U))
countjalr1 := Mux(in.pack_has_jalr && npack_jalr(0),
countjalr1 + 1.U,
Mux(countjalr1 >= 2.U, countjalr1 - 2.U, 0.U))
}
}
......@@ -64,7 +64,7 @@ class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: Mata
atk_rop_mispredict := false.B && in.isMonitoring // TODO
override def regmap(offset: Int) =
RegmapUtil.readValueMax(countjalr3, in.resetCounters, mp.counterWidth, offset + 0x10, "CountRopJalr3") ++
RegmapUtil.readValueMax(countjalr1, in.resetCounters, mp.counterWidth, offset, "CountRopJalr1") ++
Nil
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment