-
Notifications
You must be signed in to change notification settings - Fork 0
review query_learn #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
main.py
Outdated
@@ -126,53 +112,61 @@ def main(args): | |||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) | |||
model_without_ddp = model.module | |||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) | |||
params = [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezekielbarnett this might be wrong -- in the original they set a different learning rate for "backbone" not in n
and "backbone" in n
. Your version sets the same LR for both
|
||
# embed query | ||
query_fts, query_pos = self.backbone(query) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you have to pool the query somehow? I didn't look closely at their backbone
, but I think it returns the full feature map, not the global average pooled version which I'd think we want to use
models/detr.py
Outdated
|
||
src, mask = features[-1].decompose() | ||
assert mask is not None | ||
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] | ||
#hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] | ||
hs = self.transformer(self.input_proj(src), mask, query_fts, pos[-1])[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this it probably makes more sense to add query_fts
to self.query_embed.weight
instead of just repeating query_fts
. I don't 100% understand what query_embed
learns, but I think that's more reasonable.
No description provided.