Skip to content

Commit 53abd5c

Browse files
committed
Add more utils
1 parent 16486a2 commit 53abd5c

File tree

4 files changed

+208
-1
lines changed

4 files changed

+208
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
name = "SoleBase"
22
uuid = "4475fa32-7023-44a0-aa70-4813b230e492"
33
authors = ["Federico Manzella", "Patrik Cavina", "Eduard I. Stan", "Lorenzo Balboni", "Giovanni Pagliarini"]
4-
version = "0.11.1"
4+
version = "0.12.0"
55

66
[deps]
7+
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
8+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
79
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
810
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
911
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/SoleBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ function channelsize end
108108
# includes
109109

110110
include("utils.jl")
111+
include("machine-learning-utils.jl")
111112

112113
include("movingwindow.jl")
113114

src/machine-learning-utils.jl

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
2+
using FillArrays
3+
using CategoricalArrays
4+
5+
doc_supervised_ml = """
6+
const CLabel = Union{String,Integer,CategoricalValue}
7+
const RLabel = AbstractFloat
8+
const Label = Union{CLabel,RLabel}
9+
10+
Types for supervised machine learning labels (classification and regression).
11+
"""
12+
13+
"""$(doc_supervised_ml)"""
14+
const CLabel = Union{String,Integer,CategoricalValue}
15+
"""$(doc_supervised_ml)"""
16+
const RLabel = AbstractFloat
17+
"""$(doc_supervised_ml)"""
18+
const Label = Union{CLabel,RLabel}
19+
20+
# Raw labels
21+
const _CLabel = Integer # (classification labels are internally represented as integers)
22+
const _Label = Union{_CLabel,RLabel}
23+
24+
############################################################################################
25+
26+
# Convert a list of labels to categorical form
27+
Base.@propagate_inbounds @inline function get_categorical_form(Y::AbstractVector)
28+
class_names = unique(Y)
29+
30+
dict = Dict{eltype(Y),Int64}()
31+
@simd for i in 1:length(class_names)
32+
@inbounds dict[class_names[i]] = i
33+
end
34+
35+
_Y = Array{Int64}(undef, length(Y))
36+
@simd for i in 1:length(Y)
37+
@inbounds _Y[i] = dict[Y[i]]
38+
end
39+
40+
return class_names, _Y
41+
end
42+
43+
############################################################################################
44+
45+
"""
46+
bestguess(
47+
labels::AbstractVector{<:Label},
48+
weights::Union{Nothing,AbstractVector} = nothing;
49+
suppress_parity_warning = false,
50+
)
51+
52+
Return the best guess for a set of labels; that is, the label that best approximates the
53+
labels provided. For classification labels, this function returns the majority class; for
54+
regression labels, the average value.
55+
If no labels are provided, `nothing` is returned.
56+
The computation can be weighted.
57+
58+
See also
59+
[`CLabel`](@ref),
60+
[`RLabel`](@ref),
61+
[`Label`](@ref).
62+
"""
63+
function bestguess(
64+
labels::AbstractVector{<:Label},
65+
weights::Union{Nothing,AbstractVector} = nothing;
66+
suppress_parity_warning = false,
67+
) end
68+
69+
# Classification: (weighted) majority vote
70+
function bestguess(
71+
labels::AbstractVector{<:CLabel},
72+
weights::Union{Nothing,AbstractVector} = nothing;
73+
suppress_parity_warning = false,
74+
)
75+
if length(labels) == 0
76+
return nothing
77+
end
78+
79+
counts = begin
80+
if isnothing(weights)
81+
countmap(labels)
82+
else
83+
@assert length(labels) === length(weights) "Cannot compute " *
84+
"best guess with mismatching number of votes " *
85+
"$(length(labels)) and weights $(length(weights))."
86+
countmap(labels, weights)
87+
end
88+
end
89+
90+
if !suppress_parity_warning && sum(counts[argmax(counts)] .== values(counts)) > 1
91+
@warn "Parity encountered in bestguess! " *
92+
"counts ($(length(labels)) elements): $(counts), " *
93+
"argmax: $(argmax(counts)), " *
94+
"max: $(counts[argmax(counts)]) (sum = $(sum(values(counts))))"
95+
end
96+
argmax(counts)
97+
end
98+
99+
# Regression: (weighted) mean (or other central tendency measure?)
100+
function bestguess(
101+
labels::AbstractVector{<:RLabel},
102+
weights::Union{Nothing,AbstractVector} = nothing;
103+
suppress_parity_warning = false,
104+
)
105+
if length(labels) == 0
106+
return nothing
107+
end
108+
109+
(isnothing(weights) ? StatsBase.mean(labels) : sum(labels .* weights)/sum(weights))
110+
end
111+
112+
############################################################################################
113+
114+
# Default weights are optimized using FillArrays
115+
"""
116+
default_weights(n::Integer)::AbstractVector{<:Number}
117+
118+
Return a default weight vector of `n` values.
119+
"""
120+
function default_weights(n::Integer)
121+
Ones{Int64}(n)
122+
end
123+
default_weights(Y::AbstractVector) = default_weights(length(Y))
124+
125+
# Class rebalancing weights (classification case)
126+
"""
127+
default_weights(Y::AbstractVector{L}) where {L<:CLabel}::AbstractVector{<:Number}
128+
129+
Return a class-rebalancing weight vector, given a label vector `Y`.
130+
"""
131+
function balanced_weights(Y::AbstractVector{L}) where {L<:CLabel}
132+
class_counts_dict = countmap(Y)
133+
if length(unique(values(class_counts)_dict)) == 1 # balanced case
134+
default_weights(length(Y))
135+
else
136+
# Assign weights in such a way that the dataset becomes balanced
137+
tot = sum(values(class_counts_dict))
138+
balanced_tot_per_class = tot/length(class_counts_dict)
139+
weights_map = Dict{L,Float64}([class => (balanced_tot_per_class/n_instances)
140+
for (class,n_instances) in class_counts_dict])
141+
W = [weights_map[y] for y in Y]
142+
W ./ sum(W)
143+
end
144+
end
145+
146+
slice_weights(W::Ones{Int64}, inds::AbstractVector) = default_weights(length(inds))
147+
slice_weights(W::Any, inds::AbstractVector) = @view W[inds]
148+
slice_weights(W::Ones{Int64}, i::Integer) = 1
149+
slice_weights(W::Any, i::Integer) = W[i]

src/utils.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,58 @@ Useful for reproducibility.
135135
function spawn(rng::Random.AbstractRNG)
136136
Random.MersenneTwister(abs(rand(rng, Int)))
137137
end
138+
139+
############################################################################################
140+
141+
@inline function softminimum(vals, alpha)
142+
_vals = SoleBase.vectorize(vals);
143+
partialsort!(_vals,ceil(Int, alpha*length(_vals)); rev=true)
144+
end
145+
146+
@inline function softmaximum(vals, alpha)
147+
_vals = SoleBase.vectorize(vals);
148+
partialsort!(_vals,ceil(Int, alpha*length(_vals)))
149+
end
150+
151+
152+
############################################################################################
153+
# I/O utils
154+
############################################################################################
155+
156+
# Source: https://stackoverflow.com/questions/46671965/printing-variable-subscripts-in-julia/46674866
157+
# '₀'
158+
function subscriptnumber(i::Integer)
159+
join([
160+
(if i < 0
161+
[Char(0x208B)]
162+
else [] end)...,
163+
[Char(0x2080+d) for d in reverse(digits(abs(i)))]...
164+
])
165+
end
166+
# https://www.w3.org/TR/xml-entity-names/020.html
167+
# '․', 'ₑ', '₋'
168+
function subscriptnumber(s::AbstractString)
169+
char_to_subscript(ch) = begin
170+
if ch == 'e'
171+
''
172+
elseif ch == '.'
173+
''
174+
elseif ch == '.'
175+
''
176+
elseif ch == '-'
177+
''
178+
else
179+
subscriptnumber(parse(Int, ch))
180+
end
181+
end
182+
183+
try
184+
join(map(char_to_subscript, [string(ch) for ch in s]))
185+
catch
186+
s
187+
end
188+
end
189+
190+
subscriptnumber(i::AbstractFloat) = subscriptnumber(string(i))
191+
subscriptnumber(i::Any) = i
192+

0 commit comments

Comments
 (0)