Skip to content

Commit cef330b

Browse files
author
x299
committed
[Feat] Add XFeat + LightGlue and sequential matching, referencing cvg#441
1 parent 1252817 commit cef330b

File tree

8 files changed

+287
-16
lines changed

8 files changed

+287
-16
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ outputs/
88
datasets/*
99
!datasets/sacre_coeur/
1010
datasets/sacre_coeur/query
11+
.vscode

hloc/extract_features.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,17 @@
125125
"resize_max": 1024,
126126
},
127127
},
128+
"xfeat": {
129+
"output": "feats-xfeat-n5000-r1600",
130+
"model": {
131+
"name": "xfeat",
132+
"max_keypoints": 5000,
133+
},
134+
"preprocessing": {
135+
"grayscale": False,
136+
"resize_max": 1600,
137+
},
138+
},
128139
# Global descriptors
129140
"dir": {
130141
"output": "global-feats-dir",

hloc/extractors/xfeat.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
3+
from hloc import logger
4+
5+
from ..utils.base_model import BaseModel
6+
7+
8+
class XFeat(BaseModel):
9+
default_conf = {
10+
"keypoint_threshold": 0.005,
11+
"max_keypoints": -1,
12+
}
13+
required_inputs = ["image"]
14+
15+
def _init(self, conf):
16+
self.net = torch.hub.load(
17+
"verlab/accelerated_features",
18+
"XFeat",
19+
pretrained=True,
20+
top_k=self.conf["max_keypoints"],
21+
)
22+
logger.info("Load XFeat(sparse) model done.")
23+
24+
def _forward(self, data):
25+
pred = self.net.detectAndCompute(
26+
data["image"], top_k=self.conf["max_keypoints"]
27+
)[0]
28+
pred = {
29+
"keypoints": pred["keypoints"][None],
30+
"scores": pred["scores"][None],
31+
"descriptors": pred["descriptors"].T[None],
32+
}
33+
return pred

hloc/match_features.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@
4242
"features": "aliked",
4343
},
4444
},
45+
"xfeat+lighterglue": {
46+
"output": "matches-xfeat-lighterglue",
47+
"model": {
48+
"name": "lighterglue",
49+
"features": "xfeat",
50+
},
51+
},
4552
"superglue": {
4653
"output": "matches-superglue",
4754
"model": {

hloc/matchers/lighterglue.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
from lightglue import LightGlue as LightGlue_
3+
4+
from ..utils.base_model import BaseModel
5+
6+
7+
class LighterGlue(BaseModel):
8+
default_conf_xfeat = {
9+
"name": "lighterglue", # just for interfacing
10+
"input_dim": 64, # input descriptor dimension (autoselected from weights)
11+
"descriptor_dim": 96,
12+
"add_scale_ori": False,
13+
"add_laf": False, # for KeyNetAffNetHardNet
14+
"scale_coef": 1.0, # to compensate for the SIFT scale bigger than KeyNet
15+
"n_layers": 6,
16+
"num_heads": 1,
17+
"flash": True, # enable FlashAttention if available.
18+
"mp": False, # enable mixed precision
19+
"depth_confidence": -1, # early stopping, disable with -1
20+
"width_confidence": 0.95, # point pruning, disable with -1
21+
"filter_threshold": 0.1, # match threshold
22+
"weights": None,
23+
}
24+
required_inputs = [
25+
"image0",
26+
"keypoints0",
27+
"descriptors0",
28+
"image1",
29+
"keypoints1",
30+
"descriptors1",
31+
]
32+
33+
def _init(self, conf):
34+
LightGlue_.default_conf = self.default_conf_xfeat
35+
self.net = LightGlue_(None, **conf)
36+
url = "https://github.com/verlab/accelerated_features/raw/main/weights/xfeat-lighterglue.pt" # noqa: E501
37+
state_dict = torch.hub.load_state_dict_from_url(url)
38+
39+
# rename old state dict entries
40+
for i in range(self.net.conf.n_layers):
41+
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
42+
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
43+
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
44+
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
45+
state_dict = {k.replace("matcher.", ""): v for k, v in state_dict.items()}
46+
47+
self.net.load_state_dict(state_dict, strict=False)
48+
49+
def _forward(self, data):
50+
data["descriptors0"] = data["descriptors0"].transpose(-1, -2)
51+
data["descriptors1"] = data["descriptors1"].transpose(-1, -2)
52+
53+
return self.net(
54+
{
55+
"image0": {k[:-1]: v for k, v in data.items() if k[-1] == "0"},
56+
"image1": {k[:-1]: v for k, v in data.items() if k[-1] == "1"},
57+
}
58+
)

hloc/pairs_from_retrieval.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def main(
8181
db_list=None,
8282
db_model=None,
8383
db_descriptors=None,
84+
match_mask=None
8485
):
8586
logger.info("Extracting image pairs from a retrieval database.")
8687

@@ -109,7 +110,16 @@ def main(
109110
sim = torch.einsum("id,jd->ij", query_desc.to(device), db_desc.to(device))
110111

111112
# Avoid self-matching
112-
self = np.array(query_names)[:, None] == np.array(db_names)[None]
113+
# self = np.array(query_names)[:, None] == np.array(db_names)[None]
114+
if match_mask is None:
115+
# Avoid self-matching
116+
self = np.array(query_names)[:, None] == np.array(db_names)[None]
117+
else:
118+
assert match_mask.shape == (
119+
len(query_names),
120+
len(db_names),
121+
), "mask shape must match size of query and database images!"
122+
self = match_mask
113123
pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0)
114124
pairs = [(query_names[i], db_names[j]) for i, j in pairs]
115125

hloc/pairs_from_sequential.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import argparse
2+
import collections.abc as collections
3+
import os
4+
from pathlib import Path
5+
from typing import List, Optional, Union
6+
7+
import numpy as np
8+
9+
from hloc import logger, pairs_from_retrieval
10+
from hloc.utils.io import list_h5_names
11+
from hloc.utils.parsers import parse_image_lists, parse_retrieval
12+
13+
14+
def main(
15+
output: Path,
16+
image_list: Optional[Union[Path, List[str]]] = None,
17+
features: Optional[Path] = None,
18+
window_size: Optional[int] = 10,
19+
quadratic_overlap: bool = True,
20+
use_loop_closure: bool = False,
21+
retrieval_path: Optional[Union[Path, str]] = None,
22+
retrieval_interval: Optional[int] = 2,
23+
num_loc: Optional[int] = 5,
24+
) -> None:
25+
"""
26+
Generate pairs of images based on sequential matching and optional loop closure.
27+
Args:
28+
output (Path): The output file path where the pairs will be saved.
29+
image_list (Optional[Union[Path, List[str]]]):
30+
A path to a file containing a list of images or a list of image names.
31+
features (Optional[Path]):
32+
A path to a feature file containing image features.
33+
window_size (Optional[int]):
34+
The size of the window for sequential matching. Default is 10.
35+
quadratic_overlap (bool):
36+
Whether to use quadratic overlap in sequential matching. Default is True.
37+
use_loop_closure (bool):
38+
Whether to use loop closure for additional matching. Default is False.
39+
retrieval_path (Optional[Union[Path, str]]):
40+
The path to the retrieval file for loop closure.
41+
retrieval_interval (Optional[int]):
42+
The interval for selecting query images for loop closure. Default is 2.
43+
num_loc (Optional[int]):
44+
The number of top retrieval matches to consider for loop closure.
45+
Default is 5.
46+
Raises:
47+
ValueError: If neither image_list nor features are provided,
48+
or if image_list is of an unknown type.
49+
Returns:
50+
None
51+
"""
52+
if image_list is not None:
53+
if isinstance(image_list, (str, Path)):
54+
print(image_list)
55+
names_q = parse_image_lists(image_list)
56+
elif isinstance(image_list, collections.Iterable):
57+
names_q = list(image_list)
58+
else:
59+
raise ValueError(f"Unknown type for image list: {image_list}")
60+
elif features is not None:
61+
names_q = list_h5_names(features)
62+
else:
63+
raise ValueError("Provide either a list of images or a feature file.")
64+
65+
pairs = []
66+
N = len(names_q)
67+
68+
for i in range(N - 1):
69+
for j in range(i + 1, min(i + window_size + 1, N)):
70+
pairs.append((names_q[i], names_q[j]))
71+
72+
if quadratic_overlap:
73+
q = 2 ** (j - i)
74+
if q > window_size and i + q < N:
75+
pairs.append((names_q[i], names_q[i + q]))
76+
77+
if use_loop_closure:
78+
retrieval_pairs_tmp: Path = output.parent / "retrieval-pairs-tmp.txt"
79+
80+
# match mask describes for each image, which images NOT to include in retrevial
81+
# match search I.e., no reason to get retrieval matches for matches
82+
# already included from sequential matching
83+
84+
query_list = names_q[::retrieval_interval]
85+
M = len(query_list)
86+
match_mask = np.zeros((M, N), dtype=bool)
87+
88+
for i in range(M):
89+
for k in range(window_size + 1):
90+
if i * retrieval_interval - k >= 0 and i * retrieval_interval - k < N:
91+
match_mask[i][i * retrieval_interval - k] = 1
92+
if i * retrieval_interval + k >= 0 and i * retrieval_interval + k < N:
93+
match_mask[i][i * retrieval_interval + k] = 1
94+
95+
if quadratic_overlap:
96+
if (
97+
i * retrieval_interval - 2**k >= 0
98+
and i * retrieval_interval - 2**k < N
99+
):
100+
match_mask[i][i * retrieval_interval - 2**k] = 1
101+
if (
102+
i * retrieval_interval + 2**k >= 0
103+
and i * retrieval_interval + 2**k < N
104+
):
105+
match_mask[i][i * retrieval_interval + 2**k] = 1
106+
107+
pairs_from_retrieval.main(
108+
retrieval_path,
109+
retrieval_pairs_tmp,
110+
num_matched=num_loc,
111+
match_mask=match_mask,
112+
db_list=names_q,
113+
query_list=query_list,
114+
)
115+
116+
retrieval = parse_retrieval(retrieval_pairs_tmp)
117+
118+
for key, val in retrieval.items():
119+
for match in val:
120+
pairs.append((key, match))
121+
122+
os.unlink(retrieval_pairs_tmp)
123+
124+
logger.info(f"Found {len(pairs)} pairs.")
125+
with open(output, "w") as f:
126+
f.write("\n".join(" ".join([i, j]) for i, j in pairs))
127+
128+
129+
if __name__ == "__main__":
130+
parser = argparse.ArgumentParser(
131+
description="""
132+
Create a list of image pairs basedon the sequence of images on alphabetic order
133+
"""
134+
)
135+
parser.add_argument("--output", required=True, type=Path)
136+
parser.add_argument("--image_list", type=Path)
137+
parser.add_argument("--features", type=Path)
138+
parser.add_argument(
139+
"--overlap", type=int, default=10, help="Number of overlapping image pairs"
140+
)
141+
parser.add_argument(
142+
"--quadratic_overlap",
143+
action="store_true",
144+
help="Whether to match images against their quadratic neighbors.",
145+
)
146+
args = parser.parse_args()
147+
main(**args.__dict__)

hloc/reconstruction.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
import argparse
23
import multiprocessing
34
import shutil
@@ -15,16 +16,18 @@
1516
import_matches,
1617
parse_option_args,
1718
)
18-
from .utils.io import open_colmap_database
19+
from .utils.database import COLMAPDatabase
1920

2021

2122
def create_empty_db(database_path: Path):
2223
if database_path.exists():
2324
logger.warning("The database already exists, deleting it.")
2425
database_path.unlink()
2526
logger.info("Creating an empty database...")
26-
with open_colmap_database(database_path) as _:
27-
pass
27+
db = COLMAPDatabase.connect(database_path)
28+
db.create_tables()
29+
db.commit()
30+
db.close()
2831

2932

3033
def import_images(
@@ -51,9 +54,11 @@ def import_images(
5154

5255

5356
def get_image_ids(database_path: Path) -> Dict[str, int]:
57+
db = COLMAPDatabase.connect(database_path)
5458
images = {}
55-
with open_colmap_database(database_path) as db:
56-
images = {image.name: image.image_id for image in db.read_all_images()}
59+
for name, image_id in db.execute("SELECT name, image_id FROM images;"):
60+
images[name] = image_id
61+
db.close()
5762
return images
5863

5964

@@ -167,16 +172,15 @@ def main(
167172
create_empty_db(database)
168173
import_images(image_dir, database, camera_mode, image_list, image_options)
169174
image_ids = get_image_ids(database)
170-
with open_colmap_database(database) as db:
171-
import_features(image_ids, db, features)
172-
import_matches(
173-
image_ids,
174-
db,
175-
pairs,
176-
matches,
177-
min_match_score,
178-
skip_geometric_verification,
179-
)
175+
import_features(image_ids, database, features)
176+
import_matches(
177+
image_ids,
178+
database,
179+
pairs,
180+
matches,
181+
min_match_score,
182+
skip_geometric_verification,
183+
)
180184
if not skip_geometric_verification:
181185
estimation_and_geometric_verification(database, pairs, verbose)
182186
reconstruction = run_reconstruction(

0 commit comments

Comments
 (0)