From b7bfc96a83e7738c09b343cc0f1ce6a21feb86c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ziyang=20Chen=20=28=E9=99=88=E5=AD=90=E6=89=AC=29?= Date: Sun, 2 Nov 2025 20:50:20 +0800 Subject: [PATCH] size about patch_width and patch_height --- merging/merge.py | 3 ++- vggt/layers/attention.py | 28 +++++++++++++++++++++++++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/merging/merge.py b/merging/merge.py index e2094b6..9e80a3d 100644 --- a/merging/merge.py +++ b/merging/merge.py @@ -217,7 +217,8 @@ def split(x): r = min(a.shape[1], r) num_src_actual = a.shape[1] - chunk_size = min(5000, num_src_actual) + # Ensure chunk_size is at least 1 to avoid range() error + chunk_size = max(1, min(5000, num_src_actual)) node_max = torch.empty(B, num_src_actual, device=a.device, dtype=a.dtype) node_idx = torch.empty(B, num_src_actual, device=a.device, dtype=torch.long) diff --git a/vggt/layers/attention.py b/vggt/layers/attention.py index aef68a4..9c66cf6 100644 --- a/vggt/layers/attention.py +++ b/vggt/layers/attention.py @@ -168,11 +168,33 @@ def forward(self, x: Tensor, pos=None, global_merging=None) -> Tensor: merge_ratio = 0.9 r = int(x.shape[1] * merge_ratio) - + + # 动态计算patch_width和patch_height + # 假设每个图像有额外的5个token(相机token和register token) + # 尝试找到合适的w和h使得 w*h+5 能整除N + num_imgs = 1 + while (N % (num_imgs) != 0) and (num_imgs < B + 1): + num_imgs += 1 + + tokens_per_img = N // num_imgs + # 减去5个额外token得到纯patch token数 + patch_tokens = tokens_per_img - 5 + + # 尝试找到最接近的w和h使得w*h≈patch_tokens + import math + w = int(math.sqrt(patch_tokens)) + while patch_tokens % w != 0 and w > 1: + w -= 1 + h = patch_tokens // w if w > 0 else 1 + + # 如果计算失败,使用默认值 + if w * h != patch_tokens: + w, h = self.patch_width, self.patch_height + m, u = token_merge_bipartite2d( x, - self.patch_width, - self.patch_height, + w, + h, 2, 2, r,