Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion merging/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def split(x):

r = min(a.shape[1], r)
num_src_actual = a.shape[1]
chunk_size = min(5000, num_src_actual)
# Ensure chunk_size is at least 1 to avoid range() error
chunk_size = max(1, min(5000, num_src_actual))

node_max = torch.empty(B, num_src_actual, device=a.device, dtype=a.dtype)
node_idx = torch.empty(B, num_src_actual, device=a.device, dtype=torch.long)
Expand Down
28 changes: 25 additions & 3 deletions vggt/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,33 @@ def forward(self, x: Tensor, pos=None, global_merging=None) -> Tensor:

merge_ratio = 0.9
r = int(x.shape[1] * merge_ratio)


# 动态计算patch_width和patch_height
# 假设每个图像有额外的5个token(相机token和register token)
# 尝试找到合适的w和h使得 w*h+5 能整除N
num_imgs = 1
while (N % (num_imgs) != 0) and (num_imgs < B + 1):
num_imgs += 1

tokens_per_img = N // num_imgs
# 减去5个额外token得到纯patch token数
patch_tokens = tokens_per_img - 5

# 尝试找到最接近的w和h使得w*h≈patch_tokens
import math
w = int(math.sqrt(patch_tokens))
while patch_tokens % w != 0 and w > 1:
w -= 1
h = patch_tokens // w if w > 0 else 1

# 如果计算失败,使用默认值
if w * h != patch_tokens:
w, h = self.patch_width, self.patch_height

m, u = token_merge_bipartite2d(
x,
self.patch_width,
self.patch_height,
w,
h,
2,
2,
r,
Expand Down