Commit 39a6efbf authored by Yuxiao Mao's avatar Yuxiao Mao
Browse files

Merge branch 'dev': add simple EventCounter connections and pattern, add logic...

Merge branch 'dev': add simple EventCounter connections and pattern, add logic for 1/2 slow cycle ROP detection, remove some not good patterns
parents df6b33e5 4303cc6a
diff --git a/src/main/scala/rocket/RocketCore.scala b/src/main/scala/rocket/RocketCore.scala
index c6d86ce5f..301c90c24 100644
index c6d86ce5f..99d98c9fc 100644
--- a/src/main/scala/rocket/RocketCore.scala
+++ b/src/main/scala/rocket/RocketCore.scala
@@ -846,6 +846,16 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p)
@@ -846,6 +846,20 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p)
io.rocc.cmd.bits.rs1 := wb_reg_wdata
io.rocc.cmd.bits.rs2 := wb_reg_rs2
......@@ -15,6 +15,10 @@ index c6d86ce5f..301c90c24 100644
+ // dmem access is at ex stage
+ io.matana.dmemaddr_data := encodeVirtualAddress(ex_rs(0), alu.io.adder_out) // copy from io.dmem.req.bits.addr
+ io.matana.dmemaddr_valid := ex_reg_valid && ex_ctrl.mem
+ // performance counters
+ io.matana.hpc.dcache_miss := io.dmem.perf.acquire
+ io.matana.hpc.dcache_release := io.dmem.perf.release
+ io.matana.hpc.branch_mispredict := take_pc_mem && mem_direction_misprediction
+
// gate the clock
val unpause = csr.io.time(rocketParams.lgPauseCycles-1, 0) === 0 || io.dmem.perf.release || take_pc
......
......@@ -12,4 +12,11 @@ class MatanaCoreIO extends Bundle {
val pc_valid = Output(Bool())
val dmemaddr_data = Output(UInt(40.W)) // used by DCacheMetadataReq
val dmemaddr_valid = Output(Bool())
val hpc = Output(new MatanaCoreHpcIO())
}
class MatanaCoreHpcIO extends Bundle {
val dcache_miss = Bool()
val dcache_release = Bool()
val branch_mispredict = Bool()
}
......@@ -19,6 +19,7 @@ case class MatanaParams(
dExample: Option[DetectExampleParams] = None, //Some(DetectExampleParams()),
dPatternTimer: Option[DetectPatternTimerParams] = Some(DetectPatternTimerParams()),
dPatternPrime: Option[DetectPatternPrimeParams] = Some(DetectPatternPrimeParams()),
dPatternCacheEvent: Option[DetectPatternCacheEventParams] = Some(DetectPatternCacheEventParams()),
dPatternLibAccess: Option[DetectPatternLibAccessParams] = Some(DetectPatternLibAccessParams()),
dPatternRop: Option[DetectPatternRopParams] = Some(DetectPatternRopParams()),
counterWidth: Int = 32
......
......@@ -25,12 +25,13 @@ class DetectExample()(implicit mp: MatanaParams) {
class DetectExampleInternal(params: DetectExampleParams)(implicit mp: MatanaParams) extends DetectExample {
// Logic that detect the attack
val atk_example = Wire(Bool())
val atk_example = Wire(Bool()) // or RegInit(false.B) to reduce circuit complexity
atk_example := in.some_atk && in.isMonitoring
// The regmap for bus
override def regmap(offset: Int) =
RegmapUtil.countEventThreshAlarm(atk_example, in.resetCounters, offset, "AtkExample")
RegmapUtil.countEventThreshAlarm(atk_example, in.resetCounters, offset, "AtkExample") ++
Nil
}
object DetectExample {
......
package matana
import chisel3._
import freechips.rocketchip.regmapper.RegField
import freechips.rocketchip.tile.MatanaCoreHpcIO
case class DetectPatternCacheEventParams(
)
class DetectPatternCacheEventIn()(implicit val mp: MatanaParams) extends Bundle {
val resetCounters = Bool()
val isMonitoring = Bool()
val pack_has_valid = Bool()
val vec_hpc = Vec(mp.clockDiv, new MatanaCoreHpcIO())
}
class DetectPatternCacheEvent()(implicit mp: MatanaParams) {
val in = Wire(new DetectPatternCacheEventIn())
def regmap(offset: Int): Seq[(Int, Seq[RegField])] = Nil
}
class DetectPatternCacheEventInternal(params: DetectPatternCacheEventParams)(implicit mp: MatanaParams) extends DetectPatternCacheEvent {
// Attack pattern
// Description: a high rate of Data cache miss / instruction
// -) Count pack_has_dcache_miss event (as count the exact value is difficult)
// N) = atk_dcache_miss, but decrement by 1 if count_pack_valid reach thresh X
val pack_has_dcache_miss = RegInit(false.B).suggestName("dpdcacheevent_pack_has_dcache_miss")
pack_has_dcache_miss := in.vec_hpc.map(_.dcache_miss).reduce(_||_)
val atk_dcache_miss = RegInit(false.B)
atk_dcache_miss := pack_has_dcache_miss && in.isMonitoring
// Count pack valid for decrement counters
val count_pack_valid = RegInit(0.U(mp.counterWidth.W)).suggestName("dpdcacheevent_count_pack_valid")
when (in.resetCounters) {
count_pack_valid := 0.U
}.elsewhen (in.isMonitoring) {
count_pack_valid := count_pack_valid + in.pack_has_valid
}
val countPackValidThreshLengthMinus1_1 = 2 // Thresh = 2^(2+1) = 8 = 0x8 (0b000)
val countPackValidThreshLengthMinus1_2 = 4 // Thresh = 2^(4+1) = 32 = 0x20
val countPackValidThreshLengthMinus1_3 = 6 // Thresh = 2^(6+1) = 128 = 0x80
val countPackValidThreshLengthMinus1_4 = 8 // Thresh = 2^(8+1) = 512 = 0x200
val countPackValidThreshLengthMinus1_5 = 10 // Thresh = 2^(10+1) = 2048 = 0x800
val atk_dcache_miss_dec1 = RegInit(false.B)
atk_dcache_miss_dec1 := count_pack_valid(countPackValidThreshLengthMinus1_1, 0) === 0.U
val atk_dcache_miss_dec2 = RegInit(false.B)
atk_dcache_miss_dec2 := count_pack_valid(countPackValidThreshLengthMinus1_2, 0) === 0.U
val atk_dcache_miss_dec3 = RegInit(false.B)
atk_dcache_miss_dec3 := count_pack_valid(countPackValidThreshLengthMinus1_3, 0) === 0.U
val atk_dcache_miss_dec4 = RegInit(false.B)
atk_dcache_miss_dec4 := count_pack_valid(countPackValidThreshLengthMinus1_4, 0) === 0.U
val atk_dcache_miss_dec5 = RegInit(false.B)
atk_dcache_miss_dec5 := count_pack_valid(countPackValidThreshLengthMinus1_5, 0) === 0.U
override def regmap(offset: Int) =
RegmapUtil.countEvent(atk_dcache_miss, in.resetCounters, offset, "AtkDCacheEventMiss") ++
RegmapUtil.countEventMax(atk_dcache_miss, in.resetCounters, atk_dcache_miss_dec1, offset + 0x10, "AtkDCacheEventMiss1") ++
RegmapUtil.countEventMax(atk_dcache_miss, in.resetCounters, atk_dcache_miss_dec2, offset + 0x18, "AtkDCacheEventMiss2") ++
RegmapUtil.countEventMax(atk_dcache_miss, in.resetCounters, atk_dcache_miss_dec3, offset + 0x20, "AtkDCacheEventMiss3") ++
RegmapUtil.countEventMax(atk_dcache_miss, in.resetCounters, atk_dcache_miss_dec4, offset + 0x28, "AtkDCacheEventMiss4") ++
RegmapUtil.countEventMax(atk_dcache_miss, in.resetCounters, atk_dcache_miss_dec5, offset + 0x30, "AtkDCacheEventMiss5") ++
Nil
}
object DetectPatternCacheEvent {
def apply()(implicit mp: MatanaParams) : DetectPatternCacheEvent = {
mp.dPatternCacheEvent match {
case Some(params) => new DetectPatternCacheEventInternal(params)
case None => new DetectPatternCacheEvent()
}
}
}
......@@ -40,26 +40,19 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
// fdc42783 lw a5,-36(s0)
// The use of LD / LW is important in order to load an 32-bit+ address value, if use LH / LB then futher manipulation are needed
// Use of this pattern in addition of Timer pattern:
// 1) Proof of concept of memory access based detection
// 2) When the attacker use other means to time measurement (other than instruction, for example)
// 3) Detection of attacks which did not need time measurement (rowhammer, performance degratation attack)
// Attack pattern
// Description: Multiples access to DCache, each time in same set with differents tag
// 1) diff tag c+1, same tag c-1 (thresh 8)
// 2) diff tag c+1, same tag c-1 (thresh 4)
// 3) diff tag same set c+2, diff tag diff set c+1, same tag same set c-2, same tag diff set c-1 (thresh 16)
// 4) diff tag same set c+2, diff tag diff set c+1, same tag same set c-2, same tag diff set c-1 (thresh 8)
// 5) diff tag c+1, same tag c-1 (thresh 4, if thresh then 0)
// 6) diff tag c+1, same tag c-2 (thresh 4)
// 7) diff tag same set c+2, diff tag diff set c+1, same tag same set c-2, same tag diff set c-1 (thresh 8, if thresh then 0)
// 8) diff tag same set c+2, diff tag diff set c+1, same tag same set c-4, same tag diff set c-2 (thresh 8)
// 1) diff tag same set c+2, diff tag diff set c+1, same tag same set c-2, same tag diff set c-1 (thresh 8)
// 2) diff tag same set c+2, diff tag diff set c+1, same tag same set c-4, same tag diff set c-2 (thresh 8)
val countThresh1 = 8
val countThresh2 = 4
val countThresh3 = 16
val countThresh4 = 8
val countThresh5 = 4
val countThresh6 = 4
val countThresh7 = 8
val countThresh8 = 8
val countThresh2 = 8
val dmemtag = Wire(UInt((mp.addrWidth - params.tagAddrLSB).W))
dmemtag := in.dmemaddr_data(mp.addrWidth - 1, params.tagAddrLSB)
......@@ -71,15 +64,9 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
val count1 = RegInit(0.U(log2Ceil(countThresh1 + 1).W)).suggestName("dpprime_count1")
val count2 = RegInit(0.U(log2Ceil(countThresh2 + 1).W)).suggestName("dpprime_count2")
val count3 = RegInit(0.U(log2Ceil(countThresh3 + 1).W)).suggestName("dpprime_count3")
val count4 = RegInit(0.U(log2Ceil(countThresh4 + 1).W)).suggestName("dpprime_count4")
val count5 = RegInit(0.U(log2Ceil(countThresh5 + 1).W)).suggestName("dpprime_count5")
val count6 = RegInit(0.U(log2Ceil(countThresh6 + 1).W)).suggestName("dpprime_count6")
val count7 = RegInit(0.U(log2Ceil(countThresh7 + 1).W)).suggestName("dpprime_count7")
val count8 = RegInit(0.U(log2Ceil(countThresh8 + 1).W)).suggestName("dpprime_count8")
// Update saved_dmem* array in FIFO way
when (in.dmemaddr_valid) {
when (in.dmemaddr_valid && in.isMonitoring) {
when (saved_dmemtag.map(_ =/= dmemtag).reduce(_&&_)) { // new tag
when (saved_dmemset.map(_ =/= dmemset).reduce(_&&_)) { // new set
// Update saved_dmemset, newer at 0
......@@ -87,27 +74,17 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
last := data
last
}
count3 := Mux(count3 <= countThresh3.U - 1.U, count3 + 1.U, countThresh3.U)
count4 := Mux(count4 <= countThresh4.U - 1.U, count4 + 1.U, countThresh4.U)
count7 := Mux(count7 <= countThresh7.U - 1.U, count7 + 1.U, 1.U)
count8 := Mux(count8 <= countThresh8.U - 1.U, count8 + 1.U, countThresh8.U)
count1 := Mux(count1 <= countThresh1.U - 1.U, count1 + 1.U, countThresh1.U)
count2 := Mux(count2 <= countThresh2.U - 1.U, count2 + 1.U, countThresh2.U)
}.otherwise { // old set
count3 := Mux(count3 <= countThresh3.U - 2.U, count3 + 2.U, countThresh3.U)
count4 := Mux(count4 <= countThresh4.U - 2.U, count4 + 2.U, countThresh4.U)
count7 := Mux(count7 =/= countThresh7.U,
Mux(count7 <= countThresh7.U - 2.U, count7 + 2.U, countThresh7.U),
2.U)
count8 := Mux(count8 <= countThresh8.U - 2.U, count8 + 2.U, countThresh8.U)
count1 := Mux(count1 <= countThresh1.U - 2.U, count1 + 2.U, countThresh1.U)
count2 := Mux(count2 <= countThresh2.U - 2.U, count2 + 2.U, countThresh2.U)
}
// Update saved_dmemtag, newer at 0
saved_dmemtag.foldLeft(dmemtag) { (data, last) =>
last := data
last
}
count1 := Mux(count1 <= countThresh1.U - 1.U, count1 + 1.U, countThresh1.U)
count2 := Mux(count2 <= countThresh2.U - 1.U, count2 + 1.U, countThresh2.U)
count5 := Mux(count5 <= countThresh5.U - 1.U, count5 + 1.U, 1.U)
count6 := Mux(count6 <= countThresh6.U - 1.U, count6 + 1.U, countThresh6.U)
}.otherwise { // old tag
when (saved_dmemset.map(_ =/= dmemtag).reduce(_&&_)) { // new set
// Update saved_dmemset, newer at 0
......@@ -115,174 +92,62 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
last := data
last
}
count3 := Mux(count3 >= 1.U, count3 - 1.U, 0.U)
count4 := Mux(count4 >= 1.U, count4 - 1.U, 0.U)
count7 := Mux(count7 =/= countThresh7.U,
Mux(count7 >= 1.U, count7 - 1.U, 0.U),
0.U)
count8 := Mux(count8 >= 2.U, count8 - 2.U, 0.U)
count1 := Mux(count1 >= 1.U, count1 - 1.U, 0.U)
count2 := Mux(count2 >= 2.U, count2 - 2.U, 0.U)
}.otherwise { // old set
count3 := Mux(count3 >= 2.U, count3 - 2.U, 0.U)
count4 := Mux(count4 >= 2.U, count4 - 2.U, 0.U)
count7 := Mux(count7 =/= countThresh7.U,
Mux(count7 >= 2.U, count7 - 2.U, 0.U),
0.U)
count8 := Mux(count8 >= 4.U, count8 - 4.U, 0.U)
count1 := Mux(count1 >= 2.U, count1 - 2.U, 0.U)
count2 := Mux(count2 >= 4.U, count2 - 4.U, 0.U)
}
count1 := Mux(count1 >= 1.U, count1 - 1.U, 0.U)
count2 := Mux(count2 >= 1.U, count2 - 1.U, 0.U)
count5 := Mux(count5 =/= countThresh5.U,
Mux(count5 >= 1.U, count5 - 1.U, 0.U),
0.U)
count6 := Mux(count6 >= 2.U, count6 - 2.U, 0.U)
}
}
// only count atk during monitoring, and when an dmemaddr is present (prevent counting during wait)
val atk_prime_1 = Wire(Bool())
val atk_prime_1 = RegInit(false.B)
atk_prime_1 := (count1 === countThresh1.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_2 = Wire(Bool())
val atk_prime_2 = RegInit(false.B)
atk_prime_2 := (count2 === countThresh2.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_3 = Wire(Bool())
atk_prime_3 := (count3 === countThresh3.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_4 = Wire(Bool())
atk_prime_4 := (count4 === countThresh4.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_5 = Wire(Bool())
atk_prime_5 := (count5 === countThresh5.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_6 = Wire(Bool())
atk_prime_6 := (count6 === countThresh6.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_7 = Wire(Bool())
atk_prime_7 := (count7 === countThresh7.U) && in.isMonitoring && in.dmemaddr_valid
val atk_prime_8 = Wire(Bool())
atk_prime_8 := (count8 === countThresh8.U) && in.isMonitoring && in.dmemaddr_valid
// Attack Pattern
// Description: all valid pack contains an word size load
// 1) valid load c+1, valid not load c-1 (thresh 32)
// 2) valid load c+1, valid not load c-2 (thresh 8) (maybe only apply to high clkDiv, as for -O0 there is store with load. we suppose if store to the same addr, the next load can be conclude in the same slow cycle)
// 3) valid load c+1, valid not load c-2 (thresh 16)
// 4) valid load c+1, valid not load c-1 (thresh 16)
// Attack Pattern (instruction-based)
// Description: all valid pack contains an word size load. For -O0 there is store with load, we suppose if store to the same addr, the next load can be conclude in the same slow cycle.
// 1) valid load c+1, valid not load c-2 (thresh 16)
val countInstThresh1 = 32
val countInstThresh2 = 8
val countInstThresh3 = 16
val countInstThresh4 = 16
val countInstThresh1 = 16
val countinst1 = RegInit(0.U(log2Ceil(countInstThresh1 + 1).W)).suggestName("dpprime_countinst1")
val countinst2 = RegInit(0.U(log2Ceil(countInstThresh2 + 1).W)).suggestName("dpprime_countinst2")
val countinst3 = RegInit(0.U(log2Ceil(countInstThresh3 + 1).W)).suggestName("dpprime_countinst3")
val countinst4 = RegInit(0.U(log2Ceil(countInstThresh4 + 1).W)).suggestName("dpprime_countinst4")
when (in.pack_has_valid) {
when (in.pack_has_valid && in.isMonitoring) {
when (in.pack_has_loadword) {
countinst1 := Mux(countinst1 <= countInstThresh1.U - 1.U, countinst1 + 1.U, countInstThresh1.U)
countinst2 := Mux(countinst2 <= countInstThresh2.U - 1.U, countinst2 + 1.U, countInstThresh2.U)
countinst3 := Mux(countinst3 <= countInstThresh3.U - 1.U, countinst3 + 1.U, countInstThresh3.U)
countinst4 := Mux(countinst4 <= countInstThresh4.U - 1.U, countinst4 + 1.U, countInstThresh4.U)
}.otherwise {
countinst1 := Mux(countinst1 >= 1.U, countinst1 - 1.U, 0.U)
countinst2 := Mux(countinst2 >= 2.U, countinst2 - 2.U, 0.U)
countinst3 := Mux(countinst3 >= 2.U, countinst3 - 2.U, 0.U)
countinst4 := Mux(countinst4 >= 1.U, countinst4 - 1.U, 0.U)
countinst1 := Mux(countinst1 >= 2.U, countinst1 - 2.U, 0.U)
}
}
// only count atk during monitoring, and when an valid inst is present (prevent counting during wait)
val atk_primeinst_1 = Wire(Bool())
val atk_primeinst_1 = RegInit(false.B)
atk_primeinst_1 := (countinst1 === countInstThresh1.U) && in.isMonitoring && in.pack_has_valid
val atk_primeinst_2 = Wire(Bool())
atk_primeinst_2 := (countinst2 === countInstThresh2.U) && in.isMonitoring && in.pack_has_valid
val atk_primeinst_3 = Wire(Bool())
atk_primeinst_3 := (countinst3 === countInstThresh3.U) && in.isMonitoring && in.pack_has_valid
val atk_primeinst_4 = Wire(Bool())
atk_primeinst_4 := (countinst4 === countInstThresh4.U) && in.isMonitoring && in.pack_has_valid
// Attack Pattern
// Description: Mix pattern
// 1) count4, but decrement the counter by 1 if count_pack_valid reach thresh (0x800)
// => TODO: perhaps adjust using dmemaddr / slow cycle count
// 2) count4+countinst4: if 2 atk are recognized in 4 slow cycles
// 3) count4+countinst4: (similar to mix2 but use 1 counter for both count4 and countinst4)
// count4 max && countinst4 max: c+8
// count4 max || countinst4 max: c+4
// else : c-1 (as c4/ci4 no reset 0, this means no attack sign)
// (max 8, thresh 5)
// 4) count4+countinst4: (similar to mix3 but priorize count4)
// count4 max && countinst4 max: c+12
// count4 max : c+8
// countinst4 max: c+4
// else : c-1
// (max 12, thresh 9)
// 5) count4+countinst4: (only take when c4/ci4=max or c4/ci4=0, if thresh then 0 to prevent cmix remains max) TODO?
// 1) = atk_prime_1, but decrement the counter by 1 if count_pack_valid reach thresh 0x800 (2048) // TODO: perhaps adjust using dmemaddr / slow cycle count
// 2) = atk_prime_1, but decrement the counter by 1 if count_pack_valid reach thresh 0x400 (1024)
val countPackValidThresh = 0x800L // 2048
// Count pack valid for decrement counters
val count_pack_valid = RegInit(0.U(mp.counterWidth.W)).suggestName("dpprime_count_pack_valid")
val count_pack_valid_dec = RegInit(false.B).suggestName("dpprime_count_pack_valid_dec")
when (in.resetCounters) {
count_pack_valid := 0.U
count_pack_valid_dec := false.B
}.otherwise {
when (count_pack_valid =/= countPackValidThresh.U) { // count < thresh
count_pack_valid := count_pack_valid + in.pack_has_valid
count_pack_valid_dec := false.B
}.otherwise { // count reach threshold
count_pack_valid := 0.U
count_pack_valid_dec := true.B
}
}.elsewhen (in.isMonitoring) {
count_pack_valid := count_pack_valid + in.pack_has_valid
}
val atk_primemix_1 = Wire(Bool())
atk_primemix_1 := Mux(count_pack_valid_dec, !atk_prime_4, atk_prime_4) && in.isMonitoring // if dec=true, dec1 if atk0, dec0 if atk1; if dec=false, inc1 if atk1, inc0 if atk0.
val countMix2Thresh = 4
val countmix2_1 = RegInit(0.U(log2Ceil(countMix2Thresh + 1).W)).suggestName("dpprime_countmix2_1")
val countmix2_2 = RegInit(0.U(log2Ceil(countMix2Thresh + 1).W)).suggestName("dpprime_countmix2_2")
when (in.dmemaddr_valid) { // modify counter only when valid
countmix2_1 := Mux(count4 === countThresh4.U,
countMix2Thresh.U, // if count4 max, set countmix2_1 to max
Mux(countmix2_1 <= 1.U, 0.U, countmix2_1 - 1.U)) // else decrement
}
when (in.pack_has_valid) { // modify counter only when valid
countmix2_2 := Mux(countinst4 === countInstThresh4.U,
countMix2Thresh.U, // if countinst4 max, set countmix2_2 to max
Mux(countmix2_2 <= 1.U, 0.U, countmix2_2 - 1.U)) // else decrement
}
val countPackValidThreshLengthMinus1_1 = 10 // Thresh = 2^(10+1) = 2048 = 0x800
val countPackValidThreshLengthMinus1_2 = 9 // Thresh = 2^(9+1) = 1024 = 0x400
val atk_primemix_2 = Wire(Bool())
atk_primemix_2 := countmix2_1.orR && countmix2_2.orR && in.isMonitoring
val countMix3Max = 8
val countMix3Thresh = 5
val countMix4Max = 12
val countMix4Thresh = 9
val countmix3 = RegInit(0.U(log2Ceil(countMix3Max + 1).W)).suggestName("dpprime_countmix3")
val countmix4 = RegInit(0.U(log2Ceil(countMix4Max + 1).W)).suggestName("dpprime_countmix4")
val countmix3state = Wire(UInt(2.W))
countmix3state := Cat((count4 === countThresh4.U) && in.dmemaddr_valid,
(countinst4 === countInstThresh4.U) && in.pack_has_valid)
when (countmix3state.andR) {
countmix3 := Mux(countmix3 <= countMix3Max.U - 8.U, countmix3 + 8.U, countMix3Max.U)
countmix4 := Mux(countmix4 <= countMix4Max.U - 12.U, countmix4 + 12.U, countMix4Max.U)
}.elsewhen (countmix3state(0) === true.B) {
countmix3 := Mux(countmix3 <= countMix3Max.U - 4.U, countmix3 + 4.U, countMix3Max.U)
countmix4 := Mux(countmix4 <= countMix4Max.U - 8.U, countmix4 + 8.U, countMix4Max.U)
}.elsewhen (countmix3state(1) === true.B) {
countmix3 := Mux(countmix3 <= countMix3Max.U - 4.U, countmix3 + 4.U, countMix3Max.U)
countmix4 := Mux(countmix4 <= countMix4Max.U - 4.U, countmix4 + 4.U, countMix4Max.U)
}.elsewhen (in.dmemaddr_valid || in.pack_has_valid) { // has valid input, should decrement
countmix3 := Mux(countmix3 >= 1.U, countmix3 - 1.U, 0.U)
countmix4 := Mux(countmix4 >= 1.U, countmix4 - 1.U, 0.U)
}.otherwise { // no valid package, remains old value
countmix3 := countmix3
countmix4 := countmix4
}
val atk_primemix_3 = Wire(Bool())
atk_primemix_3 := (countmix3 >= countMix3Thresh.U) && in.isMonitoring
val atk_primemix_4 = Wire(Bool())
atk_primemix_4 := (countmix4 >= countMix4Thresh.U) && in.isMonitoring
val atk_prime_mix1_dec = RegInit(false.B)
atk_prime_mix1_dec := count_pack_valid(countPackValidThreshLengthMinus1_1, 0) === 0.U
val atk_prime_mix2_dec = RegInit(false.B)
atk_prime_mix2_dec := count_pack_valid(countPackValidThreshLengthMinus1_2, 0) === 0.U
// Debug Signals
......@@ -290,24 +155,13 @@ class DetectPatternPrimeInternal(params: DetectPatternPrimeParams)(implicit mp:
val saved_addr0 = RegInit(0.U(mp.addrWidth.W)).suggestName("saved_addr0")
saved_addr0 := Cat(saved_dmemtag(0), Cat(saved_dmemset(0), zeros_line))
override def regmap(offset: Int) =
RegmapUtil.countEvent(atk_prime_1, in.resetCounters, offset + 0x0, "AtkPrime1") ++
RegmapUtil.countEvent(atk_prime_2, in.resetCounters, offset + 0x4, "AtkPrime2") ++
RegmapUtil.countEvent(atk_prime_3, in.resetCounters, offset + 0x8, "AtkPrime3") ++
RegmapUtil.countEvent(atk_prime_4, in.resetCounters, offset + 0xC, "AtkPrime4") ++
RegmapUtil.countEvent(atk_prime_5, in.resetCounters, offset + 0x10, "AtkPrime5") ++
RegmapUtil.countEvent(atk_prime_6, in.resetCounters, offset + 0x14, "AtkPrime6") ++
RegmapUtil.countEvent(atk_prime_7, in.resetCounters, offset + 0x18, "AtkPrime7") ++
RegmapUtil.countEvent(atk_prime_8, in.resetCounters, offset + 0x1C, "AtkPrime8") ++
RegmapUtil.countEvent(atk_primeinst_1, in.resetCounters, offset + 0x30, "AtkPrimeInst1") ++
RegmapUtil.countEvent(atk_primeinst_2, in.resetCounters, offset + 0x34, "AtkPrimeInst2") ++
RegmapUtil.countEvent(atk_primeinst_3, in.resetCounters, offset + 0x38, "AtkPrimeInst3") ++
RegmapUtil.countEvent(atk_primeinst_4, in.resetCounters, offset + 0x3C, "AtkPrimeInst4") ++
RegmapUtil.countEventMax(atk_primemix_1, in.resetCounters, count_pack_valid_dec, offset + 0x50, "AtkPrimeMix1") ++
RegmapUtil.countEvent(atk_primemix_2, in.resetCounters, offset + 0x64, "AtkPrimeMix2") ++
RegmapUtil.countEvent(atk_primemix_3, in.resetCounters, offset + 0x68, "AtkPrimeMix3") ++
RegmapUtil.countEvent(atk_primemix_4, in.resetCounters, offset + 0x6C, "AtkPrimeMix4")
RegmapUtil.countEvent(atk_primeinst_1, in.resetCounters, offset + 0x10, "AtkPrimeInst1") ++
RegmapUtil.countEventMax(atk_prime_1, in.resetCounters, atk_prime_mix1_dec, offset + 0x20, "AtkPrimeMix1") ++
RegmapUtil.countEventMax(atk_prime_1, in.resetCounters, atk_prime_mix2_dec, offset + 0x30, "AtkPrimeMix2") ++
Nil
}
object DetectPatternPrime {
......
......@@ -14,7 +14,7 @@ class DetectPatternRopIn()(implicit val mp: MatanaParams) extends Bundle {
val isMonitoring = Bool()
val pack_has_valid = Bool()
val pack_has_jalr = Bool()
val pack_has_jump = Bool()
val pack_jalr = Vec(mp.clockDiv, Bool())
//val pack_has_mispredict = Bool() //TODO
}
......@@ -26,48 +26,51 @@ class DetectPatternRop()(implicit mp: MatanaParams) {
class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: MatanaParams) extends DetectPatternRop {
// Attack Pattern
// Description: Jalr chained short instruction sequence
// 1) jalr + jalr = c+1, other + jalr = c-2 (window size 1)(thresh proably 10 as in mispredict1)
// 2) jalr + jalr = c+1, other + jalr = c-2 (window size 2)
// 3) jalr + jalr = c+1, other jump = c-2 (window size 1)
// 4) jalr + jalr = c+1, other jump = c-2 (window size 2)
// Description: Jalr chained short instruction sequence.
// 1) jalr + jalr = c+1, other jalr = c-2 (window size 1+1) (thresh proably 10 as in mispredict1) (window size 2+1 has high false positive rate)
// 2) jalr + jalr = c+1, other jalr = c-2 (window size 1+1, step 1/2)
// Note:
// 1) clockDiv will impact the pattern on window size, smaller clockDiv is perhaps more accurate.
// 2) ROP gadget destination is likely not in cache, so jalr will take longer and will probably be visible in different clockDiv pack. But this property may increase false positive as legal chain will now seems to be shorter.
// 3) 1 pattern is a complet attack, threshold, if exists should be set to 1
val npackSizeMax = 2
val npackSizeMax = 2
val npack_jalr = RegInit(0.U(npackSizeMax.W)).suggestName("npack_jalr")
when (in.pack_has_valid) {
npack_jalr := Cat(npack_jalr(npackSizeMax - 2, 0), in.pack_has_jalr)
}
val nstepNum = 2 // step 1/2 of slow cycle
require(mp.clockDiv % nstepNum == 0)
require(nstepNum == 2) // Only support value 2 for now, else needs modify the use of dpack_jalr in when
val nstepSize: Int = mp.clockDiv / nstepNum
val dpack_jalr = Wire(Vec(nstepNum, Bool())).suggestName("dpack_jalr")
for (i <- 0 until nstepNum) {
dpack_jalr(i) := in.pack_jalr.zipWithIndex.filter{ case (data, index) =>
((index <= nstepSize*(i+1)-1) && (index >= nstepSize*i))
}.map(_._1).reduce(_||_)
}
val countjalr1 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr1")
val countjalr2 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr2")
val countjalr3 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr3")
val countjalr4 = RegInit(0.U(mp.counterWidth.W)).suggestName("dprop_countjalr4")
when (in.resetCounters) {
countjalr1 := 0.U
countjalr2 := 0.U
countjalr3 := 0.U
countjalr4 := 0.U
}.otherwise {
when (in.isMonitoring && in.pack_has_jalr) { // evaluate counter1,2 when jalr (include pack_has_valid)
countjalr1 := Mux(npack_jalr(0),
countjalr1 + 1.U,
Mux(countjalr1 >= 2.U, countjalr1 - 2.U, 0.U))
countjalr2 := Mux(npack_jalr(1,0).orR,
countjalr2 + 1.U,
Mux(countjalr2 >= 2.U, countjalr2 - 2.U, 0.U))
}
when (in.isMonitoring && in.pack_has_jump) { // evaluate counter3,4 when jump (include pack_has_valid)
countjalr3 := Mux(in.pack_has_jalr && npack_jalr(0),
countjalr3 + 1.U,
Mux(countjalr3 >= 2.U, countjalr3 - 2.U, 0.U))
countjalr4 := Mux(in.pack_has_jalr && npack_jalr(1,0).orR,
countjalr4 + 1.U,
Mux(countjalr4 >= 2.U, countjalr4 - 2.U, 0.U))
when (in.isMonitoring && in.pack_has_jalr) { // evaluate counter when jalr (include pack_has_valid)
when (npack_jalr(0)) { // has jalr adjacent
countjalr1 := countjalr1 + 1.U
}.otherwise { // no jalr adjacent
countjalr1 := Mux(countjalr1 >= 2.U, countjalr1 - 2.U, 0.U)
}
when (npack_jalr(0) && dpack_jalr.reduce(_&&_)) { // has jalr in adjacent window and step
countjalr2 := countjalr2 + 2.U
}.elsewhen (npack_jalr(0) || dpack_jalr.reduce(_&&_)) { // has jalr in adjacent window or step
countjalr2 := countjalr2 + 1.U
}.otherwise { // no jalr adjacent
countjalr2 := Mux(countjalr2 >= 2.U, countjalr2 - 2.U, 0.U)
}
}
}
......@@ -84,10 +87,8 @@ class DetectPatternRopInternal(params: DetectPatternRopParams)(implicit mp: Mata
atk_rop_mispredict := false.B && in.isMonitoring // TODO
override def regmap(offset: Int) =
RegmapUtil.readValueMax(countjalr1, in.resetCounters, mp.counterWidth, offset + 0x0, "CountRopJalr1") ++
RegmapUtil.readValueMax(countjalr2, in.resetCounters, mp.counterWidth, offset + 0x8, "CountRopJalr2") ++
RegmapUtil.readValueMax(countjalr3, in.resetCounters, mp.counterWidth, offset + 0x10, "CountRopJalr3") ++
RegmapUtil.readValueMax(countjalr4, in.resetCounters, mp.counterWidth, offset + 0x18, "CountRopJalr4") ++
RegmapUtil.readValueMax(countjalr1, in.resetCounters, mp.counterWidth, offset, "CountRopJalr1") ++
RegmapUtil.readValueMax(countjalr2, in.resetCounters, mp.counterWidth, offset + 0x10, "CountRopJalr2") ++
Nil
}
......
......@@ -35,15 +35,17 @@ class DetectPatternTimerInternal(params: DetectPatternTimerParams)(implicit mp:
val npackSize: Int = params.npackSizeInFastCycles / mp.clockDiv
val npack_timer = RegInit(0.U(npackSize.W)).suggestName("npack_timer")
val npack_timer_orR = RegInit(false.B)
when (in.pack_has_valid) {
npack_timer := Cat(npack_timer(npackSize-2, 0), in.pack_has_timer)
npack_timer_orR := npack_timer(npackSize-2, 0).orR || in.pack_has_timer
}
val atk_timer_timer = Wire(Bool())
atk_timer_timer := npack_timer.orR && in.pack_has_timer && in.isMonitoring
atk_timer_timer := npack_timer_orR && in.pack_has_timer && in.isMonitoring