Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented weighted-multi_input-[shortcut] layer with weights-normalization #4662

Open
AlexeyAB opened this issue Jan 9, 2020 · 91 comments
Labels
ToDo RoadMap

Comments

@AlexeyAB
Copy link
Owner

AlexeyAB commented Jan 9, 2020

Implemented weighted-multi_input-[shortcut] layer with weights-normalization, added:

New [shortcut] can:

  • can take more than 2 input layers for adding: from = -2, -3 (and -1 by default)

  • can multiply by weights:

    • the different input layers: weight per_feature (per_layer)
    • or the different input layers and channels: weight per_channel
  • can normalize weights by using: avg_relu or softmax

The simplest example: yolov3-tiny_3l_shortcut_multilayer_per_feature_softmax.cfg.txt


[shortcut]
from= -2, -3, 6            # means: -1, -2, -3, 6 - relative/absolute indexes of input layers
weights_type= per_feature  # none (default), per_feature, per_channel
weights_normalizion= relu  # none (default), relu, softmax
activation= linear         # linear (default), leaky, rely, logistic, swish, mish, ...

  • Original residual-connection:
[shortcut]
from= -5           # means: -1, -5 - relative indexes of input layers
activation= linear

67137190-cb6b4600-f263-11e9-87f5-60715b5da64c


https://arxiv.org/abs/1911.09070v1

image


https://arxiv.org/abs/1911.09070v1

image

@Kyuuki93
Copy link

Kyuuki93 commented Jan 10, 2020

@AlexeyAB @WongKinYiu I made a network has 3 BiFPN blocks with P3~P5, just take a look for sure,
darknet53-bifpn3.cfg.txt
but even set input size to 320*320 and subdivision=32will get this error

CUDA Error Prev: an illegal memory access was encountered
CUDA Error Prev: an illegal memory access was encountered: Resource temporarily unavailable
darknet: ./src/utils.c:297: error: Assertion `0` failed.
Aborted (core dumped)

@AlexeyAB
Copy link
Owner Author

@Kyuuki93 Try to use the latest commit. I set

batch=64
subdivisions=16
width=320
height=320

and trained your cfg-file for 100 iterations successfully.

@AlexeyAB
Copy link
Owner Author

@Kyuuki93

There is bug in your cfg:
use weights_type=per_feature instead of weights_type=per_feture

@AlexeyAB
Copy link
Owner Author

@Kyuuki93

Also you should not specify -1 layer, since it is set by default for all [shortcut] layers.

Just use

[shortcut]
from=61
weights_type=per_feature
weights_normalizion=relu
activation=leaky

instead of

[shortcut]
from=-1,61
weights_type=per_feature
weights_normalizion=relu
activation=leaky

@Kyuuki93
Copy link

Kyuuki93 commented Jan 11, 2020

@AlexeyAB cuda error was lead by weights_type=per_feture, my mistake.

Modified .cfg is here darknet53-bifpn3.cfg.txt which look like this
darknet53-bifpn3-P3-5

@Kyuuki93
Copy link

Kyuuki93 commented Jan 11, 2020

Comparison on my dataset, all cases with same training settings and used MS COCO detector pre-trained weights, only different in backbone

Model [email protected] [email protected] precision(.7) recall(.7) inference time(416x416)
yolov3 91.79% 63.09% 0.95 0.71 13.25ms
csresnext50-panet 92.80% 64.16% 0.96 0.67 15.61ms
darknet53-bifpn3(P3-5) 91.74% 63.48% 0.95 0.71 15.25ms

ALL network has SPP-layer, inference time test on RTX 2080Ti, BiFPN block use

weights_type=per_feature
weights_normalizion=relu

darkenet53-bifpn3-spp got very similar performance with yolov3-spp, I think the reason could be

  • ability of BiFPN*3(P3-5) was very close to FPN in yolov3-spp
  • my dataset was too small to show different of two FPN

the option of next step could be

  • darknet53-bifpn*N(P3-5)
  • darknet53-bifpn*N(P2-5)
  • darknet53-bifpn*N(P3-7)
  • csresnext50-panet-bifpn*N(Px-Px)

but recently my machines was occupied by other task, I will try further experiment when GPUs got free

@AlexeyAB
Copy link
Owner Author

@Kyuuki93
Yes,

  • Try to use more than 3xP
  • Also try to use 3-5 BiFPN blocks

since BiFPN is not very expensive

@AlexeyAB
Copy link
Owner Author

@Kyuuki93 @WongKinYiu I just fixed weights_normalizion=softmax for [shortcut] layer. 14172d4#diff-0c461530f46c81f7013a6eaec297ebcfR135

weights_normalizion=relu remains the same as earlier.

@WongKinYiu
Copy link
Collaborator

@AlexeyAB

Start re-train ASFF models now.

@AlexeyAB
Copy link
Owner Author

AlexeyAB commented Jan 14, 2020

@WongKinYiu
[shortcut] doesn't affect ASFF.
And [shortcut] weights_normalizion=softmax even doesn't affect BiFPN.

  • default ASFF uses [scale_channels] scale_wh=1
  • default BiFPN uses [shortcut] weights_normalizion=relu

@AlexeyAB
Copy link
Owner Author

@Kyuuki93 @WongKinYiu
Also I don't know what result of BiFPN (weights_normalizion=softmax) will be better, with max_val=0 or without it, between these 2 lines in these 2 places:

So you can test both cases on two small datasets.

@AlexeyAB
Copy link
Owner Author

@Kyuuki93 Hi,

Do you get any progress in BiFPN and BiFPN+ASFF?

@Kyuuki93
Copy link

Do you get any progress in BiFPN and BiFPN+ASFF?

Sorry, can't work this days because new year and new virus in china ...

@glenn-jocher
Copy link

@Kyuuki93 @AlexeyAB I'm interested in implementing a BiFPN head on top of darknet in https://github.com/ultralytics/yolov3 using https://github.com/AlexeyAB/darknet/files/4048909/darknet53-bifpn3.cfg.txt and benchmarking on COCO.

EfficientDet paper https://arxiv.org/pdf/1911.09070.pdf mentions 3 summation methods: "scalar (per-feature), a vector (per-channel), or a multi-dimensional tensor (per-pixel)". It seems they select scalar/per-feature for their implementation. Then I assume to add multiple 4D tensors, say of shape 16x256x13x13, would we have 2 scalar weights if done 'per_feature', and 2 vector weights of shape 1x256x1x1 if done 'per_channel'?

Also, I'm surprised softmax on the weights imparts such a slowdown (1.3X) in their paper, have you guys also observed this?

@AlexeyAB
Copy link
Owner Author

@glenn-jocher

Then I assume to add multiple 4D tensors, say of shape 16x256x13x13, would we have 2 scalar weights if done 'per_feature', and 2 vector weights of shape 1x256x1x1 if done 'per_channel'?

Yes.

  1. per_feature (per input layer) - 1 float value for each input layer
  2. per_channel - 1 float value per each channel from each layer

Also, I'm surprised softmax on the weights imparts such a slowdown (1.3X) in their paper, have you guys also observed this?

Only for training. And only for weighted-shortcut-layer.
As result:

  • for training ~1-10%
  • for detection == 0% (since normalization can be done during initialization)

    darknet/src/network.c

    Lines 1112 to 1162 in 653eceb

    else if (l->type == SHORTCUT && l->weights && l->weights_normalizion)
    {
    if (l->nweights > 0) {
    //cuda_pull_array(l.weights_gpu, l.weights, l.nweights);
    for (int i = 0; i < l->nweights; ++i) printf(" w = %f,", l->weights[i]);
    printf(" l->nweights = %d \n", l->nweights);
    }
    // nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)
    const int layer_step = l->nweights / (l->n + 1); // 1 or l.c or (l.c * l.h * l.w)
    int chan, i;
    for (chan = 0; chan < layer_step; ++chan)
    {
    float sum = 1, max_val = -FLT_MAX;
    if (l->weights_normalizion == SOFTMAX_NORMALIZATION) {
    for (i = 0; i < (l->n + 1); ++i) {
    int w_index = chan + i * layer_step;
    float w = l->weights[w_index];
    if (max_val < w) max_val = w;
    }
    }
    const float eps = 0.0001;
    sum = eps;
    for (i = 0; i < (l->n + 1); ++i) {
    int w_index = chan + i * layer_step;
    float w = l->weights[w_index];
    if (l->weights_normalizion == RELU_NORMALIZATION) sum += relu(w);
    else if (l->weights_normalizion == SOFTMAX_NORMALIZATION) sum += expf(w - max_val);
    }
    for (i = 0; i < (l->n + 1); ++i) {
    int w_index = chan + i * layer_step;
    float w = l->weights[w_index];
    if (l->weights_normalizion == RELU_NORMALIZATION) w = relu(w) / sum;
    else if (l->weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum;
    l->weights[w_index] = w;
    }
    }
    l->weights_normalizion = NO_NORMALIZATION;
    #ifdef GPU
    if (gpu_index >= 0) {
    push_shortcut_layer(*l);
    }
    #endif
    }

Did you write a paper for Mosaic data augmentation?

@glenn-jocher
Copy link

glenn-jocher commented Feb 17, 2020

@AlexeyAB ah I see. Actually no, I haven't had time to write a mosaic paper. It's too bad, because the results are pretty clear that it helps significantly.

Another thing I was wondering is I see in the BiFPN cfg only the head used weighted shortcuts, while the darknet53 backbone does not. If it helps the head then it may help the backbone as well no? Though to keep expectations in check, the EfficientDet paper only shows a pretty small +0.45 mAP bump from weighted vs non-weighted head.

One difference though is that currently the regular non-weighted shortcut layers effectively have weights=1 that sum to 2, 3 etc depending on the number of inputs. Isn't this a bit strange then that the head must be constrained to sum the weights to 1?

Screen Shot 2020-02-17 at 2 53 12 PM

@AlexeyAB
Copy link
Owner Author

AlexeyAB commented Feb 17, 2020

@glenn-jocher

If it helps the head then it may help the backbone as well no?

I have the same thoughts.
We can replace each shortcut with a weighted-shortcut. May be it will get another +0.5 AP.


Isn't this a bit strange then that the head must be constrained to sum the weights to 1?

Batch-normalization try to do the same - moves the most values to range [0 - 1] - it increases mAP, speeds up the training, makes training more stable...
https://medium.com/@ilango100/batch-normalization-speed-up-neural-network-training-245e39a62f85
1_4T4y3kI0R9Alk_2pe6B4Pg

@glenn-jocher
Copy link

@AlexeyAB ah of course, the BN after the Conv2d() following a shortcut will do this automatically, it completely slipped my mind. Good example. Ok, then I've taken @Kyuuki93's cfg, touched it up a bit (fixed the 'normalization' typo, reverted it to yolov3-spp.cfg default anchors and 80-class configuration, and implemented weighted shortcuts for all shortcut layers.)

I will experiment with a few different weighting techniques using my typical 27 epoch coco results and post the results later on this week hopefully. I suppose the tests should be:

  1. yolov3-spp.cfg results (default)
  2. yolov3-spp.cfg (all shortcuts = weighted shortcuts)
  3. darknet53-bifpn3.cfg (all shortcuts in backbone and head = weighted shortcuts)

@AlexeyAB
Copy link
Owner Author

AlexeyAB commented Feb 17, 2020

@glenn-jocher Also try to use this cfg-file that is made more similar to https://github.com/xuannianz/EfficientDet/blob/ccc795781fa173b32a6785765c8a7105ba702d0b/model.py

  1. csresnext50-bifpn-optimal.cfg.txt
    (or try to use BiFPN-head from this cfg-file with darknet53 backbone)

Just for fair comparison, set all parameters in [net] and [yolo] to the same values in all 4 models.

@glenn-jocher
Copy link

@AlexeyAB thanks for the cfg! I've been having some problems with the route layers on these csresnext50 cfgs, maybe I can just copy the BiFPN part and attach it to darknet53, ah but the shortcut layer numbers will be different... ok maybe I'll just try and use it directly. Let's see...

@AlexeyAB
Copy link
Owner Author

@glenn-jocher Yes, you should change route layers and the first from= in weighted [shortcut] layers. Tomorow I will attach csdarknet53-bifpn-optimal.cfg

@WongKinYiu
Copy link
Collaborator

@AlexeyAB @glenn-jocher

I will get some free gpus in about 2~4 days.
If you want to try weighted shortcut in backbone, i can train it on imagenet.

@AlexeyAB
Copy link
Owner Author

AlexeyAB commented Feb 18, 2020

@WongKinYiu @glenn-jocher

Yes, it will be nice.

Try to train 2 Classifiers with weighted-shortcut layers:

  1. csdarknet53-ws.cfg.txt (relu)

  2. csresnext50-ws.cfg.txt (relu)

  3. csresnext50-ws-mi2.cfg.txt (softmax) - preferably for training (multi-input weighted shortcut-layers with softmax normalization)

Which are similar to https://github.com/WongKinYiu/CrossStagePartialNetworks/blob/master/imagenet/results.md with mosaic, cutmix, mish, label smooth


Try to train these 2 models 416x416 with my BiFPN module which is more like a module by reference: https://github.com/xuannianz/EfficientDet/blob/ccc795781fa173b32a6785765c8a7105ba702d0b/model.py

  1. csdarknet53-bifpn-optimal.cfg.txt

  2. csresnext50-bifpn-optimal.cfg.txt

For comparison with: ultralytics/yolov3#698 (comment)


Model Size fps AP AP50 AP75 APS APM APL cfg weight
PANet-(SPP, CIoU) 512×512 44 42.4 64.4 45.9 23.2 45.5 55.3 - -
PANet-(SPP, RFB, CIoU) 512×512 - 41.8 62.7 45.1 22.7 44.3 55.0 - -

@WongKinYiu
Copy link
Collaborator

@AlexeyAB
Copy link
Owner Author

@WongKinYiu I think we can combine csresnext50-ws-mi2 + csresnext50sub and train with the new repo.

@AlexeyAB
Copy link
Owner Author

AlexeyAB commented Mar 20, 2020

@WongKinYiu

csresnext50-ws-mi2.cfg.txt: 79.9 top-1, 95.3 top-5.

Some of weights are negative, but I just looked, this model uses softmax normalization, so negative values for it are not a problem

image

@AlexeyAB
Copy link
Owner Author

@WongKinYiu

I created 2 new models:

  1. 40 FPS (618x618) csresnext50sub + csresnext50-ws-mi2 - csresnext50sub-mi2.cfg.txt

  2. 34 FPS (618x618) the same, just it uses route+maxout instead of weighted-shortcut csresnext50sub-mo.cfg.txt

@WongKinYiu
Copy link
Collaborator

csdarknet53-ws: 65.6 top-1, 87.3 top-5.

@WongKinYiu
Copy link
Collaborator

WongKinYiu commented Mar 21, 2020

@AlexeyAB

I created 2 new models:

  1. 40 FPS (618x618) csresnext50sub + csresnext50-ws-mi2 - csresnext50sub-mi2.cfg.txt
  2. 34 FPS (618x618) the same, just it uses route+maxout instead of weighted-shortcut csresnext50sub-mo.cfg.txt

Do you mean train with width=608 and height=608?

@AlexeyAB
Copy link
Owner Author

@WongKinYiu No, train as usual 256x256.

608x608 it's just to know the final speed of the backbone of detector.


PS,
look at: #5079

@WongKinYiu
Copy link
Collaborator

OK, start training.

@WongKinYiu
Copy link
Collaborator

WongKinYiu commented Mar 26, 2020

@AlexeyAB Hello,

i use 27 Feb repo for training bifpn model. 10a5861
but if i use latest code to test the performance, it gets totally wrong result.
is it because new repo use leaky relu instead of relu? d11caf4
also, the inference speed of bifpn seems become very slow in latest repo.

update: csresnext50-bifpngamma: 512x512, 36.8/58.2/39.4.

@AlexeyAB
Copy link
Owner Author

AlexeyAB commented Mar 26, 2020

@WongKinYiu Hi,

It seems - yes, this is the reason, if there are negative or vely low weights. Try to set relu istead of lrelu temporary for testing.
(also I found that I missed fix relu to lrelu in these places: d6181c6 and 2614a23 )

update: csresnext50-bifpngamma: 512x512, 36.8/58.2/39.4.

Can you share cfg/weights, I will check negative weights and speed?


Also try to check training time of lrelu and softmax if training speed is ~ the same, then may be better to use softmax, since during inference (Detection) weights normalization will be applied once at the initialization step

darknet/src/network.c

Lines 1101 to 1194 in 2614a23

void fuse_conv_batchnorm(network net)
{
int j;
for (j = 0; j < net.n; ++j) {
layer *l = &net.layers[j];
if (l->type == CONVOLUTIONAL) {
//printf(" Merges Convolutional-%d and batch_norm \n", j);
if (l->share_layer != NULL) {
l->batch_normalize = 0;
}
if (l->batch_normalize) {
int f;
for (f = 0; f < l->n; ++f)
{
l->biases[f] = l->biases[f] - (double)l->scales[f] * l->rolling_mean[f] / (sqrt((double)l->rolling_variance[f] + .00001));
const size_t filter_size = l->size*l->size*l->c / l->groups;
int i;
for (i = 0; i < filter_size; ++i) {
int w_index = f*filter_size + i;
l->weights[w_index] = (double)l->weights[w_index] * l->scales[f] / (sqrt((double)l->rolling_variance[f] + .00001));
}
}
free_convolutional_batchnorm(l);
l->batch_normalize = 0;
#ifdef GPU
if (gpu_index >= 0) {
push_convolutional_layer(*l);
}
#endif
}
}
else if (l->type == SHORTCUT && l->weights && l->weights_normalizion)
{
if (l->nweights > 0) {
//cuda_pull_array(l.weights_gpu, l.weights, l.nweights);
int i;
for (i = 0; i < l->nweights; ++i) printf(" w = %f,", l->weights[i]);
printf(" l->nweights = %d, j = %d \n", l->nweights, j);
}
// nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)
const int layer_step = l->nweights / (l->n + 1); // 1 or l.c or (l.c * l.h * l.w)
int chan, i;
for (chan = 0; chan < layer_step; ++chan)
{
float sum = 1, max_val = -FLT_MAX;
if (l->weights_normalizion == SOFTMAX_NORMALIZATION) {
for (i = 0; i < (l->n + 1); ++i) {
int w_index = chan + i * layer_step;
float w = l->weights[w_index];
if (max_val < w) max_val = w;
}
}
const float eps = 0.0001;
sum = eps;
for (i = 0; i < (l->n + 1); ++i) {
int w_index = chan + i * layer_step;
float w = l->weights[w_index];
if (l->weights_normalizion == RELU_NORMALIZATION) sum += lrelu(w);
else if (l->weights_normalizion == SOFTMAX_NORMALIZATION) sum += expf(w - max_val);
}
for (i = 0; i < (l->n + 1); ++i) {
int w_index = chan + i * layer_step;
float w = l->weights[w_index];
if (l->weights_normalizion == RELU_NORMALIZATION) w = lrelu(w) / sum;
else if (l->weights_normalizion == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum;
l->weights[w_index] = w;
}
}
l->weights_normalizion = NO_NORMALIZATION;
#ifdef GPU
if (gpu_index >= 0) {
push_shortcut_layer(*l);
}
#endif
}
else {
//printf(" Fusion skip layer type: %d \n", l->type);
}
}
}

so detection will have the same speed non vs relu vs softmax

For example, models csresnext50-ws-mi2.cfg.txt and csresnext50sub-mi2.cfg.txt from this topic use softmax weights normalization

@WongKinYiu
Copy link
Collaborator

@AlexeyAB

cfg
weights

@AlexeyAB
Copy link
Owner Author

@WongKinYiu

Yes, the BiFPN module in this weights-file is totaly broken, since the most of weights are negative, i.e. them don't pass any information through weighted-[shortcut] layers.
But network wors because there are pass-through [route] layers before [yolo] layers.

Ways to solve this problem:

  • for weights_normalizion=relu use the latest commit, which uses lrelu instead of relu, constrains deltas by [-1;+1] and you should add burnin_update=2 for each weighted-[shortcut] layer in cfg-file

  • just use weights_normalizion=softmax

image

predictions

@WongKinYiu
Copy link
Collaborator

WongKinYiu commented Apr 1, 2020

@AlexeyAB

  1. efficientnet-lite3: ~470 (459k/1200k)
  2. csresnext50sub-spp-asff-bifpn-rfb-db: no info (21 March repo, 49k/550k)
  3. csresnext50sub-mi2: no info (21 March repo, 480k/1200k)
  4. csresnext50sub-mo: no info (21 March repo, 440k/1200k)
  5. csdarknet53-omega-mi: ~480 (194k/1200k)
  6. csdarknet53-omega-mi-db: ~500 (190k/1200k)

@AlexeyAB
Copy link
Owner Author

AlexeyAB commented Apr 2, 2020

@WongKinYiu Thanks!

Can you check intermediate Top1/5 accuracy of models csdarknet53-omega-mi and csdarknet53-omega-mi-db ?
Depending on which model shows the accuracy higher, you can train one of these model (without or with DropBlock) - I implemented two Classification models with Input Pyramid (for fusion Spatial and Semantic information):

They may degrade the accuracy of the classifier, but should improve the accuracy of the detector.

Only -2.5% FPS.


https://arxiv.org/abs/1912.00632v2

The image pyramid is obtained by downsampling(linear interpolate) the input image into four levels with a factor of 2.

In these models: The image pyramid is obtained by downsampling(linear interpolate) the input image into five levels with a factor of 2.

@glenn-jocher
Copy link

@AlexeyAB input pyramid looks very interesting. I just had an idea. Instead of downsampling by a linear interpolation, perhaps one could simply reshape the pyramid inputs by moving pixels around.

For example instead of downsampling (3,512,512) - > (3,256,256) we might be able to reshape (3,512,512) -> (12, 256, 256) with no information loss. I suppose the order would be (rgb,512,512) -> (rrrrggggbbbb,256,256), and it might generalize to any downsampling operation, perhaps in place of the maxpool layers for example in yolov3-tiny or the stride-2 convolutions in yolov3.

I'm not aware of any op that does this currently so it would have to be custom built somehow. What do you think?

@AlexeyAB
Copy link
Owner Author

AlexeyAB commented Apr 2, 2020

@glenn-jocher

About loosing information:

  • It’s not definitely proven what is better: maxpool-stride=2, conv3x3-stride=2, reshape-stride=2, ... for Subsampling for Semantic information.
    • reshape-stride=2 + conv1x1 is very similar to conv3x3-stride=2 - information is not lost in both cases due to 3x3 kernel-size. So we already use conv3x3-stride=2 in all our networks.
    • in maxpool-stride=2 we lose some information, but to get coordinates we shouldn't lose spatial information, but to get objectness/classes we should lose spatial information to achieve shift-invariance. So in the csresnext50sub-mi2.cfg.txt we already use 2 branches with conv3x3-stride=2 and with maxpool-stride=2 Implemented weighted-multi_input-[shortcut] layer with weights-normalization #4662 (comment)

  • Yes, there are 2 reshape layers in the Darknet: [reorg] (pjreddie's version - it is broken Added DNN Darknet Yolo v2 for object detection opencv/opencv#9705 (comment) ) and new [reorg3d] (slightly fixed version - but I'm not sure that it is correct for all input parameters, it is better to implement it from scratch) that was used in Yolo v2
    I think reshape is more suitable for Semantic information rather than for Spatial information, but you can check it by training 2 models with Input Pyramid using [local_avgpool] and [reorg3d]

  • Input Pyramid performs other task than to save information about details. The Input Pyramid is used to bring more Spatial information (even with local avgpool truncation) than is contained in Semantic information.

image

@WongKinYiu
Copy link
Collaborator

@AlexeyAB

csdarknet53-omega-mi: 240k epoch, 34.4 top-1, 60.9 top-5.
csdarknet53-omega-mi-db: 240k epoch, 32.3 top-1, 58.1 top-5.

@AlexeyAB
Copy link
Owner Author

AlexeyAB commented Apr 3, 2020

@WongKinYiu If you want you can train csdarknet53-omega-mi-ip.cfg.txt

@WongKinYiu
Copy link
Collaborator

@AlexeyAB currently 24k epoch.

@AlexeyAB
Copy link
Owner Author

AlexeyAB commented Apr 3, 2020

@WongKinYiu

csdarknet53-omega-mi: 240k epoch, 34.4 top-1, 60.9 top-5.
csdarknet53-omega-mi-db: 240k epoch, 32.3 top-1, 58.1 top-5.

  1. Try to train this model csdarknet53-omega-mi-db (with DropBlock) from the begining with the latest Darknet version, I fixed and checked DropBlock: a9bae4f

  2. Also try to stop training of csresnext50sub-spp-asff-bifpn-rfb-db.cfg.txt and resume it with the new code: Try to train fast (grouped-conv) versions of csdarknet53 and csdarknet19 WongKinYiu/CrossStagePartialNetworks#6 (comment)

image

@glenn-jocher
Copy link

glenn-jocher commented Apr 3, 2020

@AlexeyAB that's a very good point about reshaping layers that the regression may be more sensitive to small positional changes of the data than the obj/cls. I've always been worried about the loss of small spatial information as the image downsizes, especially for the smallest objects in P3.

That's also a very good point that downsampling operations are all 3x3 kernels, so they overlap enough that not much information should be lost. I'll try to experiment a bit with injecting the input pyramid in different areas.

One interesting thing about the samung paper was that they used a 7x7 kernel for the first convolution layer. I see this does not significantly affect the parameter count nor the FLOPS, but I also see the mixnet/efficientnet guys do not do this (they have 3x3 on conv0 like dn53), and I'm sure they must have experimented with it.

@AlexeyAB
Copy link
Owner Author

AlexeyAB commented Apr 3, 2020

@glenn-jocher

One interesting thing about the samung paper was that they used a 7x7 kernel for the first convolution layer.

conv 5x5-7x7-9x9 should be used for stride=2 rather than for the 1st layer.

mixnet guys said - for layers with stride 2, a larger kernel can significantly improve the accuracy.: https://arxiv.org/pdf/1907.09595v3.pdf

As shown in the figure, large kernel size has
different impact on different layers: for most of layers, the accuracy doesn’t change much,
but for certain layers with stride 2, a larger kernel can significantly improve the accuracy.
Notably, although MixConv3579 uses only half parameters and FLOPS than the vanilla
DepthwiseConv9x9, our MixConv achieves similar or slightly better performance for
most of the layers.

@WongKinYiu
Copy link
Collaborator

WongKinYiu commented Apr 4, 2020

@AlexeyAB @glenn-jocher

Hello,

Reshape (reorg) layer is usually used in the models of depth prediction and semantic segmentation. They usually called the process of "reorg/reversed reorg" as "spatial to depth (channel)/depth (channel) to spatial" tensorflow layer. I have changed the downsampling/upsampling layers to reorg/reversed reorg of Elastic, and it got a little bit accuracy improvement.

Also, there is a paper in CVPR 2020 use this technique, for your reference MuxConv.
image
image

@AlexeyAB
Copy link
Owner Author

AlexeyAB commented Apr 4, 2020

@WongKinYiu

I have changed the downsampling/upsampling layers to reorg/reversed reorg of Elastic, and it got a little bit accuracy improvement.

Did you use [reorg] or [reorg3d]?
Since original [reorg] layer has a bug: opencv/opencv#9705 (comment)

@WongKinYiu
Copy link
Collaborator

@AlexeyAB i used [reorg3d].

@WongKinYiu
Copy link
Collaborator

WongKinYiu commented Apr 23, 2020

@AlexeyAB

CSPResNeXt-50 default BoF+MISH : top-1 = 79.8%, top-5 = 95.2%

csresnext50-ws.cfg.txt top-1 = 78.7%, top-5 = 94.7% negative weights, should be used burnin_update=2 and may be weights_normalizion=softmax more about it

csresnext50-ws-mi2.cfg.txt : top-1 = 79.9%, top-5 = 95.3%

csresnext50morelayers.cfg.txt : top-1 = 79.4%, top-5 = 95.2% url

csresnext50sub.cfg.txt : top-1 = 79.5%, top-5 = 95.3% url

csresnext50sub-mi2.cfg.txt: top-1 = 79.4%, top-5 = 95.3%, weights

csresnext50sub-mo.cfg.txt: top-1 = 79.2%, top-5 = 95.1%, weights


CSPDarknet-53 default BoF+MISH : top-1 = 78.7%, top-5 = 94.8%

csdarknet53-ws.cfg.txt top-1 = 65.6%, top-5 = 87.3%

csdarknet53-omega-mi.cfg.txt: top-1 = 78.6%, top-5 = 94.7%, weights

csdarknet53-omega-mi-db.cfg.txt: top-1 = 78.4%, top-5 = 94.5%

csdarknet53-omega-mi-ip.cfg.txt: top-1 = 77.8%, top-5 = 94.3%

@AlexeyAB
Copy link
Owner Author

AlexeyAB commented May 1, 2020

@WongKinYiu
Can you add Top1/Top5 accuracy and attach weights-file for this model
csdarknet53-ws.cfg.txt
#4498 (comment)
cfg files: #4662 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ToDo RoadMap
Projects
None yet
Development

No branches or pull requests

4 participants