Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AccThunk to avoid prematurely unthunking thunks #1562

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

pxl-th
Copy link
Member

@pxl-th pxl-th commented Mar 13, 2025

PR Checklist

  • Tests are added

@pxl-th pxl-th self-assigned this Mar 13, 2025
@pxl-th pxl-th requested review from mcabbott and ToucheSir March 13, 2025 10:57
@pxl-th pxl-th changed the title Add AccThunk Add AccThunk to avoid prematurely unthunking thunks Mar 13, 2025
@pxl-th
Copy link
Member Author

pxl-th commented Mar 13, 2025

Test failures are unrelated.

@ToucheSir ToucheSir closed this Mar 13, 2025
@ToucheSir ToucheSir reopened this Mar 13, 2025
@ToucheSir
Copy link
Member

Coming from #1555 (comment), I think we need at least one test that shows we aren't introducing new type instabilities when multiple paths try to accumulate different (un)thunked types.

@pxl-th
Copy link
Member Author

pxl-th commented Mar 13, 2025

Coming from #1555 (comment), I think we need at least one test that shows we aren't introducing new type instabilities when multiple paths try to accumulate different (un)thunked types.

IIUC you mean the case when AccThunk holds thunks of different type and if that leads to more type-instability -> worse performance?

I benchmarked GaussianSplatting.jl and its ~2x faster with AccThunk than without.

@ToucheSir
Copy link
Member

Yes. I'd expect any overhead from type instability to be a non-issue for GaussianSplatting.jl, because each operation is much larger and any AD bookkeeping overhead is relatively small. But for problems more sensitive to AD overhead, it would be good to have some numbers.

@pxl-th
Copy link
Member Author

pxl-th commented Mar 14, 2025

@ToucheSir do you have such a test in mind?

@ToucheSir
Copy link
Member

I was thinking something which involves smaller input sizes. e.g. this regression-esque example:

using Zygote, LinearAlgebra

x = rand(10, 100)
w = rand(2, 10)

f(w, x) = sum(w * x) + norm(w)

@benchmark gradient(f, $w, $x)

@pxl-th
Copy link
Member Author

pxl-th commented Mar 17, 2025

On your example this PR is slower:

  • This PR:
BenchmarkTools.Trial: 10000 samples with 6 evaluations per sample.
 Range (min  max):  5.714 μs   1.974 ms  ┊ GC (min  max):  0.00%  99.07%
 Time  (median):     8.062 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   9.420 μs ± 39.797 μs  ┊ GC (mean ± σ):  13.65% ±  3.90%
     ▂██▇▃             ▃▂                                     
  ▁▂▅█████▇▅▄▃▃▂▂▃▂▄▅▆█████▇▆▄▃▃▂▂▂▁▁▁▁▁▁▁▂▂▂▂▃▄▃▃▄▃▃▃▂▂▂▂▂▁ ▃
  5.71 μs        Histogram: frequency by time        12.8 μs <
 Memory estimate: 13.47 KiB, allocs estimate: 47.
  • master:
BenchmarkTools.Trial: 10000 samples with 8 evaluations per sample.
 Range (min  max):  3.047 μs  603.206 μs  ┊ GC (min  max):  0.00%  98.22%
 Time  (median):     3.696 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   4.684 μs ±  11.762 μs  ┊ GC (mean ± σ):  10.44% ±  5.14%
   ▃▆▇▇██▇▆▄▃▂▂▁▁                   ▁▁▁▁▁▁▁▁▁         ▁▁▁▂▁   ▂
  ▆█████████████████▇▆█▇▆▆▅▆▅▅▃▃▂▃▅▇█████████▇▇▆▆▆▇▇▇████████ █
  3.05 μs      Histogram: log(frequency) by time      9.14 μs <
 Memory estimate: 11.91 KiB, allocs estimate: 15.

But the times are so small it is not obvious if it's actually 2x slower on other examples (like is the overhead constant or does it scale with problem size).

Trying out reduced MWE that I've added as a test:

function main()
    W = ones(Float32, 10, 10)
    x = [ones(Float32, 10) for i in 1:512]
    gs = gradient(W) do W
        sum((W * xi)[1] for xi in x)
    end
    return
end
@benchmark main()
  • This PR:
BenchmarkTools.Trial: 438 samples with 1 evaluation per sample.
 Range (min  max):   9.155 ms  21.202 ms  ┊ GC (min  max):  0.00%  11.38%
 Time  (median):     11.008 ms              ┊ GC (median):    14.24%
 Time  (mean ± σ):   11.387 ms ±  1.351 ms  ┊ GC (mean ± σ):  13.36% ±  3.84%
           ▅█▅▇▄                                               
  ▃▃▂▂▃▄▄▅▆██████▅▆▅▅▆▄▄▄▄▃▂▅▁▃▂▂▃▃▂▁▃▁▃▁▂▁▁▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▃ ▃
  9.15 ms         Histogram: frequency by time        17.9 ms <
 Memory estimate: 31.41 MiB, allocs estimate: 49627.
  • master:
BenchmarkTools.Trial: 395 samples with 1 evaluation per sample.
 Range (min  max):  10.143 ms  23.214 ms  ┊ GC (min  max):  0.00%  29.74%
 Time  (median):     12.530 ms              ┊ GC (median):     0.00%
 Time  (mean ± σ):   12.623 ms ±  2.049 ms  ┊ GC (mean ± σ):  10.17% ± 10.01%
    ▄█  ▁                                                      
  ▃████▅█▅▆▆▅▅▄▃▃▃▃▃▃▃▅▅▇▆▅▆▆█▅▃▅▃▅▄▅▅▄▅▃▄▃▃▃▃▃▃▄▂▃▃▁▂▁▂▂▁▂▃▃ ▃
  10.1 ms         Histogram: frequency by time        17.6 ms <
 Memory estimate: 11.82 MiB, allocs estimate: 64480.

In general I think utilizing thunks is a good idea, especially since it gives huge perf boosts in my workflows :)

We can merge this PR and ask people who use Zygote to test their code if they see any regressions before tagging any release.
If there are, we can either try to optimize thunking itself or (if optimization proves to be quite hard) add a way to disable it, like we did for 2nd order.

@ToucheSir
Copy link
Member

My main worry would be that this introduces allocations into workloads which were low to zero allocation before. The loop example is perhaps the best case scenario for this change, because it involves a large number of accumulated thunks which all have the same type.

My other worry is that this makes debugging Zygote code quite a bit harder. Currently Cthulhu.jl works well because a lot of functions can be statically inferred. But with this change, many functions which were type stable under differentiation will no longer be, so that tool will not be available. The change probably makes sense on balance, but losing access to such an important debugging tool (really the only debugging tool, because the Debugger.jl does not work well enough) makes me uncomfortable.

Anyhow, I'll defer to @mcabbott or @CarloLucibello for a second opinion on this.

@pxl-th
Copy link
Member Author

pxl-th commented Mar 18, 2025

The change probably makes sense on balance, but losing access to such an important debugging tool (really the only debugging tool, because the Debugger.jl does not work well enough) makes me uncomfortable.

We can make this opt-in then and disable thunks by default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants