Skip to content

Commit 872ec18

Browse files
authored
Use multithreading in row_group_slots refarray method (#2661)
1 parent c0c8cd3 commit 872ec18

File tree

6 files changed

+179
-32
lines changed

6 files changed

+179
-32
lines changed

benchmarks/grouping_performance.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
using DataFrames
2+
using CategoricalArrays
3+
using PooledArrays
4+
using BenchmarkTools
5+
using Random
6+
7+
Random.seed!(1)
8+
9+
grouping_benchmarks = BenchmarkGroup()
10+
11+
# `refpool`/`refarray` optimized grouping method
12+
refpool_benchmarks = grouping_benchmarks["refpool"] = BenchmarkGroup()
13+
14+
for k in (10, 10_000), n in (100, 100_000, 10_000_000)
15+
for x in (PooledArray(rand(1:k, n)),
16+
CategoricalArray(rand(1:n, 10_000_000)),
17+
PooledArray(rand([missing; 1:n], 10_000_000)),
18+
CategoricalArray(rand([missing; 1:n], 10_000_000)))
19+
df = DataFrame(x=x)
20+
21+
refpool_benchmarks[k, n, nameof(typeof(x)), "skipmissing=false"] =
22+
@benchmarkable groupby($df, :x)
23+
24+
# Skipping missing values
25+
refpool_benchmarks[k, n, nameof(typeof(x)), "skipmissing=true"] =
26+
@benchmarkable groupby($df, :x, skipmissing=true)
27+
28+
# Empty group which requires adjusting group indices
29+
replace!(df.x, 5 => 6)
30+
refpool_benchmarks[k, n, nameof(typeof(x)), "empty group"] =
31+
@benchmarkable groupby($df, :x)
32+
end
33+
end
34+
35+
# If a cache of tuned parameters already exists, use it, otherwise, tune and cache
36+
# the benchmark parameters. Reusing cached parameters is faster and more reliable
37+
# than re-tuning `suite` every time the file is included.
38+
paramspath = joinpath(dirname(@__FILE__), "params.json")
39+
40+
if isfile(paramspath)
41+
loadparams!(grouping_benchmarks, BenchmarkTools.load(paramspath)[1], :evals);
42+
else
43+
tune!(grouping_benchmarks)
44+
BenchmarkTools.save(paramspath, params(grouping_benchmarks));
45+
end
46+
47+
grouping_results = run(grouping_benchmarks, verbose=true)
48+
# using Serialization
49+
# serialize("grouping_results.jls", grouping_results)
50+
# leaves(judge(median(grouping_results1), median(grouping_results2)))
51+
# leaves(regressions(judge(median(grouping_results1), median(grouping_results2))))
52+
# leaves(improvements(judge(median(grouping_results1), median(grouping_results2))))

docs/src/lib/internals.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ gennames
1515
getmaxwidths
1616
ourshow
1717
ourstrwidth
18+
tforeach
1819
```

src/dataframerow/utils.jl

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -338,46 +338,53 @@ function row_group_slots(cols::NTuple{N, AbstractVector},
338338
end
339339
refmap
340340
end
341-
@inbounds for i in eachindex(groups)
342-
local refs_i
343-
let i=i # Workaround for julia#15276
344-
refs_i = map(c -> c[i], refarrays)
345-
end
346-
vals = map((m, r, s, fi) -> m[r-fi+1] * s, refmaps, refs_i, strides, firstinds)
347-
j = sum(vals) + 1
348-
# x < 0 happens with -1 in refmap, which corresponds to missing
349-
if skipmissing && any(x -> x < 0, vals)
350-
j = 0
351-
else
352-
seen[j] = true
341+
tforeach(eachindex(groups), basesize=1_000_000) do i
342+
@inbounds begin
343+
local refs_i
344+
let i=i # Workaround for julia#15276
345+
refs_i = map(c -> c[i], refarrays)
346+
end
347+
vals = map((m, r, s, fi) -> m[r-fi+1] * s, refmaps, refs_i, strides, firstinds)
348+
j = sum(vals) + 1
349+
# x < 0 happens with -1 in refmap, which corresponds to missing
350+
if skipmissing && any(x -> x < 0, vals)
351+
j = 0
352+
else
353+
seen[j] = true
354+
end
355+
groups[i] = j
353356
end
354-
groups[i] = j
355357
end
356358
else
357-
@inbounds for i in eachindex(groups)
358-
local refs_i
359-
let i=i # Workaround for julia#15276
360-
refs_i = map(refarrays, missinginds) do ref, missingind
361-
r = Int(ref[i])
362-
if skipmissing
363-
return r == missingind ? -1 : (r > missingind ? r-1 : r)
364-
else
365-
return r
359+
tforeach(eachindex(groups), basesize=1_000_000) do i
360+
@inbounds begin
361+
local refs_i
362+
let i=i # Workaround for julia#15276
363+
refs_i = map(refarrays, missinginds) do ref, missingind
364+
r = Int(ref[i])
365+
if skipmissing
366+
return r == missingind ? -1 : (r > missingind ? r-1 : r)
367+
else
368+
return r
369+
end
366370
end
367371
end
372+
vals = map((r, s, fi) -> (r-fi) * s, refs_i, strides, firstinds)
373+
j = sum(vals) + 1
374+
# x < 0 happens with -1, which corresponds to missing
375+
if skipmissing && any(x -> x < 0, vals)
376+
j = 0
377+
else
378+
seen[j] = true
379+
end
380+
groups[i] = j
368381
end
369-
vals = map((r, s, fi) -> (r-fi) * s, refs_i, strides, firstinds)
370-
j = sum(vals) + 1
371-
# x < 0 happens with -1, which corresponds to missing
372-
if skipmissing && any(x -> x < 0, vals)
373-
j = 0
374-
else
375-
seen[j] = true
376-
end
377-
groups[i] = j
378382
end
379383
end
380-
if !all(seen) # Compress group indices to remove unused ones
384+
# If some groups are unused, compress group indices to drop them
385+
# sum(seen) is faster than all(seen) when not short-circuiting,
386+
# and short-circuit would only happen in the slower case anyway
387+
if sum(seen) < length(seen)
381388
oldngroups = ngroups
382389
remap = zeros(Int, ngroups)
383390
ngroups = 0

src/other/utils.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,48 @@ else
8383
end
8484

8585
funname(c::ComposedFunction) = Symbol(funname(c.outer), :_, funname(c.inner))
86+
87+
# Compute chunks of indices, each with at least `basesize` entries
88+
# This method ensures balanced sizes by avoiding a small last chunk
89+
function split_indices(len::Integer, basesize::Integer)
90+
len′ = Int64(len) # Avoid overflow on 32-bit machines
91+
np = max(1, div(len, basesize))
92+
return (Int(1 + ((i - 1) * len′) ÷ np):Int((i * len′) ÷ np) for i in 1:np)
93+
end
94+
95+
"""
96+
tforeach(f, x::AbstractArray; basesize::Integer)
97+
98+
Apply function `f` to each entry in `x` in parallel, spawning
99+
one separate task for each block of at least `basesize` entries.
100+
101+
A number of task higher than `Threads.nthreads()` may be spawned,
102+
since that can allow for a more efficient load balancing in case
103+
some threads are busy (nested parallelism).
104+
"""
105+
function tforeach(f, x::AbstractArray; basesize::Integer)
106+
@assert firstindex(x) == 1
107+
108+
@static if VERSION >= v"1.4"
109+
nt = Threads.nthreads()
110+
len = length(x)
111+
if nt > 1 && len > basesize
112+
@sync for p in split_indices(len, basesize)
113+
Threads.@spawn begin
114+
for i in p
115+
f(@inbounds x[i])
116+
end
117+
end
118+
end
119+
else
120+
for i in eachindex(x)
121+
f(@inbounds x[i])
122+
end
123+
end
124+
else
125+
for i in eachindex(x)
126+
f(@inbounds x[i])
127+
end
128+
end
129+
return
130+
end

test/grouping.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3831,4 +3831,22 @@ end
38313831
((x, y, z) -> x[1] <= 5 ? unwrap(y[1]) : unwrap(z[1])) => :res)
38323832
end
38333833

3834+
@testset "groupby multithreading" begin
3835+
for x in (PooledArray(rand(1:10, 1_100_000)),
3836+
PooledArray(rand([1:9; missing], 1_100_000))),
3837+
y in (PooledArray(rand(["a", "b", "c", "d"], 1_100_000)),
3838+
PooledArray(rand(["a"; "b"; "c"; missing], 1_100_000)))
3839+
df = DataFrame(x=x, y=y)
3840+
3841+
# Checks are done by groupby_checked
3842+
@test length(groupby_checked(df, :x)) == 10
3843+
@test length(groupby_checked(df, :x, skipmissing=true)) ==
3844+
length(unique(skipmissing(x)))
3845+
3846+
@test length(groupby_checked(df, [:x, :y])) == 40
3847+
@test length(groupby_checked(df, [:x, :y], skipmissing=true)) ==
3848+
length(unique(skipmissing(x))) * length(unique(skipmissing(y)))
3849+
end
3850+
end
3851+
38343852
end # module

test/utils.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,28 @@ end
101101
@test fetch(t) === true
102102
end
103103

104+
@testset "split_indices" begin
105+
for len in 0:12
106+
basesize = 10
107+
x = DataFrames.split_indices(len, basesize)
108+
109+
@test length(x) == max(1, div(len, basesize))
110+
@test reduce(vcat, x) === 1:len
111+
vmin, vmax = extrema(length(v) for v in x)
112+
@test vmin + 1 == vmax || vmin == vmax
113+
@test len < basesize || vmin >= basesize
114+
end
115+
116+
# Check overflow on 32-bit
117+
len = typemax(Int32)
118+
basesize = 100_000_000
119+
x = collect(DataFrames.split_indices(len, basesize))
120+
@test length(x) == div(len, basesize)
121+
@test x[1][1] === 1
122+
@test x[end][end] === Int(len)
123+
vmin, vmax = extrema(length(v) for v in x)
124+
@test vmin + 1 == vmax || vmin == vmax
125+
@test len < basesize || vmin >= basesize
126+
end
127+
104128
end # module

0 commit comments

Comments
 (0)