-
-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CSPResNeXt50-PANet-SPP #698
Comments
@LukeAI hi, thanks for the feedback! Off the top of my head I think we may not support some of the layers there (#631 (comment)). Do you have an exact *.cfg file that you saw improvements with? Is this complementary to Gaussian YOLO, can they both be used togethor? So this would be a replacement of the darknet53 backbone with a ReseXt50 backbone? It's too bad @WongKinYiu didn't do the modifications directly in this repo :) |
@WongKinYiu I'd like to implement this cfg in ultralytics/yolov3: The only new field I see is 'groups' in the convolution layers. Are there other new fields I didn't see? Do you know where I would slot groups into the PyTorch Lines 22 to 33 in 07c1faf
|
@LukeAI @WongKinYiu I've added import of 'groups' into the Conv2d() definition in 3bfbab7. Is this sufficient to run CSPResNeXt50-PANet-SPP? @LukeAI can you Lines 22 to 34 in 3bfbab7
|
yolov3-spp.cfg has 17 unique fields in it's cfg:
csresnext50-panet-spp.cfg has 18 unique fields. It seems group is the only newcomer. Ok, so this repo should now fully support csresnext50-panet-spp.cfg @LukeAI.
|
@WongKinYiu I am getting an error in 3 of the shortcut layers when running csresnext50-panet-spp.cfg. They are trying to add tensors of # models.py line 260:
elif mtype == 'shortcut':
try:
x = x + layer_outputs[int(mdef['from'])]
except:
print(i, x.shape, layer_outputs[int(mdef['from'])].shape)
x = layer_outputs[int(mdef['from'])]
# excepted layers:
# 8 torch.Size([1, 128, 104, 104]) torch.Size([1, 64, 104, 104])
# 12 torch.Size([1, 128, 104, 104]) torch.Size([1, 64, 104, 104])
# 16 torch.Size([1, 128, 104, 104]) torch.Size([1, 64, 104, 104]) Possible FixChange |
@glenn-jocher Hello, In pytorch, do zero padding to same size, then add. if residual_channel != shortcut_channel:
padding = torch.autograd.Variable(torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]).fill_(0))
out += torch.cat((shortcut, padding), 1)
else:
out += shortcut |
For convenient, I change line 57 to |
@WongKinYiu @LukeAI @AlexeyAB I trained csresnext50-panet-spp.cfg 86588f1 against default yolov3-spp.cfg for 27 COCO epochs at 416 (10% of full training), but got worse results at a slower speed. I ran yolov3-spp3.cfg (see #694) with slightly worse results as well. Commands to reproduce: git clone https://github.com/ultralytics/yolov3
bash yolov3/data/get_coco_dataset_gdrive.sh
cd yolov3
python3 train.py --epochs 27 --weights '' --cfg yolov3-spp.cfg --name 113
python3 train.py --epochs 27 --weights '' --cfg yolov3-spp3.cfg --name 115
python3 train.py --epochs 27 --weights '' --cfg csresnext50-panet-spp.cfg --name 121
If you guys have time and are good with PyTorch please feel free to clone this repo and try the https://github.com/WongKinYiu/CrossStagePartialNetworks/ implementations yourself. I'd really like to exploit some of the research there but I don't have time. We are getting excellent results with our baseline yolov3-spp.cfg from scratch ([email protected], [email protected] see https://github.com/ultralytics/yolov3#map), so if the improvements are relative, then they should help here also I assume. |
ok, i ll try to install this repo. so all of ur training do not use imagenet pre-trained model? |
@WongKinYiu ok great! No I don't use any pre-trained model for the initial weights. In an earlier test I found that starting from darknet53.conv.74 produced worse mAP after 273 epochs than starting from randomly initialized weights. For quick results (a day or less of training) yes, the imagenet trained weights will help, but for longer training I found they hurt. To reproduce: git clone https://github.com/ultralytics/yolov3
bash yolov3/data/get_coco_dataset_gdrive.sh
cd yolov3
python3 train.py --epochs 273 --weights darknet53.conv.74 --cfg yolov3-spp.cfg --name 41
python3 train.py --epochs 273 --weights '' --cfg yolov3-spp.cfg --name 42
|
Thanks for your reply. Do your models are trained using single GPU? |
@WongKinYiu yes I typically train them on one 2080Ti or V100, which usually do about 50 epochs per day with the default settings (5 days to train COCO). See https://github.com/ultralytics/yolov3#speed for training speeds. Multi-GPU can also be used. To get the best mAPs though |
Should I try csresnext50c.cfg? UPDATE: I put it in, but there are new layers again :)
|
No, if train from scratch, i think u will get similar results. oh, it is becuz csresnext50c.cfg is for imagenet classifier. |
Oh, haha, ok I'll leave csresnext50c.cfg alone then. |
start training... |
@WongKinYiu the exact training command to get to 40.9 AP with one GPU is: python3 train.py --weights '' --epochs 273 --batch 16 --accumulate 4 --multi --pre If you use multi-GPU though you will have more memory available, so you can use a larger --batch --accumulate combination to get to 64 like 32x2, or even 64x1: python3 train.py --weights '' --epochs 273 --batch 32 --accumulate 2 --multi --pre
|
my gpu ram is not enough even though i set |
This is weird, did you measure speed on GPU? And what FPS/ms did you get for SPP vs CSP? Have you tried converting an already trained on Darknet model CSPResNeXt50-PANet-SPP (cfg / weights) to ultralytics (pytorch), and did you get better mAP and better speed? Or does this inconsistency interfere with this conversion? #698 (comment) |
@AlexeyAB Hello, I think slow speed is talking about training speed. |
@AlexeyAB @WongKinYiu if I run test.py on the two trained models, this applies inference (and NMS) on the 5000 images in 5k.txt. This takes 138 seconds with yolov3-spp.cfg on a P4 GPU, and 139 seconds with csresnext50-panet-spp.cfg. Ah interesting, so the inference speed is nearly identical, but the training speed takes twice as long. |
So is the CSPResNeXt50-PANet-SPP operational? And does it provide better results? I am looking more into it right now. And reading the article. |
@WongKinYiu I forgot to mention, you should install Nvidia Apex for mixed precision training with this repo. It increases speed substantially and reduces memory requirements substantially. Once installed correctly you should see this: |
Yes, I have installed apex. |
@glenn-jocher Hello, I would like to know why |
@WongKinYiu ahhh this is interesting, I had not realized that. There is a memory leak when invoking train.py repeatedly, which is very obvious when running hyperparameter evolution as train.py is called repeatedly in a foor loop #392 (comment), but I did not realize --pre also causes this. This makes sense though, as it is calling train.py once to train the output biases for one epoch, then calling it again for actual training. How much extra memory is this using? |
I independently found good improvements ~+3mAP with Gaussian-Yolo and also cspresnext50-pan-spp vs. yolov3-spp - but I got pretty bad results when I tried combining them (-10mAP) - this may be because: @WongKinYiu have you tried Gaussian with cspresnext-pan-spp? Do you have any thoughts or results? |
I think you need more iterations for warmup when combine cspresnext50-pan-spp with gaussian-yolo (I have no gpus to test it currently). In my experiments, when combining cspresnext50-pan-spp with gaussian-yolo, the precision drops and recall improves. And the strange thing is that the loss become lager after 200k epochs. |
@WongKinYiu ok great! I got the last darknet model to run, but mAPs came back as 0.0. Note that I modified my default test nms Also note the latest yolov3-spp.cfg baseline trains to 41.9/61.8 at 608 with the default settings. The training commands to reproduce this are here. The two seperate --img-size are train img-size and test img-size. Multi-scale train img sizes using this command will be 288 - 640. python3 train.py --data coco2014.data --img-size 416 608 --epochs 273 --batch 16 --accum 4 --weights '' --device 0 --cfg yolov3-spp.cfg --multi |
Yes, I know.
Thanks, I just use the default setting of the repo which I used to train the model. As I remember, that repo gets about 40.9 [email protected]:0.95 on your report. By the way, all of my results are obtained by test-dev set and your results are obtained by min-val set. |
@WongKinYiu ah test-dev set could be a difference too then! Well it seems some differences remain as the ultralytics repo can't load the best performing darknet CSPDarknet53s-PANet-SPP model then. These differences must be the source of the problem I think. |
What is the difference between your training and this yolov3-spp.cfg https://github.com/WongKinYiu/CrossStagePartialNetworks/tree/pytorch#ms-coco ? |
I use this repo to train: https://github.com/ultralytics/yolov3/tree/a6f87a28e7595e71752583fb41340f9d1105d75f |
@WongKinYiu @glenn-jocher So, I want to know what improvements have been made? |
Hmmm well lots of small day to day changes. If I use the github /compare it doesn't show the date of that commit, but it shows that there are 400 commits since then, with many modifications: The README from then was showing 40.0/60.9 mAP, which is similar to what @WongKinYiu was seeing, vs today's README which shows 41.9/61.8. The improvements are over many different parts, such as the NMS, which now uses multi-label, the augmentation, which has been set to zero, the loss function reduction, which I returned to mean() instead of sum(), the cosine scheduler implementation, the increase in the LR to 0.01 after cos was implemented, and maybe a few other tiny things. The architecture itself is the same (yolov3-spp.cfg). Actually this is an important point. A lot of papers today are showing very outdated comparisons to YOLOv3, i.e. showing 33 [email protected]:0.95 like the EfficientDet paper, with a GPU latency of 51ms. The reality is the most recent YOLOv3-SPP model I trained is at 42.1 [email protected]:0.95, with a GPU latency of 12.8ms #679 (comment), which puts it far better than their own D0-D2 models in both speed and mAP. I'm not sure how best to get that message out. |
@glenn-jocher
|
Yes NMS uses multi-label now, which bumped up mAP about +0.3. Yes spatial augmentation seemed to hurt training, so I set it to zero, but left HSV augmentation on: 'hsv_h': 0.0138, # image HSV-Hue augmentation (fraction)
'hsv_s': 0.678, # image HSV-Saturation augmentation (fraction)
'hsv_v': 0.36, # image HSV-Value augmentation (fraction)
'degrees': 1.98 * 0, # image rotation (+/- deg)
'translate': 0.05 * 0, # image translation (+/- fraction)
'scale': 0.05 * 0, # image scale (+/- gain)
'shear': 0.641 * 0} # image shear (+/- deg)
I'm really hoping we might be able to merge the YOLO outputs some day so I can do away with this uncertainty in how to combine the losses from the different layers. ASFF seems to be an interesting step in that direction. |
@AlexeyAB ah also another change I forgot to mention was I changed multi-scale to change the resolution every batch now, instead of every 10 batches before. This seemed to smooth the results a bit, epoch to epoch. |
@WongKinYiu yes they look super similar to each other unfortunately. I'm not sure why we aren't seeing the same gains as the darknet training. It must have to do with the grouped convolutions I think. |
Does it currently work in such a way?
Then it will remove class1_prob = 0.5 and class2_prob = 0.5, and will leave:
Do you know how this changes the
Yes, it may help to win compete, but may be it may hurt cross-domain accuracy when testing images/videos are not similar to MS COCO. It seems it works well because Ultralitics uses letter_box-image-resizing by default, so it keeps aspect ratio and doesn't require large spatial image transformation.
What do you mean?
Does it decrease training speed, because changing of network size requires time? If we use dynamic_minibatch=1 in the Darknet, when we change |
Have you checked if EfficientNetB0-Yolo was added to the OpenCV-dnn module
So it only requires to implement |
i have only done experiments for have u tested the inference speed of enetb0-yolo using opencv-dnn? |
Not yet. I will test it on Intel CPU and Intel Myraid X neurochip |
@AlexeyAB @WongKinYiu I made a simple Colab notebook to see the time effects of group/mix convolutions. It times a tensor passing forward and backward (to mimic training) through a Conv2d() op. The speeds stay about the same even as the parameter count drops by >10X. So similar sized models using these ops may be much slower.
|
@glenn-jocher |
Hi @AlexeyAB, I ran the speed test of this network on the Intel CPU. It looks like it is almost 5 times slower than the Tiny Yolov3 PRN network on CPU as well. Below are the results, OpenCV: 3.4.10-pre (https://github.com/opencv/opencv/tree/377dd04224630e835cce8c7d67e651cae73fd3b3) It looks like depth wise convolutions are slow on CPU as well. Any thoughts? Thanks |
@glenn-jocher @mmaaz60 Take a look at the comparison: AlexeyAB/darknet#5079 |
@glenn-jocher Hi, Did you successfully train ASFF model? |
@AlexeyAB yes I trained ASFF on COCO (results99 in orange), but got slightly worse results in the end compared to default (blue). Performance in the first 5% of epochs was much better, probably because the summation of outputs reduced a lot of that early noise in the model, but did not help after that point. Of course my implementation might be wrong! |
@glenn-jocher |
@AlexeyAB yes, basically. I created a 12-anchor version of yolov3-spp.cfg called yolov4.cfg (for 4 anchors per yolo layer) which I used for my comparison (this 12 anchor model increases mAP a tiny bit, about +0.1). I compared yolov4.cfg against yolov4-asff.cfg. For asff I moved all of the yolo layers to the end, and added 3 features to the existing feature vector of length 340 to create the asff weights, so the input to each yolo layer is the same: I split the feature vectors into the traditional size
etc. using this extra ASFF code. I used sigmoid weights since softmax was much slower, and did a linear interpolation for the resizing. if ASFF:
i, n = self.index, self.nl # index in layers, number of layers
p = out[self.layers[i]]
bs, _, ny, nx = p.shape # bs, 255, 13, 13
if (self.nx, self.ny) != (nx, ny):
create_grids(self, img_size, (nx, ny), p.device, p.dtype)
# outputs and weights
# w = F.softmax(p[:, -n:], 1) # normalized weights
w = torch.sigmoid(p[:, -n:]) * (2 / n) # sigmoid weights (faster)
# w = w / w.sum(1).unsqueeze(1) # normalize across layer dimension
# weighted ASFF sum
p = out[self.layers[i]][:, :-n] * w[:, i:i + 1]
for j in range(n):
if j != i:
p += w[:, j:j + 1] * \
F.interpolate(out[self.layers[j]][:, :-n], size=[ny, nx], mode='bilinear', align_corners=False) Training was multi-scale 288-640, with metrics plotted at 608 img-size. So no, so far I haven't been able to increase accuracy with BiFPN or ASFF. The only thing that improved a tiny bit was weighted feature fusion, but the gain was tiny (0.1 mAP). |
@glenn-jocher So do you get AP50...95 higher than 40.6 - 42.4% for ASFF 608x608? https://github.com/ruinmessi/ASFF#coco It seems that ASFF+RFB or multi-block-BiFPN should use higher network resolution for higher accuracy. |
@AlexeyAB no, I actually saw worse results for my ASFF impementation, about -0.5mAP at 608 vs the default yolov4.cfg. Higher image size is definitely one of the ingredients in higher mAPs. EfficientDet uses 512@D0, 640@D1, all the way to 1280@D7: The official ASFF trains at 320-608 for 42.4@608 and 480-800 for 43.9@800. https://github.com/ruinmessi/ASFF#models |
This issue is stale because it has been open 30 days with no activity. Remove Stale label or comment or this will be closed in 5 days. |
Does this repo. support CSPResNeXt50-PANet-SPP? (https://github.com/WongKinYiu/CrossStagePartialNetworks/)
AlexeyABs support: AlexeyAB/darknet#4406
My tests have found it to be a clear winner over yolov3-spp in terms of mAP and speed.
The text was updated successfully, but these errors were encountered: