Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于PositionRelationEmbedding的参数 #36

Open
mitu752 opened this issue Jan 16, 2025 · 1 comment
Open

关于PositionRelationEmbedding的参数 #36

mitu752 opened this issue Jan 16, 2025 · 1 comment
Labels
question Further information is requested

Comments

@mitu752
Copy link

mitu752 commented Jan 16, 2025

Question

我用位置编码做密集检测训练的时候显存占用会很大,高达45g,当调整self.position_relation_embedding = PositionRelationEmbedding(16, self.num_heads)参数为4的时候,显存占用为14g(可以接受),但是效果没设置为16好。想问下作者有没有其他减少训练显存占用的方式。

补充信息

No response

@mitu752 mitu752 added the question Further information is requested label Jan 16, 2025
@mitu752 mitu752 changed the title 关于参数PositionRelationEmbedding的参数 关于PositionRelationEmbedding的参数 Jan 16, 2025
@xiuqhou
Copy link
Owner

xiuqhou commented Jan 16, 2025

Relation的空间复杂度和输入的query平方成正比,常规的检测任务900个query就够了,显存占用还可以接受。但在做密集检测任务时,显存占用就非常可观了🤔

可以用PyTorch的checkpoint功能,以增加少量的训练时间来减少训练显存,在relation_transformer.py文件中加两行代码就可以了,大概在375行:

from torch.utils.checkpoint import checkpoint

# 原本的代码: pos_relation = self.position_relation_embedding(src_boxes, tgt_boxes).flatten(0, 1)
pos_relation = checkpoint(
   self.position_relation_embedding, src_boxes, tgt_boxes
).flatten(0, 1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants