diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala index ecc28a86f5..594a2cde3d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala @@ -122,11 +122,10 @@ object MessageOnionCodecs { def messageOnionPerHopPayloadCodec(isLastPacket: Boolean): Codec[PerHopPayload] = if (isLastPacket) finalPerHopPayloadCodec.upcast[PerHopPayload] else relayPerHopPayloadCodec.upcast[PerHopPayload] - val messageOnionPacketCodec: Codec[OnionRoutingPacket] = - (variableSizePrefixedBytes(uint16.xmap(_ - 66, _ + 66), - ("version" | uint8) ~ - ("publicKey" | bytes(33)), - ("onionPayload" | bytes)) ~ - ("hmac" | bytes32) flattenLeftPairs).as[OnionRoutingPacket] + val messageOnionPacketCodec: Codec[OnionRoutingPacket] = variableSizeBytes(uint16, bytes).exmap[OnionRoutingPacket]( + // The Sphinx packet header contains a version (1 byte), a public key (33 bytes) and a mac (32 bytes) -> total 66 bytes + bytes => OnionRoutingCodecs.onionRoutingPacketCodec(bytes.length.toInt - 66).decode(bytes.bits).map(_.value), + onion => OnionRoutingCodecs.onionRoutingPacketCodec(onion.payload.length.toInt).encode(onion).map(_.bytes) + ) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/CommonCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/CommonCodecsSpec.scala index 1050b395c5..05fa6eb687 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/CommonCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/CommonCodecsSpec.scala @@ -196,6 +196,57 @@ class CommonCodecsSpec extends AnyFunSuite { } } + test("encode/decode bytevector32") { + val testCases = Seq( + (hex"0000000000000000000000000000000000000000000000000000000000000000", Some(ByteVector32.Zeroes)), + (hex"0101010101010101010101010101010101010101010101010101010101010101", Some(ByteVector32(hex"0101010101010101010101010101010101010101010101010101010101010101"))), + (hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", Some(ByteVector32(hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"))), + // Ignore additional trailing bytes + (hex"000000000000000000000000000000000000000000000000000000000000000000", Some(ByteVector32.Zeroes)), + (hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff00", Some(ByteVector32(hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"))), + // Not enough bytes + (hex"00000000000000000000000000000000000000000000000000000000000000", None), + (hex"", None) + ) + + for ((encoded, expected_opt) <- testCases) { + expected_opt match { + case Some(expected) => + val decoded = bytes32.decode(encoded.bits).require.value + assert(decoded === expected) + assert(expected.bytes === bytes32.encode(decoded).require.bytes) + case None => + assert(bytes32.decode(encoded.bits).isFailure) + } + } + } + + test("encode/decode bytevector64") { + val testCases = Seq( + (hex"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", Some(ByteVector64.Zeroes)), + (hex"01010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101", Some(ByteVector64(hex"01010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101"))), + (hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", Some(ByteVector64(hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"))), + // Ignore additional trailing bytes + (hex"0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", Some(ByteVector64.Zeroes)), + (hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff00", Some(ByteVector64(hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"))), + // Not enough bytes + (hex"000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", None), + (hex"00000000000000000000000000000000000000000000000000000000000000", None), + (hex"", None) + ) + + for ((encoded, expected_opt) <- testCases) { + expected_opt match { + case Some(expected) => + val decoded = bytes64.decode(encoded.bits).require.value + assert(decoded === expected) + assert(expected.bytes === bytes64.encode(decoded).require.bytes) + case None => + assert(bytes64.decode(encoded.bits).isFailure) + } + } + } + test("encode/decode with private key codec") { val value = PrivateKey(randomBytes32()) val wire = privateKey.encode(value).require diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala index dc65f958c4..3c88f2adda 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala @@ -51,19 +51,39 @@ class MessageOnionCodecsSpec extends AnyFunSuiteLike { assert(finalPerHopPayloadCodec.decode(serialized.bits).require.value === payload) } - test("onion packet can be any size"){ + test("onion packet can be any size") { { // small onion - val onion = OnionRoutingPacket(1, hex"032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991", hex"012345679abcdef", ByteVector32(hex"0000111122223333444455556666777788889999aaaabbbbccccddddeeee0000")) - val serialized = hex"004a01032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e6686809910012345679abcdef0000111122223333444455556666777788889999aaaabbbbccccddddeeee0000" + val onion = OnionRoutingPacket(1, hex"032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991", hex"0012345679abcdef", ByteVector32(hex"0000111122223333444455556666777788889999aaaabbbbccccddddeeee0000")) + val serialized = hex"004a 01 032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991 0012345679abcdef 0000111122223333444455556666777788889999aaaabbbbccccddddeeee0000" assert(messageOnionPacketCodec.encode(onion).require.bytes === serialized) assert(messageOnionPacketCodec.decode(serialized.bits).require.value === onion) } { // larger onion - val onion = OnionRoutingPacket(2, hex"027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007", hex"012345679abcdef012345679abcdef012345679abcdef012345679abcdef012345679abcdef", ByteVector32(hex"eeee0000111122223333444455556666777788889999aaaabbbbccccddddeeee")) - val serialized = hex"006802027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa20070012345679abcdef012345679abcdef012345679abcdef012345679abcdef012345679abcdefeeee0000111122223333444455556666777788889999aaaabbbbccccddddeeee" + val onion = OnionRoutingPacket(2, hex"027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007", hex"0012345679abcdef012345679abcdef012345679abcdef012345679abcdef012345679abcdef", ByteVector32(hex"eeee0000111122223333444455556666777788889999aaaabbbbccccddddeeee")) + val serialized = hex"0068 02 027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007 0012345679abcdef012345679abcdef012345679abcdef012345679abcdef012345679abcdef eeee0000111122223333444455556666777788889999aaaabbbbccccddddeeee" assert(messageOnionPacketCodec.encode(onion).require.bytes === serialized) assert(messageOnionPacketCodec.decode(serialized.bits).require.value === onion) } + { // onion with trailing additional bytes + val onion = OnionRoutingPacket(0, hex"032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991", hex"ffffffff", ByteVector32.Zeroes) + val serialized = hex"0046 00 032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991 ffffffff 0000000000000000000000000000000000000000000000000000000000000000 0a01020000030400000000" + assert(messageOnionPacketCodec.encode(onion).require.bytes === serialized.dropRight(11)) + assert(messageOnionPacketCodec.decode(serialized.bits).require.value === onion) + } + { // onion with empty payload + val onion = OnionRoutingPacket(0, hex"032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991", hex"", ByteVector32.Zeroes) + val serialized = hex"0042 00 032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991 0000000000000000000000000000000000000000000000000000000000000000" + assert(messageOnionPacketCodec.encode(onion).require.bytes === serialized) + assert(messageOnionPacketCodec.decode(serialized.bits).require.value === onion) + } + { // onion length too big + val serialized = hex"0048 00 032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991 ffffffff 0000000000000000000000000000000000000000000000000000000000000000" + assert(messageOnionPacketCodec.decode(serialized.bits).isFailure) + } + { // onion length way too big + val serialized = hex"00ff 00 032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991 ffffffff 0000000000000000000000000000000000000000000000000000000000000000" + assert(messageOnionPacketCodec.decode(serialized.bits).isFailure) + } } }