Skip to content

Commit 468bd68

Browse files
committed
add tests
1 parent 7eef6f7 commit 468bd68

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@ which interplay with the functions:
1616
To create new clustering algorithms simply create a new
1717
subtype of `ClusteringAlgorithm` that extends `cluster`
1818
so that it returns a new subtype of `ClusteringResult`.
19-
The result must extend `cluster_number, cluster_labels`
19+
This result must extend `cluster_number, cluster_labels`
2020
and optionally `cluster_probs`.
2121

22-
Note that data input type must always be `AbstractVector` of vectors
23-
(anything that can have distance defined).
24-
Two helper functions `each_data_point, input_data_size` can help
25-
making this harmonious with matrix inputs.
22+
For developers: see two helper functions `each_data_point, input_data_size`
23+
so that you can support matrix input while abiding the declared api
24+
of iterable of vectors as input.
2625

2726
For more, see the docstring of `cluster`.

src/ClusteringAPI.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,14 @@ Return the cluster probabilities of the data points used in [`cluster`](@ref).
6666
They are length-`n` vectors containing the "probabilities" or "score" of each point
6767
belonging to one of the created clusters (used with fuzzy clustering algorithms).
6868
"""
69-
function cluster_labels(cr::ClusteringResults)
70-
return cr.labels # typically there
69+
function cluster_probs(cr::ClusteringResults)
70+
labels = cluster_labels(cr)
71+
n = cluster_number(cr)
72+
probs = [zeros(Real, n) for _ in 1:length(labels)]
73+
for (i, label) in enumerate(labels)
74+
probs[i][label] = 1
75+
end
76+
return probs
7177
end
7278

7379
# two helper functions for agnostic input data type

test/runtests.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,25 @@
11
using Test
22
using ClusteringAPI
33

4-
@test true
4+
struct TestClustering <: ClusteringAlgorithm
5+
end
6+
struct TestResults <: ClusteringResults
7+
labels::Vector{Int}
8+
n::Int
9+
end
510

11+
function ClusteringAPI.cluster(::TestClustering, data)
12+
return TestResults(fill(1, length(data)), 2)
13+
end
14+
15+
cr = cluster(TestClustering(), randn(100))
16+
@test cluster_number(cr) == 1
17+
@test cluster_labels(cr) == fill(1, 100)
18+
@test cluster_probs(cr) == fill([1.0], 100)
19+
20+
@test ClusteringAPI.input_data_size([rand(3) for _ in 1:30]) == (3, 30)
21+
@test ClusteringAPI.input_data_size(rand(3,30)) == (3, 30)
22+
23+
v = [ones(3) for _ in 1:30]
24+
@test ClusteringAPI.each_data_point(v) == v
25+
@test ClusteringAPI.each_data_point(ones(3,30)) == v

0 commit comments

Comments
 (0)