diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index 61317a24b7..b94be2dfe1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -122,9 +122,9 @@ trait Eclair { def sendOnChain(address: String, amount: Satoshi, confirmationTarget: Long): Future[ByteVector32] - def findRoute(targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty)(implicit timeout: Timeout): Future[RouteResponse] + def findRoute(targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty, firstHopChannelId_opt: Option[ShortChannelId], lastHopChannelId_opt: Option[ShortChannelId])(implicit timeout: Timeout): Future[RouteResponse] - def findRouteBetween(sourceNodeId: PublicKey, targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty)(implicit timeout: Timeout): Future[RouteResponse] + def findRouteBetween(sourceNodeId: PublicKey, targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty, firstHopChannelId: Option[ShortChannelId], lastHopChannelId: Option[ShortChannelId])(implicit timeout: Timeout): Future[RouteResponse] def sendToRoute(amount: MilliSatoshi, recipientAmount_opt: Option[MilliSatoshi], externalId_opt: Option[String], parentId_opt: Option[UUID], invoice: PaymentRequest, finalCltvExpiryDelta: CltvExpiryDelta, route: PredefinedRoute, trampolineSecret_opt: Option[ByteVector32] = None, trampolineFees_opt: Option[MilliSatoshi] = None, trampolineExpiryDelta_opt: Option[CltvExpiryDelta] = None, trampolineNodes_opt: Seq[PublicKey] = Nil)(implicit timeout: Timeout): Future[SendPaymentToRouteResponse] @@ -279,8 +279,8 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { } } - override def findRoute(targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty)(implicit timeout: Timeout): Future[RouteResponse] = - findRouteBetween(appKit.nodeParams.nodeId, targetNodeId, amount, pathFindingExperimentName_opt, assistedRoutes) + override def findRoute(targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty, firstHopChannelId_opt: Option[ShortChannelId], lastHopChannelId_opt: Option[ShortChannelId])(implicit timeout: Timeout): Future[RouteResponse] = + findRouteBetween(appKit.nodeParams.nodeId, targetNodeId, amount, pathFindingExperimentName_opt, assistedRoutes, firstHopChannelId_opt, lastHopChannelId_opt) private def getRouteParams(pathFindingExperimentName_opt: Option[String]): Option[RouteParams] = { pathFindingExperimentName_opt match { @@ -289,11 +289,11 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { } } - override def findRouteBetween(sourceNodeId: PublicKey, targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty)(implicit timeout: Timeout): Future[RouteResponse] = { + override def findRouteBetween(sourceNodeId: PublicKey, targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty, firstHopChannelId_opt: Option[ShortChannelId], lastHopChannelId_opt: Option[ShortChannelId])(implicit timeout: Timeout): Future[RouteResponse] = { getRouteParams(pathFindingExperimentName_opt) match { case Some(routeParams) => val maxFee = routeParams.getMaxFee(amount) - (appKit.router ? RouteRequest(sourceNodeId, targetNodeId, amount, maxFee, assistedRoutes, routeParams = routeParams)).mapTo[RouteResponse] + (appKit.router ? RouteRequest(sourceNodeId, targetNodeId, amount, maxFee, assistedRoutes, routeParams = routeParams, firstHopChannelId = firstHopChannelId_opt, lastHopChannelId = lastHopChannelId_opt)).mapTo[RouteResponse] case None => Future.failed(new IllegalArgumentException(s"Path-finding experiment ${pathFindingExperimentName_opt.get} does not exist.")) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala index e4212f4599..64cf61cec6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala @@ -102,10 +102,12 @@ object Graph { wr: WeightRatios, currentBlockHeight: Long, boundaries: RichWeight => Boolean, - includeLocalChannelCost: Boolean): Seq[WeightedPath] = { + includeLocalChannelCost: Boolean, + firstHopChannelId: Option[ShortChannelId] = None, + lastHopChannelId: Option[ShortChannelId] = None): Seq[WeightedPath] = { // find the shortest path (k = 0) val targetWeight = RichWeight(amount, 0, CltvExpiryDelta(0), 0) - val shortestPath = dijkstraShortestPath(graph, sourceNode, targetNode, ignoredEdges, ignoredVertices, extraEdges, targetWeight, boundaries, currentBlockHeight, wr, includeLocalChannelCost) + val shortestPath = dijkstraShortestPath(graph, sourceNode, targetNode, ignoredEdges, ignoredVertices, extraEdges, targetWeight, boundaries, currentBlockHeight, wr, includeLocalChannelCost, firstHopChannelId, lastHopChannelId) if (shortestPath.isEmpty) { return Seq.empty // if we can't even find a single path, avoid returning a Seq(Seq.empty) } @@ -144,7 +146,7 @@ object Graph { val alreadyExploredVertices = rootPathEdges.map(_.desc.b).toSet val rootPathWeight = pathWeight(sourceNode, rootPathEdges, amount, currentBlockHeight, wr, includeLocalChannelCost) // find the "spur" path, a sub-path going from the spur node to the target avoiding previously found sub-paths - val spurPath = dijkstraShortestPath(graph, sourceNode, spurNode, ignoredEdges ++ alreadyExploredEdges, ignoredVertices ++ alreadyExploredVertices, extraEdges, rootPathWeight, boundaries, currentBlockHeight, wr, includeLocalChannelCost) + val spurPath = dijkstraShortestPath(graph, sourceNode, spurNode, ignoredEdges ++ alreadyExploredEdges, ignoredVertices ++ alreadyExploredVertices, extraEdges, rootPathWeight, boundaries, currentBlockHeight, wr, includeLocalChannelCost, firstHopChannelId, lastHopChannelId) if (spurPath.nonEmpty) { val completePath = spurPath ++ rootPathEdges val candidatePath = WeightedPath(completePath, pathWeight(sourceNode, completePath, amount, currentBlockHeight, wr, includeLocalChannelCost)) @@ -192,7 +194,9 @@ object Graph { boundaries: RichWeight => Boolean, currentBlockHeight: Long, wr: WeightRatios, - includeLocalChannelCost: Boolean): Seq[GraphEdge] = { + includeLocalChannelCost: Boolean, + firstHopChannelId: Option[ShortChannelId], + lastHopChannelId: Option[ShortChannelId]): Seq[GraphEdge] = { // the graph does not contain source/destination nodes val sourceNotInGraph = !g.containsVertex(sourceNode) && !extraEdges.exists(_.desc.a == sourceNode) val targetNotInGraph = !g.containsVertex(targetNode) && !extraEdges.exists(_.desc.b == targetNode) @@ -222,9 +226,32 @@ object Graph { visitedNodes += current.key // build the neighbors with optional extra edges val neighborEdges = { - val extraNeighbors = extraEdges.filter(_.desc.b == current.key) - // the resulting set must have only one element per shortChannelId; we prioritize extra edges - g.getIncomingEdgesOf(current.key).filterNot(e => extraNeighbors.exists(_.desc.shortChannelId == e.desc.shortChannelId)) ++ extraNeighbors + + def allEdges: Seq[GraphEdge] = { + val extraNeighbors = extraEdges.filter(_.desc.b == current.key) + // the resulting set must have only one element per shortChannelId; we prioritize extra edges + g.getIncomingEdgesOf(current.key).filterNot(e => extraNeighbors.exists(_.desc.shortChannelId == e.desc.shortChannelId)) ++ extraNeighbors + } + + def lastHopEdges(edges: Seq[GraphEdge], lastHop: ShortChannelId) = { + if (current.key == targetNode) + edges.filter(e => e.desc.shortChannelId == lastHop) + else edges + } + + def firstHopEdges(edges: Seq[GraphEdge], firstHop: ShortChannelId) = { + edges.filter(e => e.desc.a != sourceNode || e.desc.shortChannelId == firstHop) + } + + (firstHopChannelId, lastHopChannelId) match { + case (None, None) => allEdges + case (Some(firstHop), Some(lastHop)) => + firstHopEdges(lastHopEdges(allEdges, lastHop), firstHop) + case (Some(firstHop), None) => + firstHopEdges(allEdges, firstHop) + case (None, Some(lastHop)) => + lastHopEdges(allEdges, lastHop) + } } neighborEdges.foreach { edge => val neighbor = edge.desc.a diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala index 63a02c5252..bd64504baa 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -115,7 +115,7 @@ object RouteCalculation { val result = if (r.allowMultiPart) { findMultiPartRoute(d.graph, r.source, r.target, r.amount, r.maxFee, extraEdges, ignoredEdges, r.ignore.nodes, r.pendingPayments, params, currentBlockHeight) } else { - findRoute(d.graph, r.source, r.target, r.amount, r.maxFee, routesToFind, extraEdges, ignoredEdges, r.ignore.nodes, params, currentBlockHeight) + findRoute(d.graph, r.source, r.target, r.amount, r.maxFee, routesToFind, extraEdges, ignoredEdges, r.ignore.nodes, params, currentBlockHeight, r.firstHopChannelId, r.lastHopChannelId) } result match { case Success(routes) => @@ -209,8 +209,10 @@ object RouteCalculation { ignoredEdges: Set[ChannelDesc] = Set.empty, ignoredVertices: Set[PublicKey] = Set.empty, routeParams: RouteParams, - currentBlockHeight: Long): Try[Seq[Route]] = Try { - findRouteInternal(g, localNodeId, targetNodeId, amount, maxFee, numRoutes, extraEdges, ignoredEdges, ignoredVertices, routeParams, currentBlockHeight) match { + currentBlockHeight: Long, + firstHopChannelId: Option[ShortChannelId] = None, + lastHopChannelId: Option[ShortChannelId] = None): Try[Seq[Route]] = Try { + findRouteInternal(g, localNodeId, targetNodeId, amount, maxFee, numRoutes, extraEdges, ignoredEdges, ignoredVertices, routeParams, currentBlockHeight, firstHopChannelId, lastHopChannelId) match { case Right(routes) => routes.map(route => Route(amount, route.path.map(graphEdgeToHop))) case Left(ex) => return Failure(ex) } @@ -227,7 +229,9 @@ object RouteCalculation { ignoredEdges: Set[ChannelDesc] = Set.empty, ignoredVertices: Set[PublicKey] = Set.empty, routeParams: RouteParams, - currentBlockHeight: Long): Either[RouterException, Seq[Graph.WeightedPath]] = { + currentBlockHeight: Long, + firstHopChannelId: Option[ShortChannelId], + lastHopChannelId: Option[ShortChannelId]): Either[RouterException, Seq[Graph.WeightedPath]] = { require(amount > 0.msat, "route amount must be strictly positive") if (localNodeId == targetNodeId) return Left(CannotRouteToSelf) @@ -240,7 +244,7 @@ object RouteCalculation { val boundaries: RichWeight => Boolean = { weight => feeOk(weight.cost - amount) && lengthOk(weight.length) && cltvOk(weight.cltv) } - val foundRoutes: Seq[Graph.WeightedPath] = Graph.yenKshortestPaths(g, localNodeId, targetNodeId, amount, ignoredEdges, ignoredVertices, extraEdges, numRoutes, routeParams.ratios, currentBlockHeight, boundaries, routeParams.includeLocalChannelCost) + val foundRoutes: Seq[Graph.WeightedPath] = Graph.yenKshortestPaths(g, localNodeId, targetNodeId, amount, ignoredEdges, ignoredVertices, extraEdges, numRoutes, routeParams.ratios, currentBlockHeight, boundaries, routeParams.includeLocalChannelCost, firstHopChannelId, lastHopChannelId) if (foundRoutes.nonEmpty) { val (directRoutes, indirectRoutes) = foundRoutes.partition(_.path.length == 1) val routes = if (routeParams.randomize) { @@ -252,7 +256,7 @@ object RouteCalculation { } else if (routeParams.maxRouteLength < ROUTE_MAX_LENGTH) { // if not found within the constraints we relax and repeat the search val relaxedRouteParams = routeParams.copy(maxRouteLength = ROUTE_MAX_LENGTH, maxCltv = DEFAULT_ROUTE_MAX_CLTV) - findRouteInternal(g, localNodeId, targetNodeId, amount, maxFee, numRoutes, extraEdges, ignoredEdges, ignoredVertices, relaxedRouteParams, currentBlockHeight) + findRouteInternal(g, localNodeId, targetNodeId, amount, maxFee, numRoutes, extraEdges, ignoredEdges, ignoredVertices, relaxedRouteParams, currentBlockHeight, firstHopChannelId, lastHopChannelId) } else { Left(RouteNotFound) } @@ -325,7 +329,7 @@ object RouteCalculation { val minPartAmount = routeParams.mpp.minPartAmount.max(amount / numRoutes).min(amount) routeParams.copy(mpp = MultiPartParams(minPartAmount, numRoutes)) } - findRouteInternal(g, localNodeId, targetNodeId, routeParams1.mpp.minPartAmount, maxFee, routeParams1.mpp.maxParts, extraEdges, ignoredEdges, ignoredVertices, routeParams1, currentBlockHeight) match { + findRouteInternal(g, localNodeId, targetNodeId, routeParams1.mpp.minPartAmount, maxFee, routeParams1.mpp.maxParts, extraEdges, ignoredEdges, ignoredVertices, routeParams1, currentBlockHeight, None, None) match { case Right(routes) => // We use these shortest paths to find a set of non-conflicting HTLCs that send the total amount. split(amount, mutable.Queue(routes: _*), initializeUsedCapacity(pendingHtlcs), routeParams1) match { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 6ca3616658..3d24fc2721 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -484,7 +484,9 @@ object Router { routeParams: RouteParams, allowMultiPart: Boolean = false, pendingPayments: Seq[Route] = Nil, - paymentContext: Option[PaymentContext] = None) + paymentContext: Option[PaymentContext] = None, + firstHopChannelId: Option[ShortChannelId] = None, + lastHopChannelId: Option[ShortChannelId] = None) case class FinalizeRoute(amount: MilliSatoshi, route: PredefinedRoute, diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala index 4ec2f47eee..c5044bb895 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala @@ -134,6 +134,192 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { } } + test("calculate the shortest path (with the first hop channel id)") { + val (a, b, c, d, e, f) = ( + PublicKey(hex"02999fa724ec3c244e4da52b4a91ad421dc96c9a810587849cd4b2469313519c73"), // a: source + PublicKey(hex"03f1cb1af20fe9ccda3ea128e27d7c39ee27375c8480f11a87c17197e97541ca6a"), + PublicKey(hex"0358e32d245ff5f5a3eb14c78c6f69c67cea7846bdf9aeeb7199e8f6fbb0306484"), + PublicKey(hex"029e059b6780f155f38e83601969919aae631ddf6faed58fe860c72225eb327d7c"), // d: target + PublicKey(hex"03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f"), + PublicKey(hex"020c65be6f9252e85ae2fe9a46eed892cb89565e2157730e78311b1621a0db4b22") + ) + + // note: we don't actually use floating point numbers + // cost(CD) = 10005 = amountMsat + 1 + (amountMsat * 400 / 1000000) + // cost(BC) = 10009,0015 = (cost(CD) + 1 + (cost(CD) * 300 / 1000000) + // cost(FD) = 10002 = amountMsat + 1 + (amountMsat * 100 / 1000000) + // cost(EF) = 10007,0008 = cost(FD) + 1 + (cost(FD) * 400 / 1000000) + // cost(AE) = 10007 -> A is source, shortest path found + // cost(AB) = 10009 + // + // The amounts that need to be sent through each edge are then: + // + // +--- A ---+ + // 10009,0015 msat | | 10007,0008 msat + // B E + // 10005 msat | | 10002 msat + // C F + // 10000 msat | | 10000 msat + // +--> D <--+ + + val amount = 10000 msat + val (ab, ae, bc, cd, df, fd) = ( + makeEdge(1L, a, b, feeBase = 1 msat, feeProportionalMillionth = 200, minHtlc = 0 msat), + makeEdge(4L, a, e, feeBase = 1 msat, feeProportionalMillionth = 200, minHtlc = 0 msat), + makeEdge(2L, b, c, feeBase = 1 msat, feeProportionalMillionth = 300, minHtlc = 0 msat), + makeEdge(3L, c, d, feeBase = 1 msat, feeProportionalMillionth = 400, minHtlc = 0 msat), + makeEdge(5L, e, f, feeBase = 1 msat, feeProportionalMillionth = 400, minHtlc = 0 msat), + makeEdge(6L, f, d, feeBase = 1 msat, feeProportionalMillionth = 100, minHtlc = 0 msat) + ) + val graph = DirectedGraph(List(ab, ae, bc, cd, df, fd)) + + { + // a route via C: A->B->C->D + val Success(route :: Nil) = findRoute(graph, a, d, amount, maxFee = 10 msat, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000, firstHopChannelId = Some(ab.desc.shortChannelId)) + val weightedPath = Graph.pathWeight(a, route2Edges(route), amount, 0, NO_WEIGHT_RATIOS, false) + assert(route2Ids(route) === 1 :: 2 :: 3 :: Nil) + assert(weightedPath.length === 3) + assert(weightedPath.cost === 10009.msat) + } + + { + // a route via F: A->E->F->D + val Success(route :: Nil) = findRoute(graph, a, d, amount, maxFee = 10 msat, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000, firstHopChannelId = Some(ae.desc.shortChannelId)) + val weightedPath = Graph.pathWeight(a, route2Edges(route), amount, 0, NO_WEIGHT_RATIOS, false) + assert(route2Ids(route) === 4 :: 5 :: 6 :: Nil) + assert(weightedPath.length === 3) + assert(weightedPath.cost === 10007.msat) + } + + // the route via C cannot be found because its fee is 9 > maxFee + assert(findRoute(graph, a, d, amount, maxFee = 7 msat, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000, firstHopChannelId = Some(ab.desc.shortChannelId)) == Failure(RouteNotFound)) + } + + test("calculate the shortest path (with the last hop node id)") { + val (a, b, c, d, e, f) = ( + PublicKey(hex"02999fa724ec3c244e4da52b4a91ad421dc96c9a810587849cd4b2469313519c73"), // a: source + PublicKey(hex"03f1cb1af20fe9ccda3ea128e27d7c39ee27375c8480f11a87c17197e97541ca6a"), + PublicKey(hex"0358e32d245ff5f5a3eb14c78c6f69c67cea7846bdf9aeeb7199e8f6fbb0306484"), + PublicKey(hex"029e059b6780f155f38e83601969919aae631ddf6faed58fe860c72225eb327d7c"), // d: target + PublicKey(hex"03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f"), + PublicKey(hex"020c65be6f9252e85ae2fe9a46eed892cb89565e2157730e78311b1621a0db4b22") + ) + + // note: we don't actually use floating point numbers + // cost(CD) = 10005 = amountMsat + 1 + (amountMsat * 400 / 1000000) + // cost(BC) = 10009,0015 = (cost(CD) + 1 + (cost(CD) * 300 / 1000000) + // cost(FD) = 10002 = amountMsat + 1 + (amountMsat * 100 / 1000000) + // cost(EF) = 10007,0008 = cost(FD) + 1 + (cost(FD) * 400 / 1000000) + // cost(AE) = 10007 -> A is source, shortest path found + // cost(AB) = 10009 + // + // The amounts that need to be sent through each edge are then: + // + // +--- A ---+ + // 10009,0015 msat | | 10007,0008 msat + // B E + // 10005 msat | | 10002 msat + // C F + // 10000 msat | | 10000 msat + // +--> D <--+ + + // val amount = 8750 msat + val amount = 10000 msat + val (ab, ae, bc, cd, df, fd) = ( + makeEdge(1L, a, b, feeBase = 1 msat, feeProportionalMillionth = 200, minHtlc = 0 msat), + makeEdge(4L, a, e, feeBase = 1 msat, feeProportionalMillionth = 200, minHtlc = 0 msat), + makeEdge(2L, b, c, feeBase = 1 msat, feeProportionalMillionth = 300, minHtlc = 0 msat), + makeEdge(3L, c, d, feeBase = 1 msat, feeProportionalMillionth = 400, minHtlc = 0 msat), + makeEdge(5L, e, f, feeBase = 1 msat, feeProportionalMillionth = 400, minHtlc = 0 msat), + makeEdge(6L, f, d, feeBase = 1 msat, feeProportionalMillionth = 100, minHtlc = 0 msat) + ) + val graph = DirectedGraph(List(ab, ae, bc, cd, df, fd)) + + { + // a route via C: A->B->C->D + val Success(route :: Nil) = findRoute(graph, a, d, amount, maxFee = 10 msat, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000, lastHopChannelId = Some(cd.desc.shortChannelId)) + val weightedPath = Graph.pathWeight(a, route2Edges(route), amount, 0, NO_WEIGHT_RATIOS, false) + assert(route2Ids(route) === 1 :: 2 :: 3 :: Nil) + assert(weightedPath.length === 3) + assert(weightedPath.cost === 10009.msat) + } + + { + // a route via F: A->E->F->D + val Success(route :: Nil) = findRoute(graph, a, d, amount, maxFee = 10 msat, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000, lastHopChannelId = Some(fd.desc.shortChannelId)) + val weightedPath = Graph.pathWeight(a, route2Edges(route), amount, 0, NO_WEIGHT_RATIOS, false) + assert(route2Ids(route) === 4 :: 5 :: 6 :: Nil) + assert(weightedPath.length === 3) + assert(weightedPath.cost === 10007.msat) + } + + // the route via C cannot be found because its fee is 9 > maxFee + assert(findRoute(graph, a, d, amount, maxFee = 7 msat, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000, lastHopChannelId = Some(cd.desc.shortChannelId)) == Failure(RouteNotFound)) + } + + test("calculate the shortest path (with the first hop channel id and the last hop node id)") { + val (a, b, c, d, e, f) = ( + PublicKey(hex"02999fa724ec3c244e4da52b4a91ad421dc96c9a810587849cd4b2469313519c73"), // a: source + PublicKey(hex"03f1cb1af20fe9ccda3ea128e27d7c39ee27375c8480f11a87c17197e97541ca6a"), + PublicKey(hex"0358e32d245ff5f5a3eb14c78c6f69c67cea7846bdf9aeeb7199e8f6fbb0306484"), + PublicKey(hex"029e059b6780f155f38e83601969919aae631ddf6faed58fe860c72225eb327d7c"), // d: target + PublicKey(hex"03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f"), + PublicKey(hex"020c65be6f9252e85ae2fe9a46eed892cb89565e2157730e78311b1621a0db4b22") + ) + + // note: we don't actually use floating point numbers + // cost(CD) = 10005 = amountMsat + 1 + (amountMsat * 400 / 1000000) + // cost(BC) = 10009,0015 = (cost(CD) + 1 + (cost(CD) * 300 / 1000000) + // cost(FD) = 10002 = amountMsat + 1 + (amountMsat * 100 / 1000000) + // cost(EF) = 10007,0008 = cost(FD) + 1 + (cost(FD) * 400 / 1000000) + // cost(AE) = 10007 -> A is source, shortest path found + // cost(AB) = 10009 + // + // The amounts that need to be sent through each edge are then: + // + // +--- A ---+ + // 10009,0015 msat | | 10007,0008 msat + // B E + // 10005 msat | | 10002 msat + // C F + // 10000 msat | | 10000 msat + // +--> D <--+ + + // val amount = 8750 msat + val amount = 10000 msat + val (ab, ae, bc, cd, df, fd) = ( + makeEdge(1L, a, b, feeBase = 1 msat, feeProportionalMillionth = 200, minHtlc = 0 msat), + makeEdge(4L, a, e, feeBase = 1 msat, feeProportionalMillionth = 200, minHtlc = 0 msat), + makeEdge(2L, b, c, feeBase = 1 msat, feeProportionalMillionth = 300, minHtlc = 0 msat), + makeEdge(3L, c, d, feeBase = 1 msat, feeProportionalMillionth = 400, minHtlc = 0 msat), + makeEdge(5L, e, f, feeBase = 1 msat, feeProportionalMillionth = 400, minHtlc = 0 msat), + makeEdge(6L, f, d, feeBase = 1 msat, feeProportionalMillionth = 100, minHtlc = 0 msat) + ) + val graph = DirectedGraph(List(ab, ae, bc, cd, df, fd)) + + { + // a route via C: A->B->C->D + val Success(route :: Nil) = findRoute(graph, a, d, amount, maxFee = 10 msat, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000, firstHopChannelId = Some(ab.desc.shortChannelId), lastHopChannelId = Some(cd.desc.shortChannelId)) + val weightedPath = Graph.pathWeight(a, route2Edges(route), amount, 0, NO_WEIGHT_RATIOS, false) + assert(route2Ids(route) === 1 :: 2 :: 3 :: Nil) + assert(weightedPath.length === 3) + assert(weightedPath.cost === 10009.msat) + } + + assert(findRoute(graph, a, d, amount, maxFee = 10 msat, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000, firstHopChannelId = Some(ab.desc.shortChannelId), lastHopChannelId = Some(fd.desc.shortChannelId)) == Failure(RouteNotFound)) + + { + // a route via C: A->E->F->D + val Success(route :: Nil) = findRoute(graph, a, d, amount, maxFee = 10 msat, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000, firstHopChannelId = Some(ae.desc.shortChannelId), lastHopChannelId = Some(fd.desc.shortChannelId)) + val weightedPath = Graph.pathWeight(a, route2Edges(route), amount, 0, NO_WEIGHT_RATIOS, false) + assert(route2Ids(route) === 4 :: 5 :: 6 :: Nil) + assert(weightedPath.length === 3) + assert(weightedPath.cost === 10007.msat) + } + + assert(findRoute(graph, a, d, amount, maxFee = 10 msat, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000, firstHopChannelId = Some(ae.desc.shortChannelId), lastHopChannelId = Some(cd.desc.shortChannelId)) == Failure(RouteNotFound)) + } + test("calculate route considering the direct channel pays no fees") { val g = DirectedGraph(List( makeEdge(1L, a, b, 5 msat, 0), // a -> b diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/directives/ExtraDirectives.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/directives/ExtraDirectives.scala index bab2342a01..38a9a25717 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/directives/ExtraDirectives.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/directives/ExtraDirectives.scala @@ -47,6 +47,8 @@ trait ExtraDirectives extends Directives { val amountMsatFormParam: NameReceptacle[MilliSatoshi] = "amountMsat".as[MilliSatoshi] val invoiceFormParam: NameReceptacle[PaymentRequest] = "invoice".as[PaymentRequest] val routeFormat: NameUnmarshallerReceptacle[RouteFormat] = "format".as[RouteFormat](routeFormatUnmarshaller) + val firstHopChannelIdParam: NameUnmarshallerReceptacle[ShortChannelId] = "firstHopChannel".as[ShortChannelId](shortChannelIdUnmarshaller) + val lastHopChannelIdParam: NameUnmarshallerReceptacle[ShortChannelId] = "lastHopChannel".as[ShortChannelId](shortChannelIdUnmarshaller) // custom directive to fail with HTTP 404 (and JSON response) if the element was not found def completeOrNotFound[T](fut: Future[Option[T]])(implicit marshaller: ToResponseMarshaller[T]): Route = onComplete(fut) { diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/PathFinding.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/PathFinding.scala index 874bc00a6c..e38a1a026d 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/PathFinding.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/PathFinding.scala @@ -33,11 +33,11 @@ trait PathFinding { private implicit def ec: ExecutionContext = actorSystem.dispatcher val findRoute: Route = postRequest("findroute") { implicit t => - formFields(invoiceFormParam, amountMsatFormParam.?, "pathFindingExperimentName".?, routeFormat.?) { - case (invoice@PaymentRequest(_, Some(amount), _, nodeId, _, _), None, pathFindingExperimentName_opt, routeFormat) => - complete(eclairApi.findRoute(nodeId, amount, pathFindingExperimentName_opt, invoice.routingInfo).map(r => RouteFormat.format(r, routeFormat))) - case (invoice, Some(overrideAmount), pathFindingExperimentName_opt, routeFormat) => - complete(eclairApi.findRoute(invoice.nodeId, overrideAmount, pathFindingExperimentName_opt, invoice.routingInfo).map(r => RouteFormat.format(r, routeFormat))) + formFields(invoiceFormParam, amountMsatFormParam.?, "pathFindingExperimentName".?, routeFormat.?, firstHopChannelIdParam.?, lastHopChannelIdParam.?) { + case (invoice@PaymentRequest(_, Some(amount), _, nodeId, _, _), None, pathFindingExperimentName_opt, routeFormat, firstHopChannel_opt, lastHopChannel_opt) => + complete(eclairApi.findRoute(nodeId, amount, pathFindingExperimentName_opt, invoice.routingInfo, firstHopChannel_opt, lastHopChannel_opt).map(r => RouteFormat.format(r, routeFormat))) + case (invoice, Some(overrideAmount), pathFindingExperimentName_opt, routeFormat, firstHopChannel_opt, lastHopChannel_opt) => + complete(eclairApi.findRoute(invoice.nodeId, overrideAmount, pathFindingExperimentName_opt, invoice.routingInfo, firstHopChannel_opt, lastHopChannel_opt).map(r => RouteFormat.format(r, routeFormat))) case _ => reject(MalformedFormFieldRejection( "invoice", "The invoice must have an amount or you need to specify one using 'amountMsat'" )) @@ -45,14 +45,16 @@ trait PathFinding { } val findRouteToNode: Route = postRequest("findroutetonode") { implicit t => - formFields(nodeIdFormParam, amountMsatFormParam, "pathFindingExperimentName".?, routeFormat.?) { (nodeId, amount, pathFindingExperimentName_opt, routeFormat) => - complete(eclairApi.findRoute(nodeId, amount, pathFindingExperimentName_opt).map(r => RouteFormat.format(r, routeFormat))) + formFields(nodeIdFormParam, amountMsatFormParam, "pathFindingExperimentName".?, routeFormat.?, firstHopChannelIdParam.?, lastHopChannelIdParam.?) { + (nodeId, amount, pathFindingExperimentName_opt, routeFormat, firstHopChannel_opt, lastHopChannel_opt) => + complete(eclairApi.findRoute(nodeId, amount, pathFindingExperimentName_opt, firstHopChannelId_opt = firstHopChannel_opt, lastHopChannelId_opt = lastHopChannel_opt).map(r => RouteFormat.format(r, routeFormat))) } } val findRouteBetweenNodes: Route = postRequest("findroutebetweennodes") { implicit t => - formFields("sourceNodeId".as[PublicKey], "targetNodeId".as[PublicKey], amountMsatFormParam, "pathFindingExperimentName".?, routeFormat.?) { (sourceNodeId, targetNodeId, amount, pathFindingExperimentName_opt, routeFormat) => - complete(eclairApi.findRouteBetween(sourceNodeId, targetNodeId, amount, pathFindingExperimentName_opt).map(r => RouteFormat.format(r, routeFormat))) + formFields("sourceNodeId".as[PublicKey], "targetNodeId".as[PublicKey], amountMsatFormParam, "pathFindingExperimentName".?, routeFormat.?, firstHopChannelIdParam.?, lastHopChannelIdParam.?) { + (sourceNodeId, targetNodeId, amount, pathFindingExperimentName_opt, routeFormat, firstHopChannel_opt, lastHopChannel_opt) => + complete(eclairApi.findRouteBetween(sourceNodeId, targetNodeId, amount, pathFindingExperimentName_opt, firstHopChannelId = firstHopChannel_opt, lastHopChannelId = lastHopChannel_opt).map(r => RouteFormat.format(r, routeFormat))) } } diff --git a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index 9223f9e960..9c249f5800 100644 --- a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -979,7 +979,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM val eclair = mock[Eclair] val mockService = new MockService(eclair) - eclair.findRoute(any, any, any, any)(any[Timeout]) returns Future.successful(Router.RouteResponse(Seq(Router.Route(456.msat, mockHops)))) + eclair.findRoute(any, any, any, any, any, any)(any[Timeout]) returns Future.successful(Router.RouteResponse(Seq(Router.Route(456.msat, mockHops)))) // invalid format Post("/findroute", FormData("format"-> "invalid-output-format", "invoice" -> invoice, "amountMsat" -> "456")) ~> @@ -989,7 +989,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM check { assert(handled) assert(status == BadRequest) - eclair.findRoute(PublicKey.fromBin(ByteVector.fromValidHex("036ded9bb8175d0c9fd3fad145965cf5005ec599570f35c682e710dc6001ff605e")), 456.msat, any, any)(any[Timeout]).wasNever(called) + eclair.findRoute(PublicKey.fromBin(ByteVector.fromValidHex("036ded9bb8175d0c9fd3fad145965cf5005ec599570f35c682e710dc6001ff605e")), 456.msat, any, any, any, any)(any[Timeout]).wasNever(called) } // default format @@ -1006,7 +1006,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM JString(mockHop3.nodeId.toString()), JString(mockHop3.nextNodeId.toString()) ))) - eclair.findRoute(PublicKey.fromBin(ByteVector.fromValidHex("036ded9bb8175d0c9fd3fad145965cf5005ec599570f35c682e710dc6001ff605e")), 456.msat, any, any)(any[Timeout]).wasCalled(once) + eclair.findRoute(PublicKey.fromBin(ByteVector.fromValidHex("036ded9bb8175d0c9fd3fad145965cf5005ec599570f35c682e710dc6001ff605e")), 456.msat, any, any, any, any)(any[Timeout]).wasCalled(once) } Post("/findroute", FormData("format" -> "nodeId", "invoice" -> invoice, "amountMsat" -> "456")) ~> @@ -1022,7 +1022,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM JString(mockHop3.nodeId.toString()), JString(mockHop3.nextNodeId.toString()) ))) - eclair.findRoute(PublicKey.fromBin(ByteVector.fromValidHex("036ded9bb8175d0c9fd3fad145965cf5005ec599570f35c682e710dc6001ff605e")), 456.msat, any, any)(any[Timeout]).wasCalled(twice) + eclair.findRoute(PublicKey.fromBin(ByteVector.fromValidHex("036ded9bb8175d0c9fd3fad145965cf5005ec599570f35c682e710dc6001ff605e")), 456.msat, any, any, any, any)(any[Timeout]).wasCalled(twice) } Post("/findroute", FormData("format" -> "shortChannelId", "invoice" -> invoice, "amountMsat" -> "456")) ~> @@ -1037,7 +1037,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM JString(mockHop2.lastUpdate.shortChannelId.toString()), JString(mockHop3.lastUpdate.shortChannelId.toString()) ))) - eclair.findRoute(PublicKey.fromBin(ByteVector.fromValidHex("036ded9bb8175d0c9fd3fad145965cf5005ec599570f35c682e710dc6001ff605e")), 456.msat, any, any)(any[Timeout]).wasCalled(threeTimes) + eclair.findRoute(PublicKey.fromBin(ByteVector.fromValidHex("036ded9bb8175d0c9fd3fad145965cf5005ec599570f35c682e710dc6001ff605e")), 456.msat, any, any, any, any)(any[Timeout]).wasCalled(threeTimes) } }