Skip to content

Commit

Permalink
[Enhance]: Avoid overflow issue in pooling layers of RTMDet when usin…
Browse files Browse the repository at this point in the history
…g AMP (open-mmlab#9670)

Co-authored-by: Wenwei Zhang <[email protected]>
  • Loading branch information
zylo117 and ZwwWayne authored Jan 30, 2023
1 parent a2b3e84 commit 4e83e86
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
4 changes: 3 additions & 1 deletion mmdet/models/backbones/csp_darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def __init__(self,

def forward(self, x):
x = self.conv1(x)
x = torch.cat([x] + [pooling(x) for pooling in self.poolings], dim=1)
with torch.cuda.amp.autocast(enabled=False):
x = torch.cat(
[x] + [pooling(x) for pooling in self.poolings], dim=1)
x = self.conv2(x)
return x

Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/layers/se_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def __init__(self, channels: int, init_cfg: OptMultiConfig = None) -> None:

def forward(self, x: Tensor) -> Tensor:
"""Forward function for ChannelAttention."""
out = self.global_avgpool(x)
with torch.cuda.amp.autocast(enabled=False):
out = self.global_avgpool(x)
out = self.fc(out)
out = self.act(out)
return x * out

0 comments on commit 4e83e86

Please sign in to comment.