Skip to content

Commit

Permalink
rfq: policies track accepted htlcs
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeTsagk committed Dec 4, 2024
1 parent 6ce4316 commit 1923ca2
Showing 1 changed file with 221 additions and 10 deletions.
231 changes: 221 additions & 10 deletions rfq/order.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/lightninglabs/taproot-assets/fn"
"github.com/lightninglabs/taproot-assets/rfqmath"
"github.com/lightninglabs/taproot-assets/rfqmsg"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnrpc/routerrpc"
"github.com/lightningnetwork/lnd/lnutils"
Expand Down Expand Up @@ -71,6 +72,14 @@ type Policy interface {
// which the policy applies.
Scid() uint64

// TrackAcceptedHtlc makes the policy aware of this new accepted HTLC.
// This is important in cases where the set of existing HTLCs may affect
// whether the next compliance check passes.
TrackAcceptedHtlc(circuitKey models.CircuitKey, amt lnwire.MilliSatoshi)

// UntrackHtlc stops tracking the uniquely identified HTLC.
UntrackHtlc(circuitKey models.CircuitKey)

// GenerateInterceptorResponse generates an interceptor response for the
// HTLC interceptor from the policy.
GenerateInterceptorResponse(
Expand All @@ -95,22 +104,38 @@ type AssetSalePolicy struct {
// the policy.
MaxOutboundAssetAmount uint64

// CurrentAssetAmountMsat is the total amount that is held currently in
// accepted HTLCs.
CurrentAmountMsat lnwire.MilliSatoshi

// stateMutex is a mutex that locks access to this policy's internal
// state. This is needed as state is updated asynchronously by each
// routine that handles an intercepted HTLC.
stateMutex sync.RWMutex

// AskAssetRate is the quote's asking asset unit to BTC conversion rate.
AskAssetRate rfqmath.BigIntFixedPoint

// htlcToAmt maps the unique HTLC identifiers to the effective amount
// that they carry.
htlcToAmt map[models.CircuitKey]lnwire.MilliSatoshi

// expiry is the policy's expiry unix timestamp after which the policy
// is no longer valid.
expiry uint64
}

// NewAssetSalePolicy creates a new asset sale policy.
func NewAssetSalePolicy(quote rfqmsg.BuyAccept) *AssetSalePolicy {
htlcToAmtMap := make(map[models.CircuitKey]lnwire.MilliSatoshi)

return &AssetSalePolicy{
AssetSpecifier: quote.Request.AssetSpecifier,
AcceptedQuoteId: quote.ID,
MaxOutboundAssetAmount: quote.Request.AssetMaxAmt,
AskAssetRate: quote.AssetRate.Rate,
expiry: uint64(quote.AssetRate.Expiry.Unix()),
htlcToAmt: htlcToAmtMap,
}
}

Expand All @@ -128,7 +153,7 @@ func (c *AssetSalePolicy) CheckHtlcCompliance(
// Check that the channel SCID is as expected.
htlcScid := SerialisedScid(htlc.OutgoingChannelID.ToUint64())
if htlcScid != c.AcceptedQuoteId.Scid() {
return fmt.Errorf("htlc outgoing channel ID does not match "+
return fmt.Errorf("HTLC outgoing channel ID does not match "+
"policy's SCID (htlc_scid=%d, policy_scid=%d)",
htlcScid, c.AcceptedQuoteId.Scid())
}
Expand All @@ -152,8 +177,13 @@ func (c *AssetSalePolicy) CheckHtlcCompliance(
maxAssetAmount, c.AskAssetRate,
)

if htlc.AmountOutMsat > policyMaxOutMsat {
return fmt.Errorf("htlc out amount is greater than the policy "+
// Since we will be reading CurrentAmountMsat value we acquire a read
// lock.
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()

if (c.CurrentAmountMsat + htlc.AmountOutMsat) > policyMaxOutMsat {
return fmt.Errorf("HTLC out amount is greater than the policy "+
"maximum (htlc_out_msat=%d, policy_max_out_msat=%d)",
htlc.AmountOutMsat, policyMaxOutMsat)
}
Expand All @@ -167,6 +197,34 @@ func (c *AssetSalePolicy) CheckHtlcCompliance(
return nil
}

// TrackAcceptedHtlc accounts for the newly accepted HTLC. This may affect the
// acceptance of future HTLCs.
func (c *AssetSalePolicy) TrackAcceptedHtlc(circuitKey models.CircuitKey,
amt lnwire.MilliSatoshi) {

c.stateMutex.Lock()
defer c.stateMutex.Unlock()

c.CurrentAmountMsat += amt

c.htlcToAmt[circuitKey] = amt
}

// UntrackHtlc stops tracking the uniquely identified HTLC.
func (c *AssetSalePolicy) UntrackHtlc(circuitKey models.CircuitKey) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()

amt, found := c.htlcToAmt[circuitKey]
if !found {
return
}

delete(c.htlcToAmt, circuitKey)

c.CurrentAmountMsat -= amt
}

// Expiry returns the policy's expiry time as a unix timestamp.
func (c *AssetSalePolicy) Expiry() uint64 {
return c.expiry
Expand Down Expand Up @@ -246,26 +304,42 @@ type AssetPurchasePolicy struct {
// AcceptedQuoteId is the ID of the accepted quote.
AcceptedQuoteId rfqmsg.ID

// CurrentAssetAmountMsat is the total amount that is held currently in
// accepted HTLCs.
CurrentAmountMsat lnwire.MilliSatoshi

// stateMutex is a mutex that locks access to this policy's internal
// state. This is needed as state is updated asynchronously by each
// routine that handles an intercepted HTLC.
stateMutex sync.RWMutex

// BidAssetRate is the quote's asset to BTC conversion rate.
BidAssetRate rfqmath.BigIntFixedPoint

// PaymentMaxAmt is the maximum agreed BTC payment.
PaymentMaxAmt lnwire.MilliSatoshi

// htlcToAmt maps the unique HTLC identifiers to the effective amount
// that they carry.
htlcToAmt map[models.CircuitKey]lnwire.MilliSatoshi

// expiry is the policy's expiry unix timestamp in seconds after which
// the policy is no longer valid.
expiry uint64
}

// NewAssetPurchasePolicy creates a new asset purchase policy.
func NewAssetPurchasePolicy(quote rfqmsg.SellAccept) *AssetPurchasePolicy {
htlcToAmtMap := make(map[models.CircuitKey]lnwire.MilliSatoshi)

return &AssetPurchasePolicy{
scid: quote.ShortChannelId(),
AssetSpecifier: quote.Request.AssetSpecifier,
AcceptedQuoteId: quote.ID,
BidAssetRate: quote.AssetRate.Rate,
PaymentMaxAmt: quote.Request.PaymentMaxAmt,
expiry: uint64(quote.AssetRate.Expiry.Unix()),
htlcToAmt: htlcToAmtMap,
}
}

Expand All @@ -288,7 +362,7 @@ func (c *AssetPurchasePolicy) CheckHtlcCompliance(

if rfqID != c.AcceptedQuoteId {
return fmt.Errorf("HTLC contains a custom record, but it does "+
"not contain the accepted quote ID (htlc=%v, "+
"not contain the accepted quote ID (HTLC=%v, "+
"accepted_quote_id=%v)", htlc, c.AcceptedQuoteId)
}

Expand All @@ -313,17 +387,22 @@ func (c *AssetPurchasePolicy) CheckHtlcCompliance(
)

if inboundAmountMSat < htlc.AmountOutMsat {
return fmt.Errorf("htlc out amount is more than inbound "+
return fmt.Errorf("HTLC out amount is more than inbound "+
"asset amount in millisatoshis (htlc_out_msat=%d, "+
"inbound_asset_amount=%s, "+
"inbound_asset_amount_msat=%v)", htlc.AmountOutMsat,
assetAmt.String(), inboundAmountMSat)
}

// Since we will be reading CurrentAmountMsat value we acquire a read
// lock.
c.stateMutex.RLock()
defer c.stateMutex.RUnlock()

// Ensure that the outbound HTLC amount is less than the maximum agreed
// BTC payment.
if htlc.AmountOutMsat > c.PaymentMaxAmt {
return fmt.Errorf("htlc out amount is more than the maximum "+
if (c.CurrentAmountMsat + htlc.AmountOutMsat) > c.PaymentMaxAmt {
return fmt.Errorf("HTLC out amount is more than the maximum "+
"agreed BTC payment (htlc_out_msat=%d, "+
"payment_max_amt=%d)", htlc.AmountOutMsat,
c.PaymentMaxAmt)
Expand All @@ -338,6 +417,34 @@ func (c *AssetPurchasePolicy) CheckHtlcCompliance(
return nil
}

// TrackAcceptedHtlc accounts for the newly accepted HTLC. This may affect the
// acceptance of future HTLCs.
func (c *AssetPurchasePolicy) TrackAcceptedHtlc(circuitKey models.CircuitKey,
amt lnwire.MilliSatoshi) {

c.stateMutex.Lock()
defer c.stateMutex.Unlock()

c.CurrentAmountMsat += amt

c.htlcToAmt[circuitKey] = amt
}

// UntrackHtlc stops tracking the uniquely identified HTLC.
func (c *AssetPurchasePolicy) UntrackHtlc(circuitKey models.CircuitKey) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()

amt, found := c.htlcToAmt[circuitKey]
if !found {
return
}

delete(c.htlcToAmt, circuitKey)

c.CurrentAmountMsat -= amt
}

// Expiry returns the policy's expiry time as a unix timestamp in seconds.
func (c *AssetPurchasePolicy) Expiry() uint64 {
return c.expiry
Expand Down Expand Up @@ -436,6 +543,27 @@ func (a *AssetForwardPolicy) CheckHtlcCompliance(
return nil
}

// TrackAcceptedHtlc accounts for the newly accepted HTLC. This may affect the
// acceptance of future HTLCs.
func (a *AssetForwardPolicy) TrackAcceptedHtlc(circuitKey models.CircuitKey,
amt lnwire.MilliSatoshi) {

// Track accepted HTLC in the incoming policy.
a.incomingPolicy.TrackAcceptedHtlc(circuitKey, amt)

// Track accepted HTLC in the outgoing policy.
a.outgoingPolicy.TrackAcceptedHtlc(circuitKey, amt)
}

// UntrackHtlc stops tracking the uniquely identified HTLC.
func (a *AssetForwardPolicy) UntrackHtlc(circuitKey models.CircuitKey) {
// Untrack HTLC in the incoming policy.
a.incomingPolicy.UntrackHtlc(circuitKey)

// Untrack HTLC in the outgoing policy.
a.outgoingPolicy.UntrackHtlc(circuitKey)
}

// Expiry returns the policy's expiry time as a unix timestamp in seconds. The
// returned expiry time is the earliest expiry time of the incoming and outgoing
// policies.
Expand Down Expand Up @@ -514,6 +642,10 @@ type OrderHandlerCfg struct {

// AcceptHtlcEvents is a channel that receives accepted HTLCs.
AcceptHtlcEvents chan<- *AcceptHtlcEvent

// HtlcSubscriber is a subscriber that is used to retrieve live HTLC
// event updates.
HtlcSubscriber HtlcSubscriber
}

// OrderHandler orchestrates management of accepted quote bundles. It monitors
Expand All @@ -530,6 +662,11 @@ type OrderHandler struct {
// associated asset transaction policies.
policies lnutils.SyncMap[SerialisedScid, Policy]

// htlcToPolicy maps an HTLC circuit key to the policy that applies to
// it. We need this map because for failed HTLCs we don't have the RFQ
// data available, so we need to cache this info.
htlcToPolicy lnutils.SyncMap[models.CircuitKey, Policy]

// ContextGuard provides a wait group and main quit channel that can be
// used to create guarded contexts.
*fn.ContextGuard
Expand Down Expand Up @@ -586,13 +723,19 @@ func (h *OrderHandler) handleIncomingHtlc(_ context.Context,
err = policy.CheckHtlcCompliance(htlc)
if err != nil {
log.Warnf("HTLC does not comply with policy: %v "+
"(htlc=%v, policy=%v)", err, htlc, policy)
"(HTLC=%v, policy=%v)", err, htlc, policy)

return &lndclient.InterceptedHtlcResponse{
Action: lndclient.InterceptorActionFail,
}, nil
}

h.htlcToPolicy.Store(htlc.IncomingCircuitKey, policy)

// The HTLC passed the compliance checks, so now we keep track of the
// accepted HTLC.
policy.TrackAcceptedHtlc(htlc.IncomingCircuitKey, htlc.AmountOutMsat)

log.Debug("HTLC complies with policy. Broadcasting accept event.")
h.cfg.AcceptHtlcEvents <- NewAcceptHtlcEvent(htlc, policy)

Expand Down Expand Up @@ -640,12 +783,66 @@ func (h *OrderHandler) mainEventLoop() {
}
}

// subscribeHtlcs subscribes the OrderHandler to HTLC events provided by the lnd
// RPC interface. We use this subscription to track HTLC forwarding failures,
// which we use to performn a live update of our policies.
func (h *OrderHandler) subscribeHtlcs(ctx context.Context) error {
events, chErr, err := h.cfg.HtlcSubscriber.SubscribeHtlcEvents(ctx)
if err != nil {
return err
}

for {
select {
case event := <-events:
// We only care about forwarding events.
if event.GetEventType() != routerrpc.HtlcEvent_FORWARD {
continue
}

// Retrieve the two instances that may be relevant.
failEvent := event.GetForwardFailEvent()
linkFail := event.GetLinkFailEvent()

// Craft the circuit key that identifies this HTLC.
circuitKey := models.CircuitKey{
ChanID: lnwire.NewShortChanIDFromInt(
event.IncomingChannelId,
),
HtlcID: event.IncomingHtlcId,
}

switch {
case failEvent != nil:
fallthrough
case linkFail != nil:
// Fetch the policy that is related to this
// HTLC.
policy, found := h.htlcToPolicy.LoadAndDelete(
circuitKey,
)

if !found {
continue
}

// Stop tracking this HTLC as it failed.
policy.UntrackHtlc(circuitKey)
}

case err := <-chErr:
return err

case <-ctx.Done():
return ctx.Err()
}
}
}

// Start starts the service.
func (h *OrderHandler) Start() error {
var startErr error
h.startOnce.Do(func() {
log.Info("Starting subsystem: order handler")

// Start the main event loop in a separate goroutine.
h.Wg.Add(1)
go func() {
Expand All @@ -663,6 +860,20 @@ func (h *OrderHandler) Start() error {

h.mainEventLoop()
}()

// Start the HTLC event subscription loop.
h.Wg.Add(1)
go func() {
defer h.Wg.Done()

ctx, cancel := h.WithCtxQuitNoTimeout()
defer cancel()

err := h.subscribeHtlcs(ctx)
if err != nil {
log.Errorf("HTLC subscriber error: %v", err)
}
}()
})

return startErr
Expand Down

0 comments on commit 1923ca2

Please sign in to comment.