Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

numerical issues in the mish implementation #5452

Closed
YashasSamaga opened this issue May 2, 2020 · 17 comments
Closed

numerical issues in the mish implementation #5452

YashasSamaga opened this issue May 2, 2020 · 17 comments

Comments

@YashasSamaga
Copy link

YashasSamaga commented May 2, 2020

__device__ float softplus_kernel(float x, float threshold = 20) {
if (x > threshold) return x; // too large
else if (x < -threshold) return expf(x); // too small
return logf(expf(x) + 1);
}

output_gpu[i] = x_val * tanh_activate_kernel( softplus_kernel(x_val, MISH_THRESHOLD) );

This implementation has numerical issues due to the use of log(1 + exp(x)). log(1 + exp(x)) suffers from numerical problems for small values of exp(x). There is a significant loss of precision in the range [-20, -10]. It keeps losing precision as x gets more and more negative and eventually results in zeros for x in the range [-20, -16.6]. The precision is regained below -20 as softplus switches to exp(x) for that range.

The standard library has log1p which takes care of this issue. CUDA also provides log1p for single-precision floats.

@AlexeyAB AlexeyAB added the want enhancement Want to improve accuracy, speed or functionality label May 2, 2020
@YashasSamaga
Copy link
Author

YashasSamaga commented May 12, 2020

YOLOv4 support was added to the CUDA backend in OpenCV. While profiling darknet and OCV to find the performance discrepancy, I noticed that Darknet's bias and mish kernels together takes 3.3x more time than OpenCV's fused bias activation kernel. The convolution seems to be faster than them.

Here are some stats for the first convolution layer in YOLOv4 on GTX 1050 for 608x608 image.

Operation Time SM SOL Bandwidth
cuDNN convolution 1.4ms
add_bias 1.36ms ~86% ~70 GBps
activate_array_mish_kernel 1.9ms ~75% ~74 GBps
OpenCV's biasN_mish kernel 984us ~44% ~97 GBps

There are some neat approximations to mish which can be considerably faster: https://cs.stackexchange.com/questions/125002/fast-and-stable-x-tanhlog1pexpx-computation

The add_bias can be optimized to reuse computation and bandwidth since the same bias is added to the adjacent elements often.

@AlexeyAB
Copy link
Owner

The convolution seems to be faster than them.

Do you mean that cuDNN convolution is 2x faster than add_bias + activate_array_mish_kernel ?

  • How did you measure it?

  • Which layers did you use for measuring it?

  • There is comparison of YOLOv4-mish (csdarknet53-mish 512x512) 47 FPS vs YOLOv4-leaky (csdarknet53-opt 512x512) 50 FPS Comparison of some models on CPU vs VPU (neurochip) vs GPU #5079 So difference is only 6%.
    ./darknet detector demo cfg/coco.data cfg/yolov4.cfg yolov4.weights test.mp4 -benchmark

  • Do you suggest using log1pf() or fast or correct implementation from https://cs.stackexchange.com/a/125072 ?

  • What do you think about derivative?

    __global__ void gradient_array_mish_kernel(int n, float *activation_input_gpu, float *delta)
    {
    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if (i < n) {
    const float MISH_THRESHOLD = 20.0f;
    // implementation from TensorFlow: https://github.com/tensorflow/addons/blob/093cdfa85d334cbe19a37624c33198f3140109ed/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h#L66-L80
    // implementation from Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L26-L31
    // log1p(x) == log(x + 1)
    const float inp = activation_input_gpu[i];
    const float sp = softplus_kernel(inp, MISH_THRESHOLD);
    const float grad_sp = 1 - expf(-sp);
    const float tsp = tanh(sp);
    const float grad_tsp = (1 - tsp*tsp) * grad_sp;
    const float grad = inp * grad_tsp + tsp;
    delta[i] *= grad;
    //float x = activation_input[i];
    //float d = 2 * expf(x) + expf(2 * x) + 2;
    //float w = 4 * (x + 1) + 4 * expf(2 * x) + expf(3 * x) + expf(x)*(4 * x + 6);
    //float derivative = expf(x) * w / (d * d);
    //delta[i] *= derivative;
    }
    }


SM SOL

What is SM SOL?

@YashasSamaga
Copy link
Author

YashasSamaga commented May 12, 2020

Do you mean that cuDNN convolution is 2x faster than add_bias + activate_array_mish_kernel ?

Yes, in the first few layers.

How did you measure it?

NVIDIA Nsight Compute (profiled all kernels that were executed with ./darknet detector test cfg/coco.data cfg/yolov4.cfg yolov4.weights data/dog.jpg)

Which layers did you use for measuring it?

The numbers I reported are for the first convolution layer in YOLOv4. The first 4 kernels that are executed by darknet are:

  • computeOffsetsKernel & conv kernel (both from cuDNN)
  • add_bias
  • activate_array_mish_kernel

The total darknet inference time is around 116ms on my device.

There is comparison of YOLOv4-mish (csdarknet53-mish 512x512) 47 FPS vs YOLOv4-leaky (csdarknet53-opt 512x512) 50 FPS #5079 So difference is only 6%.

OpenCV is executing YOLOv4 almost ~15% faster than Darknet on my device. On profiling, at least for my device, the bias and mish kernels seem to be much slower than OpenCV equivalents.

Convolution is faster only in the initial few layers. The convolution eventually becomes slower at some point.

Do you suggest using log1pf() or fast or correct implementation from https://cs.stackexchange.com/a/125072 ?

The accepted answer is nearly as accurate as using log1pf. My answer is even faster but is slightly less accurate. They are both more accurate than log(1 + exp(x)) though. I think the approximation should work reasonably well. It seems to be working well in OpenCV. I haven't done a thorough check though.

What do you think about derivative?

1 - expf(-sp) can be replaced with -expm1(-sp).

I think it's possible to come up with accurate fast approximations for the gradient.

What is SM SOL?

Streaming Multiprocessor Speed Of Light

It's a measure of compute usage relative to the maximum theoretical performance. 100% indicates compute usage has reached maximum theoretical compute performance.

Bias and activation steps are generally expected to be bandwidth bound kernels. So it should sort of have high memory utilization and low compute utilization. Darknet kernels seem to have high compute usage which is preventing maximal utilization of the memory bandwidth (because enough memory requests aren't made fast enough to keep the memory subsystem busy).

@AlexeyAB
Copy link
Owner

  1. Because our old-mish-activation corresponds to the old-mish-derivative, then if we change mish-activation, then should we change mish-derivative too?

  2. If we change both mish activation and derivative, then may be we should re-train the model?

  3. It is not strictly proven that Mish is the best activation. So there may be a less accurate implementation of it may give higher detection accuracy. What do you think?

Do you suggest to use one of these implementations?

__device__ float mish_njuffa(float x)
{
    float r;
    float e = expf(x);
    r = 1.0f / fmaf(fmaf(-0.5f, e, -1.0f), e, -1.0f);
    r = fmaf(r, x, x);
    return r;
}

__device__ float mish_yashas(float x)
{
    auto e = __expf(x);
    if (x <= -18.0f)
        return x * e;

    auto n = e * e + 2 * e;
    if (x <= -5.0f)
        return x * __fdividef(n, n + 2);

    return x - 2 * __fdividef(x, n + 2);
}

Streaming Multiprocessor Speed Of Light

Did you calculate it by using TFlops from GPU specification, number of OPs in your function and execution time?

Is this your own definition? I can not find it even on Google.

@YashasSamaga
Copy link
Author

YashasSamaga commented May 12, 2020

Because our old-mish-activation corresponds to the old-mish-derivative, then if we change mish-activation, then should we change mish-derivative too?

Nope. Mish is not being changed at all. Mish is still mish. The mathematical formulation is still the same. The difference is in the implementation. The fast implementations are very good approximations to the mish (errors in the range of few ULPs).

Even using log(1 + exp(x)) is an approximation to mish. It returns zero when x is equal -17 but if you calculated mish in double precision, you would get something that's small but still not zero.

If we change both mish activation and derivative, then maybe we should re-train the model?

Not really. The trained models will still work. YOLOv4 is giving really good detections in OpenCV with the approximation. The max relative error for a sample image I used is 2e-5 compared to OpenCV's CPU implementation (which uses x * tanh(log(1 + exp(x)))).

I had written some code while designing and testing the approximations. You can find it here. With this, you can actually see how close the approximations are to the original mish function by comparing individual places in the mantissa.

Based on my memory, the approximations are practically identical for numbers below -20 and numbers more than -2. There are small differences for some values in the range [-20, -2] and these differences are often in the last few significant digits (i.e. it's like 0.1234567 vs 0.1234561).

It is not strictly proven that Mish is the best activation. So there may be a less accurate implementation of it may give higher detection accuracy. What do you think?

I don't think a less accurate implementation (at least the ones we are talking about currently) will change the accuracy. They are still more or less mish. Any improvements or losses could just be noise. It's my intuition and I could be wrong.

Do you suggest to use one of these implementations?

Mine is faster but less accurate. I have put in OpenCV because it doesn't seem to alter the results and the approximation in the CUDA backend is still more accurate than OpenCV's CPU implementation.

Mine gives at least 5 decimal places of precision if I remember correctly. Njuffa's implementation has errors in the 7th decimal place which is almost perfect but it's a tad slower.

I don't have much experience in DL. I don't know which is best but given that people use half precision for training (which is very inaccurate compared to the fp32 approximations presented here), I think these single-precision approximations should do really well.

Did you calculate it by using TFlops from GPU specification, number of OPs in your function and execution time?

NVIDIA Nsight Compute does all the calculation and reports it.

Is this your own definition? I can not find it even on Google.

No, it's NVIDIA's metric. This article explains it.

@YashasSamaga
Copy link
Author

YashasSamaga commented May 12, 2020

Comparisions of double precision version vs darknet vs njuffa vs mine

https://gist.github.com/YashasSamaga/3fdf001d32f04062e3f36495d5c962db

Note that even direct double-precision implementation becomes less accurate after a certain point. The differences you see near -100 are because the direct double-precision version is less accurate than the approximations at those ranges.

You can see darknet accuracy falling rapidly starting from -3 because of using log(1 + exp(x)) instead of log1p(exp(x)). It recovers from -20 because that's when darknet's softplus implementation switches to another numerically stable expression. In fact, darknet performs better than the approximations in this range.

All of them are same in the positive halfplane (probably bitwise identical). The differences are primarily in [-18, -2] range. The approximations (njuffa and mine) are quite close to the double precision version.

Ideally, I should have compared the accuracy with multi-precision libraries instead of using a double-precision reference. I wasn't able to set it up on CUDA. I had to use CUDA for testing because I was using the CUDA intrinsics which I cannot get to test on a CPU.

@AlexeyAB
Copy link
Owner

As I understand it, all these approximations do not correspond strictly/analytically to the original Mish formula.

I mean, that values of your-mish implementation are closer to the fp64-mish than Darkent-mish.

Does your mish-implementation more correspond the current Darknet mish-derivative than the current Darknet-mish implementation?

It keeps losing precision as x gets more and more negative and eventually results in zeros for x in the range [-20, -16.6].

Did you check, do we have the same issue for derivative? May be we have the same issue for training too?

Do we just need to use return log1pf(expf(x)); (with f postfix) instead of?

return logf(expf(x) + 1);

Or should we change anything else?

@AlexeyAB
Copy link
Owner

Yes, new mish is faster, gives the same accuracy (AP) of detection, and gives a smoother and more accurate curve.

Can you check please, do we have any problem with derivative, and should we solve it?

  • old-mish - 35.8 FPS
  • new-mish - 37.0 FPS (+3% speedup)

I tested new Mish-implementation mish_yashas() on yolov4-512x512 (testdev) and get the same good results as default yolov4:

overall performance
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.430
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.649
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.465
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.243
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.461
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.552
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.341
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.542
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.572
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.375
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.612
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.729
Done (t=506.74s)

@YashasSamaga
Copy link
Author

YashasSamaga commented May 12, 2020

Does your mish-implementation more correspond the current Darknet mish-derivative than the current Darknet-mish implementation?

If the darknet is trying to implement the gradient of the analytical mish activation, then I think using log1p corresponds better. I am not sure about my implementation though. Mine does better than the current implementation in the [-20, -2] range but it's slightly worse than darknet beyond -20. I think the [-20, -2] range is more important than [-infinity, -20) since large negative numbers will anyway lead to tiny gradients which probably doesn't matter at all.

Did you check, do we have the same issue for derivative?

I have never even written the derivative of mish on paper. I haven't checked if there are any issues with the derivative. I'll look into it.

May be we have the same issue for training too?

The problems with the forward pass itself might affect training since the forward pass computation will affect the outputs of the layers that follow it (including the loss and eventually the gradients). Since the precision of the current mish implementation is really poor in the [-20, -2] range, it might be affecting the training. I don't know to what extent it affects. I have very little intuition of how these precision errors translate to overall performance.

And then there could be numerical problems in the gradient implementation. The expm1 should help but I don't know if it is of any significance.

There is also a possibility that analytical mish is bad and the current darknet's implementation of mish is better than the analytical mish activation. If that's the case, then there is a new activation which is better than mish!

Do we just need to use return log1pf(expf(x)); (with f postfix) instead of?

Yes, softplus should return log1pf(expf(x)) instead of log(expf(x) + 1). This change will make the Darknet implementation of softplus identical to the TensorFlow implementation.

Or should we change anything else?

I had experimented with YOLOv4 two weeks ago. If I remember correctly, the dynamic range for activations happens to be in the range [-12000, 2000] for a sample of 10 random natural images. The majority of the activations are in the range [-100, 100].

The gradient would become really small for large negative numbers, so it probably doesn't matter what happens there. The log1p change is probably important since the current implementation has low precision in the [-20, -2] range before losing all it all at -16. The current implementation is very good for positive numbers. So for the forward pass of the activation, changing mish to use log1p should bring it closer to analytical mish.


Looks like I am really slow at replying. This reply was written without looking at your latest reply.

Can you check please, do we have any problem with derivative, and should we solve it?

I'm looking into it.

@YashasSamaga
Copy link
Author

YashasSamaga commented May 12, 2020

// https://github.com/digantamisra98/Mish
__global__ void gradient_array_mish_kernel(int n, float *activation_input_gpu, float *delta)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (i < n) {
const float MISH_THRESHOLD = 20.0f;
// implementation from TensorFlow: https://github.com/tensorflow/addons/blob/093cdfa85d334cbe19a37624c33198f3140109ed/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h#L66-L80
// implementation from Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L26-L31
// log1p(x) == log(x + 1)
const float inp = activation_input_gpu[i];
const float sp = softplus_kernel(inp, MISH_THRESHOLD);
const float grad_sp = 1 - expf(-sp);
const float tsp = tanh(sp);
const float grad_tsp = (1 - tsp*tsp) * grad_sp;
const float grad = inp * grad_tsp + tsp;
delta[i] *= grad;
//float x = activation_input[i];
//float d = 2 * expf(x) + expf(2 * x) + 2;
//float w = 4 * (x + 1) + 4 * expf(2 * x) + expf(3 * x) + expf(x)*(4 * x + 6);
//float derivative = expf(x) * w / (d * d);
//delta[i] *= derivative;
}
}

const float sp = softplus_kernel(inp, MISH_THRESHOLD);

The introduction of log1p in softplus will affect the performance here. For the positive halfplane, it's nearly identity. In the negative halfplane, it starts giving out really tiny positive values and this fact is important for the next step.

const float grad_sp = 1 - expf(-sp);

This has numerical problems when -sp is really close to zero. The sp is actually really close to zero in the negative halfplane. So there is a good reason to use expm1 for better numerical accuracy.

I have compared these three:

  • reference implementation in double precision
  • current darknet gradient
  • log1p in softplus and expm1 in darknet gradient

Code and results here: https://gist.github.com/YashasSamaga/1572a1b446c66594000fc1e59d82158d

The current darknet gradient has problems. Using log1p and expm1 fixes it. Both are required for good numerical accuracy.

I don't know how better numerical accuracy will translate to the actual performance of the models. I hope it causes some improvement.

@AlexeyAB
Copy link
Owner

@YashasSamaga Thanks!

@WongKinYiu Hi,

I fixed Mish-activation and Mish-gradient for Training and Detection:

It doesn't reduce accuracy (AP) during inference, but increases +3% FPS: #5452 (comment)

It's closer to the original Mish-formula.

There is a difference in derivative (gradient) only for x = [-103; -15] and less than 1e-5
https://gist.github.com/YashasSamaga/1572a1b446c66594000fc1e59d82158d#file-reference_vs_darknet_vs_stable-txt

@YashasSamaga
Copy link
Author

By the way, when I compared the mish implementations, the njuffa's version I used was:

__device__ float my_mishf (float x)
{
    float r;
    if (x >= -1.0f) {
        float e = expf (x);
        r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
        r = fmaf (r, x, x);
    } else {
        float eh = expf (0.5f * x);
        float p =        1.03628484e-3f;  //  0x1.0fa7e6p-10
        p = fmaf (p, x, -7.28869531e-3f); // -0x1.ddac04p-8
        p = fmaf (p, x,  3.47027816e-2f); //  0x1.1c4902p-5
        p = fmaf (p, x, -3.54762226e-1f); // -0x1.6b46cap-2
        p = fmaf (p, x,  8.58785570e-1f); //  0x1.b7b2bep-1
        p = fmaf (p, x, -1.38065982e+0f); // -0x1.6172ecp+0
        p = fmaf (p, x,  5.97694337e-1f); //  0x1.3204fep-1
        float q =        1.03527203e-3f;  //  0x1.0f63eep-10
        q = fmaf (q, x, -7.35638570e-3f); // -0x1.e21bacp-8
        q = fmaf (q, x,  3.28683928e-2f); //  0x1.0d4204p-5
        q = fmaf (q, x, -3.79927397e-1f); // -0x1.850bb0p-2 
        q = fmaf (q, x,  6.86127126e-1f); //  0x1.5f4c0ep-1
        q = fmaf (q, x, -1.81509292e+0f); // -0x1.d0a9eep+0
        q = fmaf (q, x,  1.00000000e+0f); //  0x1.000000p+0
        r = (1.0f / q) * p;
        if (x < -15.0f) r = 1.0f;
        r = r * x * eh * eh;
    }
    return r;
}

This is the correct version but is slightly slower.

The other short version he gave is if giving zeros below -16 was acceptable. It has the problems of log(1 + exp(x)) and also doesn't have the corrections after -20 like darknet does.

@AlexeyAB AlexeyAB added enhancement and removed want enhancement Want to improve accuracy, speed or functionality labels May 13, 2020
@armadillojim
Copy link

Hi, @AlexeyAB and @YashasSamaga ! I came here by way of the CS StackExchange question. Thought I'd give you a heads up I added an answer there.

Also saw #5922 which points to this code. That code is branch-free, and does use log1p where appropriate, but it still uses multiple special-function calls. I'm not sure it's nearly as fast as it could be.

TL;DR: I would suggest one of these two branch-free implementations:

float mish(float x)
{
    float expx = __expf(x);
    return x / (1.0f + 2.0f / (expx * (2.0f + expx)));
}

Or:

float mish(float x)
{
    float expx = __expf(x);
    float psi = expx * (2.0f + expx);
    return x * (psi / (2.0f + psi));
}

@armadillojim
Copy link

Follow up: #5922 also points to:

  • davisking code. But as YashasSamaga points out, it suffers from precision problems. It also misses an opportunity to save a multiply: 2*e + e*e would be better as (2.0f + e)*e.
  • fastai code and rwightman code but those use tanh applied to the softplus function. They probably has accuracy and speed problems.
  • digantamisra98 code doesn't have any code; only a plot of function evaluations.

@YashasSamaga
Copy link
Author

YashasSamaga commented Jun 13, 2020

@armadillojim Looking at your error graph, I think I have messed up the thresholds. Thanks for pointing out.

Code Time
fwd_relu 1.45ms
mish_fwd_tb 1.82ms
mish_fwd_aj1 2.17ms
mish_fwd_aj2 1.53ms
mish_fwd_aj2_with_fast_division 1.49ms
mish_fwd_dlib 1.94ms
mish_fwd_ocv 1.51ms

@armadillojim mish_fwd_aj1 and mish_fwd_aj2 is your first and second implementation. I replaced the division in mish_fwd_aj2 with fast approximate division and called it mish_fwd_aj2_with_fast_division. The fast division might affect accuracy however.

mish_fwd_tb is what Darknet used to previously used and #5922 links it to this repo.

mish_fwd_ocv is what OpenCV and Darknet currently use.

The compute usage is least in mish_fwd_aj2_with_fast_division followed by mish_fwd_ocv. This gives room for fusing more operations like bias addition, elementwise addition, etc.


2 * e + e * e would be optimized to FMA(e, e, FADD(e, e))
(2 + e) * e would be optimized to FMUL(e, FADD(2, e))

And they might not be the same in terms of precision. I haven't checked but my intuition suggests that 2 * e + e * e is safer than (2 + e) * e.

@YashasSamaga
Copy link
Author

YashasSamaga commented Jun 13, 2020

@AlexeyAB There is a faster and more accurate mish implementation:

__device__ float mish(float x)
{
    auto e = fast_exp(value);
    auto n = e * e + 2 * e;
    if (value <= -0.6f)
        return value * fast_divide(n, n + 2);

    return value - 2 * fast_divide(value, n + 2);
}
# Time (float) Time (float4) L2 norm* SM SOL (float) SM SOL (float4)
ReLU 1.47ms 1.39ms N/A 21.31% 14.04%
Mish (old) 1.51ms 1.39ms 0.00012458 40.27% 28.77%
Mish (new) 1.49ms 1.39ms 2.4583e-05 38.5% 23.5%

* L2 norm of the error vector for 16 million+ activations uniformly sampled from [-50, 50]

Code: https://gist.github.com/YashasSamaga/8ad0cd3b30dbd0eb588c1f4c035db28c

The accuracy has improved by an entire unit place. There was a bug in my threshold finding code. Thanks to the error graph in @armadillojim answer (wouldn't have cross-checked otherwise).

Input Range L2 norm
[-100, -80] 1.73186e-36
[-80, -20] 7.14342e-12
[-20, 0] 3.37544e-05
[0, 100] 1.94085e-05
Raw accuracy output (with other implementations)

For [-100, -80]

[vec1] mish_tb: 5.23885e-38
[vec1] mish_rw: 5.23885e-38
[vec1] mish_njuffa1: 9.41435e-31
[vec1] mish_njuffa2: 9.58304e-38
[vec1] mish_aj1: 1.69838e-34
[vec1] mish_aj2: 1.73186e-36
[vec1] mish_aj2_fastdiv: 1.73186e-36
[vec1] mish_dlib: 9.41435e-31
[vec1] mish_ocv: 1.73186e-36
[vec4] mish_tb: 5.23885e-38
[vec4] mish_rw: 5.23885e-38
[vec4] mish_njuffa1: 9.41435e-31
[vec4] mish_njuffa2: 9.58304e-38
[vec4] mish_aj1: 1.69838e-34
[vec4] mish_aj2: 1.73186e-36
[vec4] mish_aj2_fastdiv: 1.73186e-36
[vec4] mish_dlib: 9.41435e-31
[vec4] mish_ocv: 1.73186e-36

For [-80, -20]:

[vec1] mish_tb: 8.93439e-13
[vec1] mish_rw: 8.93439e-13
[vec1] mish_njuffa1: 1.58053e-05
[vec1] mish_njuffa2: 1.53516e-12
[vec1] mish_aj1: 7.15513e-12
[vec1] mish_aj2: 7.14342e-12
[vec1] mish_aj2_fastdiv: 7.14342e-12
[vec1] mish_dlib: 1.58053e-05
[vec1] mish_ocv: 7.14342e-12
[vec4] mish_tb: 8.93439e-13
[vec4] mish_rw: 8.93439e-13
[vec4] mish_njuffa1: 1.58053e-05
[vec4] mish_njuffa2: 1.53516e-12
[vec4] mish_aj1: 7.15513e-12
[vec4] mish_aj2: 7.14342e-12
[vec4] mish_aj2_fastdiv: 7.14342e-12
[vec4] mish_dlib: 1.58053e-05
[vec4] mish_ocv: 7.14342e-12

For [-20, 0]:

[vec1] mish_tb: 3.34032e-05
[vec1] mish_rw: 0.00141026
[vec1] mish_njuffa1: 0.00143353
[vec1] mish_njuffa2: 4.30863e-05
[vec1] mish_aj1: 3.23414e-05
[vec1] mish_aj2: 3.37344e-05
[vec1] mish_aj2_fastdiv: 3.50414e-05
[vec1] mish_dlib: 0.0015637
[vec1] mish_ocv: 3.37544e-05
[vec4] mish_tb: 3.34032e-05
[vec4] mish_rw: 0.00141026
[vec4] mish_njuffa1: 0.00143353
[vec4] mish_njuffa2: 4.30863e-05
[vec4] mish_aj1: 3.23414e-05
[vec4] mish_aj2: 3.37344e-05
[vec4] mish_aj2_fastdiv: 3.50414e-05
[vec4] mish_dlib: 0.0015637
[vec4] mish_ocv: 3.37544e-05

For [0, 100]:

[vec1] mish_tb: 0.000302692
[vec1] mish_rw: 0.000302707
[vec1] mish_njuffa1: 1.55673e-05
[vec1] mish_njuffa2: 1.55673e-05
[vec1] mish_aj1: 0.000268539
[vec1] mish_aj2: nan
[vec1] mish_aj2_fastdiv: nan
[vec1] mish_dlib: 1.64015e-05
[vec1] mish_ocv: 1.94085e-05
[vec4] mish_tb: 0.000302692
[vec4] mish_rw: 0.000302707
[vec4] mish_njuffa1: 1.55673e-05
[vec4] mish_njuffa2: 1.55673e-05
[vec4] mish_aj1: 0.000268539
[vec4] mish_aj2: nan
[vec4] mish_aj2_fastdiv: nan
[vec4] mish_dlib: 1.64015e-05
[vec4] mish_ocv: 1.94085e-05

The performance improvement is not significant as the activation is not compute-bound. But the reduction in compute usage makes more room for other compute-heavy operations to be fused with the activation (like the division step in bias addition). There will probably be improvements in timings in fused operations in OpenCV.

@AlexeyAB
Copy link
Owner

@YashasSamaga Hi, Thanks!
I added
d724306
e08a818

__device__ float mish_yashas2(float x)
{
    float e = __expf(x);
    float n = e * e + 2 * e;
    if (x <= -0.6f)
        return x * __fdividef(n, n + 2);

    return x - 2 * __fdividef(x, n + 2);
}

@cenit cenit closed this as completed Jan 23, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants