Skip to content

Commit

Permalink
Refactor sphinx payment packet (#2052)
Browse files Browse the repository at this point in the history
We previously created restrictions in Sphinx.scala to only allow using it
for two types of onions: a 1300 bytes one for HTLCs and a 400 bytes one
for trampoline.

This doesn't make sense anymore. The latest version of trampoline allows
any onion size, and onion messages also allow any onion size. The Sphinx
protocol doesn't care either about the size of the payload.

Another reason to remove it is that it wasn't working that well with
pattern matching because of type erasure.

So now the caller must explicitly set the length of the payload, which is
more flexible. Verifying that the correct length is used is deferred to
higher level components.
  • Loading branch information
t-bast committed Nov 5, 2021
1 parent 3dc4ae1 commit b45dd00
Show file tree
Hide file tree
Showing 10 changed files with 296 additions and 311 deletions.
292 changes: 135 additions & 157 deletions eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop}
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshi, UInt64, randomKey}
import scodec.bits.ByteVector
import scodec.{Attempt, DecodeResult}
import scodec.{Attempt, Codec, DecodeResult}

import java.util.UUID
import scala.reflect.ClassTag

/**
* Created by t-bast on 08/10/2019.
Expand All @@ -56,11 +55,11 @@ object IncomingPaymentPacket {

case class DecodedOnionPacket[T <: PaymentOnion.PacketType](payload: T, next: OnionRoutingPacket)

private[payment] def decryptOnion[T <: PaymentOnion.PacketType : ClassTag](paymentHash: ByteVector32, privateKey: PrivateKey)(packet: OnionRoutingPacket, packetType: Sphinx.OnionRoutingPacket[T])(implicit log: LoggingAdapter): Either[FailureMessage, DecodedOnionPacket[T]] =
packetType.peel(privateKey, paymentHash, packet) match {
private[payment] def decryptOnion[T <: PaymentOnion.PacketType](paymentHash: ByteVector32, privateKey: PrivateKey, packet: OnionRoutingPacket, perHopPayloadCodec: Boolean => Codec[T])(implicit log: LoggingAdapter): Either[FailureMessage, DecodedOnionPacket[T]] =
Sphinx.peel(privateKey, paymentHash, packet) match {
case Right(p@Sphinx.DecryptedPacket(payload, nextPacket, _)) =>
(PaymentOnionCodecs.perHopPayloadCodecByPacketType(packetType, p.isLastPacket).decode(payload.bits): @unchecked) match {
case Attempt.Successful(DecodeResult(perHopPayload: T, _)) => Right(DecodedOnionPacket(perHopPayload, nextPacket))
perHopPayloadCodec(p.isLastPacket).decode(payload.bits) match {
case Attempt.Successful(DecodeResult(perHopPayload, _)) => Right(DecodedOnionPacket(perHopPayload, nextPacket))
case Attempt.Failure(e: OnionRoutingCodecs.MissingRequiredTlv) => Left(e.failureMessage)
// Onion is correctly encrypted but the content of the per-hop payload couldn't be parsed.
// It's hard to provide tag and offset information from scodec failures, so we currently don't do it.
Expand All @@ -81,12 +80,12 @@ object IncomingPaymentPacket {
* @return whether the payment is to be relayed or if our node is the final recipient (or an error).
*/
def decrypt(add: UpdateAddHtlc, privateKey: PrivateKey)(implicit log: LoggingAdapter): Either[FailureMessage, IncomingPaymentPacket] = {
decryptOnion(add.paymentHash, privateKey)(add.onionRoutingPacket, Sphinx.PaymentPacket) match {
decryptOnion(add.paymentHash, privateKey, add.onionRoutingPacket, PaymentOnionCodecs.paymentOnionPerHopPayloadCodec) match {
case Left(failure) => Left(failure)
// NB: we don't validate the ChannelRelayPacket here because its fees and cltv depend on what channel we'll choose to use.
case Right(DecodedOnionPacket(payload: PaymentOnion.ChannelRelayPayload, next)) => Right(ChannelRelayPacket(add, payload, next))
case Right(DecodedOnionPacket(payload: PaymentOnion.FinalTlvPayload, _)) => payload.records.get[OnionPaymentPayloadTlv.TrampolineOnion] match {
case Some(OnionPaymentPayloadTlv.TrampolineOnion(trampolinePacket)) => decryptOnion(add.paymentHash, privateKey)(trampolinePacket, Sphinx.TrampolinePacket) match {
case Some(OnionPaymentPayloadTlv.TrampolineOnion(trampolinePacket)) => decryptOnion(add.paymentHash, privateKey, trampolinePacket, PaymentOnionCodecs.trampolineOnionPerHopPayloadCodec) match {
case Left(failure) => Left(failure)
case Right(DecodedOnionPacket(innerPayload: PaymentOnion.NodeRelayPayload, next)) => validateNodeRelay(add, payload, innerPayload, next)
case Right(DecodedOnionPacket(innerPayload: PaymentOnion.FinalPayload, _)) => validateFinal(add, payload, innerPayload)
Expand Down Expand Up @@ -140,7 +139,7 @@ object OutgoingPaymentPacket {
/**
* Build an encrypted onion packet from onion payloads and node public keys.
*/
def buildOnion[T <: PaymentOnion.PacketType](packetType: Sphinx.OnionRoutingPacket[T])(nodes: Seq[PublicKey], payloads: Seq[PaymentOnion.PerHopPayload], associatedData: ByteVector32): Sphinx.PacketAndSecrets = {
private def buildOnion(packetPayloadLength: Int, nodes: Seq[PublicKey], payloads: Seq[PaymentOnion.PerHopPayload], associatedData: ByteVector32): Sphinx.PacketAndSecrets = {
require(nodes.size == payloads.size)
val sessionKey = randomKey()
val payloadsBin: Seq[ByteVector] = payloads
Expand All @@ -153,7 +152,7 @@ object OutgoingPaymentPacket {
case Attempt.Successful(bitVector) => bitVector.bytes
case Attempt.Failure(cause) => throw new RuntimeException(s"serialization error: $cause")
}
packetType.create(sessionKey, nodes, payloadsBin, associatedData)
Sphinx.create(sessionKey, packetPayloadLength, nodes, payloadsBin, associatedData)
}

/**
Expand Down Expand Up @@ -187,14 +186,20 @@ object OutgoingPaymentPacket {
* - firstExpiry is the cltv expiry for the first htlc in the route
* - the onion to include in the HTLC
*/
def buildPacket[T <: PaymentOnion.PacketType](packetType: Sphinx.OnionRoutingPacket[T])(paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: PaymentOnion.FinalPayload): (MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets) = {
private def buildPacket(packetPayloadLength: Int, paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: PaymentOnion.FinalPayload): (MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets) = {
val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), finalPayload)
val nodes = hops.map(_.nextNodeId)
// BOLT 2 requires that associatedData == paymentHash
val onion = buildOnion(packetType)(nodes, payloads, paymentHash)
val onion = buildOnion(packetPayloadLength, nodes, payloads, paymentHash)
(firstAmount, firstExpiry, onion)
}

def buildPaymentPacket(paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: PaymentOnion.FinalPayload): (MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets) =
buildPacket(PaymentOnionCodecs.paymentOnionPayloadLength, paymentHash, hops, finalPayload)

def buildTrampolinePacket(paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: PaymentOnion.FinalPayload): (MilliSatoshi, CltvExpiry, Sphinx.PacketAndSecrets) =
buildPacket(PaymentOnionCodecs.trampolineOnionPayloadLength, paymentHash, hops, finalPayload)

/**
* Build an encrypted trampoline onion packet when the final recipient doesn't support trampoline.
* The next-to-last trampoline node payload will contain instructions to convert to a legacy payment.
Expand All @@ -219,7 +224,7 @@ object OutgoingPaymentPacket {
(amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, payload +: payloads)
}
val nodes = hops.map(_.nextNodeId)
val onion = buildOnion(Sphinx.TrampolinePacket)(nodes, payloads, invoice.paymentHash)
val onion = buildOnion(PaymentOnionCodecs.trampolineOnionPayloadLength, nodes, payloads, invoice.paymentHash)
(firstAmount, firstExpiry, onion)
}

Expand All @@ -240,12 +245,12 @@ object OutgoingPaymentPacket {
* @return the command and the onion shared secrets (used to decrypt the error in case of payment failure)
*/
def buildCommand(replyTo: ActorRef, upstream: Upstream, paymentHash: ByteVector32, hops: Seq[ChannelHop], finalPayload: PaymentOnion.FinalPayload): (CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)]) = {
val (firstAmount, firstExpiry, onion) = buildPacket(Sphinx.PaymentPacket)(paymentHash, hops, finalPayload)
val (firstAmount, firstExpiry, onion) = buildPaymentPacket(paymentHash, hops, finalPayload)
CMD_ADD_HTLC(replyTo, firstAmount, paymentHash, firstExpiry, onion.packet, Origin.Hot(replyTo, upstream), commit = true) -> onion.sharedSecrets
}

def buildHtlcFailure(nodeSecret: PrivateKey, reason: Either[ByteVector, FailureMessage], add: UpdateAddHtlc): Either[CannotExtractSharedSecret, ByteVector] = {
Sphinx.PaymentPacket.peel(nodeSecret, add.paymentHash, add.onionRoutingPacket) match {
Sphinx.peel(nodeSecret, add.paymentHash, add.onionRoutingPacket) match {
case Right(Sphinx.DecryptedPacket(_, _, sharedSecret)) =>
val encryptedReason = reason match {
case Left(forwarded) => Sphinx.FailurePacket.wrap(forwarded, sharedSecret)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn
}
// We assume that the trampoline node supports multi-part payments (it should).
val (trampolineAmount, trampolineExpiry, trampolineOnion) = if (r.paymentRequest.features.allowTrampoline) {
OutgoingPaymentPacket.buildPacket(Sphinx.TrampolinePacket)(r.paymentHash, trampolineRoute, finalPayload)
OutgoingPaymentPacket.buildTrampolinePacket(r.paymentHash, trampolineRoute, finalPayload)
} else {
OutgoingPaymentPacket.buildTrampolineToLegacyPacket(r.paymentRequest, trampolineRoute, finalPayload)
}
Expand Down Expand Up @@ -260,13 +260,13 @@ object PaymentInitiator {
blockUntilComplete: Boolean = false) extends SendRequestedPayment

/**
* @param recipientAmount amount that should be received by the final recipient.
* @param recipientNodeId id of the final recipient.
* @param paymentPreimage payment preimage.
* @param maxAttempts maximum number of retries.
* @param externalId (optional) externally-controlled identifier (to reconcile between application DB and eclair DB).
* @param routeParams (optional) parameters to fine-tune the routing algorithm.
* @param userCustomTlvs (optional) user-defined custom tlvs that will be added to the onion sent to the target node.
* @param recipientAmount amount that should be received by the final recipient.
* @param recipientNodeId id of the final recipient.
* @param paymentPreimage payment preimage.
* @param maxAttempts maximum number of retries.
* @param externalId (optional) externally-controlled identifier (to reconcile between application DB and eclair DB).
* @param routeParams (optional) parameters to fine-tune the routing algorithm.
* @param userCustomTlvs (optional) user-defined custom tlvs that will be added to the onion sent to the target node.
* @param recordPathFindingMetrics will be used to build [[SendPaymentConfig]].
*/
case class SendSpontaneousPayment(recipientAmount: MilliSatoshi,
Expand Down Expand Up @@ -339,22 +339,22 @@ object PaymentInitiator {
/**
* Configuration for an instance of a payment state machine.
*
* @param id id of the outgoing payment (mapped to a single outgoing HTLC).
* @param parentId id of the whole payment (if using multi-part, there will be N associated child payments,
* each with a different id).
* @param externalId externally-controlled identifier (to reconcile between application DB and eclair DB).
* @param paymentHash payment hash.
* @param recipientAmount amount that should be received by the final recipient (usually from a Bolt 11 invoice).
* @param recipientNodeId id of the final recipient.
* @param upstream information about the payment origin (to link upstream to downstream when relaying a payment).
* @param paymentRequest Bolt 11 invoice.
* @param storeInDb whether to store data in the payments DB (e.g. when we're relaying a trampoline payment, we
* don't want to store in the DB).
* @param publishEvent whether to publish a [[fr.acinq.eclair.payment.PaymentEvent]] on success/failure (e.g. for
* multi-part child payments, we don't want to emit events for each child, only for the whole payment).
* @param recordPathFindingMetrics We don't record metrics for payments that don't use path finding or that are a part of a bigger payment.
* @param additionalHops additional hops that the payment state machine isn't aware of (e.g. when using trampoline, hops
* that occur after the first trampoline node).
* @param id id of the outgoing payment (mapped to a single outgoing HTLC).
* @param parentId id of the whole payment (if using multi-part, there will be N associated child payments,
* each with a different id).
* @param externalId externally-controlled identifier (to reconcile between application DB and eclair DB).
* @param paymentHash payment hash.
* @param recipientAmount amount that should be received by the final recipient (usually from a Bolt 11 invoice).
* @param recipientNodeId id of the final recipient.
* @param upstream information about the payment origin (to link upstream to downstream when relaying a payment).
* @param paymentRequest Bolt 11 invoice.
* @param storeInDb whether to store data in the payments DB (e.g. when we're relaying a trampoline payment, we
* don't want to store in the DB).
* @param publishEvent whether to publish a [[fr.acinq.eclair.payment.PaymentEvent]] on success/failure (e.g. for
* multi-part child payments, we don't want to emit events for each child, only for the whole payment).
* @param recordPathFindingMetrics We don't record metrics for payments that don't use path finding or that are a part of a bigger payment.
* @param additionalHops additional hops that the payment state machine isn't aware of (e.g. when using trampoline, hops
* that occur after the first trampoline node).
*/
case class SendPaymentConfig(id: UUID,
parentId: UUID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package fr.acinq.eclair.wire.protocol

import fr.acinq.bitcoin.ByteVector32
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.payment.PaymentRequest
import fr.acinq.eclair.wire.protocol.CommonCodecs._
import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.MissingRequiredTlv
Expand Down Expand Up @@ -208,13 +207,16 @@ object PaymentOnion {
def records: TlvStream[OnionPaymentPayloadTlv]
}

/** Payment onion packet type (see [[fr.acinq.eclair.crypto.Sphinx.OnionRoutingPacket]]). */
/** Payment onion packet type. */
sealed trait PacketType

/** See [[fr.acinq.eclair.crypto.Sphinx.PaymentPacket]]. */
/** A payment onion packet is used when offering an HTLC to a remote node. */
sealed trait PaymentPacket extends PacketType

/** See [[fr.acinq.eclair.crypto.Sphinx.TrampolinePacket]]. */
/**
* A trampoline onion packet is used to defer route construction to trampoline nodes.
* It is usually embedded inside a [[PaymentPacket]] in the final node's payload.
*/
sealed trait TrampolinePacket extends PacketType

/** Per-hop payload from an HTLC's payment onion (after decryption and decoding). */
Expand Down Expand Up @@ -309,9 +311,10 @@ object PaymentOnionCodecs {
import scodec.codecs._
import scodec.{Attempt, Codec, DecodeResult, Decoder}

val paymentOnionPacketCodec: Codec[OnionRoutingPacket] = OnionRoutingCodecs.onionRoutingPacketCodec(Sphinx.PaymentPacket.PayloadLength)

val trampolineOnionPacketCodec: Codec[OnionRoutingPacket] = OnionRoutingCodecs.onionRoutingPacketCodec(Sphinx.TrampolinePacket.PayloadLength)
val paymentOnionPayloadLength = 1300
val trampolineOnionPayloadLength = 400
val paymentOnionPacketCodec: Codec[OnionRoutingPacket] = OnionRoutingCodecs.onionRoutingPacketCodec(paymentOnionPayloadLength)
val trampolineOnionPacketCodec: Codec[OnionRoutingPacket] = OnionRoutingCodecs.onionRoutingPacketCodec(trampolineOnionPayloadLength)

/**
* The 1.1 BOLT spec changed the payment onion frame format to use variable-length per-hop payloads.
Expand Down Expand Up @@ -396,9 +399,8 @@ object PaymentOnionCodecs {
case FinalTlvPayload(tlvs) => tlvs
})

def perHopPayloadCodecByPacketType[T <: PacketType](packetType: Sphinx.OnionRoutingPacket[T], isLastPacket: Boolean): Codec[PacketType] = packetType match {
case Sphinx.PaymentPacket => if (isLastPacket) finalPerHopPayloadCodec.upcast[PacketType] else channelRelayPerHopPayloadCodec.upcast[PacketType]
case Sphinx.TrampolinePacket => if (isLastPacket) finalPerHopPayloadCodec.upcast[PacketType] else nodeRelayPerHopPayloadCodec.upcast[PacketType]
}
def paymentOnionPerHopPayloadCodec(isLastPacket: Boolean): Codec[PaymentPacket] = if (isLastPacket) finalPerHopPayloadCodec.upcast[PaymentPacket] else channelRelayPerHopPayloadCodec.upcast[PaymentPacket]

def trampolineOnionPerHopPayloadCodec(isLastPacket: Boolean): Codec[TrampolinePacket] = if (isLastPacket) finalPerHopPayloadCodec.upcast[TrampolinePacket] else nodeRelayPerHopPayloadCodec.upcast[TrampolinePacket]

}
Original file line number Diff line number Diff line change
Expand Up @@ -1420,7 +1420,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with
import f._
val (_, htlc) = addHtlc(150000000 msat, alice, bob, alice2bob, bob2alice)
crossSign(alice, bob, alice2bob, bob2alice)
bob ! CMD_FAIL_MALFORMED_HTLC(htlc.id, Sphinx.PaymentPacket.hash(htlc.onionRoutingPacket), FailureMessageCodecs.BADONION)
bob ! CMD_FAIL_MALFORMED_HTLC(htlc.id, Sphinx.hash(htlc.onionRoutingPacket), FailureMessageCodecs.BADONION)
val fail = bob2alice.expectMsgType[UpdateFailMalformedHtlc]
bob2alice.forward(alice)
bob ! CMD_SIGN()
Expand Down Expand Up @@ -1755,7 +1755,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with

// actual test begins
val initialState = bob.stateData.asInstanceOf[DATA_NORMAL]
bob ! CMD_FAIL_MALFORMED_HTLC(htlc.id, Sphinx.PaymentPacket.hash(htlc.onionRoutingPacket), FailureMessageCodecs.BADONION)
bob ! CMD_FAIL_MALFORMED_HTLC(htlc.id, Sphinx.hash(htlc.onionRoutingPacket), FailureMessageCodecs.BADONION)
val fail = bob2alice.expectMsgType[UpdateFailMalformedHtlc]
awaitCond(bob.stateData == initialState.copy(
commitments = initialState.commitments.copy(
Expand Down Expand Up @@ -1834,7 +1834,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with
crossSign(alice, bob, alice2bob, bob2alice)
// Bob fails the HTLC because he cannot parse it
val initialState = alice.stateData.asInstanceOf[DATA_NORMAL]
bob ! CMD_FAIL_MALFORMED_HTLC(htlc.id, Sphinx.PaymentPacket.hash(htlc.onionRoutingPacket), FailureMessageCodecs.BADONION)
bob ! CMD_FAIL_MALFORMED_HTLC(htlc.id, Sphinx.hash(htlc.onionRoutingPacket), FailureMessageCodecs.BADONION)
val fail = bob2alice.expectMsgType[UpdateFailMalformedHtlc]
bob2alice.forward(alice)

Expand All @@ -1860,7 +1860,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with

// actual test begins
val tx = alice.stateData.asInstanceOf[DATA_NORMAL].commitments.localCommit.commitTxAndRemoteSig.commitTx.tx
val fail = UpdateFailMalformedHtlc(ByteVector32.Zeroes, htlc.id, Sphinx.PaymentPacket.hash(htlc.onionRoutingPacket), 42)
val fail = UpdateFailMalformedHtlc(ByteVector32.Zeroes, htlc.id, Sphinx.hash(htlc.onionRoutingPacket), 42)
alice ! fail
val error = alice2bob.expectMsgType[Error]
assert(new String(error.data.toArray) === InvalidFailureCode(ByteVector32.Zeroes).getMessage)
Expand Down
Loading

0 comments on commit b45dd00

Please sign in to comment.