Skip to content

Commit 65ae30b

Browse files
committed
add detection_threshold parameter for fixed keypoint number detector
1 parent 5580d38 commit 65ae30b

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

hubconf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from modules.xfeat import XFeat as _XFeat
33
import torch
44

5-
def XFeat(pretrained=True, top_k=4096):
5+
def XFeat(pretrained=True, top_k=4096, detection_threshold=0.05):
66
"""
77
XFeat model
88
pretrained (bool): kwargs, load pretrained weights into the model
@@ -11,5 +11,5 @@ def XFeat(pretrained=True, top_k=4096):
1111
if pretrained:
1212
weights = torch.hub.load_state_dict_from_url("https://github.com/verlab/accelerated_features/raw/main/weights/xfeat.pt")
1313

14-
model = _XFeat(weights, top_k=top_k)
14+
model = _XFeat(weights, top_k=top_k, detection_threshold=detection_threshold)
1515
return model

modules/xfeat.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ class XFeat(nn.Module):
2020
It supports inference for both sparse and semi-dense feature extraction & matching.
2121
"""
2222

23-
def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.pt', top_k = 4096):
23+
def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.pt', top_k = 4096, detection_threshold=0.05):
2424
super().__init__()
2525
self.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2626
self.net = XFeatModel().to(self.dev).eval()
2727
self.top_k = top_k
28+
self.detection_threshold = detection_threshold
2829

2930
if weights is not None:
3031
if isinstance(weights, str):
@@ -36,7 +37,7 @@ def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../w
3637
self.interpolator = InterpolateSparse2d('bicubic')
3738

3839
@torch.inference_mode()
39-
def detectAndCompute(self, x, top_k = None):
40+
def detectAndCompute(self, x, top_k = None, detection_threshold = None):
4041
"""
4142
Compute sparse keypoints & descriptors. Supports batched mode.
4243
@@ -50,6 +51,7 @@ def detectAndCompute(self, x, top_k = None):
5051
'descriptors' -> torch.Tensor(N, 64): local features
5152
"""
5253
if top_k is None: top_k = self.top_k
54+
if detection_threshold is None: detection_threshold = self.detection_threshold
5355
x, rh1, rw1 = self.preprocess_tensor(x)
5456

5557
B, _, _H1, _W1 = x.shape
@@ -59,7 +61,7 @@ def detectAndCompute(self, x, top_k = None):
5961

6062
#Convert logits to heatmap and extract kpts
6163
K1h = self.get_kpts_heatmap(K1)
62-
mkpts = self.NMS(K1h, threshold=0.05, kernel_size=5)
64+
mkpts = self.NMS(K1h, threshold=detection_threshold, kernel_size=5)
6365

6466
#Compute reliability scores
6567
_nearest = InterpolateSparse2d('nearest')

0 commit comments

Comments
 (0)