-
Notifications
You must be signed in to change notification settings - Fork 1
Tim/privacy experiments ces22 #180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
neeshjaa
wants to merge
33
commits into
privacy-analysis-main-branch
Choose a base branch
from
tim/privacy-experiments-ces22
base: privacy-analysis-main-branch
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
e8f3700
update: Initial commit with privacy experiments on CES 2022 dataset
neeshjaa 298bd15
update: Improvement to the documentation
neeshjaa ead6749
update: Generate CSV files with binary target feature directly from t…
neeshjaa 3c7c4c2
update: Update documentation regarding outstanding question on predic…
neeshjaa 2dbd04e
update: More documentation tweaks
neeshjaa 258b294
fix: Must make sure Python lint likes my indentation
neeshjaa 06277a9
fix: Yet another indentation issue (tabs aren't allowed?)
neeshjaa bed5f93
fix: Lint pedantry needed to make the code ugly
neeshjaa 2f7cbaa
fix: Generalize selection of predictive probabilities in predictions.csv
neeshjaa e31f08d
fix: Clean up some temporary versions
neeshjaa a948aec
docs: Update privacy README to remove open question about predict.py
neeshjaa 5c9d026
fix: Replace y_true[] input to roc_curve with synthetic column
neeshjaa ebea79f
Revert "fix: Replace y_true[] input to roc_curve with synthetic column"
neeshjaa 4d656ec
Revert "docs: Update privacy README to remove open question about pre…
neeshjaa 1387d39
Revert "fix: Clean up some temporary versions"
neeshjaa 0597366
Revert "fix: Generalize selection of predictive probabilities in pred…
neeshjaa 9aad228
Revert "fix: Lint pedantry needed to make the code ugly"
neeshjaa 4cdeb7f
Revert "fix: Yet another indentation issue (tabs aren't allowed?)"
neeshjaa 6bc3dc2
Revert "fix: Must make sure Python lint likes my indentation"
neeshjaa 2212bef
Revert "update: More documentation tweaks"
neeshjaa e6a92a9
Revert "update: Update documentation regarding outstanding question o…
neeshjaa aeaff38
Revert "update: Generate CSV files with binary target feature directl…
neeshjaa 9a7825e
Revert "update: Improvement to the documentation"
neeshjaa cc94798
Revert "update: Initial commit with privacy experiments on CES 2022 d…
neeshjaa 6f928d7
feat: Add column to predictions.csv with probability of the predicted…
neeshjaa 8aefa67
fix: Python lint
neeshjaa 944b082
fix: Don't choose target for prediction at random
Schaechtle c90d756
fix: Ensure predictions.csv is pointing to data/directory
Schaechtle 451d357
chore: Remove plot.show()
Schaechtle 289276f
feat: Save results to disk
Schaechtle f500f2b
feat: Add pipeline stages for Tim's analysis
Schaechtle 6ca7b00
chore: Treat roc.png as pipeline output
Schaechtle c603f57
feat: Make the "positive" label for the ROC curve explicit throughout
Schaechtle File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,3 +24,4 @@ devenv.local.nix | |
pom.xml | ||
pom.xml.asc | ||
sum-product-dsl/ | ||
/roc.png |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,3 +39,4 @@ | |
/predictions.csv | ||
/synthetic-data-iql.csv | ||
/db.edn | ||
/ml-metrics.csv |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#!/usr/bin/python3 | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.metrics import roc_curve, auc | ||
from sklearn.datasets import make_classification | ||
from sklearn.linear_model import LogisticRegression | ||
import yaml | ||
|
||
# Generate generic ROC curve for the results found in 'predictions.csv' in the | ||
# current directory | ||
|
||
df = pd.read_csv("data/predictions.csv", header=0) | ||
yv = df["prediction"] | ||
yp = df["predictive-probability"] | ||
tv = df["true_value"] | ||
|
||
with open("params.yaml", "r") as f: | ||
params = yaml.safe_load(f.read()) | ||
# Get held-out configuration for evaluation. | ||
pos_label = params["synthetic_data_evaluation"]["positive_label"] | ||
|
||
# Compute ROC curve and ROC area | ||
fpr, tpr, thresholds = roc_curve(tv, yp, pos_label=pos_label) | ||
roc_auc = auc(fpr, tpr) | ||
|
||
# Plot ROC curve | ||
plt.figure() | ||
plt.plot(fpr, tpr, color="darkorange", lw=2, label="ROC curve (area = %0.2f)" % roc_auc) | ||
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--") | ||
plt.xlim([0.0, 1.0]) | ||
plt.ylim([0.0, 1.05]) | ||
plt.xlabel("False Positive Rate") | ||
plt.ylabel("True Positive Rate") | ||
plt.title("Receiver Operating Characteristic") | ||
plt.legend(loc="lower right") | ||
plt.savefig("roc.png") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#!/usr/bin/python3 | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
df = pd.read_csv("data/predictions.csv", header=0) | ||
|
||
# Show some generic metrics for the results found in 'predictions.csv' in the | ||
# current directory | ||
|
||
X = df["true_value"] | ||
Y = df["prediction"] | ||
|
||
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score | ||
|
||
print("Accuracy...: %f" % accuracy_score(Y, X)) | ||
print("Precision..: %f" % precision_score(Y, X, average="macro")) | ||
print("Recall.....: %f" % recall_score(Y, X, average="macro")) | ||
print("F1.........: %f" % f1_score(Y, X, average="macro")) | ||
|
||
# Also save to disk. Helpful to track the result with DVC. | ||
result = pd.DataFrame( | ||
{ | ||
"metric": ["Accuracy", "Precision", "Recall", "F1"], | ||
"score:": [ | ||
accuracy_score(Y, X), | ||
precision_score(Y, X, average="macro"), | ||
recall_score(Y, X, average="macro"), | ||
f1_score(Y, X, average="macro"), | ||
], | ||
} | ||
) | ||
result.to_csv("data/ml-metrics.csv", index=False) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, we need to decide which one of the two binary classes we want to predict and always grep the corresponding vector. Assuming the labels are encoded as 0, and 1, you could do the following:
Alternatively, if your true value is encoded as
"true-value"
, you could runj = list(ml_model.classes_).index("true-value")
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe for greatest generality we should specify the value to be predicted in params.yaml, then use that last method to create a new vector saying whether or not the predicted value equals that? That way we could handle dataset options like yes/no, true/false, without any preprocessing?