A machine learning project to classify SMS messages as spam or legitimate (ham) communications using Natural Language Processing and Naive Bayes classification.
This project demonstrates how to build and train a machine learning model to identify spam SMS messages. The implementation includes data preprocessing, feature extraction, model training with hyperparameter optimization, and a testing framework to evaluate the model's performance on unseen data.
This project is based on materials from HTB Academy's "Applications of AI in InfoSec" module. The original codebase and learning materials are provided by Hack The Box Academy, and this implementation builds upon their educational resources.
- Data preprocessing pipeline for text normalization
- Feature extraction using TF-IDF vectorization
- Naive Bayes classification with hyperparameter tuning
- Model evaluation with F1 score metric
- Testing framework for evaluating model performance on new messages
- Python 3.8+
- pip package manager
- Clone the repository:
git clone https://github.com/yourusername/sms-spam-classification.git
cd sms-spam-classification- Create a virtual environment (optional but recommended):
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate- Install the required packages:
pip install -r requirements.txtThe project uses the SMS Spam Collection dataset from UCI Machine Learning Repository. The dataset will be automatically downloaded and preprocessed when running the main script.
To train the spam classification model:
python main.pyThis will:
- Download the dataset
- Preprocess the text data
- Extract features
- Train a Naive Bayes classifier with grid search for hyperparameter optimization
- Save the trained model to
spam_classifier.joblib
To test the trained model on sample messages:
python test.pyBy default, this will test the model on the last 15 messages from the test dataset. You can modify the sample_size and sample_method parameters in test.py to test different portions of the dataset:
# Define test parameters
sample_size = 15 # Number of messages to test
sample_method = "tail" # Options: "head", "tail", or set to another value for random samplingTo use the trained model in your own applications:
import joblib
import pandas as pd
from preprocessing import preprocessor
# Load the trained model
model = joblib.load('spam_classifier.joblib')
# Create a DataFrame with your message
message = "Your message here"
df = pd.DataFrame({'message': [message]})
# Preprocess your message
processed_df = preprocessor(df)
# Make a prediction
prediction = model.predict(processed_df['message'])
probability = model.predict_proba(processed_df['message'])
# Interpret the result
is_spam = prediction[0] == 1
spam_probability = probability[0][1]
print(f"Message: {message}")
print(f"Is spam: {is_spam}")
print(f"Spam probability: {spam_probability:.4f}")main.py: Main script for training the modeltest.py: Script for testing the model on new messagespreprocessing.py: Functions for data preprocessingfeature_extraction.py: Functions for feature extractionDataset/: Directory containing the dataset filesspam_classifier.joblib: Saved trained model
- Convert text to lowercase
- Remove punctuation and numbers (preserving $ and ! as they can be indicators of spam)
- Tokenize the text
- Remove stop words
- Apply stemming to reduce words to their base form
- Rejoin tokens into a single string
- Algorithm: Multinomial Naive Bayes
- Feature Extraction: TF-IDF Vectorization
- Hyperparameter Tuning: Grid search with cross-validation
- Evaluation Metric: F1 score
This project is for educational purposes and is based on materials from HTB Academy. Please refer to HTB Academy's terms for usage rights.
- HTB Academy for the original materials and guidance
- UCI Machine Learning Repository for the SMS Spam Collection dataset