@@ -20,11 +20,12 @@ class XFeat(nn.Module):
2020 It supports inference for both sparse and semi-dense feature extraction & matching.
2121 """
2222
23- def __init__ (self , weights = os .path .abspath (os .path .dirname (__file__ )) + '/../weights/xfeat.pt' , top_k = 4096 ):
23+ def __init__ (self , weights = os .path .abspath (os .path .dirname (__file__ )) + '/../weights/xfeat.pt' , top_k = 4096 , detection_threshold = 0.05 ):
2424 super ().__init__ ()
2525 self .dev = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
2626 self .net = XFeatModel ().to (self .dev ).eval ()
2727 self .top_k = top_k
28+ self .detection_threshold = detection_threshold
2829
2930 if weights is not None :
3031 if isinstance (weights , str ):
@@ -36,7 +37,7 @@ def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../w
3637 self .interpolator = InterpolateSparse2d ('bicubic' )
3738
3839 @torch .inference_mode ()
39- def detectAndCompute (self , x , top_k = None ):
40+ def detectAndCompute (self , x , top_k = None , detection_threshold = None ):
4041 """
4142 Compute sparse keypoints & descriptors. Supports batched mode.
4243
@@ -50,6 +51,7 @@ def detectAndCompute(self, x, top_k = None):
5051 'descriptors' -> torch.Tensor(N, 64): local features
5152 """
5253 if top_k is None : top_k = self .top_k
54+ if detection_threshold is None : detection_threshold = self .detection_threshold
5355 x , rh1 , rw1 = self .preprocess_tensor (x )
5456
5557 B , _ , _H1 , _W1 = x .shape
@@ -59,7 +61,7 @@ def detectAndCompute(self, x, top_k = None):
5961
6062 #Convert logits to heatmap and extract kpts
6163 K1h = self .get_kpts_heatmap (K1 )
62- mkpts = self .NMS (K1h , threshold = 0.05 , kernel_size = 5 )
64+ mkpts = self .NMS (K1h , threshold = detection_threshold , kernel_size = 5 )
6365
6466 #Compute reliability scores
6567 _nearest = InterpolateSparse2d ('nearest' )
0 commit comments