Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add firstHopChannel/lastHopChannel parameters to findroute* API calls #1950

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ trait Eclair {

def sendOnChain(address: String, amount: Satoshi, confirmationTarget: Long): Future[ByteVector32]

def findRoute(targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty)(implicit timeout: Timeout): Future[RouteResponse]
def findRoute(targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty, firstHopChannelId_opt: Option[ShortChannelId], lastHopChannelId_opt: Option[ShortChannelId])(implicit timeout: Timeout): Future[RouteResponse]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you really need to change both findRoute and findRouteBetween?
It seems to me that findRouteBetween is sufficient and is the API that better matches rebalancing needs, isn't it?


def findRouteBetween(sourceNodeId: PublicKey, targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty)(implicit timeout: Timeout): Future[RouteResponse]
def findRouteBetween(sourceNodeId: PublicKey, targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty, firstHopChannelId: Option[ShortChannelId], lastHopChannelId: Option[ShortChannelId])(implicit timeout: Timeout): Future[RouteResponse]

def sendToRoute(amount: MilliSatoshi, recipientAmount_opt: Option[MilliSatoshi], externalId_opt: Option[String], parentId_opt: Option[UUID], invoice: PaymentRequest, finalCltvExpiryDelta: CltvExpiryDelta, route: PredefinedRoute, trampolineSecret_opt: Option[ByteVector32] = None, trampolineFees_opt: Option[MilliSatoshi] = None, trampolineExpiryDelta_opt: Option[CltvExpiryDelta] = None, trampolineNodes_opt: Seq[PublicKey] = Nil)(implicit timeout: Timeout): Future[SendPaymentToRouteResponse]

Expand Down Expand Up @@ -279,8 +279,8 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging {
}
}

override def findRoute(targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty)(implicit timeout: Timeout): Future[RouteResponse] =
findRouteBetween(appKit.nodeParams.nodeId, targetNodeId, amount, pathFindingExperimentName_opt, assistedRoutes)
override def findRoute(targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty, firstHopChannelId_opt: Option[ShortChannelId], lastHopChannelId_opt: Option[ShortChannelId])(implicit timeout: Timeout): Future[RouteResponse] =
findRouteBetween(appKit.nodeParams.nodeId, targetNodeId, amount, pathFindingExperimentName_opt, assistedRoutes, firstHopChannelId_opt, lastHopChannelId_opt)

private def getRouteParams(pathFindingExperimentName_opt: Option[String]): Option[RouteParams] = {
pathFindingExperimentName_opt match {
Expand All @@ -289,11 +289,11 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging {
}
}

override def findRouteBetween(sourceNodeId: PublicKey, targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty)(implicit timeout: Timeout): Future[RouteResponse] = {
override def findRouteBetween(sourceNodeId: PublicKey, targetNodeId: PublicKey, amount: MilliSatoshi, pathFindingExperimentName_opt: Option[String], assistedRoutes: Seq[Seq[PaymentRequest.ExtraHop]] = Seq.empty, firstHopChannelId_opt: Option[ShortChannelId], lastHopChannelId_opt: Option[ShortChannelId])(implicit timeout: Timeout): Future[RouteResponse] = {
getRouteParams(pathFindingExperimentName_opt) match {
case Some(routeParams) =>
val maxFee = routeParams.getMaxFee(amount)
(appKit.router ? RouteRequest(sourceNodeId, targetNodeId, amount, maxFee, assistedRoutes, routeParams = routeParams)).mapTo[RouteResponse]
(appKit.router ? RouteRequest(sourceNodeId, targetNodeId, amount, maxFee, assistedRoutes, routeParams = routeParams, firstHopChannelId = firstHopChannelId_opt, lastHopChannelId = lastHopChannelId_opt)).mapTo[RouteResponse]
case None => Future.failed(new IllegalArgumentException(s"Path-finding experiment ${pathFindingExperimentName_opt.get} does not exist."))
}
}
Expand Down
41 changes: 34 additions & 7 deletions eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,12 @@ object Graph {
wr: WeightRatios,
currentBlockHeight: Long,
boundaries: RichWeight => Boolean,
includeLocalChannelCost: Boolean): Seq[WeightedPath] = {
includeLocalChannelCost: Boolean,
firstHopChannelId: Option[ShortChannelId] = None,
lastHopChannelId: Option[ShortChannelId] = None): Seq[WeightedPath] = {
// find the shortest path (k = 0)
val targetWeight = RichWeight(amount, 0, CltvExpiryDelta(0), 0)
val shortestPath = dijkstraShortestPath(graph, sourceNode, targetNode, ignoredEdges, ignoredVertices, extraEdges, targetWeight, boundaries, currentBlockHeight, wr, includeLocalChannelCost)
val shortestPath = dijkstraShortestPath(graph, sourceNode, targetNode, ignoredEdges, ignoredVertices, extraEdges, targetWeight, boundaries, currentBlockHeight, wr, includeLocalChannelCost, firstHopChannelId, lastHopChannelId)
if (shortestPath.isEmpty) {
return Seq.empty // if we can't even find a single path, avoid returning a Seq(Seq.empty)
}
Expand Down Expand Up @@ -144,7 +146,7 @@ object Graph {
val alreadyExploredVertices = rootPathEdges.map(_.desc.b).toSet
val rootPathWeight = pathWeight(sourceNode, rootPathEdges, amount, currentBlockHeight, wr, includeLocalChannelCost)
// find the "spur" path, a sub-path going from the spur node to the target avoiding previously found sub-paths
val spurPath = dijkstraShortestPath(graph, sourceNode, spurNode, ignoredEdges ++ alreadyExploredEdges, ignoredVertices ++ alreadyExploredVertices, extraEdges, rootPathWeight, boundaries, currentBlockHeight, wr, includeLocalChannelCost)
val spurPath = dijkstraShortestPath(graph, sourceNode, spurNode, ignoredEdges ++ alreadyExploredEdges, ignoredVertices ++ alreadyExploredVertices, extraEdges, rootPathWeight, boundaries, currentBlockHeight, wr, includeLocalChannelCost, firstHopChannelId, lastHopChannelId)
if (spurPath.nonEmpty) {
val completePath = spurPath ++ rootPathEdges
val candidatePath = WeightedPath(completePath, pathWeight(sourceNode, completePath, amount, currentBlockHeight, wr, includeLocalChannelCost))
Expand Down Expand Up @@ -192,7 +194,9 @@ object Graph {
boundaries: RichWeight => Boolean,
currentBlockHeight: Long,
wr: WeightRatios,
includeLocalChannelCost: Boolean): Seq[GraphEdge] = {
includeLocalChannelCost: Boolean,
firstHopChannelId: Option[ShortChannelId],
lastHopChannelId: Option[ShortChannelId]): Seq[GraphEdge] = {
Comment on lines +198 to +199
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't touch dijkstraShortestPath, there is no need. Instead search for a path from the end of firstHopChannelId to the start of lastHopChannelId and add our nodeId to ignoredVertices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out there's no need in changing Router at all. All above can be done over the API parameters (see #1969). I'm closing this PR.

// the graph does not contain source/destination nodes
val sourceNotInGraph = !g.containsVertex(sourceNode) && !extraEdges.exists(_.desc.a == sourceNode)
val targetNotInGraph = !g.containsVertex(targetNode) && !extraEdges.exists(_.desc.b == targetNode)
Expand Down Expand Up @@ -222,9 +226,32 @@ object Graph {
visitedNodes += current.key
// build the neighbors with optional extra edges
val neighborEdges = {
val extraNeighbors = extraEdges.filter(_.desc.b == current.key)
// the resulting set must have only one element per shortChannelId; we prioritize extra edges
g.getIncomingEdgesOf(current.key).filterNot(e => extraNeighbors.exists(_.desc.shortChannelId == e.desc.shortChannelId)) ++ extraNeighbors

def allEdges: Seq[GraphEdge] = {
val extraNeighbors = extraEdges.filter(_.desc.b == current.key)
// the resulting set must have only one element per shortChannelId; we prioritize extra edges
g.getIncomingEdgesOf(current.key).filterNot(e => extraNeighbors.exists(_.desc.shortChannelId == e.desc.shortChannelId)) ++ extraNeighbors
}

def lastHopEdges(edges: Seq[GraphEdge], lastHop: ShortChannelId) = {
if (current.key == targetNode)
edges.filter(e => e.desc.shortChannelId == lastHop)
else edges
}

def firstHopEdges(edges: Seq[GraphEdge], firstHop: ShortChannelId) = {
edges.filter(e => e.desc.a != sourceNode || e.desc.shortChannelId == firstHop)
}

(firstHopChannelId, lastHopChannelId) match {
case (None, None) => allEdges
case (Some(firstHop), Some(lastHop)) =>
firstHopEdges(lastHopEdges(allEdges, lastHop), firstHop)
case (Some(firstHop), None) =>
firstHopEdges(allEdges, firstHop)
case (None, Some(lastHop)) =>
lastHopEdges(allEdges, lastHop)
}
}
neighborEdges.foreach { edge =>
val neighbor = edge.desc.a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ object RouteCalculation {
val result = if (r.allowMultiPart) {
findMultiPartRoute(d.graph, r.source, r.target, r.amount, r.maxFee, extraEdges, ignoredEdges, r.ignore.nodes, r.pendingPayments, params, currentBlockHeight)
} else {
findRoute(d.graph, r.source, r.target, r.amount, r.maxFee, routesToFind, extraEdges, ignoredEdges, r.ignore.nodes, params, currentBlockHeight)
findRoute(d.graph, r.source, r.target, r.amount, r.maxFee, routesToFind, extraEdges, ignoredEdges, r.ignore.nodes, params, currentBlockHeight, r.firstHopChannelId, r.lastHopChannelId)
}
result match {
case Success(routes) =>
Expand Down Expand Up @@ -209,8 +209,10 @@ object RouteCalculation {
ignoredEdges: Set[ChannelDesc] = Set.empty,
ignoredVertices: Set[PublicKey] = Set.empty,
routeParams: RouteParams,
currentBlockHeight: Long): Try[Seq[Route]] = Try {
findRouteInternal(g, localNodeId, targetNodeId, amount, maxFee, numRoutes, extraEdges, ignoredEdges, ignoredVertices, routeParams, currentBlockHeight) match {
currentBlockHeight: Long,
firstHopChannelId: Option[ShortChannelId] = None,
lastHopChannelId: Option[ShortChannelId] = None): Try[Seq[Route]] = Try {
findRouteInternal(g, localNodeId, targetNodeId, amount, maxFee, numRoutes, extraEdges, ignoredEdges, ignoredVertices, routeParams, currentBlockHeight, firstHopChannelId, lastHopChannelId) match {
case Right(routes) => routes.map(route => Route(amount, route.path.map(graphEdgeToHop)))
case Left(ex) => return Failure(ex)
}
Expand All @@ -227,7 +229,9 @@ object RouteCalculation {
ignoredEdges: Set[ChannelDesc] = Set.empty,
ignoredVertices: Set[PublicKey] = Set.empty,
routeParams: RouteParams,
currentBlockHeight: Long): Either[RouterException, Seq[Graph.WeightedPath]] = {
currentBlockHeight: Long,
firstHopChannelId: Option[ShortChannelId],
lastHopChannelId: Option[ShortChannelId]): Either[RouterException, Seq[Graph.WeightedPath]] = {
require(amount > 0.msat, "route amount must be strictly positive")

if (localNodeId == targetNodeId) return Left(CannotRouteToSelf)
Expand All @@ -240,7 +244,7 @@ object RouteCalculation {

val boundaries: RichWeight => Boolean = { weight => feeOk(weight.cost - amount) && lengthOk(weight.length) && cltvOk(weight.cltv) }

val foundRoutes: Seq[Graph.WeightedPath] = Graph.yenKshortestPaths(g, localNodeId, targetNodeId, amount, ignoredEdges, ignoredVertices, extraEdges, numRoutes, routeParams.ratios, currentBlockHeight, boundaries, routeParams.includeLocalChannelCost)
val foundRoutes: Seq[Graph.WeightedPath] = Graph.yenKshortestPaths(g, localNodeId, targetNodeId, amount, ignoredEdges, ignoredVertices, extraEdges, numRoutes, routeParams.ratios, currentBlockHeight, boundaries, routeParams.includeLocalChannelCost, firstHopChannelId, lastHopChannelId)
if (foundRoutes.nonEmpty) {
val (directRoutes, indirectRoutes) = foundRoutes.partition(_.path.length == 1)
val routes = if (routeParams.randomize) {
Expand All @@ -252,7 +256,7 @@ object RouteCalculation {
} else if (routeParams.maxRouteLength < ROUTE_MAX_LENGTH) {
// if not found within the constraints we relax and repeat the search
val relaxedRouteParams = routeParams.copy(maxRouteLength = ROUTE_MAX_LENGTH, maxCltv = DEFAULT_ROUTE_MAX_CLTV)
findRouteInternal(g, localNodeId, targetNodeId, amount, maxFee, numRoutes, extraEdges, ignoredEdges, ignoredVertices, relaxedRouteParams, currentBlockHeight)
findRouteInternal(g, localNodeId, targetNodeId, amount, maxFee, numRoutes, extraEdges, ignoredEdges, ignoredVertices, relaxedRouteParams, currentBlockHeight, firstHopChannelId, lastHopChannelId)
} else {
Left(RouteNotFound)
}
Expand Down Expand Up @@ -325,7 +329,7 @@ object RouteCalculation {
val minPartAmount = routeParams.mpp.minPartAmount.max(amount / numRoutes).min(amount)
routeParams.copy(mpp = MultiPartParams(minPartAmount, numRoutes))
}
findRouteInternal(g, localNodeId, targetNodeId, routeParams1.mpp.minPartAmount, maxFee, routeParams1.mpp.maxParts, extraEdges, ignoredEdges, ignoredVertices, routeParams1, currentBlockHeight) match {
findRouteInternal(g, localNodeId, targetNodeId, routeParams1.mpp.minPartAmount, maxFee, routeParams1.mpp.maxParts, extraEdges, ignoredEdges, ignoredVertices, routeParams1, currentBlockHeight, None, None) match {
case Right(routes) =>
// We use these shortest paths to find a set of non-conflicting HTLCs that send the total amount.
split(amount, mutable.Queue(routes: _*), initializeUsedCapacity(pendingHtlcs), routeParams1) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,9 @@ object Router {
routeParams: RouteParams,
allowMultiPart: Boolean = false,
pendingPayments: Seq[Route] = Nil,
paymentContext: Option[PaymentContext] = None)
paymentContext: Option[PaymentContext] = None,
firstHopChannelId: Option[ShortChannelId] = None,
lastHopChannelId: Option[ShortChannelId] = None)

case class FinalizeRoute(amount: MilliSatoshi,
route: PredefinedRoute,
Expand Down
Loading