Skip to content

Commit 06344c4

Browse files
committed
test_c11_machine_learning.py
1 parent 6f41df5 commit 06344c4

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

src/data-scratch-library/tests/test_c11_machine_learning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from dsl.c04_linear_algebra.e0402_matrices import make_random_matrix
88
from dsl.c11_machine_learning import machine_learning
9+
from dsl.c11_machine_learning.machine_learning import split_data
910

1011
current_dir = os.path.dirname(__file__)
1112
parent_dir = os.path.join(current_dir, os.pardir)
@@ -63,13 +64,12 @@ def test_recall(tp, fp, fn, tn, expected):
6364

6465

6566
def test_split_data():
66-
result = machine_learning.split_data(make_random_matrix(), 0.5)
67+
result = split_data(make_random_matrix(), 0.5)
6768
assert len(result[0]) == pytest.approx(50, abs=10)
6869
assert len(result[1]) == pytest.approx(50, abs=10)
6970

7071

7172
def test_train_test_split():
72-
print(make_random_matrix)
7373
x_train, x_test, y_train, y_test = machine_learning.train_test_split(make_random_matrix(), make_random_matrix(), 0.5)
7474
assert len(x_train) == pytest.approx(50, abs=10)
7575
assert len(x_test) == pytest.approx(50, abs=10)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import random
2+
3+
from dsl.c08_gradient_descent import negate, negate_all
4+
from dsl.c08_gradient_descent.e0805_stochastic_gd import maximize_stochastic
5+
from dsl.c14_simple_linear_regression.simple_linear_regression import squared_error
6+
from dsl.c15_multiple_regression.multiple_regression import squared_error_gradient
7+
8+
9+
def test_maximize_stochastic_squared_error():
10+
x = [
11+
[1, 49, 4, 0], [1, 41, 9, 0], [1, 40, 8, 0],
12+
[1, 25, 6, 0], [1, 21, 1, 0], [1, 21, 0, 0],
13+
[1, 19, 3, 0], [1, 19, 0, 0], [1, 18, 9, 0], [1, 18, 8, 0]
14+
]
15+
y = [
16+
68.77, 51.25, 52.08,
17+
38.36, 44.54, 57.13,
18+
51.4, 41.42, 31.22, 34.76,
19+
]
20+
maximize_stochastic(
21+
target_fn=negate(squared_error),
22+
gradient_fn=negate_all(squared_error_gradient),
23+
x=x,
24+
y=y,
25+
theta_0=[random.random() for _ in x[0]],
26+
alpha_0=0.01
27+
)

0 commit comments

Comments
 (0)