forked from epfml/ML_course
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Scott Pesme
committed
Nov 3, 2022
1 parent
4181eec
commit 09432d0
Showing
4 changed files
with
637 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# -*- coding: utf-8 -*- | ||
"""some helper functions for project 1.""" | ||
import csv | ||
import numpy as np | ||
|
||
|
||
def load_csv_data(data_path, sub_sample=False): | ||
"""Loads data. | ||
return | ||
y(class labels), tX (features) and ids (event ids). | ||
""" | ||
y = np.genfromtxt(data_path, delimiter=",", skip_header=1, dtype=str, usecols=1) | ||
x = np.genfromtxt(data_path, delimiter=",", skip_header=1) | ||
ids = x[:, 0].astype(np.int) | ||
input_data = x[:, 2:] | ||
|
||
# convert class labels from strings to binary (-1,1) | ||
yb = np.ones(len(y)) | ||
yb[np.where(y == "b")] = -1 | ||
|
||
# sub-sample | ||
if sub_sample: | ||
yb = yb[::50] | ||
input_data = input_data[::50] | ||
ids = ids[::50] | ||
|
||
return yb, input_data, ids | ||
|
||
|
||
def predict_labels(weights, data): | ||
"""Generates class predictions given weights, and a test data matrix.""" | ||
y_pred = np.dot(data, weights) | ||
y_pred[np.where(y_pred <= 0)] = -1 | ||
y_pred[np.where(y_pred > 0)] = 1 | ||
|
||
return y_pred |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import doctest | ||
import io | ||
import sys | ||
import numpy as np | ||
|
||
|
||
def test(f): | ||
# The `globs` defines the variables, functions and packages allowed in the docstring. | ||
tests = doctest.DocTestFinder().find(f) | ||
assert len(tests) <= 1 | ||
for test in tests: | ||
# We redirect stdout to a string, so we can tell if the tests worked out or not | ||
orig_stdout = sys.stdout | ||
sys.stdout = io.StringIO() | ||
|
||
try: | ||
results: doctest.TestResults = doctest.DocTestRunner().run(test) | ||
output = sys.stdout.getvalue() | ||
finally: | ||
sys.stdout = orig_stdout | ||
|
||
if results.failed > 0: | ||
print(f"❌ The are some issues with your implementation of `{f.__name__}`:") | ||
print(output, end="") | ||
print( | ||
"**********************************************************************" | ||
) | ||
elif results.attempted > 0: | ||
print(f"✅ Your `{f.__name__}` passed {results.attempted} tests.") | ||
else: | ||
print(f"Could not find any tests for {f.__name__}") |