Skip to content

Commit

Permalink
Avoid default arguments in test helpers (#2108)
Browse files Browse the repository at this point in the history
A few of our test helper functions started using too many default arguments,
which makes them a bit messy to use in tests, especially when we want to
add new arguments.

We change this to use method overloading instead, which makes it easier to
add new overloads without changes to existing tests.
  • Loading branch information
t-bast committed Dec 15, 2021
1 parent 3d88c43 commit 8ff7dc7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -228,18 +228,38 @@ trait ChannelStateTestsHelperMethods extends TestKitBase {
channelUpdateListener.expectMsgType[LocalChannelUpdate]
}

def localOrigin(replyTo: ActorRef): Origin.LocalHot = Origin.LocalHot(replyTo, UUID.randomUUID)
def localOrigin(replyTo: ActorRef): Origin.LocalHot = Origin.LocalHot(replyTo, UUID.randomUUID())

def makeCmdAdd(amount: MilliSatoshi, destination: PublicKey, currentBlockHeight: Long, paymentPreimage: ByteVector32 = randomBytes32(), upstream: Upstream = Upstream.Local(UUID.randomUUID), replyTo: ActorRef = TestProbe().ref): (ByteVector32, CMD_ADD_HTLC) = {
def makeCmdAdd(amount: MilliSatoshi, destination: PublicKey, currentBlockHeight: Long): (ByteVector32, CMD_ADD_HTLC) = {
makeCmdAdd(amount, CltvExpiryDelta(144), destination, randomBytes32(), currentBlockHeight, Upstream.Local(UUID.randomUUID()))
}

def makeCmdAdd(amount: MilliSatoshi, destination: PublicKey, currentBlockHeight: Long, paymentPreimage: ByteVector32): (ByteVector32, CMD_ADD_HTLC) = {
makeCmdAdd(amount, destination, currentBlockHeight, paymentPreimage, Upstream.Local(UUID.randomUUID()))
}

def makeCmdAdd(amount: MilliSatoshi, destination: PublicKey, currentBlockHeight: Long, paymentPreimage: ByteVector32, upstream: Upstream): (ByteVector32, CMD_ADD_HTLC) = {
makeCmdAdd(amount, CltvExpiryDelta(144), destination, paymentPreimage, currentBlockHeight, upstream)
}

def makeCmdAdd(amount: MilliSatoshi, cltvExpiryDelta: CltvExpiryDelta, destination: PublicKey, paymentPreimage: ByteVector32, currentBlockHeight: Long, upstream: Upstream, replyTo: ActorRef = TestProbe().ref): (ByteVector32, CMD_ADD_HTLC) = {
val paymentHash: ByteVector32 = Crypto.sha256(paymentPreimage)
val expiry = CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight)
val expiry = cltvExpiryDelta.toCltvExpiry(currentBlockHeight)
val cmd = OutgoingPaymentPacket.buildCommand(replyTo, upstream, paymentHash, ChannelHop(null, destination, null) :: Nil, PaymentOnion.createSinglePartPayload(amount, expiry, randomBytes32()))._1.copy(commit = false)
(paymentPreimage, cmd)
}

def addHtlc(amount: MilliSatoshi, s: TestFSMRef[ChannelState, ChannelData, Channel], r: TestFSMRef[ChannelState, ChannelData, Channel], s2r: TestProbe, r2s: TestProbe, replyTo: ActorRef = TestProbe().ref): (ByteVector32, UpdateAddHtlc) = {
def addHtlc(amount: MilliSatoshi, s: TestFSMRef[ChannelState, ChannelData, Channel], r: TestFSMRef[ChannelState, ChannelData, Channel], s2r: TestProbe, r2s: TestProbe): (ByteVector32, UpdateAddHtlc) = {
addHtlc(amount, CltvExpiryDelta(144), s, r, s2r, r2s)
}

def addHtlc(amount: MilliSatoshi, s: TestFSMRef[ChannelState, ChannelData, Channel], r: TestFSMRef[ChannelState, ChannelData, Channel], s2r: TestProbe, r2s: TestProbe, replyTo: ActorRef): (ByteVector32, UpdateAddHtlc) = {
addHtlc(amount, CltvExpiryDelta(144), s, r, s2r, r2s, replyTo)
}

def addHtlc(amount: MilliSatoshi, cltvExpiryDelta: CltvExpiryDelta, s: TestFSMRef[ChannelState, ChannelData, Channel], r: TestFSMRef[ChannelState, ChannelData, Channel], s2r: TestProbe, r2s: TestProbe, replyTo: ActorRef = TestProbe().ref): (ByteVector32, UpdateAddHtlc) = {
val currentBlockHeight = s.underlyingActor.nodeParams.currentBlockHeight
val (payment_preimage, cmd) = makeCmdAdd(amount, r.underlyingActor.nodeParams.nodeId, currentBlockHeight, replyTo = replyTo)
val (payment_preimage, cmd) = makeCmdAdd(amount, cltvExpiryDelta, r.underlyingActor.nodeParams.nodeId, randomBytes32(), currentBlockHeight, Upstream.Local(UUID.randomUUID()), replyTo)
val htlc = addHtlc(cmd, s, r, s2r, r2s)
(payment_preimage, htlc)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import fr.acinq.eclair.router.Router.ChannelHop
import fr.acinq.eclair.transactions.{DirectedHtlc, IncomingHtlc, OutgoingHtlc}
import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, CustomCommitmentsPlugin, MilliSatoshi, MilliSatoshiLong, NodeParams, TestConstants, TestKitBaseClass, TimestampMilli, TimestampMilliLong, randomBytes32}
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, CustomCommitmentsPlugin, MilliSatoshi, MilliSatoshiLong, NodeParams, TestConstants, TestKitBaseClass, TimestampMilliLong, randomBytes32}
import org.scalatest.funsuite.FixtureAnyFunSuiteLike
import org.scalatest.{Outcome, ParallelTestExecution}
import scodec.bits.ByteVector
Expand Down

0 comments on commit 8ff7dc7

Please sign in to comment.