|
| 1 | + |
| 2 | +using FillArrays |
| 3 | +using CategoricalArrays |
| 4 | + |
| 5 | +doc_supervised_ml = """ |
| 6 | + const CLabel = Union{String,Integer,CategoricalValue} |
| 7 | + const RLabel = AbstractFloat |
| 8 | + const Label = Union{CLabel,RLabel} |
| 9 | +
|
| 10 | +Types for supervised machine learning labels (classification and regression). |
| 11 | +""" |
| 12 | + |
| 13 | +"""$(doc_supervised_ml)""" |
| 14 | +const CLabel = Union{String,Integer,CategoricalValue} |
| 15 | +"""$(doc_supervised_ml)""" |
| 16 | +const RLabel = AbstractFloat |
| 17 | +"""$(doc_supervised_ml)""" |
| 18 | +const Label = Union{CLabel,RLabel} |
| 19 | + |
| 20 | +# Raw labels |
| 21 | +const _CLabel = Integer # (classification labels are internally represented as integers) |
| 22 | +const _Label = Union{_CLabel,RLabel} |
| 23 | + |
| 24 | +############################################################################################ |
| 25 | + |
| 26 | +# Convert a list of labels to categorical form |
| 27 | +Base.@propagate_inbounds @inline function get_categorical_form(Y::AbstractVector) |
| 28 | + class_names = unique(Y) |
| 29 | + |
| 30 | + dict = Dict{eltype(Y),Int64}() |
| 31 | + @simd for i in 1:length(class_names) |
| 32 | + @inbounds dict[class_names[i]] = i |
| 33 | + end |
| 34 | + |
| 35 | + _Y = Array{Int64}(undef, length(Y)) |
| 36 | + @simd for i in 1:length(Y) |
| 37 | + @inbounds _Y[i] = dict[Y[i]] |
| 38 | + end |
| 39 | + |
| 40 | + return class_names, _Y |
| 41 | +end |
| 42 | + |
| 43 | +############################################################################################ |
| 44 | + |
| 45 | +""" |
| 46 | + bestguess( |
| 47 | + labels::AbstractVector{<:Label}, |
| 48 | + weights::Union{Nothing,AbstractVector} = nothing; |
| 49 | + suppress_parity_warning = false, |
| 50 | + ) |
| 51 | +
|
| 52 | +Return the best guess for a set of labels; that is, the label that best approximates the |
| 53 | +labels provided. For classification labels, this function returns the majority class; for |
| 54 | +regression labels, the average value. |
| 55 | +If no labels are provided, `nothing` is returned. |
| 56 | +The computation can be weighted. |
| 57 | +
|
| 58 | +See also |
| 59 | +[`CLabel`](@ref), |
| 60 | +[`RLabel`](@ref), |
| 61 | +[`Label`](@ref). |
| 62 | +""" |
| 63 | +function bestguess( |
| 64 | + labels::AbstractVector{<:Label}, |
| 65 | + weights::Union{Nothing,AbstractVector} = nothing; |
| 66 | + suppress_parity_warning = false, |
| 67 | +) end |
| 68 | + |
| 69 | +# Classification: (weighted) majority vote |
| 70 | +function bestguess( |
| 71 | + labels::AbstractVector{<:CLabel}, |
| 72 | + weights::Union{Nothing,AbstractVector} = nothing; |
| 73 | + suppress_parity_warning = false, |
| 74 | +) |
| 75 | + if length(labels) == 0 |
| 76 | + return nothing |
| 77 | + end |
| 78 | + |
| 79 | + counts = begin |
| 80 | + if isnothing(weights) |
| 81 | + countmap(labels) |
| 82 | + else |
| 83 | + @assert length(labels) === length(weights) "Cannot compute " * |
| 84 | + "best guess with mismatching number of votes " * |
| 85 | + "$(length(labels)) and weights $(length(weights))." |
| 86 | + countmap(labels, weights) |
| 87 | + end |
| 88 | + end |
| 89 | + |
| 90 | + if !suppress_parity_warning && sum(counts[argmax(counts)] .== values(counts)) > 1 |
| 91 | + @warn "Parity encountered in bestguess! " * |
| 92 | + "counts ($(length(labels)) elements): $(counts), " * |
| 93 | + "argmax: $(argmax(counts)), " * |
| 94 | + "max: $(counts[argmax(counts)]) (sum = $(sum(values(counts))))" |
| 95 | + end |
| 96 | + argmax(counts) |
| 97 | +end |
| 98 | + |
| 99 | +# Regression: (weighted) mean (or other central tendency measure?) |
| 100 | +function bestguess( |
| 101 | + labels::AbstractVector{<:RLabel}, |
| 102 | + weights::Union{Nothing,AbstractVector} = nothing; |
| 103 | + suppress_parity_warning = false, |
| 104 | +) |
| 105 | + if length(labels) == 0 |
| 106 | + return nothing |
| 107 | + end |
| 108 | + |
| 109 | + (isnothing(weights) ? StatsBase.mean(labels) : sum(labels .* weights)/sum(weights)) |
| 110 | +end |
| 111 | + |
| 112 | +############################################################################################ |
| 113 | + |
| 114 | +# Default weights are optimized using FillArrays |
| 115 | +""" |
| 116 | + default_weights(n::Integer)::AbstractVector{<:Number} |
| 117 | +
|
| 118 | +Return a default weight vector of `n` values. |
| 119 | +""" |
| 120 | +function default_weights(n::Integer) |
| 121 | + Ones{Int64}(n) |
| 122 | +end |
| 123 | +default_weights(Y::AbstractVector) = default_weights(length(Y)) |
| 124 | + |
| 125 | +# Class rebalancing weights (classification case) |
| 126 | +""" |
| 127 | + default_weights(Y::AbstractVector{L}) where {L<:CLabel}::AbstractVector{<:Number} |
| 128 | +
|
| 129 | +Return a class-rebalancing weight vector, given a label vector `Y`. |
| 130 | +""" |
| 131 | +function balanced_weights(Y::AbstractVector{L}) where {L<:CLabel} |
| 132 | + class_counts_dict = countmap(Y) |
| 133 | + if length(unique(values(class_counts)_dict)) == 1 # balanced case |
| 134 | + default_weights(length(Y)) |
| 135 | + else |
| 136 | + # Assign weights in such a way that the dataset becomes balanced |
| 137 | + tot = sum(values(class_counts_dict)) |
| 138 | + balanced_tot_per_class = tot/length(class_counts_dict) |
| 139 | + weights_map = Dict{L,Float64}([class => (balanced_tot_per_class/n_instances) |
| 140 | + for (class,n_instances) in class_counts_dict]) |
| 141 | + W = [weights_map[y] for y in Y] |
| 142 | + W ./ sum(W) |
| 143 | + end |
| 144 | +end |
| 145 | + |
| 146 | +slice_weights(W::Ones{Int64}, inds::AbstractVector) = default_weights(length(inds)) |
| 147 | +slice_weights(W::Any, inds::AbstractVector) = @view W[inds] |
| 148 | +slice_weights(W::Ones{Int64}, i::Integer) = 1 |
| 149 | +slice_weights(W::Any, i::Integer) = W[i] |
0 commit comments