diff --git a/feature/default_sets.go b/feature/default_sets.go index 4227691aa5c..28ecc5e53e1 100644 --- a/feature/default_sets.go +++ b/feature/default_sets.go @@ -49,6 +49,13 @@ var defaultSetDesc = setDesc{ SetInit: {}, // I SetNodeAnn: {}, // N }, + // Note: we set route blinding optionally in our init and announcement, + // but not yet in invoices (9) as the spec instructs because we do not + // yet support receiving payments to blinded routes, only relaying them. + lnwire.RouteBlindingOptional: { + SetInit: {}, // I + SetNodeAnn: {}, // N + }, lnwire.WumboChannelsOptional: { SetInit: {}, // I SetNodeAnn: {}, // N diff --git a/htlcswitch/hop/forwarding_info.go b/htlcswitch/hop/forwarding_info.go index 3ec358a0acb..484681b9d3c 100644 --- a/htlcswitch/hop/forwarding_info.go +++ b/htlcswitch/hop/forwarding_info.go @@ -1,6 +1,7 @@ package hop import ( + "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/lnwire" ) @@ -22,4 +23,9 @@ type ForwardingInfo struct { // OutgoingCTLV is the specified value of the CTLV timelock to be used // in the outgoing HTLC. OutgoingCTLV uint32 + + // NextBlinding is an optional blinding point to be passed to the next + // node in UpdateAddHtlc. This field is set if the htlc is part of a + // blinded route. + NextBlinding *btcec.PublicKey } diff --git a/htlcswitch/hop/iterator.go b/htlcswitch/hop/iterator.go index 14f11c56730..14c57675262 100644 --- a/htlcswitch/hop/iterator.go +++ b/htlcswitch/hop/iterator.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" ) // Iterator is an interface that abstracts away the routing information @@ -47,16 +48,24 @@ type sphinxHopIterator struct { // includes the information required to properly forward the packet to // the next hop. processedPacket *sphinx.ProcessedPacket + + // blindingKit contains the elements required to process hops that are + // part of a blinded route. + blindingKit *BlindingKit } // makeSphinxHopIterator converts a processed packet returned from a sphinx -// router and converts it into an hop iterator for usage in the link. +// router and converts it into an hop iterator for usage in the link. A +// blinding kit is passed through for the link to obtain forwarding information +// for blinded routes. func makeSphinxHopIterator(ogPacket *sphinx.OnionPacket, - packet *sphinx.ProcessedPacket) *sphinxHopIterator { + packet *sphinx.ProcessedPacket, + blindingKit *BlindingKit) *sphinxHopIterator { return &sphinxHopIterator{ ogPacket: ogPacket, processedPacket: packet, + blindingKit: blindingKit, } } @@ -92,7 +101,7 @@ func (r *sphinxHopIterator) HopPayload() (*Payload, error) { case sphinx.PayloadTLV: return NewPayloadFromReader(bytes.NewReader( r.processedPacket.Payload.Payload, - )) + ), r.blindingKit) default: return nil, fmt.Errorf("unknown sphinx payload type: %v", @@ -112,6 +121,166 @@ func (r *sphinxHopIterator) ExtractErrorEncrypter( return extracter(r.ogPacket.EphemeralKey) } +// BlindingProcessor is an interface that provides the cryptographic operations +// required for processing blinded hops. +type BlindingProcessor interface { + // UnBlindData decrypts a blinded blob of data using the ephemeral key + // provided. + UnBlindData(ephemPub *btcec.PublicKey, + encryptedData []byte) ([]byte, error) + + // NextEphemeral returns the next hop's ephemeral key, calculated + // from the current ephemeral key provided. + NextEphemeral(*btcec.PublicKey) (*btcec.PublicKey, error) +} + +// BlindingKit contains the components required to extract forwarding +// information for hops in a blinded route. +type BlindingKit struct { + // BlindingPoint holds a blinding point that was passed to the node via + // update_add_htlc's TLVs. + BlindingPoint *btcec.PublicKey + + // lastHop indicates whether we're in the last hop in the onion route. + lastHop bool + + // forwardingInfo uses the ephemeral blinding key provided to decrypt + // a blob of encrypted data provided in the onion and obtain the + // forwarding information for the blinded hop. + forwardingInfo func(*btcec.PublicKey, []byte) (*ForwardingInfo, + error) +} + +// MakeBlindingKit produces a kit that is used to decrypte and decode +// forwarding information for hops in blinded routes. +func MakeBlindingKit(processor BlindingProcessor, + blindingPoint *btcec.PublicKey, lastHop bool, + incomingAmount lnwire.MilliSatoshi, incomingCltv uint32) *BlindingKit { + + return &BlindingKit{ + BlindingPoint: blindingPoint, + lastHop: lastHop, + forwardingInfo: deriveForwardingInfo( + processor, incomingAmount, incomingCltv, + ), + } +} + +// deriveForwardingInfo produces a function that will decrypt and deserialize +// an encrypted blob of data for a hop in a blinded route and reconstruct the +// forwarding information for the hop from the information provided. +func deriveForwardingInfo(processor BlindingProcessor, + incomingAmount lnwire.MilliSatoshi, incomingCltv uint32) func( + *btcec.PublicKey, []byte) (*ForwardingInfo, error) { + + return func(blinding *btcec.PublicKey, data []byte) (*ForwardingInfo, + error) { + + decrypted, err := processor.UnBlindData(blinding, data) + if err != nil { + return nil, fmt.Errorf("decrypt blinded data: %w", err) + } + + b := bytes.NewBuffer(decrypted) + routeData, err := record.DecodeBlindedRouteData(b) + if err != nil { + return nil, fmt.Errorf("decode route data: %w", err) + } + + if err := validateBlindedRouteData( + routeData, incomingAmount, incomingCltv, + ); err != nil { + return nil, err + } + // If we have our short channel ID or expiry present, set + // values in our forwarding information. We start with the + // incoming values as defaults so that they will have the + // correct values for the final hop in the blinded route + // (which does not have relay info set). + var ( + nextHop = Exit + expiry = incomingCltv + fwdAmt = incomingAmount + ) + + if routeData.ShortChannelID != nil { + nextHop = *routeData.ShortChannelID + } + + if routeData.RelayInfo != nil { + fwdAmt, err = calculateForwardingAmount( + incomingAmount, routeData.RelayInfo.BaseFee, + routeData.RelayInfo.FeeRate, + ) + if err != nil { + return nil, err + } + + expiry = incomingCltv - uint32( + routeData.RelayInfo.CltvExpiryDelta, + ) + } + + nextEph, err := processor.NextEphemeral(blinding) + if err != nil { + return nil, err + } + + return &ForwardingInfo{ + Network: BitcoinNetwork, + NextHop: nextHop, + AmountToForward: fwdAmt, + OutgoingCTLV: expiry, + NextBlinding: nextEph, + }, nil + } +} + +// calculateForwardingAmount calculates the amount to forward for a blinded +// hop based on the incoming amount and forwarding parameters. +// +// When forwarding a payment, the fee we take is calculated, not on the +// incoming amount, but rather on the amount we forward. We charge fees based +// on our own liquidity we are forwarding downstream. +// +// With route blinding, we are NOT given the amount to forward. This +// unintuitive looking formula comes from the fact that without the amount to +// forward, we cannot compute the fees taken directly. +// +// The amount to be forwarded can be computed as follows: +// +// amt_to_forward = incoming_amount - total_fees //nolint:dupword +// total_fees = base_fee + amt_to_forward*(fee_rate/1000000) +// +// After substitution and some massaging you will get: +// +// amt_to_forward = (incoming_amount - base_fee) / +// ( 1 + fee_rate / 1000000 ) +// +// From there we use a ceiling formula for integer division so that we always +// round up, otherwise the sender may receive slightly less than intended: +// +// ceil(a/b) = (a + b - 1)/(b) +func calculateForwardingAmount(incomingAmount lnwire.MilliSatoshi, baseFee, + proportionalFee uint32) (lnwire.MilliSatoshi, error) { + + // proportionalParts is the number of parts that our proportional fee + // is expressed per. + var proportionalParts uint64 = 1_000_000 + + // Sanity check to prevent overflow. + if incomingAmount < lnwire.MilliSatoshi(baseFee) { + return 0, fmt.Errorf("incoming amount: %v < base fee: %v", + incomingAmount, baseFee) + } + + ceiling := ((uint64(incomingAmount) - uint64(baseFee)) + + (1 + uint64(proportionalFee)/proportionalParts) - 1) / + (1 + uint64(proportionalFee)/proportionalParts) + + return lnwire.MilliSatoshi(ceiling), nil +} + // OnionProcessor is responsible for keeping all sphinx dependent parts inside // and expose only decoding function. With such approach we give freedom for // subsystems which wants to decode sphinx path to not be dependable from @@ -146,53 +315,9 @@ func (p *OnionProcessor) Stop() error { return nil } -// DecodeHopIterator attempts to decode a valid sphinx packet from the passed io.Reader -// instance using the rHash as the associated data when checking the relevant -// MACs during the decoding process. -func (p *OnionProcessor) DecodeHopIterator(r io.Reader, rHash []byte, - incomingCltv uint32) (Iterator, lnwire.FailCode) { - - onionPkt := &sphinx.OnionPacket{} - if err := onionPkt.Decode(r); err != nil { - switch err { - case sphinx.ErrInvalidOnionVersion: - return nil, lnwire.CodeInvalidOnionVersion - case sphinx.ErrInvalidOnionKey: - return nil, lnwire.CodeInvalidOnionKey - default: - log.Errorf("unable to decode onion packet: %v", err) - return nil, lnwire.CodeInvalidOnionKey - } - } - - // Attempt to process the Sphinx packet. We include the payment hash of - // the HTLC as it's authenticated within the Sphinx packet itself as - // associated data in order to thwart attempts a replay attacks. In the - // case of a replay, an attacker is *forced* to use the same payment - // hash twice, thereby losing their money entirely. - sphinxPacket, err := p.router.ProcessOnionPacket( - onionPkt, rHash, incomingCltv, - ) - if err != nil { - switch err { - case sphinx.ErrInvalidOnionVersion: - return nil, lnwire.CodeInvalidOnionVersion - case sphinx.ErrInvalidOnionHMAC: - return nil, lnwire.CodeInvalidOnionHmac - case sphinx.ErrInvalidOnionKey: - return nil, lnwire.CodeInvalidOnionKey - default: - log.Errorf("unable to process onion packet: %v", err) - return nil, lnwire.CodeInvalidOnionKey - } - } - - return makeSphinxHopIterator(onionPkt, sphinxPacket), lnwire.CodeNone -} - -// ReconstructHopIterator attempts to decode a valid sphinx packet from the passed io.Reader -// instance using the rHash as the associated data when checking the relevant -// MACs during the decoding process. +// ReconstructHopIterator attempts to decode a valid sphinx packet from the +// passed io.Reader instance using the rHash as the associated data when +// checking the relevant MACs during the decoding process. func (p *OnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte) ( Iterator, error) { @@ -206,21 +331,28 @@ func (p *OnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte) ( // associated data in order to thwart attempts a replay attacks. In the // case of a replay, an attacker is *forced* to use the same payment // hash twice, thereby losing their money entirely. - sphinxPacket, err := p.router.ReconstructOnionPacket(onionPkt, rHash) + // + // TODO(carla): contract court will need to be able to pass the + // blinding point back in here (requires interface update). + sphinxPacket, err := p.router.ReconstructOnionPacket( + onionPkt, rHash, // Blinding Opts + ) if err != nil { return nil, err } - return makeSphinxHopIterator(onionPkt, sphinxPacket), nil + return makeSphinxHopIterator(onionPkt, sphinxPacket, nil), nil } // DecodeHopIteratorRequest encapsulates all date necessary to process an onion // packet, perform sphinx replay detection, and schedule the entry for garbage // collection. type DecodeHopIteratorRequest struct { - OnionReader io.Reader - RHash []byte - IncomingCltv uint32 + OnionReader io.Reader + RHash []byte + IncomingCltv uint32 + IncomingAmount lnwire.MilliSatoshi + BlindingPoint *btcec.PublicKey } // DecodeHopIteratorResponse encapsulates the outcome of a batched sphinx onion @@ -370,7 +502,15 @@ func (p *OnionProcessor) DecodeHopIterators(id []byte, // Finally, construct a hop iterator from our processed sphinx // packet, simultaneously caching the original onion packet. - resp.HopIterator = makeSphinxHopIterator(&onionPkts[i], &packets[i]) + resp.HopIterator = makeSphinxHopIterator( + &onionPkts[i], &packets[i], MakeBlindingKit( + p.router, reqs[i].BlindingPoint, + // We are the last hop if the next hop if the + // processed packet's action is to exit. + packets[i].Action == sphinx.ExitNode, + reqs[i].IncomingAmount, reqs[i].IncomingCltv, + ), + ) } return resps, nil diff --git a/htlcswitch/hop/iterator_test.go b/htlcswitch/hop/iterator_test.go index 524b70df095..a8dda32fc87 100644 --- a/htlcswitch/hop/iterator_test.go +++ b/htlcswitch/hop/iterator_test.go @@ -98,3 +98,48 @@ func TestSphinxHopIteratorForwardingInstructions(t *testing.T) { } } } + +// TestForwardingAmountCalc tests calculation of forwarding amounts from the +// hop's forwarding parameters. +func TestForwardingAmountCalc(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + incomingAmount lnwire.MilliSatoshi + baseFee uint32 + proportional uint32 + forwardAmount lnwire.MilliSatoshi + expectErr bool + }{ + { + name: "overflow", + incomingAmount: 10, + baseFee: 100, + expectErr: true, + }, + { + name: "ok", + incomingAmount: 100_000, + baseFee: 1000, + proportional: 10, + forwardAmount: 99000, + }, + } + + for _, testCase := range tests { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + actual, err := calculateForwardingAmount( + testCase.incomingAmount, testCase.baseFee, + testCase.proportional, + ) + + require.Equal(t, testCase.expectErr, err != nil) + require.Equal(t, testCase.forwardAmount, actual) + }) + } +} diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index 03fdd94f321..03e7b2c9344 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -2,6 +2,7 @@ package hop import ( "encoding/binary" + "errors" "fmt" "io" @@ -28,6 +29,10 @@ const ( // RequiredViolation indicates that an unknown even type was found in // the payload that we could not process. RequiredViolation + + // InsufficientViolation indicates that the provided type does + // not satisfy constraints. + InsufficientViolation ) // String returns a human-readable description of the violation as a verb. @@ -42,6 +47,9 @@ func (v PayloadViolation) String() string { case RequiredViolation: return "required" + case InsufficientViolation: + return "insufficient" + default: return "unknown violation" } @@ -127,8 +135,12 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload { } // NewPayloadFromReader builds a new Hop from the passed io.Reader. The reader -// should correspond to the bytes encapsulated in a TLV onion payload. -func NewPayloadFromReader(r io.Reader) (*Payload, error) { +// should correspond to the bytes encapsulated in a TLV onion payload. A +// blinding kit is passed in to help handle payloads that are part of a blinded +// route. +func NewPayloadFromReader(r io.Reader, blindingKit *BlindingKit) ( + *Payload, error) { + var ( cid uint64 amt uint64 @@ -166,7 +178,9 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { // Validate whether the sender properly included or omitted tlv records // in accordance with BOLT 04. nextHop := lnwire.NewShortChanIDFromInt(cid) - err = ValidateParsedPayloadTypes(parsedTypes, nextHop) + activeBlindingPoint, err := ValidateParsedPayloadTypes( + parsedTypes, nextHop, blindingKit, blindingPoint, + ) if err != nil { return nil, err } @@ -208,20 +222,40 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { // Filter out the custom records. customRecords := NewCustomRecords(parsedTypes) - return &Payload{ - FwdInfo: ForwardingInfo{ - NextHop: nextHop, - AmountToForward: lnwire.MilliSatoshi(amt), - OutgoingCTLV: cltv, - }, + payload := &Payload{ MPP: mpp, AMP: amp, metadata: metadata, encryptedData: encryptedData, - blindingPoint: blindingPoint, + blindingPoint: activeBlindingPoint, customRecords: customRecords, totalAmtMsat: lnwire.MilliSatoshi(totalAmtMsat), - }, nil + } + + // If there is no blinding point, set forwarding information from + // the payload. If we have a blinding point we'll get this information + // from the encrypted data blob, so can leave it nil for now. + if activeBlindingPoint == nil { + payload.FwdInfo = ForwardingInfo{ + NextHop: nextHop, + AmountToForward: lnwire.MilliSatoshi(amt), + OutgoingCTLV: cltv, + } + } else { + forwarding, err := blindingKit.forwardingInfo( + activeBlindingPoint, payload.encryptedData, + ) + if err != nil { + return nil, fmt.Errorf("decrypting blinded data "+ + "failed: %w", err) + } + + // If we obtained forwarding info without error, we expect this + // to be non-nil. + payload.FwdInfo = *forwarding + } + + return payload, nil } // ForwardingInfo returns the basic parameters required for HTLC forwarding, @@ -246,9 +280,24 @@ func NewCustomRecords(parsedTypes tlv.TypeMap) record.CustomSet { // ValidateParsedPayloadTypes checks the types parsed from a hop payload to // ensure that the proper fields are either included or omitted. The finalHop // boolean should be true if the payload was parsed for an exit hop. The -// requirements for this method are described in BOLT 04. +// requirements for this method are described in BOLT 04. If the payload is for +// a blinded route, it also returns the blinding point that should be used for +// further payload processing. func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, - nextHop lnwire.ShortChannelID) error { + nextHop lnwire.ShortChannelID, + blindingKit *BlindingKit, + onionBlinding *btcec.PublicKey) (*btcec.PublicKey, error) { + + // If encrypted data is present in our payload, validate fields for + // a blinded route - this validation is different to regular payload + // validation because some fields are contained in encrypted data + // instead of the onion TLVs. + _, dataPresent := parsedTypes[record.EncryptedDataOnionType] + if dataPresent { + return validateBlindedRouteTypes( + parsedTypes, blindingKit, onionBlinding, + ) + } isFinalHop := nextHop == Exit @@ -262,7 +311,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, // All hops must include an amount to forward. case !hasAmt: - return ErrInvalidPayload{ + return nil, ErrInvalidPayload{ Type: record.AmtOnionType, Violation: OmittedViolation, FinalHop: isFinalHop, @@ -270,7 +319,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, // All hops must include a cltv expiry. case !hasLockTime: - return ErrInvalidPayload{ + return nil, ErrInvalidPayload{ Type: record.LockTimeOnionType, Violation: OmittedViolation, FinalHop: isFinalHop, @@ -280,7 +329,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, // sender must have included a record, so we don't need to test for its // inclusion at intermediate hops directly. case isFinalHop && hasNextHop: - return ErrInvalidPayload{ + return nil, ErrInvalidPayload{ Type: record.NextHopOnionType, Violation: IncludedViolation, FinalHop: true, @@ -288,7 +337,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, // Intermediate nodes should never receive MPP fields. case !isFinalHop && hasMPP: - return ErrInvalidPayload{ + return nil, ErrInvalidPayload{ Type: record.MPPOnionType, Violation: IncludedViolation, FinalHop: isFinalHop, @@ -296,14 +345,85 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, // Intermediate nodes should never receive AMP fields. case !isFinalHop && hasAMP: - return ErrInvalidPayload{ + return nil, ErrInvalidPayload{ Type: record.AMPOnionType, Violation: IncludedViolation, FinalHop: isFinalHop, } } - return nil + return nil, nil +} + +// validateBlindedRouteTypes performs the validation required for payloads +// that are part of a blinded route, returning the blinding point that is in +// use for that payload. +func validateBlindedRouteTypes(parsedTypes tlv.TypeMap, + blindingKit *BlindingKit, + onionBlinding *btcec.PublicKey) (*btcec.PublicKey, error) { + + if blindingKit == nil { + return nil, errors.New("blinding kit required") + } + + var ( + blindingPoint *btcec.PublicKey + updateAddBlindingSet = blindingKit.BlindingPoint != nil + onionBlindingSet = onionBlinding != nil + ) + + switch { + // We should have a blinding key either in update_add_htlc or in the + // onion, but not both. + case updateAddBlindingSet && onionBlindingSet: + return nil, ErrInvalidPayload{ + Type: record.BlindingPointOnionType, + Violation: IncludedViolation, + FinalHop: false, + } + + case updateAddBlindingSet: + blindingPoint = blindingKit.BlindingPoint + + case onionBlindingSet: + blindingPoint = onionBlinding + } + + if _, ok := parsedTypes[record.EncryptedDataOnionType]; !ok { + return nil, ErrInvalidPayload{ + Type: record.EncryptedDataOnionType, + Violation: RequiredViolation, + FinalHop: false, + } + } + + // We have restrictions on the types of TLVs that we allow in + // intermediate and final hops. Set a map of allowed TLVs and run a + // check for any other non-nil values in our parsed map. + allowedTLVs := map[tlv.Type]struct{}{ + record.EncryptedDataOnionType: {}, + record.BlindingPointOnionType: {}, + } + + // The last hop is allowed some additional TLVs. + if blindingKit.lastHop { + allowedTLVs[record.AmtOnionType] = struct{}{} + allowedTLVs[record.LockTimeOnionType] = struct{}{} + } + + for tlvType := range parsedTypes { + if _, ok := allowedTLVs[tlvType]; ok { + continue + } + + return nil, ErrInvalidPayload{ + Type: tlvType, + Violation: IncludedViolation, + FinalHop: false, + } + } + + return blindingPoint, nil } // MultiPath returns the record corresponding the option_mpp parsed from the @@ -380,3 +500,39 @@ func getMinRequiredViolation(set tlv.TypeMap) *tlv.Type { return nil } + +// validateBlindedRouteData performs the additional validation that is +// required for payments that rely on data provided in an encrypted blob to +// be forwarded. We enforce additional constraints here to prevent malicious +// parties from probing portions of the blinded route to "un-blind" them. +func validateBlindedRouteData(blindedData *record.BlindedRouteData, + incomingAmount lnwire.MilliSatoshi, incomingTimelock uint32) error { + + if blindedData.Constraints != nil { + maxCLTV := blindedData.Constraints.MaxCltvExpiry + if maxCLTV != 0 && incomingTimelock > maxCLTV { + return ErrInvalidPayload{ + Type: record.LockTimeOnionType, + Violation: InsufficientViolation, + } + } + + if incomingAmount < blindedData.Constraints.HtlcMinimumMsat { + return ErrInvalidPayload{ + Type: record.AmtOnionType, + Violation: InsufficientViolation, + } + } + } + + // Fail if we don't understand any features (even or odd), because we + // expect the features to have been set from our announcement. + if blindedData.Features.UnknownFeatures() { + return ErrInvalidPayload{ + Type: record.FeatureVectorType, + Violation: InsufficientViolation, + } + } + + return nil +} diff --git a/htlcswitch/hop/payload_test.go b/htlcswitch/hop/payload_test.go index 6913f11c6f9..d8de83d127c 100644 --- a/htlcswitch/hop/payload_test.go +++ b/htlcswitch/hop/payload_test.go @@ -361,7 +361,9 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { testChildIndex = uint32(9) ) - p, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload)) + p, err := hop.NewPayloadFromReader( + bytes.NewReader(test.payload), &hop.BlindingKit{}, + ) if !reflect.DeepEqual(test.expErr, err) { t.Fatalf("expected error mismatch, want: %v, got: %v", test.expErr, err) diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 46aee939d12..fa352079788 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2939,9 +2939,11 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, onionReader := bytes.NewReader(pd.OnionBlob) req := hop.DecodeHopIteratorRequest{ - OnionReader: onionReader, - RHash: pd.RHash[:], - IncomingCltv: pd.Timeout, + OnionReader: onionReader, + RHash: pd.RHash[:], + IncomingCltv: pd.Timeout, + IncomingAmount: pd.Amount, + BlindingPoint: pd.BlindingPoint, } decodeReqs = append(decodeReqs, req) @@ -3092,8 +3094,11 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, // Otherwise, it was already processed, we can // can collect it and continue. addMsg := &lnwire.UpdateAddHTLC{ - Expiry: fwdInfo.OutgoingCTLV, - Amount: fwdInfo.AmountToForward, + Expiry: fwdInfo.OutgoingCTLV, + Amount: fwdInfo.AmountToForward, + BlindingPoint: lnwire.NewBlindingPoint( + fwdInfo.NextBlinding, + ), PaymentHash: pd.RHash, } @@ -3134,8 +3139,11 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, // create the outgoing HTLC using the parameters as // specified in the forwarding info. addMsg := &lnwire.UpdateAddHTLC{ - Expiry: fwdInfo.OutgoingCTLV, - Amount: fwdInfo.AmountToForward, + Expiry: fwdInfo.OutgoingCTLV, + Amount: fwdInfo.AmountToForward, + BlindingPoint: lnwire.NewBlindingPoint( + fwdInfo.NextBlinding, + ), PaymentHash: pd.RHash, } diff --git a/itest/list_on_test.go b/itest/list_on_test.go index 64277045808..6bbb8a68da0 100644 --- a/itest/list_on_test.go +++ b/itest/list_on_test.go @@ -546,4 +546,8 @@ var allTestCases = []*lntest.TestCase{ Name: "query blinded route", TestFunc: testQueryBlindedRoutes, }, + { + Name: "forward blinded", + TestFunc: testForwardBlindedRoute, + }, } diff --git a/itest/lnd_route_blinding.go b/itest/lnd_route_blinding.go index 2104cb1c830..7a1fefdd9cd 100644 --- a/itest/lnd_route_blinding.go +++ b/itest/lnd_route_blinding.go @@ -1,15 +1,22 @@ package itest import ( + "bytes" "crypto/sha256" "encoding/hex" + "fmt" + "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" + sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/chainreg" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/routerrpc" "github.com/lightningnetwork/lnd/lntest" + "github.com/lightningnetwork/lnd/lntest/node" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -310,3 +317,414 @@ func testQueryBlindedRoutes(ht *lntest.HarnessTest) { ht.CloseChannel(alice, chanPointAliceBob) ht.CloseChannel(bob, chanPointBobCarol) } + +type blindedForwardTest struct { + ht *lntest.HarnessTest + carol *node.HarnessNode + dave *node.HarnessNode + channels []*lnrpc.ChannelPoint + + preimage [33]byte + + // ctx is a context to be used by the test. + ctx context.Context //nolint:containedctx + + // cancel will cancel the test's top level context. + cancel func() +} + +func newBlindedForwardTest(ht *lntest.HarnessTest) *blindedForwardTest { + ctx, cancel := context.WithCancel(context.Background()) + + return &blindedForwardTest{ + ht: ht, + ctx: ctx, + cancel: cancel, + preimage: [33]byte{1, 2, 3}, + } +} + +// setup spins up additional nodes needed for our test and creates a four hop +// network for testing blinded forwarding and returns a blinded route from +// Bob -> Carol -> Dave, with Bob acting as the introduction point. +func (b *blindedForwardTest) setup() *routing.BlindedPayment { + b.carol = b.ht.NewNode("Carol", nil) + b.dave = b.ht.NewNode("Dave", nil) + + b.channels = setupFourHopNetwork(b.ht, b.carol, b.dave) + + // Create a blinded route to Dave via Bob --- Carol --- Dave: + bobChan := b.ht.GetChannelByChanPoint(b.ht.Bob, b.channels[1]) + carolChan := b.ht.GetChannelByChanPoint(b.carol, b.channels[2]) + + edges := []*forwardingEdge{ + getForwardingEdge(b.ctx, b.ht, b.ht.Bob, bobChan.ChanId), + getForwardingEdge(b.ctx, b.ht, b.carol, carolChan.ChanId), + } + + davePk, err := btcec.ParsePubKey(b.dave.PubKey[:]) + require.NoError(b.ht, err, "dave pubkey") + + return b.createBlindedRoute(edges, davePk) +} + +// cleanup tears down all channels created by the test and cancels the top +// level context used in the test. +func (b *blindedForwardTest) cleanup() { + b.ht.CloseChannel(b.ht.Alice, b.channels[0]) + b.ht.CloseChannel(b.ht.Bob, b.channels[1]) + b.ht.CloseChannel(b.carol, b.channels[2]) + + b.cancel() +} + +// createRouteToBlinded queries for a route from alice to the blinded path +// provided. +// +//nolint:gomnd +func (b *blindedForwardTest) createRouteToBlinded(paymentAmt int64, + route *routing.BlindedPayment) *lnrpc.Route { + + intro := route.BlindedPath.IntroductionPoint.SerializeCompressed() + blinding := route.BlindedPath.BlindingPoint.SerializeCompressed() + + blindedRoute := &lnrpc.BlindedPath{ + IntroductionNode: intro, + BlindingPoint: blinding, + BlindedHops: make( + []*lnrpc.BlindedHop, + len(route.BlindedPath.BlindedHops), + ), + } + + for i, hop := range route.BlindedPath.BlindedHops { + blindedRoute.BlindedHops[i] = &lnrpc.BlindedHop{ + BlindedNode: hop.NodePub.SerializeCompressed(), + EncryptedData: hop.Payload, + } + } + blindedPath := &lnrpc.BlindedPaymentPath{ + BlindedPath: blindedRoute, + BaseFeeMsat: uint64( + route.BaseFee, + ), + ProportionalFeeMsat: uint64( + route.ProportionalFee, + ), + TotalCltvDelta: uint32( + route.CltvExpiryDelta, + ), + } + + ctxt, cancel := context.WithTimeout(b.ctx, defaultTimeout) + req := &lnrpc.QueryRoutesRequest{ + AmtMsat: paymentAmt, + // Our fee limit doesn't really matter, we just want to + // be able to make the payment. + FeeLimit: &lnrpc.FeeLimit{ + Limit: &lnrpc.FeeLimit_Percent{ + Percent: 50, + }, + }, + BlindedPaymentPaths: []*lnrpc.BlindedPaymentPath{ + blindedPath, + }, + } + + resp, err := b.ht.Alice.RPC.LN.QueryRoutes(ctxt, req) + cancel() + require.NoError(b.ht, err, "query routes") + require.Greater(b.ht, len(resp.Routes), 0, "no routes") + require.Len(b.ht, resp.Routes[0].Hops, 3, "unexpected route length") + + return resp.Routes[0] +} + +// sendBlindedPayment dispatches a payment to the route provided. The streaming +// client for the send is returned with a cancel function that can be used to +// terminate the stream. +func (b *blindedForwardTest) sendBlindedPayment(route *lnrpc.Route) ( + lnrpc.Lightning_SendToRouteClient, func()) { + + hash := sha256.Sum256(b.preimage[:]) + + ctxt, cancel := context.WithCancel(b.ctx) + sendReq := &lnrpc.SendToRouteRequest{ + PaymentHash: hash[:], + Route: route, + } + + sendClient, err := b.ht.Alice.RPC.LN.SendToRoute(ctxt) + require.NoError(b.ht, err, "send to route client") + + err = sendClient.SendMsg(sendReq) + require.NoError(b.ht, err, "send to route request") + + return sendClient, cancel +} + +// interceptFinalHop sets up a htlc interceptor on carol's incoming link to +// intercept incoming htlcs. +func (b *blindedForwardTest) interceptFinalHop() chan error { + interceptor, err := b.carol.RPC.Router.HtlcInterceptor(b.ctx) + require.NoError(b.ht, err, "interceptor") + + hash := sha256.Sum256(b.preimage[:]) + + // Create an error channel to deliver any interceptor errors. + errChan := make(chan error) + go func() { + forward, err := interceptor.Recv() + if err != nil { + errChan <- err + return + } + + if !bytes.Equal(forward.PaymentHash, hash[:]) { + errChan <- fmt.Errorf("unexpected payment hash: %v", + hash) + return + } + + //nolint:lll + resp := &routerrpc.ForwardHtlcInterceptResponse{ + IncomingCircuitKey: forward.IncomingCircuitKey, + Action: routerrpc.ResolveHoldForwardAction_SETTLE, + Preimage: b.preimage[:], + } + + errChan <- interceptor.Send(resp) + }() + + return errChan +} + +func (b *blindedForwardTest) assertIntercepted(errChan chan error) { + select { + case err := <-errChan: + require.NoError(b.ht, err, "interceptor error") + + case <-time.After(defaultTimeout): + b.ht.Fatalf("interceptor did not exit") + } +} + +// setupFourHopNetwork creates a network with the following topology and +// liquidity: +// Alice (100k)----- Bob (100k) ----- Carol (100k) ----- Dave +// +// The funding outpoint for AB / BC / CD are returned in-order. +func setupFourHopNetwork(ht *lntest.HarnessTest, + carol, dave *node.HarnessNode) []*lnrpc.ChannelPoint { + + const chanAmt = btcutil.Amount(100000) + var networkChans []*lnrpc.ChannelPoint + + // Open a channel with 100k satoshis between Alice and Bob with Alice + // being the sole funder of the channel. + chanPointAlice := ht.OpenChannel( + ht.Alice, ht.Bob, lntest.OpenChannelParams{ + Amt: chanAmt, + }, + ) + networkChans = append(networkChans, chanPointAlice) + + // Create a channel between bob and carol. + ht.EnsureConnected(ht.Bob, carol) + chanPointBob := ht.OpenChannel( + ht.Bob, carol, lntest.OpenChannelParams{ + Amt: chanAmt, + }, + ) + networkChans = append(networkChans, chanPointBob) + + // Fund carol and connect her and dave so that she can create a channel + // between them. + ht.FundCoins(btcutil.SatoshiPerBitcoin, carol) + ht.EnsureConnected(carol, dave) + + chanPointCarol := ht.OpenChannel( + carol, dave, lntest.OpenChannelParams{ + Amt: chanAmt, + }, + ) + networkChans = append(networkChans, chanPointCarol) + + // Wait for all nodes to have seen all channels. + nodes := []*node.HarnessNode{ht.Alice, ht.Bob, carol, dave} + for _, chanPoint := range networkChans { + for _, node := range nodes { + ht.AssertTopologyChannelOpen(node, chanPoint) + } + } + + return []*lnrpc.ChannelPoint{ + chanPointAlice, + chanPointBob, + chanPointCarol, + } +} + +// createBlindedRoute creates a blinded route to the recipient node provided. +// The set of hops is expected to start at the introduction node and end at +// the recipient. +func (b *blindedForwardTest) createBlindedRoute(hops []*forwardingEdge, + dest *btcec.PublicKey) *routing.BlindedPayment { + + // Create a path with space for each of our hops + the destination + // node. + blindedPayment := &routing.BlindedPayment{} + pathLength := len(hops) + 1 + blindedPath := make([]*sphinx.UnBlindedHopInfo, pathLength) + + for i := 0; i < len(hops); i++ { + node := hops[i] + payload := &record.BlindedRouteData{ + NextNodeID: node.pubkey, + ShortChannelID: &node.channelID, + } + + // Add the next hop's ID for all nodes that have a next hop. + if i < len(hops)-1 { + nextHop := hops[i+1] + + payload.NextNodeID = nextHop.pubkey + payload.ShortChannelID = &node.channelID + } + + // Set the relay information for this edge, and add it to our + // aggregate info and update our aggregate constraints. + delta := uint16(node.edge.TimeLockDelta) + payload.RelayInfo = &record.PaymentRelayInfo{ + BaseFee: uint32(node.edge.FeeBaseMsat), + FeeRate: uint32(node.edge.FeeRateMilliMsat), + CltvExpiryDelta: delta, + } + + // We set our constraints with our edge's actual htlc min, and + // an arbitrary maximum expiry (since it's just an anti-probing + // mechanism). + payload.Constraints = &record.PaymentConstraints{ + HtlcMinimumMsat: lnwire.MilliSatoshi(node.edge.MinHtlc), + MaxCltvExpiry: 100000, + } + + blindedPayment.BaseFee += payload.RelayInfo.BaseFee + blindedPayment.ProportionalFee += payload.RelayInfo.FeeRate + blindedPayment.CltvExpiryDelta += delta + + // Encode the route's blinded data and include it in the + // blinded hop. + payloadBytes, err := record.EncodeBlindedRouteData(payload) + require.NoError(b.ht, err) + + blindedPath[i] = &sphinx.UnBlindedHopInfo{ + NodePub: node.pubkey, + Payload: payloadBytes, + } + } + + // Add our destination node at the end of the path. We don't need to + // add any forwarding parameters because we're at the final hop. + payloadBytes, err := record.EncodeBlindedRouteData( + &record.BlindedRouteData{ + // TODO: we don't have support for the final hop fields, + // because only forwarding is supported. We add a next + // node ID here so that it _looks like_ a valid + // forwarding hop (though in reality it's the last + // hop). + NextNodeID: dest, + }, + ) + require.NoError(b.ht, err, "final payload") + + blindedPath[pathLength-1] = &sphinx.UnBlindedHopInfo{ + NodePub: dest, + Payload: payloadBytes, + } + + // Blind the path. + blindingKey, err := btcec.NewPrivateKey() + require.NoError(b.ht, err) + + blindedPayment.BlindedPath, err = sphinx.BuildBlindedPath( + blindingKey, blindedPath, + ) + require.NoError(b.ht, err, "build blinded path") + + return blindedPayment +} + +// forwardingEdge contains the channel id/source public key for a forwarding +// edge and the policy associated with the channel in that direction. +type forwardingEdge struct { + pubkey *btcec.PublicKey + channelID lnwire.ShortChannelID + edge *lnrpc.RoutingPolicy +} + +func getForwardingEdge(ctxb context.Context, ht *lntest.HarnessTest, + node *node.HarnessNode, chanID uint64) *forwardingEdge { + + ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout) + chanInfo, err := node.RPC.LN.GetChanInfo(ctxt, &lnrpc.ChanInfoRequest{ + ChanId: chanID, + }) + cancel() + require.NoError(ht, err, "%v chan info", node.Cfg.Name) + + pubkey, err := btcec.ParsePubKey(node.PubKey[:]) + require.NoError(ht, err, "%v pubkey", node.Cfg.Name) + + fwdEdge := &forwardingEdge{ + pubkey: pubkey, + channelID: lnwire.NewShortChanIDFromInt(chanID), + } + + if chanInfo.Node1Pub == node.PubKeyStr { + fwdEdge.edge = chanInfo.Node1Policy + } else { + require.Equal(ht, node.PubKeyStr, chanInfo.Node2Pub, + "policy edge sanity check") + + fwdEdge.edge = chanInfo.Node2Policy + } + + return fwdEdge +} + +// testForwardBlindedRoute tests lnd's ability to forward payments in a blinded +// route. +func testForwardBlindedRoute(ht *lntest.HarnessTest) { + testCase := newBlindedForwardTest(ht) + defer testCase.cleanup() + + route := testCase.setup() + blindedRoute := testCase.createRouteToBlinded(100_000, route) + + // Receiving via blinded routes is not yet supported, so Dave won't be + // able to process the payment. + // + // We have an interceptor at our disposal that will catch htlcs as they + // are forwarded (ie, it won't intercept a HTLC that dave is receiving, + // since no forwarding occurs). We initiate this interceptor with + // Carol, so that we can catch it and settle on the outgoing link to + // Dave. Once we hit the outgoing link, we know that we successfully + // parsed the htlc, so this is an acceptable compromise. + // Assert that our interceptor has exited without an error. + errChan := testCase.interceptFinalHop() + // TODO: interceptor is racing w/ payment, make sure interceptor is + // used + testCase.sendBlindedPayment(blindedRoute) + + // Wait for the HTLC to be active on Alice's channel. + hash := sha256.Sum256(testCase.preimage[:]) + ht.AssertHLTCNotActive(ht.Alice, testCase.channels[0], hash[:]) + + // Intercept and settle the HTLC. + testCase.assertIntercepted(errChan) + + // Assert that the HTLC has settled before test cleanup runs so that + // we can cooperatively close all channels. + ht.AssertHLTCNotActive(ht.Alice, testCase.channels[0], hash[:]) +} diff --git a/lnwallet/channel.go b/lnwallet/channel.go index bbe12eb6e80..bc781ceee60 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -385,6 +385,12 @@ type PaymentDescriptor struct { // isForwarded denotes if an incoming HTLC has been forwarded to any // possible upstream peers in the route. isForwarded bool + + // BlindingPoint is an optional ephemeral key used in route blinding. + // This value is set for nodes that are relaying payments inside of a + // blinded route (ie, not the introduction node) from update_add_htlc's + // TLVs. + BlindingPoint *btcec.PublicKey } // PayDescsFromRemoteLogUpdates converts a slice of LogUpdates received from the @@ -425,6 +431,7 @@ func PayDescsFromRemoteLogUpdates(chanID lnwire.ShortChannelID, height uint64, Height: height, Index: uint16(i), }, + BlindingPoint: wireMsg.BlindingPoint.Pubkey(), } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob[:], wireMsg.OnionBlob[:]) @@ -748,6 +755,7 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { HtlcIndex: htlc.HtlcIndex, LogIndex: htlc.LogIndex, Incoming: false, + // TODO: blidning point } copy(h.OnionBlob[:], htlc.OnionBlob) @@ -772,6 +780,7 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { HtlcIndex: htlc.HtlcIndex, LogIndex: htlc.LogIndex, Incoming: true, + // TODO: blinding point } copy(h.OnionBlob[:], htlc.OnionBlob) @@ -790,6 +799,9 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { // commitment struct and updateLog. This function is used when we need to // restore commitment state written do disk back into memory once we need to // restart a channel session. +// +// Note that HTLCs are restored _without_ any additional data that was provided +// in UpdateAddHtlc's TLVs. func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, commitHeight uint64, htlc *channeldb.HTLC, localCommitKeys, remoteCommitKeys *CommitmentKeyRing, isLocal bool) (PaymentDescriptor, @@ -871,6 +883,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, ourWitnessScript: ourWitnessScript, theirPkScript: theirP2WSH, theirWitnessScript: theirWitnessScript, + // TODO: Blinding Point } return pd, nil @@ -1558,6 +1571,7 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, HtlcIndex: wireMsg.ID, LogIndex: logUpdate.LogIndex, addCommitHeightRemote: commitHeight, + BlindingPoint: wireMsg.BlindingPoint.Pubkey(), } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob[:], wireMsg.OnionBlob[:]) @@ -1755,6 +1769,7 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd HtlcIndex: wireMsg.ID, LogIndex: logUpdate.LogIndex, addCommitHeightLocal: commitHeight, + BlindingPoint: wireMsg.BlindingPoint.Pubkey(), } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob, wireMsg.OnionBlob[:]) @@ -3746,6 +3761,9 @@ func (lc *LightningChannel) getUnsignedAckedUpdates() []channeldb.LogUpdate { Amount: pd.Amount, Expiry: pd.Timeout, PaymentHash: pd.RHash, + BlindingPoint: lnwire.NewBlindingPoint( + pd.BlindingPoint, + ), } copy(htlc.OnionBlob[:], pd.OnionBlob) logUpdate.UpdateMsg = htlc @@ -5837,6 +5855,7 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, HtlcIndex: lc.localUpdateLog.htlcCounter, OnionBlob: htlc.OnionBlob[:], OpenCircuitKey: openKey, + BlindingPoint: htlc.BlindingPoint.Pubkey(), } } @@ -5885,13 +5904,14 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, err } pd := &PaymentDescriptor{ - EntryType: Add, - RHash: PaymentHash(htlc.PaymentHash), - Timeout: htlc.Expiry, - Amount: htlc.Amount, - LogIndex: lc.remoteUpdateLog.logIndex, - HtlcIndex: lc.remoteUpdateLog.htlcCounter, - OnionBlob: htlc.OnionBlob[:], + EntryType: Add, + RHash: PaymentHash(htlc.PaymentHash), + Timeout: htlc.Expiry, + Amount: htlc.Amount, + LogIndex: lc.remoteUpdateLog.logIndex, + HtlcIndex: lc.remoteUpdateLog.htlcCounter, + OnionBlob: htlc.OnionBlob[:], + BlindingPoint: htlc.BlindingPoint.Pubkey(), } localACKedIndex := lc.remoteCommitChain.tail().ourMessageIndex diff --git a/lnwire/blinding_point.go b/lnwire/blinding_point.go new file mode 100644 index 00000000000..ab3ff71eb8b --- /dev/null +++ b/lnwire/blinding_point.go @@ -0,0 +1,88 @@ +package lnwire + +import ( + "io" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // BlindingPointRecordType is the type for ephemeral pubkeys used in + // route blinding. + BlindingPointRecordType tlv.Type = 0 +) + +// BlindingPoint is used to communicate ephemeral pubkeys used by route +// blinding. +// +// Note: this struct wraps a btcec.Pubkey key so that we can implement the +// RecordProducer interface on the struct and re-use the existing tlv library +// functions for public keys. The type is unexported to enforce proper handling +// of aliasing / nil types. +type blindingPoint btcec.PublicKey + +// NewBlindingPoint converts a pubkey into its aliased blinding point type, +// returning nil if the pubkey provided is nil. +// +//nolint:revive +func NewBlindingPoint(pubkey *btcec.PublicKey) *blindingPoint { + if pubkey == nil { + return nil + } + + blindingPoint := blindingPoint(*pubkey) + + return &blindingPoint +} + +// Pubkey returns the underlying btcec.Pubkey type for a blinding point. +func (b *blindingPoint) Pubkey() *btcec.PublicKey { + if b == nil { + return nil + } + + pubkey := btcec.PublicKey(*b) + + return &pubkey +} + +// Record returns a TLV record for blinded pubkeys. +// +// Note: implements the RecordProducer interface. +func (b *blindingPoint) Record() tlv.Record { + return tlv.MakeStaticRecord( + BlindingPointRecordType, b, 33, + blindingPointEncoder, blindingPointDecoder, + ) +} + +// blindingPointEncoder is a custom TLV encoder for the BlindingPoint record. +func blindingPointEncoder(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*blindingPoint); ok { + // EPubkey requires a double pointer, so we de-alias and + // reference the blinding point provided. + pubkey := btcec.PublicKey(*v) + pubkeyRef := &pubkey + return tlv.EPubKey(w, &pubkeyRef, buf) + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.BlindingPoint") +} + +// blindingPointDecoder is a custom TLV decoder for the BlindingPoint record. +func blindingPointDecoder(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if v, ok := val.(*blindingPoint); ok { + var pubkey *btcec.PublicKey + if err := tlv.DPubKey(r, &pubkey, buf, l); err != nil { + return err + } + *v = blindingPoint(*pubkey) + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.BlindingPoint") +} diff --git a/lnwire/channel_type.go b/lnwire/channel_type.go index a0696048bef..de755e135aa 100644 --- a/lnwire/channel_type.go +++ b/lnwire/channel_type.go @@ -19,7 +19,7 @@ type ChannelType RawFeatureVector // featureBitLen returns the length in bytes of the encoded feature bits. func (c ChannelType) featureBitLen() uint64 { fv := RawFeatureVector(c) - return uint64(fv.SerializeSize()) + return fv.sizeFunc() } // Record returns a TLV record that can be used to encode/decode the channel @@ -34,25 +34,27 @@ func (c *ChannelType) Record() tlv.Record { // channelTypeEncoder is a custom TLV encoder for the ChannelType record. func channelTypeEncoder(w io.Writer, val interface{}, buf *[8]byte) error { if v, ok := val.(*ChannelType); ok { - // Encode the feature bits as a byte slice without its length - // prepended, as that's already taken care of by the TLV record. fv := RawFeatureVector(*v) - return fv.encode(w, fv.SerializeSize(), 8) + return rawFeatureEncoder(w, &fv, buf) } - return tlv.NewTypeForEncodingErr(val, "lnwire.ChannelType") + return tlv.NewTypeForEncodingErr(val, "*lnwire.ChannelType") } // channelTypeDecoder is a custom TLV decoder for the ChannelType record. -func channelTypeDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { +func channelTypeDecoder(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + if v, ok := val.(*ChannelType); ok { fv := NewRawFeatureVector() - if err := fv.decode(r, int(l), 8); err != nil { + + if err := rawFeatureDecoder(r, fv, buf, l); err != nil { return err } + *v = ChannelType(*fv) return nil } - return tlv.NewTypeForEncodingErr(val, "lnwire.ChannelType") + return tlv.NewTypeForEncodingErr(val, "*lnwire.ChannelType") } diff --git a/lnwire/features.go b/lnwire/features.go index 3d29aaf7309..66da7508e7c 100644 --- a/lnwire/features.go +++ b/lnwire/features.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "io" + + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -139,6 +141,15 @@ const ( // transactions, which also imply anchor commitments. AnchorsZeroFeeHtlcTxOptional FeatureBit = 23 + // RouteBlindingRequired is a required bit that indicates that the + // receiving peer must understand forwarding of blinded payments. + RouteBlindingRequired = 24 + + // RouteBlindingOptional is an optional bit that indicates that this + // node understands forwarding of blinded payments, but the remote + // peer is not required to. + RouteBlindingOptional = 25 + // ShutdownAnySegwitRequired is an required feature bit that signals // that the sender is able to properly handle/parse segwit witness // programs up to version 16. This enables utilization of Taproot @@ -297,6 +308,8 @@ var Features = map[FeatureBit]string{ AnchorsOptional: "anchor-commitments", AnchorsZeroFeeHtlcTxRequired: "anchors-zero-fee-htlc-tx", AnchorsZeroFeeHtlcTxOptional: "anchors-zero-fee-htlc-tx", + RouteBlindingRequired: "route-blinding", + RouteBlindingOptional: "route-blinding", WumboChannelsRequired: "wumbo-channels", WumboChannelsOptional: "wumbo-channels", AMPRequired: "amp", @@ -612,6 +625,41 @@ func (fv *RawFeatureVector) decode(r io.Reader, length, width int) error { return nil } +// sizeFunc returns the length required to encode the feature vector. +func (fv *RawFeatureVector) sizeFunc() uint64 { + return uint64(fv.SerializeSize()) +} + +// Record returns a TLV record that can be used to encode/decode raw feature +// vectors. Note that the length of the feature vector is not included, because +// it is covered by the TLV record's length field. +func (fv *RawFeatureVector) Record(recordType tlv.Type) tlv.Record { + return tlv.MakeDynamicRecord( + recordType, fv, fv.sizeFunc, rawFeatureEncoder, + rawFeatureDecoder, + ) +} + +// rawFeatureEncoder is a custom TLV encoder for raw feature vectors. +func rawFeatureEncoder(w io.Writer, val interface{}, buf *[8]byte) error { + if f, ok := val.(*RawFeatureVector); ok { + return f.encode(w, f.SerializeSize(), 8) + } + + return tlv.NewTypeForEncodingErr(val, "*lnwire.RawFeatureVector") +} + +// rawFeatureDecoder is a custom TLV decoder for raw feature vectors. +func rawFeatureDecoder(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if f, ok := val.(*RawFeatureVector); ok { + return f.decode(r, int(l), 8) + } + + return tlv.NewTypeForEncodingErr(val, "*lnwire.RawFeatureVector") +} + // FeatureVector represents a set of enabled features. The set stores // information on enabled flags and metadata about the feature names. A feature // vector is serializable to a compact byte representation that is included in @@ -678,6 +726,18 @@ func (fv *FeatureVector) UnknownRequiredFeatures() []FeatureBit { return unknown } +// UnknownFeatures returns a boolean if a feature vector contains *any* +// unknown features (even if they are off). +func (fv *FeatureVector) UnknownFeatures() bool { + for feature := range fv.features { + if !fv.IsKnown(feature) { + return true + } + } + + return false +} + // Name returns a string identifier for the feature represented by this bit. If // the bit does not represent a known feature, this returns a string indicating // as such. diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index f5c028581bc..bbb00a6a8e2 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -1070,6 +1070,20 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(req) }, + MsgUpdateAddHTLC: func(v []reflect.Value, r *rand.Rand) { + req := NewUpdateAddHTLC() + + pubkey, err := randPubKey() + if err != nil { + t.Fatalf("unable to generate key: %v", err) + return + } + + blinding := blindingPoint(*pubkey) + req.BlindingPoint = &blinding + + v[0] = reflect.ValueOf(*req) + }, } // With the above types defined, we'll now generate a slice of diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 666a549427f..1438bdaa134 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -54,6 +54,10 @@ type UpdateAddHTLC struct { // used in the subsequent UpdateAddHTLC message. OnionBlob [OnionPacketSize]byte + // BlindingPoint is the ephemeral pubkey used to optionally blind the + // next hop for this htlc. + BlindingPoint *blindingPoint + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -74,7 +78,7 @@ var _ Message = (*UpdateAddHTLC)(nil) // // This is part of the lnwire.Message interface. func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, + if err := ReadElements(r, &c.ChanID, &c.ID, &c.Amount, @@ -82,7 +86,20 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { &c.Expiry, c.OnionBlob[:], &c.ExtraData, + ); err != nil { + return err + } + + var blindingPoint blindingPoint + tlvMap, err := c.ExtraData.ExtractRecords( + &blindingPoint, ) + + if _, ok := tlvMap[BlindingPointRecordType]; ok { + c.BlindingPoint = &blindingPoint + } + + return err } // Encode serializes the target UpdateAddHTLC into the passed io.Writer @@ -114,6 +131,14 @@ func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error { return err } + // Only include blinding point in extra data if present. + if c.BlindingPoint != nil { + err := EncodeMessageExtraData(&c.ExtraData, c.BlindingPoint) + if err != nil { + return err + } + } + return WriteBytes(w, c.ExtraData) } diff --git a/lnwire/update_add_htlc_test.go b/lnwire/update_add_htlc_test.go new file mode 100644 index 00000000000..01a70ce55dc --- /dev/null +++ b/lnwire/update_add_htlc_test.go @@ -0,0 +1,48 @@ +package lnwire + +import ( + "bytes" + "encoding/hex" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/stretchr/testify/require" +) + +// TestUpdateAddHTLCExtraData tests encoding of update_add_htlc with and +// without a blinding point TLV included. +func TestUpdateAddHTLCExtraData(t *testing.T) { + t.Parallel() + + // First, test an update_add_htlc that does not include a blinding + // point. + htlc := UpdateAddHTLC{ + ID: 1, + Amount: 100, + Expiry: 25, + ExtraData: make([]byte, 0), + } + + var b bytes.Buffer + require.NoError(t, htlc.Encode(&b, 0)) + + var htlcDecoded UpdateAddHTLC + require.NoError(t, htlcDecoded.Decode(&b, 0)) + + require.Equal(t, htlc, htlcDecoded) + + // Next test inclusion of a blinding point. + pubKeyStr := "036a0c5ea35df8a528b98edf6f290b28676d51d0fe202b073fe677612a39c0aa09" //nolint:lll + pubHex, err := hex.DecodeString(pubKeyStr) + require.NoError(t, err, "unable to decode pubkey") + + pubKey, err := btcec.ParsePubKey(pubHex) + require.NoError(t, err, "unable to parse pubkey") + blindingPoint := blindingPoint(*pubKey) + + htlc.BlindingPoint = &blindingPoint + + require.NoError(t, htlc.Encode(&b, 0)) + require.NoError(t, htlcDecoded.Decode(&b, 0)) + require.Equal(t, htlc, htlcDecoded) +} diff --git a/record/blinded_data.go b/record/blinded_data.go new file mode 100644 index 00000000000..d10c7eaf0b7 --- /dev/null +++ b/record/blinded_data.go @@ -0,0 +1,319 @@ +package record + +import ( + "bytes" + "encoding/binary" + "io" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // ShortChannelIDType is a record type for the outgoing channel short + // ID. + ShortChannelIDType tlv.Type = 2 + + // NextNodeType is a record type for the unblinded next node ID. + NextNodeType tlv.Type = 4 + + // PaymentRelayType is the record type for a tlv containing fee and + // cltv forwarding information. + PaymentRelayType tlv.Type = 10 + + // PaymentConstraintType is a tlv containing the constraints placed + // on a forwarded payment. + PaymentConstraintType tlv.Type = 12 + + // FeatureVectorType is the record type for a tlv with the features + // supported by the blinded hop. + FeatureVectorType tlv.Type = 14 +) + +// BlindedRouteData contains the information that is included in a blinded +// route encrypted data blob. +type BlindedRouteData struct { + // ShortChannelID is the channel ID of the next hop. + ShortChannelID *lnwire.ShortChannelID + + // NextNodeID is the unblinded node ID of the next hop. + NextNodeID *btcec.PublicKey + + // RelayInfo provides the relay parameters for the hop. + RelayInfo *PaymentRelayInfo + + // Constraints provides the payment relay constraints for the hop. + Constraints *PaymentConstraints + + // Features is the set of features the payment requires. + Features *lnwire.FeatureVector +} + +// DecodeBlindedRouteData decodes the data provided within a blinded route. +func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) { + var ( + routeData = &BlindedRouteData{ + RelayInfo: &PaymentRelayInfo{}, + Constraints: &PaymentConstraints{}, + // We create a non-nil but empty set of features by + // default, so that we don't need to worry about nil + // values and can decode directly into the raw vector. + Features: lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(), lnwire.Features, + ), + } + + shortID uint64 + ) + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(ShortChannelIDType, &shortID), + tlv.MakePrimitiveRecord(NextNodeType, &routeData.NextNodeID), + newPaymentRelayRecord(routeData.RelayInfo), + newPaymentConstraintsRecord(routeData.Constraints), + routeData.Features.Record(FeatureVectorType), + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + tlvMap, err := stream.DecodeWithParsedTypes(r) + if err != nil { + return nil, err + } + + if _, ok := tlvMap[PaymentRelayType]; !ok { + routeData.RelayInfo = nil + } + + if _, ok := tlvMap[PaymentConstraintType]; !ok { + routeData.Constraints = nil + } + + if _, ok := tlvMap[ShortChannelIDType]; ok { + shortID := lnwire.NewShortChanIDFromInt(shortID) + routeData.ShortChannelID = &shortID + } + + return routeData, nil +} + +// EncodeBlindedRouteData encodes the blinded route data provided. +func EncodeBlindedRouteData(data *BlindedRouteData) ([]byte, error) { + var ( + w = new(bytes.Buffer) + records []tlv.Record + ) + + if data.ShortChannelID != nil { + shortID := data.ShortChannelID.ToUint64() + shortIDRecord := tlv.MakePrimitiveRecord( + ShortChannelIDType, &shortID, + ) + + records = append(records, shortIDRecord) + } + + if data.NextNodeID != nil { + nodeIDRecord := tlv.MakePrimitiveRecord( + NextNodeType, &data.NextNodeID, + ) + records = append(records, nodeIDRecord) + } + + if data.RelayInfo != nil { + relayRecord := newPaymentRelayRecord(data.RelayInfo) + records = append(records, relayRecord) + } + + if data.Constraints != nil { + constraintsRecord := newPaymentConstraintsRecord( + data.Constraints, + ) + records = append(records, constraintsRecord) + } + + if data.Features != nil && !data.Features.IsEmpty() { + featuresRecord := data.Features.RawFeatureVector.Record( + FeatureVectorType, + ) + records = append(records, featuresRecord) + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + if err := stream.Encode(w); err != nil { + return nil, err + } + + return w.Bytes(), nil +} + +// PaymentRelayInfo describes the relay policy for a blinded path. +type PaymentRelayInfo struct { + // CltvExpiryDelta is the expiry delta for the payment. + CltvExpiryDelta uint16 + + // FeeRate is the fee rate that will be charged per millionth of a + // satoshi. + FeeRate uint32 + + // BaseFee is the per-htlc fee charged. + BaseFee uint32 +} + +// newPaymentRelayRecord creates a tlv.Record that encodes the payment relay +// (type 10) type for an encrypted blob payload. +func newPaymentRelayRecord(info *PaymentRelayInfo) tlv.Record { + return tlv.MakeDynamicRecord( + PaymentRelayType, &info, func() uint64 { + // uint16 + uint32 + tuint32 + return 2 + 4 + tlv.SizeTUint32(info.BaseFee) + }, encodePaymentRelay, decodePaymentRelay, + ) +} + +func encodePaymentRelay(w io.Writer, val interface{}, buf *[8]byte) error { + if t, ok := val.(**PaymentRelayInfo); ok { + relayInfo := *t + + // Just write our first 6 bytes directly. + binary.BigEndian.PutUint16(buf[:2], relayInfo.CltvExpiryDelta) + binary.BigEndian.PutUint32(buf[2:6], relayInfo.FeeRate) + if _, err := w.Write(buf[0:6]); err != nil { + return err + } + + // We can safely reuse buf here because we overwrite its + // contents. + return tlv.ETUint32(w, &relayInfo.BaseFee, buf) + } + + return tlv.NewTypeForEncodingErr(val, "*hop.paymentRelayInfo") +} + +func decodePaymentRelay(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if t, ok := val.(**PaymentRelayInfo); ok && l <= 10 { + scratch := make([]byte, l) + + n, err := io.ReadFull(r, scratch) + if err != nil { + return err + } + + // We expect at least 6 bytes, because we have q, bytes for + // cltv delta and 4 bytes for fee rate. + if n < 6 { + return tlv.NewTypeForDecodingErr(val, + "*hop.paymentRelayInfo", uint64(n), 6) + } + + relayInfo := *t + + relayInfo.CltvExpiryDelta = binary.BigEndian.Uint16( + scratch[0:2], + ) + relayInfo.FeeRate = binary.BigEndian.Uint32(scratch[2:6]) + + // To be able to re-use the DTUint32 function we create a + // buffer with just the bytes holding the variable length u32. + // If the base fee is zero, this will be an empty buffer, which + // is okay. + b := bytes.NewBuffer(scratch[6:]) + + return tlv.DTUint32(b, &relayInfo.BaseFee, buf, l-6) + } + + return tlv.NewTypeForDecodingErr(val, "*hop.paymentRelayInfo", l, 10) +} + +// PaymentConstraints is a set of restrictions on a payment. +type PaymentConstraints struct { + // MaxCltvExpiry is the maximum expiry height for the payment. + MaxCltvExpiry uint32 + + // HtlcMinimumMsat is the minimum htlc size for the payment. + HtlcMinimumMsat lnwire.MilliSatoshi +} + +func newPaymentConstraintsRecord(constraints *PaymentConstraints) tlv.Record { + return tlv.MakeDynamicRecord( + PaymentConstraintType, &constraints, func() uint64 { + // uint32 + tuint64. + return 4 + tlv.SizeTUint64(uint64( + constraints.HtlcMinimumMsat, + )) + }, + encodePaymentConstraints, decodePaymentConstraints, + ) +} + +func encodePaymentConstraints(w io.Writer, val interface{}, + buf *[8]byte) error { + + if c, ok := val.(**PaymentConstraints); ok { + constraints := *c + + binary.BigEndian.PutUint32(buf[:4], constraints.MaxCltvExpiry) + if _, err := w.Write(buf[:4]); err != nil { + return err + } + + // We can safely re-use buf here because we overwrite its + // contents. + htlcMsat := uint64(constraints.HtlcMinimumMsat) + + return tlv.ETUint64(w, &htlcMsat, buf) + } + + return tlv.NewTypeForEncodingErr(val, "**PaymentConstraints") +} + +func decodePaymentConstraints(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if c, ok := val.(**PaymentConstraints); ok && l <= 12 { + scratch := make([]byte, l) + + n, err := io.ReadFull(r, scratch) + if err != nil { + return err + } + + // We expect at least 4 bytes for our uint32. + if n < 4 { + return tlv.NewTypeForDecodingErr(val, + "*paymentConstraints", uint64(n), 4) + } + + payConstraints := *c + + payConstraints.MaxCltvExpiry = binary.BigEndian.Uint32( + scratch[:4], + ) + + // This could be empty if our minimum is zero, that's okay. + var ( + b = bytes.NewBuffer(scratch[4:]) + minHtlc uint64 + ) + + err = tlv.DTUint64(b, &minHtlc, buf, l-4) + if err != nil { + return err + } + payConstraints.HtlcMinimumMsat = lnwire.MilliSatoshi(minHtlc) + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "**PaymentConstraints", l, l) +} diff --git a/record/blinded_data_test.go b/record/blinded_data_test.go new file mode 100644 index 00000000000..b60165d6a0f --- /dev/null +++ b/record/blinded_data_test.go @@ -0,0 +1,119 @@ +package record + +import ( + "bytes" + "encoding/hex" + "math" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" +) + +//nolint:lll +const pubkeyStr = "02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619" + +func pubkey(t *testing.T) *btcec.PublicKey { + t.Helper() + + nodeBytes, err := hex.DecodeString(pubkeyStr) + require.NoError(t, err) + + nodePk, err := btcec.ParsePubKey(nodeBytes) + require.NoError(t, err) + + return nodePk +} + +// TestBlindedDataEncoding tests encoding and decoding of blinded data blobs. +// These tests specifically cover cases where the variable length encoded +// integers values have different numbers of leading zeros trimmed because +// these TLVs are the first composite records with variable length tlvs +// (previously, a variable length integer would take up the whole record). +func TestBlindedDataEncoding(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + baseFee uint32 + htlcMin lnwire.MilliSatoshi + features *lnwire.FeatureVector + }{ + { + name: "zero variable values", + baseFee: 0, + htlcMin: 0, + }, + { + name: "zeros trimmed", + baseFee: math.MaxUint32 / 2, + htlcMin: math.MaxUint64 / 2, + }, + { + name: "no zeros trimmed", + baseFee: math.MaxUint32, + htlcMin: math.MaxUint64, + }, + { + name: "nil feature vector", + features: nil, + }, + { + name: "non-nil, but empty feature vector", + features: lnwire.EmptyFeatureVector(), + }, + { + name: "populated feature vector", + features: lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(lnwire.AMPOptional), + lnwire.Features, + ), + }, + } + + for _, testCase := range tests { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + // Create a standard set of blinded route data, using + // the values from our test case for the variable + // length encoded values. + channelID := lnwire.NewShortChanIDFromInt(1) + encodedData := &BlindedRouteData{ + ShortChannelID: &channelID, + NextNodeID: pubkey(t), + RelayInfo: &PaymentRelayInfo{ + FeeRate: 2, + CltvExpiryDelta: 3, + BaseFee: testCase.baseFee, + }, + Constraints: &PaymentConstraints{ + MaxCltvExpiry: 4, + HtlcMinimumMsat: testCase.htlcMin, + }, + Features: testCase.features, + } + + encoded, err := EncodeBlindedRouteData(encodedData) + require.NoError(t, err) + + // We fill a non-nil feature vector if there is no + // features tlv, so we set our expected feature vector + // to an empty one if that's what we expect + if encodedData.Features == nil || + encodedData.Features.IsEmpty() { + + //nolint:lll + encodedData.Features = lnwire.EmptyFeatureVector() + } + + b := bytes.NewBuffer(encoded) + decodedData, err := DecodeBlindedRouteData(b) + require.NoError(t, err) + + require.Equal(t, encodedData, decodedData) + }) + } +}