@@ -49,13 +49,15 @@ pub fn find_path<L: Deref, GL: Deref>(
4949
5050 // Add our start and first-hops to `frontier`.
5151 let start = NodeId :: from_pubkey ( & our_node_pubkey) ;
52+ let mut valid_first_hops = HashSet :: new ( ) ;
5253 let mut frontier = BinaryHeap :: new ( ) ;
5354 frontier. push ( PathBuildingHop { cost : 0 , node_id : start, parent_node_id : start } ) ;
5455 if let Some ( first_hops) = first_hops {
5556 for hop in first_hops {
5657 if !hop. counterparty . features . supports_onion_messages ( ) { continue ; }
5758 let node_id = NodeId :: from_pubkey ( & hop. counterparty . node_id ) ;
5859 frontier. push ( PathBuildingHop { cost : 1 , node_id, parent_node_id : start } ) ;
60+ valid_first_hops. insert ( node_id) ;
5961 }
6062 }
6163
@@ -71,7 +73,7 @@ pub fn find_path<L: Deref, GL: Deref>(
7173 return Ok ( path)
7274 }
7375 if let Some ( node_info) = network_nodes. get ( & node_id) {
74- if node_id == our_node_id {
76+ if valid_first_hops . contains ( & node_id ) || node_id == our_node_id {
7577 } else if let Some ( node_ann) = & node_info. announcement_info {
7678 if !node_ann. features . supports_onion_messages ( ) || node_ann. features . requires_unknown_bits ( )
7779 { continue ; }
@@ -166,7 +168,7 @@ fn reverse_path(
166168#[ cfg( test) ]
167169mod tests {
168170 use ln:: features:: { InitFeatures , NodeFeatures } ;
169- use routing:: test_utils:: { add_or_update_node, build_graph_with_features, build_line_graph, get_nodes} ;
171+ use routing:: test_utils:: { add_or_update_node, build_graph_with_features, build_line_graph, get_channel_details , get_nodes} ;
170172
171173 use sync:: Arc ;
172174
@@ -239,6 +241,13 @@ mod tests {
239241 // If all nodes require some features we don't understand, route should fail
240242 let err = super :: find_path ( & our_id, & node_pks[ 2 ] , & network_graph, None , Arc :: clone ( & logger) ) . unwrap_err ( ) ;
241243 assert_eq ! ( err, super :: Error :: PathNotFound ) ;
244+
245+ // If we specify a channel to node7, that overrides our local channel view and that gets used
246+ let our_chans = vec ! [ get_channel_details( Some ( 42 ) , node_pks[ 7 ] . clone( ) , features, 250_000_000 ) ] ;
247+ let path = super :: find_path ( & our_id, & node_pks[ 2 ] , & network_graph, Some ( & our_chans. iter ( ) . collect :: < Vec < _ > > ( ) ) , Arc :: clone ( & logger) ) . unwrap ( ) ;
248+ assert_eq ! ( path. len( ) , 2 ) ;
249+ assert_eq ! ( path[ 0 ] , node_pks[ 7 ] ) ;
250+ assert_eq ! ( path[ 1 ] , node_pks[ 2 ] ) ;
242251 }
243252
244253 #[ test]
@@ -256,6 +265,12 @@ mod tests {
256265 assert_eq ! ( path[ 0 ] , node_pks[ 1 ] ) ;
257266 assert_eq ! ( path[ 1 ] , node_pks[ 2 ] ) ;
258267 assert_eq ! ( path[ 2 ] , node_pks[ 0 ] ) ;
268+
269+ // If we specify a channel to node1, that overrides our local channel view and that gets used
270+ let our_chans = vec ! [ get_channel_details( Some ( 42 ) , node_pks[ 0 ] . clone( ) , features, 250_000_000 ) ] ;
271+ let path = super :: find_path ( & our_id, & node_pks[ 0 ] , & network_graph, Some ( & our_chans. iter ( ) . collect :: < Vec < _ > > ( ) ) , Arc :: clone ( & logger) ) . unwrap ( ) ;
272+ assert_eq ! ( path. len( ) , 1 ) ;
273+ assert_eq ! ( path[ 0 ] , node_pks[ 0 ] ) ;
259274 }
260275}
261276
0 commit comments