-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AccThunk to avoid prematurely unthunking thunks #1562
base: master
Are you sure you want to change the base?
Conversation
Test failures are unrelated. |
Coming from #1555 (comment), I think we need at least one test that shows we aren't introducing new type instabilities when multiple paths try to accumulate different (un)thunked types. |
IIUC you mean the case when I benchmarked GaussianSplatting.jl and its ~2x faster with |
Yes. I'd expect any overhead from type instability to be a non-issue for GaussianSplatting.jl, because each operation is much larger and any AD bookkeeping overhead is relatively small. But for problems more sensitive to AD overhead, it would be good to have some numbers. |
@ToucheSir do you have such a test in mind? |
I was thinking something which involves smaller input sizes. e.g. this regression-esque example: using Zygote, LinearAlgebra
x = rand(10, 100)
w = rand(2, 10)
f(w, x) = sum(w * x) + norm(w)
@benchmark gradient(f, $w, $x) |
On your example this PR is slower:
BenchmarkTools.Trial: 10000 samples with 6 evaluations per sample.
Range (min … max): 5.714 μs … 1.974 ms ┊ GC (min … max): 0.00% … 99.07%
Time (median): 8.062 μs ┊ GC (median): 0.00%
Time (mean ± σ): 9.420 μs ± 39.797 μs ┊ GC (mean ± σ): 13.65% ± 3.90%
▂██▇▃ ▃▂
▁▂▅█████▇▅▄▃▃▂▂▃▂▄▅▆█████▇▆▄▃▃▂▂▂▁▁▁▁▁▁▁▂▂▂▂▃▄▃▃▄▃▃▃▂▂▂▂▂▁ ▃
5.71 μs Histogram: frequency by time 12.8 μs <
Memory estimate: 13.47 KiB, allocs estimate: 47.
BenchmarkTools.Trial: 10000 samples with 8 evaluations per sample.
Range (min … max): 3.047 μs … 603.206 μs ┊ GC (min … max): 0.00% … 98.22%
Time (median): 3.696 μs ┊ GC (median): 0.00%
Time (mean ± σ): 4.684 μs ± 11.762 μs ┊ GC (mean ± σ): 10.44% ± 5.14%
▃▆▇▇██▇▆▄▃▂▂▁▁ ▁▁▁▁▁▁▁▁▁ ▁▁▁▂▁ ▂
▆█████████████████▇▆█▇▆▆▅▆▅▅▃▃▂▃▅▇█████████▇▇▆▆▆▇▇▇████████ █
3.05 μs Histogram: log(frequency) by time 9.14 μs <
Memory estimate: 11.91 KiB, allocs estimate: 15. But the times are so small it is not obvious if it's actually 2x slower on other examples (like is the overhead constant or does it scale with problem size). Trying out reduced MWE that I've added as a test: function main()
W = ones(Float32, 10, 10)
x = [ones(Float32, 10) for i in 1:512]
gs = gradient(W) do W
sum((W * xi)[1] for xi in x)
end
return
end
@benchmark main()
BenchmarkTools.Trial: 438 samples with 1 evaluation per sample.
Range (min … max): 9.155 ms … 21.202 ms ┊ GC (min … max): 0.00% … 11.38%
Time (median): 11.008 ms ┊ GC (median): 14.24%
Time (mean ± σ): 11.387 ms ± 1.351 ms ┊ GC (mean ± σ): 13.36% ± 3.84%
▅█▅▇▄
▃▃▂▂▃▄▄▅▆██████▅▆▅▅▆▄▄▄▄▃▂▅▁▃▂▂▃▃▂▁▃▁▃▁▂▁▁▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▃ ▃
9.15 ms Histogram: frequency by time 17.9 ms <
Memory estimate: 31.41 MiB, allocs estimate: 49627.
BenchmarkTools.Trial: 395 samples with 1 evaluation per sample.
Range (min … max): 10.143 ms … 23.214 ms ┊ GC (min … max): 0.00% … 29.74%
Time (median): 12.530 ms ┊ GC (median): 0.00%
Time (mean ± σ): 12.623 ms ± 2.049 ms ┊ GC (mean ± σ): 10.17% ± 10.01%
▄█ ▁
▃████▅█▅▆▆▅▅▄▃▃▃▃▃▃▃▅▅▇▆▅▆▆█▅▃▅▃▅▄▅▅▄▅▃▄▃▃▃▃▃▃▄▂▃▃▁▂▁▂▂▁▂▃▃ ▃
10.1 ms Histogram: frequency by time 17.6 ms <
Memory estimate: 11.82 MiB, allocs estimate: 64480. In general I think utilizing thunks is a good idea, especially since it gives huge perf boosts in my workflows :) We can merge this PR and ask people who use Zygote to test their code if they see any regressions before tagging any release. |
My main worry would be that this introduces allocations into workloads which were low to zero allocation before. The loop example is perhaps the best case scenario for this change, because it involves a large number of accumulated thunks which all have the same type. My other worry is that this makes debugging Zygote code quite a bit harder. Currently Cthulhu.jl works well because a lot of functions can be statically inferred. But with this change, many functions which were type stable under differentiation will no longer be, so that tool will not be available. The change probably makes sense on balance, but losing access to such an important debugging tool (really the only debugging tool, because the Debugger.jl does not work well enough) makes me uncomfortable. Anyhow, I'll defer to @mcabbott or @CarloLucibello for a second opinion on this. |
We can make this opt-in then and disable thunks by default. |
AccThunk
type (as suggested in Don't create nested thunks when accumulating #1555) that holds thunks to-be-accumulated without unthunking prematurely.AccThunk
.PR Checklist