Skip to content

review query_learn #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

review query_learn #1

wants to merge 7 commits into from

Conversation

bkj
Copy link

@bkj bkj commented Jan 19, 2021

No description provided.

main.py Outdated
@@ -126,53 +112,61 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
params = [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezekielbarnett this might be wrong -- in the original they set a different learning rate for "backbone" not in n and "backbone" in n. Your version sets the same LR for both


# embed query
query_fts, query_pos = self.backbone(query)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you have to pool the query somehow? I didn't look closely at their backbone, but I think it returns the full feature map, not the global average pooled version which I'd think we want to use

models/detr.py Outdated

src, mask = features[-1].decompose()
assert mask is not None
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
#hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
hs = self.transformer(self.input_proj(src), mask, query_fts, pos[-1])[0]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this it probably makes more sense to add query_fts to self.query_embed.weight instead of just repeating query_fts. I don't 100% understand what query_embed learns, but I think that's more reasonable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants