diff --git a/.github/workflows/manual.yml b/.github/workflows/manual.yml index f43ffbd885..ea5adcf45f 100644 --- a/.github/workflows/manual.yml +++ b/.github/workflows/manual.yml @@ -1,31 +1,42 @@ -name: Python CI +import pickle +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import fbeta_score, precision_score, recall_score +from ml.data import process_data -on: [push] +def train_model(X_train, y_train): + model = RandomForestClassifier(random_state=42) + model.fit(X_train, y_train) + return model -jobs: - build: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10"] +def compute_model_metrics(y, preds): + fbeta = fbeta_score(y, preds, beta=1, zero_division=1) + precision = precision_score(y, preds, zero_division=1) + recall = recall_score(y, preds, zero_division=1) + return precision, recall, fbeta - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install flake8 pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - pytest test_ml.py -v +def inference(model, X): + return model.predict(X) + +def save_model(model, path): + with open(path, 'wb') as f: + pickle.dump(model, f) + +def load_model(path): + with open(path, 'rb') as f: + return pickle.load(f) + +def performance_on_categorical_slice( + data, column_name, slice_value, categorical_features, label, encoder, lb, model +): + slice_data = data[data[column_name] == slice_value] + X_slice, y_slice, _, _ = process_data( + slice_data, + categorical_features=categorical_features, + label=label, + training=False, + encoder=encoder, + lb=lb + ) + preds = inference(model, X_slice) + precision, recall, fbeta = compute_model_metrics(y_slice, preds) + return precision, recall, fbeta diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml new file mode 100644 index 0000000000..4fb8eec773 --- /dev/null +++ b/.github/workflows/python-app.yml @@ -0,0 +1,27 @@ +name: Python CI + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + pytest test_ml.py -v diff --git a/local_api.py b/local_api.py index a3bff2f988..da945cdb66 100644 --- a/local_api.py +++ b/local_api.py @@ -1,14 +1,14 @@ import json - import requests -# TODO: send a GET using the URL http://127.0.0.1:8000 -r = None # Your code here +# DONE: send a GET using the URL http://127.0.0.1:8000 +local_URL = "http://127.0.0.1:8000/" +r = requests.get(local_URL) -# TODO: print the status code -# print() -# TODO: print the welcome message -# print() +# DONE: print the status code +print("Get request status code: ", r.status_code) +# DONE?: print the welcome message +print("Welcome Message:", r.json()) @@ -26,13 +26,17 @@ "capital-gain": 0, "capital-loss": 0, "hours-per-week": 40, - "native-country": "United-States", + "native-country": "United-States" } -# TODO: send a POST using the data above -r = None # Your code here -# TODO: print the status code -# print() -# TODO: print the result -# print() +# DONE?: send a POST using the data above +#r = requests.post(local_URL+"data?", +# data = data) +r2 = requests.post(f'{local_URL}data', json=data) + + +# DONE: print the status code +print("Post request status code: ", r2.status_code) +# DONE: print the result +print("Inference result: ", r2.json()) \ No newline at end of file diff --git a/main.py b/main.py index 638e2414de..a4217ff79a 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ from ml.data import apply_label, process_data from ml.model import inference, load_model + # DO NOT MODIFY class Data(BaseModel): age: int = Field(..., example=37) @@ -15,7 +16,9 @@ class Data(BaseModel): education: str = Field(..., example="HS-grad") education_num: int = Field(..., example=10, alias="education-num") marital_status: str = Field( - ..., example="Married-civ-spouse", alias="marital-status" + ..., + example="Married-civ-spouse", + alias="marital-status", ) occupation: str = Field(..., example="Prof-specialty") relationship: str = Field(..., example="Husband") @@ -24,26 +27,31 @@ class Data(BaseModel): capital_gain: int = Field(..., example=0, alias="capital-gain") capital_loss: int = Field(..., example=0, alias="capital-loss") hours_per_week: int = Field(..., example=40, alias="hours-per-week") - native_country: str = Field(..., example="United-States", alias="native-country") + native_country: str = Field( + ..., + example="United-States", + alias="native-country", + ) + + +project_path = "/mnt/c/Users/kaleb/Desktop/DEPLOYING-A-SCALABLE-ML-PIPELINE-WITH-FASTAPI" -path = None # TODO: enter the path for the saved encoder -encoder = load_model(path) +encoder_path = os.path.join(project_path, "model", "encoder.pkl") +encoder = load_model(encoder_path) -path = None # TODO: enter the path for the saved model -model = load_model(path) +model_path = os.path.join(project_path, "model", "model.pkl") +model = load_model(model_path) + + +app = FastAPI() -# TODO: create a RESTful API using FastAPI -app = None # your code here -# TODO: create a GET on the root giving a welcome message @app.get("/") async def get_root(): """ Say hello!""" - # your code here - pass + return {"message": "Hello! Welcome to Stephen's API."} -# TODO: create a POST on a different path that does model inference @app.post("/data/") async def post_inference(data: Data): # DO NOT MODIFY: turn the Pydantic model into a dict. @@ -64,11 +72,15 @@ async def post_inference(data: Data): "sex", "native-country", ] + data_processed, _, _, _ = process_data( - # your code here - # use data as data input - # use training = False - # do not need to pass lb as input + data, + categorical_features=cat_features, + training=False, + encoder=encoder, ) - _inference = None # your code here to predict the result using data_processed - return {"result": apply_label(_inference)} + + _inference = inference(model, data_processed) + return { + "result": apply_label(_inference) + } diff --git a/ml/model.py b/ml/model.py index f361110f18..32a211d21a 100644 --- a/ml/model.py +++ b/ml/model.py @@ -1,7 +1,8 @@ import pickle from sklearn.metrics import fbeta_score, precision_score, recall_score from ml.data import process_data -# TODO: add necessary import +from sklearn.ensemble import RandomForestClassifier # or your chosen model +import numpy as np # Optional: implement hyperparameter tuning. def train_model(X_train, y_train): @@ -20,7 +21,9 @@ def train_model(X_train, y_train): Trained machine learning model. """ # TODO: implement the function - pass + model = RandomForestClassifier(random_state=42) + model.fit(X_train, y_train) + return model def compute_model_metrics(y, preds): @@ -60,7 +63,7 @@ def inference(model, X): Predictions from the model. """ # TODO: implement the function - pass + return model.predict(X) def save_model(model, path): """ Serializes model to a file. @@ -73,12 +76,14 @@ def save_model(model, path): Path to save pickle file. """ # TODO: implement the function - pass + with open(path, 'wb') as f: + pickle.dump(model, f) def load_model(path): """ Loads pickle file from `path` and returns it.""" # TODO: implement the function - pass + with open(path, 'rb') as f: + return pickle.load(f) def performance_on_categorical_slice( @@ -117,12 +122,15 @@ def performance_on_categorical_slice( fbeta : float """ - # TODO: implement the function + slice_data = data[data[column_name] == slice_value] X_slice, y_slice, _, _ = process_data( - # your code here - # for input data, use data in column given as "column_name", with the slice_value - # use training = False + slice_data, + categorical_features=categorical_features, + label=label, + training=False, + encoder=encoder, + lb=lb ) - preds = None # your code here to get prediction on X_slice using the inference function + preds = inference(model, X_slice) precision, recall, fbeta = compute_model_metrics(y_slice, preds) return precision, recall, fbeta diff --git a/model/encoder.pkl b/model/encoder.pkl new file mode 100644 index 0000000000..051c53f367 Binary files /dev/null and b/model/encoder.pkl differ diff --git a/model/model.pkl b/model/model.pkl new file mode 100644 index 0000000000..0bd32041a0 Binary files /dev/null and b/model/model.pkl differ diff --git a/model_card.md b/model_card.md new file mode 100644 index 0000000000..084b260f82 --- /dev/null +++ b/model_card.md @@ -0,0 +1,55 @@ +# Model Card + +For additional information see the Model Card paper: https://arxiv.org/pdf/1810.03993.pdf + +## Model Details + +- Developed by: Stephen Byrd, Feburary 2025 +- Model Type: This model uses a **Random Forest Classifier** for binary classification. +- Dataset: The model is trained on the **Adult Census Dataset** from the UCI Machine Learning Repository. + +## Intended Use + +- The model is intended to predict whether an individual earns more than $50,000 per year based on demographic features from the census data. + +## Training Data + +- Dataset Source: The data is extracted from the 1994 Census database. +- Features: + - age + - workclass + - education + - education-num + - marital-status + - occupation + - relationship + - race + - sex + - capital-gain + - capital-loss + - hours-per-week + - native-country +- Target Label: 'salary' with values '>50K' (1) and '<=50K' (0). + +## Evaluation Data + +- Validation Method: The model is evaluated using a test dataset split from the original data. +- Metrics Used: Precision, Recall, F1 Score. + +## Metrics + +- Metrics Used: Precision, Recall, F1 Score. +- Model Performance: + - Precision: **0.7419** + - Recall: **0.6384** + - F1 Score: **0.6863** + +## Ethical Considerations + +- The dataset may have biases towards certain demographics (e.g., more men than women, predominantly white individuals). +- No direct human life risks are associated with this model. + +## Caveats and Recommendations + +- Limitations: The model's performance could be improved with further tuning or using different classifiers. +- Future Work: Consider using techniques like SHAP for feature importance analysis or exploring other classification models. diff --git a/model_card.md:Zone.Identifier b/model_card.md:Zone.Identifier new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_card_template.md b/model_card_template.md deleted file mode 100644 index 0392f3b9eb..0000000000 --- a/model_card_template.md +++ /dev/null @@ -1,18 +0,0 @@ -# Model Card - -For additional information see the Model Card paper: https://arxiv.org/pdf/1810.03993.pdf - -## Model Details - -## Intended Use - -## Training Data - -## Evaluation Data - -## Metrics -_Please include the metrics used and your model's performance on those metrics._ - -## Ethical Considerations - -## Caveats and Recommendations diff --git a/screenshots/continuous_integration.png b/screenshots/continuous_integration.png new file mode 100644 index 0000000000..38ffc8a1a2 Binary files /dev/null and b/screenshots/continuous_integration.png differ diff --git a/screenshots/local_api.png b/screenshots/local_api.png new file mode 100644 index 0000000000..a1b3d0c789 Binary files /dev/null and b/screenshots/local_api.png differ diff --git a/screenshots/unit_test.png b/screenshots/unit_test.png new file mode 100644 index 0000000000..d56eda4430 Binary files /dev/null and b/screenshots/unit_test.png differ diff --git a/slice_output.txt b/slice_output.txt new file mode 100644 index 0000000000..8d94b84e6f --- /dev/null +++ b/slice_output.txt @@ -0,0 +1,396 @@ +workclass: ?, Count: 389 +Precision: 0.6538 | Recall: 0.4048 | F1: 0.5000 +workclass: Federal-gov, Count: 191 +Precision: 0.7971 | Recall: 0.7857 | F1: 0.7914 +workclass: Local-gov, Count: 387 +Precision: 0.7576 | Recall: 0.6818 | F1: 0.7177 +workclass: Private, Count: 4,578 +Precision: 0.7376 | Recall: 0.6404 | F1: 0.6856 +workclass: Self-emp-inc, Count: 212 +Precision: 0.7807 | Recall: 0.7542 | F1: 0.7672 +workclass: Self-emp-not-inc, Count: 498 +Precision: 0.7064 | Recall: 0.4904 | F1: 0.5789 +workclass: State-gov, Count: 254 +Precision: 0.7424 | Recall: 0.6712 | F1: 0.7050 +workclass: Without-pay, Count: 4 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +education: 10th, Count: 183 +Precision: 0.4000 | Recall: 0.1667 | F1: 0.2353 +education: 11th, Count: 225 +Precision: 1.0000 | Recall: 0.2727 | F1: 0.4286 +education: 12th, Count: 98 +Precision: 1.0000 | Recall: 0.4000 | F1: 0.5714 +education: 1st-4th, Count: 23 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +education: 5th-6th, Count: 62 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +education: 7th-8th, Count: 141 +Precision: 0.0000 | Recall: 0.0000 | F1: 0.0000 +education: 9th, Count: 115 +Precision: 1.0000 | Recall: 0.3333 | F1: 0.5000 +education: Assoc-acdm, Count: 198 +Precision: 0.7000 | Recall: 0.5957 | F1: 0.6437 +education: Assoc-voc, Count: 273 +Precision: 0.6471 | Recall: 0.5238 | F1: 0.5789 +education: Bachelors, Count: 1,053 +Precision: 0.7523 | Recall: 0.7289 | F1: 0.7404 +education: Doctorate, Count: 77 +Precision: 0.8644 | Recall: 0.8947 | F1: 0.8793 +education: HS-grad, Count: 2,085 +Precision: 0.6594 | Recall: 0.4377 | F1: 0.5261 +education: Masters, Count: 369 +Precision: 0.8271 | Recall: 0.8551 | F1: 0.8409 +education: Preschool, Count: 10 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +education: Prof-school, Count: 116 +Precision: 0.8182 | Recall: 0.9643 | F1: 0.8852 +education: Some-college, Count: 1,485 +Precision: 0.6857 | Recall: 0.5199 | F1: 0.5914 +marital-status: Divorced, Count: 920 +Precision: 0.7600 | Recall: 0.3689 | F1: 0.4967 +marital-status: Married-AF-spouse, Count: 4 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +marital-status: Married-civ-spouse, Count: 2,950 +Precision: 0.7346 | Recall: 0.6900 | F1: 0.7116 +marital-status: Married-spouse-absent, Count: 96 +Precision: 1.0000 | Recall: 0.2500 | F1: 0.4000 +marital-status: Never-married, Count: 2,126 +Precision: 0.8302 | Recall: 0.4272 | F1: 0.5641 +marital-status: Separated, Count: 209 +Precision: 1.0000 | Recall: 0.4211 | F1: 0.5926 +marital-status: Widowed, Count: 208 +Precision: 1.0000 | Recall: 0.1579 | F1: 0.2727 +occupation: ?, Count: 389 +Precision: 0.6538 | Recall: 0.4048 | F1: 0.5000 +occupation: Adm-clerical, Count: 726 +Precision: 0.6338 | Recall: 0.4688 | F1: 0.5389 +occupation: Armed-Forces, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +occupation: Craft-repair, Count: 821 +Precision: 0.6567 | Recall: 0.4862 | F1: 0.5587 +occupation: Exec-managerial, Count: 838 +Precision: 0.7952 | Recall: 0.7531 | F1: 0.7736 +occupation: Farming-fishing, Count: 193 +Precision: 0.5455 | Recall: 0.2143 | F1: 0.3077 +occupation: Handlers-cleaners, Count: 273 +Precision: 0.5714 | Recall: 0.3333 | F1: 0.4211 +occupation: Machine-op-inspct, Count: 378 +Precision: 0.5938 | Recall: 0.4043 | F1: 0.4810 +occupation: Other-service, Count: 667 +Precision: 1.0000 | Recall: 0.1923 | F1: 0.3226 +occupation: Priv-house-serv, Count: 26 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +occupation: Prof-specialty, Count: 828 +Precision: 0.7880 | Recall: 0.7679 | F1: 0.7778 +occupation: Protective-serv, Count: 136 +Precision: 0.7353 | Recall: 0.5952 | F1: 0.6579 +occupation: Sales, Count: 729 +Precision: 0.7273 | Recall: 0.6667 | F1: 0.6957 +occupation: Tech-support, Count: 189 +Precision: 0.7143 | Recall: 0.6863 | F1: 0.7000 +occupation: Transport-moving, Count: 317 +Precision: 0.6250 | Recall: 0.4688 | F1: 0.5357 +relationship: Husband, Count: 2,590 +Precision: 0.7370 | Recall: 0.6923 | F1: 0.7140 +relationship: Not-in-family, Count: 1,702 +Precision: 0.7959 | Recall: 0.4149 | F1: 0.5455 +relationship: Other-relative, Count: 178 +Precision: 1.0000 | Recall: 0.3750 | F1: 0.5455 +relationship: Own-child, Count: 1,019 +Precision: 1.0000 | Recall: 0.1765 | F1: 0.3000 +relationship: Unmarried, Count: 702 +Precision: 0.9231 | Recall: 0.2667 | F1: 0.4138 +relationship: Wife, Count: 322 +Precision: 0.7132 | Recall: 0.6783 | F1: 0.6953 +race: Amer-Indian-Eskimo, Count: 71 +Precision: 0.6250 | Recall: 0.5000 | F1: 0.5556 +race: Asian-Pac-Islander, Count: 193 +Precision: 0.7857 | Recall: 0.7097 | F1: 0.7458 +race: Black, Count: 599 +Precision: 0.7273 | Recall: 0.6154 | F1: 0.6667 +race: Other, Count: 55 +Precision: 1.0000 | Recall: 0.6667 | F1: 0.8000 +race: White, Count: 5,595 +Precision: 0.7404 | Recall: 0.6373 | F1: 0.6850 +sex: Female, Count: 2,126 +Precision: 0.7229 | Recall: 0.5150 | F1: 0.6015 +sex: Male, Count: 4,387 +Precision: 0.7445 | Recall: 0.6599 | F1: 0.6997 +native-country: ?, Count: 125 +Precision: 0.7500 | Recall: 0.6774 | F1: 0.7119 +native-country: Cambodia, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Canada, Count: 22 +Precision: 0.6667 | Recall: 0.7500 | F1: 0.7059 +native-country: China, Count: 18 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Columbia, Count: 6 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Cuba, Count: 19 +Precision: 0.6667 | Recall: 0.8000 | F1: 0.7273 +native-country: Dominican-Republic, Count: 8 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Ecuador, Count: 5 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +native-country: El-Salvador, Count: 20 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: England, Count: 14 +Precision: 0.6667 | Recall: 0.5000 | F1: 0.5714 +native-country: France, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Germany, Count: 32 +Precision: 0.8462 | Recall: 0.8462 | F1: 0.8462 +native-country: Greece, Count: 7 +Precision: 0.0000 | Recall: 0.0000 | F1: 0.0000 +native-country: Guatemala, Count: 13 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Haiti, Count: 6 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Honduras, Count: 4 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Hong, Count: 8 +Precision: 0.5000 | Recall: 1.0000 | F1: 0.6667 +native-country: Hungary, Count: 3 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +native-country: India, Count: 21 +Precision: 0.8750 | Recall: 0.8750 | F1: 0.8750 +native-country: Iran, Count: 12 +Precision: 0.3333 | Recall: 0.2000 | F1: 0.2500 +native-country: Ireland, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Italy, Count: 14 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +native-country: Jamaica, Count: 13 +Precision: 0.0000 | Recall: 1.0000 | F1: 0.0000 +native-country: Japan, Count: 11 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +native-country: Laos, Count: 4 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +native-country: Mexico, Count: 114 +Precision: 1.0000 | Recall: 0.3333 | F1: 0.5000 +native-country: Nicaragua, Count: 7 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Peru, Count: 5 +Precision: 0.0000 | Recall: 0.0000 | F1: 0.0000 +native-country: Philippines, Count: 35 +Precision: 1.0000 | Recall: 0.6875 | F1: 0.8148 +native-country: Poland, Count: 14 +Precision: 0.6667 | Recall: 1.0000 | F1: 0.8000 +native-country: Portugal, Count: 6 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Puerto-Rico, Count: 22 +Precision: 0.8333 | Recall: 0.8333 | F1: 0.8333 +native-country: Scotland, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: South, Count: 13 +Precision: 0.3333 | Recall: 0.5000 | F1: 0.4000 +native-country: Taiwan, Count: 11 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +native-country: Thailand, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Trinadad&Tobago, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: United-States, Count: 5,870 +Precision: 0.7392 | Recall: 0.6321 | F1: 0.6814 +native-country: Vietnam, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Yugoslavia, Count: 2 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +workclass: ?, Count: 389 +Precision: 0.6538 | Recall: 0.4048 | F1: 0.5000 +workclass: Federal-gov, Count: 191 +Precision: 0.7971 | Recall: 0.7857 | F1: 0.7914 +workclass: Local-gov, Count: 387 +Precision: 0.7576 | Recall: 0.6818 | F1: 0.7177 +workclass: Private, Count: 4,578 +Precision: 0.7376 | Recall: 0.6404 | F1: 0.6856 +workclass: Self-emp-inc, Count: 212 +Precision: 0.7807 | Recall: 0.7542 | F1: 0.7672 +workclass: Self-emp-not-inc, Count: 498 +Precision: 0.7064 | Recall: 0.4904 | F1: 0.5789 +workclass: State-gov, Count: 254 +Precision: 0.7424 | Recall: 0.6712 | F1: 0.7050 +workclass: Without-pay, Count: 4 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +education: 10th, Count: 183 +Precision: 0.4000 | Recall: 0.1667 | F1: 0.2353 +education: 11th, Count: 225 +Precision: 1.0000 | Recall: 0.2727 | F1: 0.4286 +education: 12th, Count: 98 +Precision: 1.0000 | Recall: 0.4000 | F1: 0.5714 +education: 1st-4th, Count: 23 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +education: 5th-6th, Count: 62 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +education: 7th-8th, Count: 141 +Precision: 0.0000 | Recall: 0.0000 | F1: 0.0000 +education: 9th, Count: 115 +Precision: 1.0000 | Recall: 0.3333 | F1: 0.5000 +education: Assoc-acdm, Count: 198 +Precision: 0.7000 | Recall: 0.5957 | F1: 0.6437 +education: Assoc-voc, Count: 273 +Precision: 0.6471 | Recall: 0.5238 | F1: 0.5789 +education: Bachelors, Count: 1,053 +Precision: 0.7523 | Recall: 0.7289 | F1: 0.7404 +education: Doctorate, Count: 77 +Precision: 0.8644 | Recall: 0.8947 | F1: 0.8793 +education: HS-grad, Count: 2,085 +Precision: 0.6594 | Recall: 0.4377 | F1: 0.5261 +education: Masters, Count: 369 +Precision: 0.8271 | Recall: 0.8551 | F1: 0.8409 +education: Preschool, Count: 10 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +education: Prof-school, Count: 116 +Precision: 0.8182 | Recall: 0.9643 | F1: 0.8852 +education: Some-college, Count: 1,485 +Precision: 0.6857 | Recall: 0.5199 | F1: 0.5914 +marital-status: Divorced, Count: 920 +Precision: 0.7600 | Recall: 0.3689 | F1: 0.4967 +marital-status: Married-AF-spouse, Count: 4 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +marital-status: Married-civ-spouse, Count: 2,950 +Precision: 0.7346 | Recall: 0.6900 | F1: 0.7116 +marital-status: Married-spouse-absent, Count: 96 +Precision: 1.0000 | Recall: 0.2500 | F1: 0.4000 +marital-status: Never-married, Count: 2,126 +Precision: 0.8302 | Recall: 0.4272 | F1: 0.5641 +marital-status: Separated, Count: 209 +Precision: 1.0000 | Recall: 0.4211 | F1: 0.5926 +marital-status: Widowed, Count: 208 +Precision: 1.0000 | Recall: 0.1579 | F1: 0.2727 +occupation: ?, Count: 389 +Precision: 0.6538 | Recall: 0.4048 | F1: 0.5000 +occupation: Adm-clerical, Count: 726 +Precision: 0.6338 | Recall: 0.4688 | F1: 0.5389 +occupation: Armed-Forces, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +occupation: Craft-repair, Count: 821 +Precision: 0.6567 | Recall: 0.4862 | F1: 0.5587 +occupation: Exec-managerial, Count: 838 +Precision: 0.7952 | Recall: 0.7531 | F1: 0.7736 +occupation: Farming-fishing, Count: 193 +Precision: 0.5455 | Recall: 0.2143 | F1: 0.3077 +occupation: Handlers-cleaners, Count: 273 +Precision: 0.5714 | Recall: 0.3333 | F1: 0.4211 +occupation: Machine-op-inspct, Count: 378 +Precision: 0.5938 | Recall: 0.4043 | F1: 0.4810 +occupation: Other-service, Count: 667 +Precision: 1.0000 | Recall: 0.1923 | F1: 0.3226 +occupation: Priv-house-serv, Count: 26 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +occupation: Prof-specialty, Count: 828 +Precision: 0.7880 | Recall: 0.7679 | F1: 0.7778 +occupation: Protective-serv, Count: 136 +Precision: 0.7353 | Recall: 0.5952 | F1: 0.6579 +occupation: Sales, Count: 729 +Precision: 0.7273 | Recall: 0.6667 | F1: 0.6957 +occupation: Tech-support, Count: 189 +Precision: 0.7143 | Recall: 0.6863 | F1: 0.7000 +occupation: Transport-moving, Count: 317 +Precision: 0.6250 | Recall: 0.4688 | F1: 0.5357 +relationship: Husband, Count: 2,590 +Precision: 0.7370 | Recall: 0.6923 | F1: 0.7140 +relationship: Not-in-family, Count: 1,702 +Precision: 0.7959 | Recall: 0.4149 | F1: 0.5455 +relationship: Other-relative, Count: 178 +Precision: 1.0000 | Recall: 0.3750 | F1: 0.5455 +relationship: Own-child, Count: 1,019 +Precision: 1.0000 | Recall: 0.1765 | F1: 0.3000 +relationship: Unmarried, Count: 702 +Precision: 0.9231 | Recall: 0.2667 | F1: 0.4138 +relationship: Wife, Count: 322 +Precision: 0.7132 | Recall: 0.6783 | F1: 0.6953 +race: Amer-Indian-Eskimo, Count: 71 +Precision: 0.6250 | Recall: 0.5000 | F1: 0.5556 +race: Asian-Pac-Islander, Count: 193 +Precision: 0.7857 | Recall: 0.7097 | F1: 0.7458 +race: Black, Count: 599 +Precision: 0.7273 | Recall: 0.6154 | F1: 0.6667 +race: Other, Count: 55 +Precision: 1.0000 | Recall: 0.6667 | F1: 0.8000 +race: White, Count: 5,595 +Precision: 0.7404 | Recall: 0.6373 | F1: 0.6850 +sex: Female, Count: 2,126 +Precision: 0.7229 | Recall: 0.5150 | F1: 0.6015 +sex: Male, Count: 4,387 +Precision: 0.7445 | Recall: 0.6599 | F1: 0.6997 +native-country: ?, Count: 125 +Precision: 0.7500 | Recall: 0.6774 | F1: 0.7119 +native-country: Cambodia, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Canada, Count: 22 +Precision: 0.6667 | Recall: 0.7500 | F1: 0.7059 +native-country: China, Count: 18 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Columbia, Count: 6 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Cuba, Count: 19 +Precision: 0.6667 | Recall: 0.8000 | F1: 0.7273 +native-country: Dominican-Republic, Count: 8 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Ecuador, Count: 5 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +native-country: El-Salvador, Count: 20 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: England, Count: 14 +Precision: 0.6667 | Recall: 0.5000 | F1: 0.5714 +native-country: France, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Germany, Count: 32 +Precision: 0.8462 | Recall: 0.8462 | F1: 0.8462 +native-country: Greece, Count: 7 +Precision: 0.0000 | Recall: 0.0000 | F1: 0.0000 +native-country: Guatemala, Count: 13 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Haiti, Count: 6 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Honduras, Count: 4 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Hong, Count: 8 +Precision: 0.5000 | Recall: 1.0000 | F1: 0.6667 +native-country: Hungary, Count: 3 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +native-country: India, Count: 21 +Precision: 0.8750 | Recall: 0.8750 | F1: 0.8750 +native-country: Iran, Count: 12 +Precision: 0.3333 | Recall: 0.2000 | F1: 0.2500 +native-country: Ireland, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Italy, Count: 14 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +native-country: Jamaica, Count: 13 +Precision: 0.0000 | Recall: 1.0000 | F1: 0.0000 +native-country: Japan, Count: 11 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +native-country: Laos, Count: 4 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +native-country: Mexico, Count: 114 +Precision: 1.0000 | Recall: 0.3333 | F1: 0.5000 +native-country: Nicaragua, Count: 7 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Peru, Count: 5 +Precision: 0.0000 | Recall: 0.0000 | F1: 0.0000 +native-country: Philippines, Count: 35 +Precision: 1.0000 | Recall: 0.6875 | F1: 0.8148 +native-country: Poland, Count: 14 +Precision: 0.6667 | Recall: 1.0000 | F1: 0.8000 +native-country: Portugal, Count: 6 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Puerto-Rico, Count: 22 +Precision: 0.8333 | Recall: 0.8333 | F1: 0.8333 +native-country: Scotland, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: South, Count: 13 +Precision: 0.3333 | Recall: 0.5000 | F1: 0.4000 +native-country: Taiwan, Count: 11 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +native-country: Thailand, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Trinadad&Tobago, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: United-States, Count: 5,870 +Precision: 0.7392 | Recall: 0.6321 | F1: 0.6814 +native-country: Vietnam, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Yugoslavia, Count: 2 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 diff --git a/test_ml.py b/test_ml.py index 5f8306f14c..158d6f62c9 100644 --- a/test_ml.py +++ b/test_ml.py @@ -1,28 +1,37 @@ import pytest -# TODO: add necessary import +import numpy as np +from sklearn.preprocessing import StandardScaler +from sklearn.datasets import make_classification +from sklearn.ensemble import RandomForestClassifier +from ml.model import train_model -# TODO: implement the first test. Change the function name and input as needed -def test_one(): +def test_load_data(): """ - # add description for the first test + This test verifies that the data is loaded correctly and has the expected shape. """ - # Your code here - pass + X, y = make_classification(n_samples=100, n_features=8, random_state=42) + assert X.shape == (100, 8), f"Incorrect shape for X: {X.shape}" + assert y.shape == (100,), f"Incorrect shape for y: {y.shape}" -# TODO: implement the second test. Change the function name and input as needed -def test_two(): +def test_model_type(): """ - # add description for the second test + This test verifies that model returned is correct: RandomForestClassifier. """ - # Your code here - pass + X, y = make_classification(n_samples=100, n_features=8, random_state=42) + model = train_model(X, y) + assert type(model) == RandomForestClassifier, f"Incorrect model type: {type(model)}" -# TODO: implement the third test. Change the function name and input as needed -def test_three(): + +def test_preprocessing_scaler(): """ - # add description for the third test + This test ensures that the scaling of the data is done properly, an important step in preprocessing. """ - # Your code here - pass + scaler = StandardScaler() + X, _ = make_classification(n_samples=100, n_features=8, random_state=42) + X_scaled = scaler.fit_transform(X) + + # Check that the mean and standard deviation are close to expected values + assert np.isclose(np.mean(X_scaled), 0), f"Incorrect mean: {np.mean(X_scaled)}" + assert np.isclose(np.std(X_scaled), 1), f"Incorrect std dev: {np.std(X_scaled)}" \ No newline at end of file diff --git a/train_model.py b/train_model.py index ae783ed5b9..8874195f0d 100644 --- a/train_model.py +++ b/train_model.py @@ -1,8 +1,6 @@ import os - import pandas as pd from sklearn.model_selection import train_test_split - from ml.data import process_data from ml.model import ( compute_model_metrics, @@ -12,17 +10,14 @@ save_model, train_model, ) -# TODO: load the cencus.csv data -project_path = "Your path here" -data_path = os.path.join(project_path, "data", "census.csv") -print(data_path) -data = None # your code here -# TODO: split the provided data to have a train dataset and a test dataset -# Optional enhancement, use K-fold cross validation instead of a train-test split. -train, test = None, None# Your code here +# Load the census data +data = pd.read_csv("data/census.csv") -# DO NOT MODIFY +# Split data into train and test sets +train, test = train_test_split(data, test_size=0.20, random_state=42) + +# Define categorical features cat_features = [ "workclass", "education", @@ -34,14 +29,15 @@ "native-country", ] -# TODO: use the process_data function provided to process the data. +# Process training data X_train, y_train, encoder, lb = process_data( - # your code here - # use the train dataset - # use training=True - # do not need to pass encoder and lb as input - ) + train, + categorical_features=cat_features, + label="salary", + training=True +) +# Process test data X_test, y_test, _, _ = process_data( test, categorical_features=cat_features, @@ -51,37 +47,42 @@ lb=lb, ) -# TODO: use the train_model function to train the model on the training dataset -model = None # your code here +# Train model +model = train_model(X_train, y_train) -# save the model and the encoder -model_path = os.path.join(project_path, "model", "model.pkl") +# Save model and encoder +model_path = os.path.join("model", "model.pkl") save_model(model, model_path) -encoder_path = os.path.join(project_path, "model", "encoder.pkl") +encoder_path = os.path.join("model", "encoder.pkl") save_model(encoder, encoder_path) -# load the model -model = load_model( - model_path -) +# Load model for inference +model = load_model(model_path) -# TODO: use the inference function to run the model inferences on the test dataset. -preds = None # your code here +# Make predictions +preds = inference(model, X_test) -# Calculate and print the metrics -p, r, fb = compute_model_metrics(y_test, preds) -print(f"Precision: {p:.4f} | Recall: {r:.4f} | F1: {fb:.4f}") +# Calculate metrics +precision, recall, fbeta = compute_model_metrics(y_test, preds) +print(f"Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {fbeta:.4f}") -# TODO: compute the performance on model slices using the performance_on_categorical_slice function -# iterate through the categorical features +# Compute performance on slices +# Compute performance on slices for col in cat_features: - # iterate through the unique values in one categorical feature for slicevalue in sorted(test[col].unique()): count = test[test[col] == slicevalue].shape[0] - p, r, fb = performance_on_categorical_slice( - # your code here - # use test, col and slicevalue as part of the input + precision, recall, fbeta = performance_on_categorical_slice( + test, + col, + slicevalue, + cat_features, + "salary", + encoder, + lb, + model ) + with open("slice_output.txt", "a") as f: print(f"{col}: {slicevalue}, Count: {count:,}", file=f) - print(f"Precision: {p:.4f} | Recall: {r:.4f} | F1: {fb:.4f}", file=f) + print(f"Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {fbeta:.4f}", file=f) +