Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

Commit

Permalink
Merge branch 'master' into readmem-annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosedp committed Mar 9, 2021
2 parents 31d4849 + 29d57a6 commit 9ab5d65
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

package firrtl.backends.experimental.smt.random

import firrtl.Utils.{isLiteral, kind, BoolType}
import firrtl.WrappedExpression.{we, weq}
import firrtl.Utils.{isLiteral, BoolType}
import firrtl._
import firrtl.annotations.NoTargetAnnotation
import firrtl.backends.experimental.smt._
Expand All @@ -14,7 +13,7 @@ import firrtl.passes.memlib.AnalysisUtils.Connects
import firrtl.passes.memlib.InferReadWritePass.checkComplement
import firrtl.passes.memlib.{AnalysisUtils, InferReadWritePass, VerilogMemDelays}
import firrtl.stage.Forms
import firrtl.transforms.{ConstantPropagation, RemoveWires}
import firrtl.transforms.RemoveWires

import scala.collection.mutable

Expand Down Expand Up @@ -111,7 +110,10 @@ private class InstrumentMems(

private def onMem(m: DefMemory): Statement = {
// collect wire and random statement defines
val declarations = mutable.ListBuffer[Statement]()
implicit val declarations: mutable.ListBuffer[Statement] = mutable.ListBuffer[Statement]()

// cache for the expressions of memory inputs
implicit val cache: mutable.HashMap[String, Expression] = mutable.HashMap[String, Expression]()

// only for non power of 2 memories do we have to worry about reading or writing out of bounds
val canBeOutOfBounds = !isPow2(m.depth)
Expand All @@ -132,51 +134,31 @@ private class InstrumentMems(
val maskRef = memPortField(m, write, "mask")

val prods = getProductTerms(enRef) ++ getProductTerms(maskRef)
val expr = Utils.and(readInput(m.info, enRef), readInput(m.info, maskRef))

// if we can have write/write conflicts, we are going to change the mask and enable pins
val expr = if (canHaveWriteWriteConflicts) {
val maskIsOne = isTrue(connects(maskRef.serialize))
// if the mask is connected to a constant true, we do not need to consider it, this is a common case
if (maskIsOne) {
val enWire = disconnectInput(m.info, enRef)
declarations += enWire
Reference(enWire)
} else {
val maskWire = disconnectInput(m.info, maskRef)
val enWire = disconnectInput(m.info, enRef)
// create a node for the conjunction
val nodeName = namespace.newName(s"${m.name}_${write}_mask_and_en")
val node = DefNode(m.info, nodeName, Utils.and(Reference(maskWire), Reference(enWire)))
declarations ++= List(maskWire, enWire, node)
Reference(node)
}
} else {
Utils.and(enRef, maskRef)
}
(expr, prods)
}

// implement the three undefined read behaviors
m.readers.foreach { read =>
// many memories have their read enable hard wired to true
val canBeDisabled = !isTrue(memPortField(m, read, "en"))
val readEn = if (canBeDisabled) memPortField(m, read, "en") else Utils.True()
val addr = memPortField(m, read, "addr")
val canBeDisabled = !isTrue(readInput(m, read, "en"))
val readEn = if (canBeDisabled) readInput(m, read, "en") else Utils.True()

// collect signals that would lead to a randomization
var doRand = List[Expression]()

// randomize the read value when the address is out of bounds
if (canBeOutOfBounds && opt.randomizeOutOfBoundsRead) {
val addr = readInput(m, read, "addr")
val cond = Utils.and(readEn, Utils.not(isInBounds(m.depth, addr)))
val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_oob"), cond)
declarations += node
doRand = Reference(node) +: doRand
}

if (readWriteUndefined && opt.randomizeReadWriteConflicts) {
val (cond, d) = readWriteConflict(m, read, writeEn)
declarations ++= d
val cond = readWriteConflict(m, read, writeEn)
val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_rwc"), cond)
declarations += node
doRand = Reference(node) +: doRand
Expand All @@ -203,7 +185,7 @@ private class InstrumentMems(
}
val doRandSignal = if (m.readLatency == 0) { doRandNode }
else {
val clock = memPortField(m, read, "clk")
val clock = readInput(m, read, "clk")
val (signal, regDecls) = pipeline(m.info, clock, doRandName, doRandNode, m.readLatency)
declarations ++= regDecls
signal
Expand All @@ -217,7 +199,7 @@ private class InstrumentMems(

// create a source of randomness and connect the new wire either to the actual data port or to the random value
val randName = namespace.newName(s"${m.name}_${read}_rand_data")
val random = DefRandom(m.info, randName, m.dataType, Some(memPortField(m, read, "clk")), doRandSignal)
val random = DefRandom(m.info, randName, m.dataType, Some(readInput(m, read, "clk")), doRandSignal)
declarations += random
val data = Utils.mux(doRandSignal, Reference(random), dataRef)
newStmts.append(Connect(m.info, Reference(dataWire), data))
Expand All @@ -226,16 +208,16 @@ private class InstrumentMems(

// write
if (opt.randomizeWriteWriteConflicts) {
declarations ++= writeWriteConflicts(m, writeEn)
writeWriteConflicts(m, writeEn)
}

// add an assertion that if the write is taking place, then the address must be in range
if (canBeOutOfBounds && opt.assertNoOutOfBoundsWrites) {
m.writers.zip(writeEn).foreach {
case (write, (combinedEn, _)) =>
val addr = memPortField(m, write, "addr")
val addr = readInput(m, write, "addr")
val cond = Utils.implies(combinedEn, isInBounds(m.depth, addr))
val clk = memPortField(m, write, "clk")
val clk = readInput(m, write, "clk")
val a = Verification(Formal.Assert, m.info, clk, cond, Utils.True(), StringLit("out of bounds read"))
newStmts.append(a)
}
Expand Down Expand Up @@ -268,11 +250,13 @@ private class InstrumentMems(
m: DefMemory,
read: String,
writeEn: Seq[(Expression, ProdTerms)]
): (Expression, Seq[Statement]) = {
if (m.writers.isEmpty) return (Utils.False(), List())
val declarations = mutable.ListBuffer[Statement]()
)(
implicit cache: mutable.HashMap[String, Expression],
declarations: mutable.ListBuffer[Statement]
): Expression = {
if (m.writers.isEmpty) return Utils.False()

val readEn = memPortField(m, read, "en")
val readEn = readInput(m, read, "en")
val readProd = getProductTerms(readEn)

// create all conflict signals
Expand All @@ -283,7 +267,7 @@ private class InstrumentMems(
} else {
val name = namespace.newName(s"${m.name}_${read}_${write}_rwc")
val bothEn = Utils.and(readEn, writeEn)
val sameAddr = Utils.eq(memPortField(m, read, "addr"), memPortField(m, write, "addr"))
val sameAddr = Utils.eq(readInput(m, read, "addr"), readInput(m, write, "addr"))
// we need a wire because this condition might be used in a random statement
val wire = DefWire(m.info, name, BoolType)
declarations += wire
Expand All @@ -292,13 +276,18 @@ private class InstrumentMems(
}
}

(conflicts.reduce(Utils.or), declarations.toList)
conflicts.reduce(Utils.or)
}

private type ProdTerms = Seq[Expression]
private def writeWriteConflicts(m: DefMemory, writeEn: Seq[(Expression, ProdTerms)]): Seq[Statement] = {
if (m.writers.size < 2) return List()
val declarations = mutable.ListBuffer[Statement]()
private def writeWriteConflicts(
m: DefMemory,
writeEn: Seq[(Expression, ProdTerms)]
)(
implicit cache: mutable.HashMap[String, Expression],
declarations: mutable.ListBuffer[Statement]
): Unit = {
if (m.writers.size < 2) return

// we first create all conflict signals:
val conflict =
Expand All @@ -314,7 +303,7 @@ private class InstrumentMems(
} else {
val name = namespace.newName(s"${m.name}_${w1}_${w2}_wwc")
val bothEn = Utils.and(en1, en2)
val sameAddr = Utils.eq(memPortField(m, w1, "addr"), memPortField(m, w2, "addr"))
val sameAddr = Utils.eq(readInput(m, w1, "addr"), readInput(m, w2, "addr"))
// we need a wire because this condition might be used in a random statement
val wire = DefWire(m.info, name, BoolType)
declarations += wire
Expand Down Expand Up @@ -355,25 +344,24 @@ private class InstrumentMems(

// create the source of randomness
val name = namespace.newName(s"${m.name}_${w1}_wwc_data")
val random = DefRandom(m.info, name, m.dataType, Some(memPortField(m, w1, "clk")), hasConflict)
val random = DefRandom(m.info, name, m.dataType, Some(readInput(m, w1, "clk")), hasConflict)
declarations.append(random)
// replace the old data input
val dataWire = disconnectInput(m.info, memPortField(m, w1, "data"))
declarations += dataWire

// generate new data input
val data = Utils.mux(hasConflict, Reference(random), Reference(dataWire))
val data = Utils.mux(hasConflict, Reference(random), readInput(m, w1, "data"))
newStmts.append(Connect(m.info, memPortField(m, w1, "data"), data))
doDisconnect.add(memPortField(m, w1, "data").serialize)
}

// connect data enable signals
val maskIsOne = isTrue(connects(memPortField(m, w1, "mask").serialize))
val maskIsOne = isTrue(readInput(m, w1, "mask"))
if (!maskIsOne) {
newStmts.append(Connect(m.info, memPortField(m, w1, "mask"), Utils.True()))
doDisconnect.add(memPortField(m, w1, "mask").serialize)
}
newStmts.append(Connect(m.info, memPortField(m, w1, "en"), en))
doDisconnect.add(memPortField(m, w1, "en").serialize)
}

declarations.toList
}

/** check whether two signals can be proven to be mutually exclusive */
Expand All @@ -383,27 +371,43 @@ private class InstrumentMems(
proofOfMutualExclusion.nonEmpty
}

/** replace a memory port with a wire */
private def disconnectInput(info: Info, signal: RefLikeExpression): DefWire = {
// disconnect the old value
doDisconnect.add(signal.serialize)

// if the old value is a literal, we just replace all references to it with this literal
val oldValue = connects(signal.serialize)
if (isLiteral(oldValue)) {
println("TODO: better code for literal")
}
/** memory inputs my not be read, only assigned to, thus we might need to add a wire to make them accessible */
private def readInput(
info: Info,
signal: RefLikeExpression
)(
implicit cache: mutable.HashMap[String, Expression],
declarations: mutable.ListBuffer[Statement]
): Expression =
cache.getOrElseUpdate(
signal.serialize, {
// if it is a literal, we just return it
val value = connects(signal.serialize)
if (isLiteral(value)) {
value
} else {
// otherwise we make a wire that refelect the value
val wire = DefWire(info, copyName(signal), signal.tpe)
declarations += wire

// create a new wire and replace all references to the original port with this wire
val wire = DefWire(info, copyName(signal), signal.tpe)
exprReplacements(signal.serialize) = Reference(wire)
// connect the old expression to the new wire
val con = Connect(info, Reference(wire), connects(signal.serialize))
newStmts.append(con)
// connect the old expression to the new wire
val con = Connect(info, Reference(wire), value)
newStmts.append(con)

// the wire definition should end up right after the memory definition
wire
}
// use a reference to this new wire
Reference(wire)
}
}
)
private def readInput(
m: DefMemory,
port: String,
field: String
)(
implicit cache: mutable.HashMap[String, Expression],
declarations: mutable.ListBuffer[Statement]
): Expression =
readInput(m.info, memPortField(m, port, field))

private def copyName(ref: RefLikeExpression): String =
namespace.newName(ref.serialize.replace('.', '_'))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,46 @@ class MemorySpec extends EndToEndSMTBaseSpec {
"memory with two write ports" should "can have collisions when enables are unconstrained" taggedAs (RequiresZ3) in {
test(collisionTest("UInt(1)"), MCFail(1), kmax = 1)
}

private def readEnableSrc(pred: String, num: Int) =
s"""
|circuit ReadEnableTest$num:
| module ReadEnableTest$num:
| input c : Clock
| input preset: AsyncReset
|
| reg first: UInt<1>, c with: (reset => (preset, UInt(1)))
| first <= UInt(0)
|
| reg even: UInt<1>, c with: (reset => (preset, UInt(0)))
| node odd = not(even)
| even <= not(even)
|
| mem m:
| data-type => UInt<8>
| depth => 4
| reader => r
| read-latency => 1
| write-latency => 1
| read-under-write => undefined
|
| m.r.clk <= c
| m.r.addr <= UInt(0)
| ; the read port is enabled in even cycles
| m.r.en <= even
|
| assert(c, $pred, not(first), "")
|""".stripMargin

"a memory with read enable" should "supply valid data one cycle after en=1" in {
val init = Seq(MemoryScalarInitAnnotation(CircuitTarget(s"ReadEnableTest1").module(s"ReadEnableTest1").ref("m"), 0))
// the read port is enabled on even cycles, so on odd cycles we should reliably get zeros
test(readEnableSrc("or(not(odd), eq(m.r.data, UInt(0)))", 1), MCSuccess, kmax = 3, annos = init)
}

"a memory with read enable" should "supply invalid data one cycle after en=0" in {
val init = Seq(MemoryScalarInitAnnotation(CircuitTarget(s"ReadEnableTest2").module(s"ReadEnableTest2").ref("m"), 0))
// the read port is disabled on odd cycles, so on even cycles we should *NOT* reliably get zeros
test(readEnableSrc("or(not(even), eq(m.r.data, UInt(0)))", 2), MCFail(1), kmax = 1, annos = init)
}
}
Loading

0 comments on commit 9ab5d65

Please sign in to comment.