Skip to content

Commit 4de37f1

Browse files
authoredNov 27, 2024··
Add BernoulliNB and Binarizer (#306)
1 parent f84177f commit 4de37f1

File tree

4 files changed

+695
-0
lines changed

4 files changed

+695
-0
lines changed
 

‎lib/scholar/naive_bayes/bernoulli.ex

+432
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,432 @@
1+
defmodule Scholar.NaiveBayes.Bernoulli do
2+
@moduledoc """
3+
Naive Bayes classifier for multivariate Bernoulli models.
4+
5+
Like MultinomialNB, this classifier is suitable for discrete data. The
6+
difference is that while MultinomialNB works with occurrence counts,
7+
BernoulliNB is designed for binary/boolean features.
8+
"""
9+
import Nx.Defn
10+
import Scholar.Shared
11+
12+
@derive {Nx.Container,
13+
containers: [
14+
:feature_count,
15+
:class_count,
16+
:class_log_priors,
17+
:feature_log_probability
18+
]}
19+
defstruct [:feature_count, :class_count, :class_log_priors, :feature_log_probability]
20+
21+
opts_schema = [
22+
num_classes: [
23+
type: :pos_integer,
24+
required: true,
25+
doc: ~S"""
26+
Number of different classes used in training.
27+
"""
28+
],
29+
alpha: [
30+
type: {:or, [:float, {:list, :float}]},
31+
default: 1.0,
32+
doc: ~S"""
33+
Additive (Laplace/Lidstone) smoothing parameter
34+
(set alpha to 0.0 and force_alpha to true, for no smoothing).
35+
"""
36+
],
37+
force_alpha: [
38+
type: :boolean,
39+
default: true,
40+
doc: ~S"""
41+
If `false` and alpha is less than 1e-10, it will set alpha to
42+
1e-10. If `true`, alpha will remain unchanged. This may cause
43+
numerical errors if alpha is too close to 0.
44+
"""
45+
],
46+
binarize: [
47+
type: {:or, [:float, {:in, [nil]}]},
48+
default: 0.0,
49+
doc: ~S"""
50+
Threshold for binarizing (mapping to booleans) of sample features.
51+
If nil, input is presumed to already consist of binary vectors.
52+
"""
53+
],
54+
fit_priors: [
55+
type: :boolean,
56+
default: true,
57+
doc: ~S"""
58+
Whether to learn class prior probabilities or not.
59+
If `false`, a uniform prior will be used.
60+
"""
61+
],
62+
class_priors: [
63+
type: {:custom, Scholar.Options, :weights, []},
64+
doc: ~S"""
65+
Prior probabilities of the classes. If specified, the priors are not
66+
adjusted according to the data.
67+
"""
68+
],
69+
sample_weights: [
70+
type: {:custom, Scholar.Options, :weights, []},
71+
doc: ~S"""
72+
List of `num_samples` elements.
73+
A list of 1.0 values is used if none is given.
74+
"""
75+
]
76+
]
77+
78+
@opts_schema NimbleOptions.new!(opts_schema)
79+
80+
@doc """
81+
Fits a naive Bayes model. The function assumes that the targets `y` are integers
82+
between 0 and `num_classes` - 1 (inclusive). Otherwise, those samples will not
83+
contribute to `class_count`.
84+
85+
## Options
86+
87+
#{NimbleOptions.docs(@opts_schema)}
88+
89+
## Return Values
90+
91+
The function returns a struct with the following parameters:
92+
93+
* `:class_count` - Number of samples encountered for each class during fitting. This
94+
value is weighted by the sample weight when provided.
95+
96+
* `:class_log_priors` - Smoothed empirical log probability for each class.
97+
98+
* `:feature_count` - Number of samples encountered for each (class, feature)
99+
during fitting. This value is weighted by the sample weight when
100+
provided.
101+
102+
* `:feature_log_probability` - Empirical log probability of features
103+
given a class, ``P(x_i|y)``.
104+
105+
## Examples
106+
107+
iex> x = Nx.iota({4, 3})
108+
iex> y = Nx.tensor([1, 2, 0, 2])
109+
iex> Scholar.NaiveBayes.Bernoulli.fit(x, y, num_classes: 3, binarize: 1.0)
110+
%Scholar.NaiveBayes.Bernoulli{
111+
feature_count: Nx.tensor(
112+
[
113+
[1.0, 1.0, 1.0],
114+
[0.0, 0.0, 1.0],
115+
[2.0, 2.0, 2.0]
116+
]
117+
),
118+
class_count: Nx.tensor(
119+
[1.0, 1.0, 2.0]
120+
),
121+
class_log_priors: Nx.tensor(
122+
[-1.3862943649291992, -1.3862943649291992, -0.6931471824645996]
123+
),
124+
feature_log_probability: Nx.tensor(
125+
[
126+
[-0.40546512603759766, -0.40546512603759766, -0.40546512603759766],
127+
[-1.0986123085021973, -1.0986123085021973, -0.40546512603759766],
128+
[-0.28768205642700195, -0.28768205642700195, -0.28768205642700195]
129+
]
130+
)
131+
}
132+
133+
iex> x = Nx.iota({4, 3})
134+
iex> y = Nx.tensor([1, 2, 0, 2])
135+
iex> Scholar.NaiveBayes.Bernoulli.fit(x, y, num_classes: 3, force_alpha: false, alpha: 0.0)
136+
%Scholar.NaiveBayes.Bernoulli{
137+
feature_count: Nx.tensor(
138+
[
139+
[1.0, 1.0, 1.0],
140+
[0.0, 1.0, 1.0],
141+
[2.0, 2.0, 2.0]
142+
]
143+
),
144+
class_count: Nx.tensor(
145+
[1.0, 1.0, 2.0]
146+
),
147+
class_log_priors: Nx.tensor(
148+
[-1.3862943649291992, -1.3862943649291992, -0.6931471824645996]
149+
),
150+
feature_log_probability: Nx.tensor(
151+
[
152+
[0.0, 0.0, 0.0],
153+
[-23.025850296020508, 0.0, 0.0],
154+
[0.0, 0.0, 0.0]
155+
]
156+
)
157+
}
158+
"""
159+
160+
deftransform fit(x, y, opts \\ []) do
161+
if Nx.rank(x) != 2 do
162+
raise ArgumentError,
163+
"""
164+
expected x to have shape {num_samples, num_features}, \
165+
got tensor with shape: #{inspect(Nx.shape(x))}\
166+
"""
167+
end
168+
169+
if Nx.rank(y) != 1 do
170+
raise ArgumentError,
171+
"""
172+
expected y to have shape {num_samples}, \
173+
got tensor with shape: #{inspect(Nx.shape(y))}\
174+
"""
175+
end
176+
177+
{num_samples, num_features} = Nx.shape(x)
178+
179+
if num_samples != Nx.axis_size(y, 0) do
180+
raise ArgumentError,
181+
"""
182+
expected first dimension of x and y to be of same size, \
183+
got: #{num_samples} and #{Nx.axis_size(y, 0)}\
184+
"""
185+
end
186+
187+
opts = NimbleOptions.validate!(opts, @opts_schema)
188+
type = to_float_type(x)
189+
190+
{alpha, opts} = Keyword.pop!(opts, :alpha)
191+
alpha = Nx.tensor(alpha, type: type)
192+
193+
if Nx.shape(alpha) not in [{}, {num_features}] do
194+
raise ArgumentError,
195+
"""
196+
when alpha is list it should have length equal to num_features = #{num_features}, \
197+
got: #{Nx.size(alpha)}\
198+
"""
199+
end
200+
201+
num_classes = opts[:num_classes]
202+
203+
priors_flag = opts[:class_priors] != nil
204+
205+
{class_priors, opts} = Keyword.pop(opts, :class_priors, :nan)
206+
class_priors = Nx.tensor(class_priors)
207+
208+
if priors_flag and Nx.size(class_priors) != num_classes do
209+
raise ArgumentError,
210+
"""
211+
expected class_priors to be list of length num_classes = #{num_classes}, \
212+
got: #{Nx.size(class_priors)}\
213+
"""
214+
end
215+
216+
sample_weights_flag = opts[:sample_weights] != nil
217+
218+
{sample_weights, opts} = Keyword.pop(opts, :sample_weights, :nan)
219+
sample_weights = Nx.tensor(sample_weights, type: type)
220+
221+
if sample_weights_flag and Nx.shape(sample_weights) != {num_samples} do
222+
raise ArgumentError,
223+
"""
224+
expected sample_weights to be list of length num_samples = #{num_samples}, \
225+
got: #{Nx.size(sample_weights)}\
226+
"""
227+
end
228+
229+
opts =
230+
opts ++
231+
[
232+
type: type,
233+
priors_flag: priors_flag,
234+
sample_weights_flag: sample_weights_flag
235+
]
236+
237+
fit_n(x, y, class_priors, sample_weights, alpha, opts)
238+
end
239+
240+
defnp fit_n(x, y, class_priors, sample_weights, alpha, opts) do
241+
type = opts[:type]
242+
num_samples = Nx.axis_size(x, 0)
243+
244+
num_classes = opts[:num_classes]
245+
246+
x =
247+
case opts[:binarize] do
248+
nil -> x
249+
binarize -> Scholar.Preprocessing.Binarizer.fit_transform(x, threshold: binarize)
250+
end
251+
252+
y_one_hot = Scholar.Preprocessing.OneHotEncoder.fit_transform(y, num_categories: num_classes)
253+
y_one_hot = Nx.select(y_one_hot, Nx.tensor(1, type: type), Nx.tensor(0, type: type))
254+
255+
y_weighted =
256+
if opts[:sample_weights_flag],
257+
do: Nx.reshape(sample_weights, {num_samples, 1}) * y_one_hot,
258+
else: y_one_hot
259+
260+
alpha_lower_bound = Nx.tensor(1.0e-10, type: type)
261+
262+
alpha =
263+
if opts[:force_alpha], do: alpha, else: Nx.max(alpha, alpha_lower_bound)
264+
265+
class_count = Nx.sum(y_weighted, axes: [0])
266+
feature_count = Nx.dot(y_weighted, [0], x, [0])
267+
268+
smoothed_feature_count = feature_count + alpha
269+
smoothed_cumulative_count = class_count + alpha * 2
270+
271+
feature_log_probability =
272+
Nx.log(smoothed_feature_count) -
273+
Nx.log(Nx.reshape(smoothed_cumulative_count, {num_classes, 1}))
274+
275+
class_log_priors =
276+
cond do
277+
opts[:priors_flag] ->
278+
Nx.log(class_priors)
279+
280+
opts[:fit_priors] ->
281+
Nx.log(class_count) - Nx.log(Nx.sum(class_count))
282+
283+
true ->
284+
Nx.broadcast(-Nx.log(num_classes), {num_classes})
285+
end
286+
287+
%__MODULE__{
288+
class_count: class_count,
289+
class_log_priors: class_log_priors,
290+
feature_count: feature_count,
291+
feature_log_probability: feature_log_probability
292+
}
293+
end
294+
295+
@doc """
296+
Perform classification on an array of test vectors `x` using `model`.
297+
You need to add sorted classes from the training data as the second argument.
298+
299+
## Examples
300+
301+
iex> x = Nx.iota({4, 3})
302+
iex> y = Nx.tensor([1, 2, 0, 2])
303+
iex> model = Scholar.NaiveBayes.Bernoulli.fit(x, y, num_classes: 3)
304+
iex> Scholar.NaiveBayes.Bernoulli.predict(model, Nx.tensor([[6, 2, 4], [8, 5, 9]]), Nx.tensor([0, 1, 2]))
305+
#Nx.Tensor<
306+
s32[2]
307+
[2, 2]
308+
>
309+
"""
310+
311+
defn predict(%__MODULE__{} = model, x, classes) do
312+
check_dim(x, Nx.axis_size(model.feature_count, 1))
313+
314+
if Nx.rank(classes) != 1 do
315+
raise ArgumentError,
316+
"""
317+
expected classes to be a 1D tensor, \
318+
got tensor with shape: #{inspect(Nx.shape(classes))}\
319+
"""
320+
end
321+
322+
if Nx.axis_size(classes, 0) != Nx.axis_size(model.class_count, 0) do
323+
raise ArgumentError,
324+
"""
325+
expected classes to have same size as the number of classes in the model, \
326+
got: #{Nx.axis_size(classes, 0)} for classes and #{Nx.axis_size(model.class_count, 0)} for model\
327+
"""
328+
end
329+
330+
jll = joint_log_likelihood(model, x)
331+
classes[Nx.argmax(jll, axis: 1)]
332+
end
333+
334+
@doc """
335+
Return log-probability estimates for the test vector `x` using `model`.
336+
337+
## Examples
338+
339+
iex> x = Nx.iota({4, 3})
340+
iex> y = Nx.tensor([1, 2, 0, 2])
341+
iex> model = Scholar.NaiveBayes.Bernoulli.fit(x, y, num_classes: 3)
342+
iex> Scholar.NaiveBayes.Bernoulli.predict_log_probability(model, Nx.tensor([[6, 2, 4], [8, 5, 9]]))
343+
#Nx.Tensor<
344+
f32[2][3]
345+
[
346+
[-4.704780578613281, -12.329399108886719, -0.009097099304199219],
347+
[-8.750494003295898, -19.147701263427734, -1.583099365234375e-4]
348+
]
349+
>
350+
"""
351+
352+
defn predict_log_probability(%__MODULE__{} = model, x) do
353+
check_dim(x, Nx.axis_size(model.feature_count, 1))
354+
jll = joint_log_likelihood(model, x)
355+
356+
log_proba_x =
357+
jll
358+
|> Nx.logsumexp(axes: [1])
359+
|> Nx.reshape({Nx.axis_size(jll, 0), 1})
360+
361+
jll - log_proba_x
362+
end
363+
364+
@doc """
365+
Return probability estimates for the test vector `x` using `model`.
366+
367+
## Examples
368+
369+
iex> x = Nx.iota({4, 3})
370+
iex> y = Nx.tensor([1, 2, 0, 2])
371+
iex> model = Scholar.NaiveBayes.Bernoulli.fit(x, y, num_classes: 3)
372+
iex> Scholar.NaiveBayes.Bernoulli.predict_probability(model, Nx.tensor([[6, 2, 4], [8, 5, 9]]))
373+
#Nx.Tensor<
374+
f32[2][3]
375+
[
376+
[0.00905190035700798, 4.4198750401847064e-6, 0.9909441471099854],
377+
[1.5838305989746004e-4, 4.833469624543341e-9, 0.9998416900634766]
378+
]
379+
>
380+
"""
381+
382+
defn predict_probability(%__MODULE__{} = model, x) do
383+
Nx.exp(predict_log_probability(model, x))
384+
end
385+
386+
@doc """
387+
Return joint log probability estimates for the test vector `x` using `model`.
388+
389+
## Examples
390+
391+
iex> x = Nx.iota({4, 3})
392+
iex> y = Nx.tensor([1, 2, 0, 2])
393+
iex> model = Scholar.NaiveBayes.Bernoulli.fit(x, y, num_classes: 3)
394+
iex> Scholar.NaiveBayes.Bernoulli.predict_joint_log_probability(model, Nx.tensor([[6, 2, 4], [8, 5, 9]]))
395+
#Nx.Tensor<
396+
f32[2][3]
397+
[
398+
[3.6356334686279297, -3.988985061645508, 8.331316947937012],
399+
[10.56710433959961, 0.16989731788635254, 19.317440032958984]
400+
]
401+
>
402+
"""
403+
404+
defn predict_joint_log_probability(%__MODULE__{} = model, x) do
405+
check_dim(x, Nx.axis_size(model.feature_count, 1))
406+
joint_log_likelihood(model, x)
407+
end
408+
409+
defnp check_dim(x, dim) do
410+
num_features = Nx.axis_size(x, 1)
411+
412+
if num_features != dim do
413+
raise ArgumentError,
414+
"""
415+
expected x to have same second dimension as data used for fitting model, \
416+
got: #{num_features} for x and #{dim} for training data\
417+
"""
418+
end
419+
end
420+
421+
defnp joint_log_likelihood(
422+
%__MODULE__{
423+
feature_log_probability: feature_log_probability,
424+
class_log_priors: class_log_priors
425+
},
426+
x
427+
) do
428+
neg_prob = Nx.log(1 - Nx.exp(feature_log_probability))
429+
jll = Nx.dot(x, [1], feature_log_probability - neg_prob, [1])
430+
jll + class_log_priors + Nx.sum(neg_prob, axes: [1])
431+
end
432+
end
+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
defmodule Scholar.Preprocessing.Binarizer do
2+
@moduledoc """
3+
Binarize data according to a threshold.
4+
"""
5+
import Nx.Defn
6+
7+
binarize_schema = [
8+
threshold: [
9+
type: :float,
10+
default: 0.0,
11+
doc: """
12+
Feature values below or equal to this are replaced by 0, above it by 1.
13+
Threshold may not be less than 0 for operations on sparse matrices.
14+
"""
15+
]
16+
]
17+
18+
@binarize_schema NimbleOptions.new!(binarize_schema)
19+
20+
@doc """
21+
Values greater than the threshold map to 1, while values less than
22+
or equal to the threshold map to 0. With the default threshold of 0,
23+
only positive values map to 1.
24+
## Options
25+
#{NimbleOptions.docs(@binarize_schema)}
26+
## Examples
27+
iex> t = Nx.tensor([[0, 0, 0], [3, 4, 5], [-2, 4, 3]])
28+
iex> Scholar.Preprocessing.Binarizer.fit_transform(t, threshold: 3.0)
29+
#Nx.Tensor<
30+
u8[3][3]
31+
[
32+
[0, 0, 0],
33+
[0, 1, 1],
34+
[0, 1, 0]
35+
]
36+
>
37+
iex> t = Nx.tensor([[0, 0, 0], [3, 4, 5], [-2, 4, 3]])
38+
iex> Scholar.Preprocessing.Binarizer.fit_transform(t,threshold: 0.4)
39+
#Nx.Tensor<
40+
u8[3][3]
41+
[
42+
[0, 0, 0],
43+
[1, 1, 1],
44+
[0, 1, 1]
45+
]
46+
>
47+
"""
48+
deftransform fit_transform(tensor, opts \\ []) do
49+
binarize_n(tensor, NimbleOptions.validate!(opts, @binarize_schema))
50+
end
51+
52+
defnp binarize_n(tensor, opts) do
53+
threshold = opts[:threshold]
54+
tensor > threshold
55+
end
56+
end
+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
defmodule Scholar.NaiveBayes.BernoulliTest do
2+
use Scholar.Case, async: true
3+
alias Scholar.NaiveBayes.Bernoulli
4+
doctest Bernoulli
5+
6+
describe "fit" do
7+
test "binary y" do
8+
x = Nx.iota({5, 6})
9+
x = Scholar.Preprocessing.Binarizer.fit_transform(x)
10+
y = Nx.tensor([1, 0, 1, 0, 1])
11+
12+
model = Bernoulli.fit(x, y, num_classes: 2, binarize: nil)
13+
14+
assert model.feature_count ==
15+
Nx.tensor([
16+
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
17+
[2.0, 3.0, 3.0, 3.0, 3.0, 3.0]
18+
])
19+
20+
expected_feature_log_probability =
21+
Nx.tensor([
22+
[-0.28768207, -0.28768207, -0.28768207, -0.28768207, -0.28768207, -0.28768207],
23+
[-0.51082562, -0.22314355, -0.22314355, -0.22314355, -0.22314355, -0.22314355]
24+
])
25+
26+
assert_all_close(model.feature_log_probability, expected_feature_log_probability)
27+
28+
expected_class_log_priors =
29+
Nx.tensor([
30+
-0.91629073,
31+
-0.51082562
32+
])
33+
34+
assert_all_close(model.class_log_priors, expected_class_log_priors)
35+
36+
assert model.class_count == Nx.tensor([2.0, 3.0])
37+
end
38+
39+
test ":alpha set to a different value" do
40+
x = Nx.iota({5, 6})
41+
y = Nx.tensor([1, 2, 6, 3, 1])
42+
43+
model = Bernoulli.fit(x, y, num_classes: 4, alpha: 0.4)
44+
45+
assert model.feature_count ==
46+
Nx.tensor([
47+
[1.0, 2.0, 2.0, 2.0, 2.0, 2.0],
48+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
49+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
50+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
51+
])
52+
53+
expected_feature_log_probability =
54+
Nx.tensor([
55+
[-0.69314718, -0.15415068, -0.15415068, -0.15415068, -0.15415068, -0.15415068],
56+
[-0.25131443, -0.25131443, -0.25131443, -0.25131443, -0.25131443, -0.25131443],
57+
[-0.25131443, -0.25131443, -0.25131443, -0.25131443, -0.25131443, -0.25131443],
58+
[-0.25131443, -0.25131443, -0.25131443, -0.25131443, -0.25131443, -0.25131443]
59+
])
60+
61+
assert_all_close(model.feature_log_probability, expected_feature_log_probability)
62+
63+
expected_class_log_priors =
64+
Nx.tensor([-0.91629073, -1.60943791, -1.60943791, -1.60943791])
65+
66+
assert_all_close(model.class_log_priors, expected_class_log_priors)
67+
assert_all_close(expected_class_log_priors, model.class_log_priors)
68+
69+
assert model.class_count == Nx.tensor([2.0, 1.0, 1.0, 1.0])
70+
end
71+
72+
test ":fit_priors set to false" do
73+
x = Nx.iota({5, 6})
74+
y = Nx.tensor([1, 0, 1, 0, 1])
75+
76+
model = Bernoulli.fit(x, y, num_classes: 2, fit_priors: false)
77+
78+
assert model.feature_count ==
79+
Nx.tensor([
80+
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
81+
[2.0, 3.0, 3.0, 3.0, 3.0, 3.0]
82+
])
83+
84+
expected_feature_log_probability =
85+
Nx.tensor([
86+
[-0.28768207, -0.28768207, -0.28768207, -0.28768207, -0.28768207, -0.28768207],
87+
[-0.51082562, -0.22314355, -0.22314355, -0.22314355, -0.22314355, -0.22314355]
88+
])
89+
90+
assert_all_close(model.feature_log_probability, expected_feature_log_probability)
91+
92+
expected_class_log_priors =
93+
Nx.tensor([-0.69314718, -0.69314718])
94+
95+
assert_all_close(model.class_log_priors, expected_class_log_priors)
96+
assert_all_close(expected_class_log_priors, model.class_log_priors)
97+
98+
assert model.class_count == Nx.tensor([2.0, 3.0])
99+
end
100+
101+
#
102+
test "fit test - :class_priors are set as a list" do
103+
x = Nx.iota({5, 6})
104+
y = Nx.tensor([1, 2, 3, 2, 1])
105+
106+
model = Bernoulli.fit(x, y, num_classes: 3, class_priors: [0.3, 0.4, 0.3])
107+
108+
assert model.feature_count ==
109+
Nx.tensor([
110+
[1.0, 2.0, 2.0, 2.0, 2.0, 2.0],
111+
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
112+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
113+
])
114+
115+
expected_feature_log_probability =
116+
Nx.tensor([
117+
[-0.69314718, -0.28768207, -0.28768207, -0.28768207, -0.28768207, -0.28768207],
118+
[-0.28768207, -0.28768207, -0.28768207, -0.28768207, -0.28768207, -0.28768207],
119+
[-0.40546511, -0.40546511, -0.40546511, -0.40546511, -0.40546511, -0.40546511]
120+
])
121+
122+
assert_all_close(model.feature_log_probability, expected_feature_log_probability)
123+
124+
expected_class_log_priors =
125+
Nx.tensor([-1.2039728, -0.91629073, -1.2039728])
126+
127+
assert_all_close(model.class_log_priors, expected_class_log_priors)
128+
assert_all_close(expected_class_log_priors, model.class_log_priors)
129+
130+
assert model.class_count == Nx.tensor([2.0, 2.0, 1.0])
131+
end
132+
133+
test "error handling for wrong input shapes" do
134+
assert_raise ArgumentError,
135+
"expected x to have shape {num_samples, num_features}, got tensor with shape: {4}",
136+
fn ->
137+
Bernoulli.fit(
138+
Nx.tensor([1, 2, 3, 4]),
139+
Nx.tensor([1, 0, 1, 0]),
140+
num_classes: 2
141+
)
142+
end
143+
144+
assert_raise ArgumentError,
145+
"expected y to have shape {num_samples}, got tensor with shape: {1, 4}",
146+
fn ->
147+
Bernoulli.fit(
148+
Nx.tensor([[1, 2, 3, 4]]),
149+
Nx.tensor([[1, 0, 1, 0]]),
150+
num_classes: 2
151+
)
152+
end
153+
end
154+
end
155+
156+
describe "predict" do
157+
test "predicts classes correctly for new data" do
158+
x = Nx.iota({5, 6})
159+
y = Nx.tensor([1, 2, 3, 4, 5])
160+
161+
jit_model = Nx.Defn.jit(&Bernoulli.fit/3)
162+
model = jit_model.(x, y, num_classes: 5)
163+
164+
x_test = Nx.tensor([[1, 2, 3, 4, 5, 6], [0, 0, 0, 0, 0, 0]])
165+
166+
jit_predict = Nx.Defn.jit(&Bernoulli.predict/3)
167+
predictions = jit_predict.(model, x_test, Nx.tensor([1, 2, 3, 4, 5]))
168+
169+
expected_predictions = Nx.tensor([2, 1])
170+
assert predictions == expected_predictions
171+
end
172+
end
173+
end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
defmodule Scholar.Preprocessing.BinarizerTest do
2+
use Scholar.Case, async: true
3+
alias Scholar.Preprocessing.Binarizer
4+
doctest Binarizer
5+
6+
describe "binarization" do
7+
test "binarize with positive threshold" do
8+
tensor = Nx.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [-2.0, -1.0, 0.0]])
9+
10+
jit_binarizer = Nx.Defn.jit(&Binarizer.fit_transform/2)
11+
12+
result = jit_binarizer.(tensor, threshold: 2.0)
13+
14+
assert Nx.to_flat_list(result) == [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0]
15+
end
16+
17+
test "binarize values with default threshold" do
18+
tensor = Nx.tensor([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0], [-2.0, 1.0, 0.0]])
19+
20+
result = Binarizer.fit_transform(tensor)
21+
22+
assert Nx.to_flat_list(result) == [0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]
23+
end
24+
25+
test "binarize with threshold less than 0" do
26+
tensor = Nx.tensor([[0.0, 0.5, -0.5], [-0.1, -0.2, -0.3]])
27+
jit_binarizer = Nx.Defn.jit(&Binarizer.fit_transform/2)
28+
29+
result = jit_binarizer.(tensor, threshold: -0.2)
30+
31+
assert Nx.to_flat_list(result) == [1.0, 1.0, 0.0, 1.0, 0.0, 0.0]
32+
end
33+
end
34+
end

0 commit comments

Comments
 (0)
Please sign in to comment.