diff --git a/README.md b/README.md index 6ea6940..22e46cb 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# ๐Ÿ›ก๏ธ CyberAttackDetection-Python +# Cybersecurity Attack Detection Framework ![๐Ÿ”ง Build Status](https://github.com/canstralian/CyberAttackDetection-Python/actions/workflows/ci.yml/badge.svg) ![๐Ÿ“Š Coverage](https://codecov.io/gh/canstralian/CyberAttackDetection-Python/branch/main/graph/badge.svg) @@ -8,57 +8,276 @@ ![๐Ÿš€ Release](https://img.shields.io/github/v/release/canstralian/CyberAttackDetection-Python) ![๐Ÿž Issues](https://img.shields.io/github/issues/canstralian/CyberAttackDetection-Python) ---- +## Overview -## ๐Ÿ›ก๏ธ About CyberAttackDetection-Python +A modular, secure, and extensible framework for detecting cyber attacks using machine learning. This framework provides a robust foundation for cybersecurity threat detection with built-in best practices for data preprocessing, model management, REST API development, and security. -**CyberAttackDetection-Python** is a Python application designed to detect and mitigate cyber attacks using **advanced machine learning techniques**. +### Key Features -### ๐ŸŒŸ Features +- **๐Ÿ›ก๏ธ Security-First Design**: JWT authentication, rate limiting, input validation, and security headers +- **๐Ÿ”ง Modular Architecture**: Clean separation of concerns with dedicated modules for preprocessing, modeling, and API +- **๐Ÿ“Š Multiple ML Models**: Support for Random Forest, Logistic Regression, SVM, and Neural Networks +- **๐Ÿš€ REST API**: Secure Flask-based API with comprehensive validation using Marshmallow +- **๐Ÿ“ˆ Comprehensive Evaluation**: Built-in metrics, cross-validation, and hyperparameter tuning +- **๐Ÿงช Testing Framework**: Extensive test suite with pytest and continuous integration +- **๐Ÿ“– Documentation**: Jupyter notebooks, scripts, and comprehensive API documentation +- **๐Ÿ”„ Model Management**: Easy model saving, loading, and comparison capabilities -- ๐Ÿš€ **Real-Time Attack Detection** -- ๐Ÿง  **Machine Learning Model Training and Evaluation** -- ๐Ÿ“ **Comprehensive Logging and Alerting System** +## Quick Start ---- +### Installation -## ๐Ÿ“‹ How to Use +```bash +# Clone the repository +git clone https://github.com/canstralian/CyberAttackDetection-Python.git +cd CyberAttackDetection-Python -1. **Clone the Repository** - ```bash - git clone https://github.com/canstralian/CyberAttackDetection-Python.git - ``` +# Install dependencies +pip install -r requirements.txt +``` -2. **Install Dependencies** - ```bash - pip install -r requirements.txt - ``` +### Basic Usage -3. **Run the Application** - ```bash - python main.py - ``` +#### 1. Generate Sample Data +```bash +python scripts/generate_data.py --samples 1000 --features 20 --output-dir data/ +``` ---- +#### 2. Train a Model +```bash +python scripts/train_model.py --data data/cyber_dataset_1000s_20f_2c.csv --model random_forest --hyperparameter-tuning +``` -## ๐Ÿค Contributing +#### 3. Start the API Server +```bash +cd src/api && python app.py +``` -Contributions are welcome! ๐Ÿ› ๏ธ Please follow the guidelines outlined in the [CONTRIBUTING.md](CONTRIBUTING.md). +#### 4. Use the Framework Programmatically ---- +```python +from src.core.preprocessing import DataPreprocessor +from src.models.detector import CyberAttackDetector +from src.utils.helpers import create_sample_dataset -## ๐Ÿ“œ License +# Generate sample data +data = create_sample_dataset(n_samples=1000, n_features=20) -This project is licensed under the **MIT License**. For more details, check the [LICENSE.md](LICENSE.md) file. +# Preprocess data +preprocessor = DataPreprocessor() +X, y, _ = preprocessor.full_preprocessing_pipeline(data) +X_scaled, _ = preprocessor.scale_features(X.select_dtypes(include=['float64', 'int64'])) ---- +# Train model +detector = CyberAttackDetector('random_forest') +detector.train(X_scaled, y) + +# Make predictions +predictions = detector.predict(X_scaled) +evaluation = detector.evaluate(X_scaled, y) +print(f"Accuracy: {evaluation['accuracy']:.4f}") +``` + +## Directory Structure + +``` +CyberAttackDetection-Python/ +โ”œโ”€โ”€ .github/ # GitHub workflows and templates +โ”‚ โ””โ”€โ”€ workflows/ +โ”‚ โ”œโ”€โ”€ ci.yml # Continuous integration +โ”‚ โ”œโ”€โ”€ security-scan.yml # Security scanning +โ”‚ โ””โ”€โ”€ python-app.yml # Python application workflow +โ”œโ”€โ”€ src/ # Main source code +โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ”œโ”€โ”€ api/ # REST API module +โ”‚ โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ”‚ โ””โ”€โ”€ app.py # Flask application with security features +โ”‚ โ”œโ”€โ”€ core/ # Core detection and preprocessing +โ”‚ โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ”‚ โ””โ”€โ”€ preprocessing.py # Data preprocessing pipeline +โ”‚ โ”œโ”€โ”€ models/ # Machine learning models +โ”‚ โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ”‚ โ””โ”€โ”€ detector.py # Model classes and management +โ”‚ โ””โ”€โ”€ utils/ # Utility functions +โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ””โ”€โ”€ helpers.py # Helper functions and utilities +โ”œโ”€โ”€ tests/ # Test suite +โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ””โ”€โ”€ test_framework.py # Comprehensive tests +โ”œโ”€โ”€ notebooks/ # Jupyter notebooks +โ”‚ โ””โ”€โ”€ model_training_demo.ipynb # Training demonstration +โ”œโ”€โ”€ scripts/ # Command-line scripts +โ”‚ โ”œโ”€โ”€ train_model.py # Model training script +โ”‚ โ”œโ”€โ”€ evaluate_model.py # Model evaluation script +โ”‚ โ””โ”€โ”€ generate_data.py # Data generation script +โ”œโ”€โ”€ config/ # Configuration files +โ”‚ โ”œโ”€โ”€ .env.example # Environment variables template +โ”‚ โ””โ”€โ”€ config.py # Application configuration +โ”œโ”€โ”€ data/ # Data directory +โ”œโ”€โ”€ models/ # Saved models directory +โ”œโ”€โ”€ requirements.txt # Python dependencies +โ”œโ”€โ”€ pyproject.toml # Project configuration +โ””โ”€โ”€ README.md # This file +``` + +## API Usage + +### Authentication + +First, obtain a JWT token: + +```bash +curl -X POST http://localhost:5000/api/auth/token \ + -H "Content-Type: application/json" \ + -d '{"username": "user"}' +``` + +### Available Endpoints + +- `GET /api/health` - Health check +- `POST /api/auth/token` - Get authentication token +- `GET /api/models/available` - List available models +- `POST /api/detect` - Detect attacks (requires auth) +- `POST /api/train` - Train new model (requires auth) +- `POST /api/models/{model_type}/evaluate` - Evaluate model (requires auth) + +### Detection Example + +```bash +curl -X POST http://localhost:5000/api/detect \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer YOUR_JWT_TOKEN" \ + -d '{ + "data": [[1.2, 0.5, -0.3, ...], [0.8, -1.1, 0.4, ...]], + "model_type": "random_forest" + }' +``` + +## Model Types + +The framework supports multiple machine learning algorithms: + +- **Random Forest** (`random_forest`): Ensemble method with excellent performance +- **Logistic Regression** (`logistic_regression`): Fast linear classifier +- **Support Vector Machine** (`svm`): Powerful non-linear classifier +- **Neural Network** (`neural_network`): Multi-layer perceptron + +## Security Features + +- **JWT Authentication**: Secure token-based authentication +- **Rate Limiting**: Protection against API abuse +- **Input Validation**: Comprehensive request validation using Marshmallow +- **Security Headers**: Standard security headers (HSTS, XSS Protection, etc.) +- **CORS Protection**: Configurable cross-origin resource sharing +- **Error Handling**: Secure error responses without information leakage + +## Development + +### Running Tests + +```bash +# Run all tests +pytest + +# Run with coverage +pytest --cov=src + +# Run specific test file +pytest tests/test_framework.py +``` + +### Code Quality + +```bash +# Linting +flake8 src/ tests/ scripts/ + +# Code formatting +black src/ tests/ scripts/ + +# Type checking (if using mypy) +mypy src/ +``` + +### Contributing + +1. Fork the repository +2. Create a feature branch (`git checkout -b feature/amazing-feature`) +3. Make your changes following PEP 8 guidelines +4. Add tests for new functionality +5. Run the test suite (`pytest`) +6. Commit your changes (`git commit -m 'Add amazing feature'`) +7. Push to the branch (`git push origin feature/amazing-feature`) +8. Open a Pull Request + +## Configuration + +### Environment Variables + +Copy `config/.env.example` to `.env` and customize: + +```bash +# Security +SECRET_KEY=your-secret-key-here +JWT_EXPIRATION_HOURS=1 + +# API Configuration +RATE_LIMIT_PER_HOUR=100 +RATE_LIMIT_AUTH_PER_MINUTE=20 + +# Model Configuration +DEFAULT_MODEL_TYPE=random_forest +MODEL_DIRECTORY=models/ + +# Data Configuration +DATA_DIRECTORY=data/ +MAX_FILE_SIZE_MB=16 +``` + +## Jupyter Notebooks + +Explore the framework with interactive notebooks: + +- `notebooks/model_training_demo.ipynb`: Complete training and evaluation workflow +- Examples of data preprocessing, model comparison, and hyperparameter tuning + +## Performance and Scalability + +- **Efficient preprocessing**: Optimized data pipeline with memory management +- **Model caching**: Intelligent model loading and caching in the API +- **Batch processing**: Support for batch predictions +- **Async operations**: Performance timers and logging for optimization + +## Troubleshooting + +### Common Issues + +1. **Import Errors**: Ensure you're in the project root directory +2. **Model Not Found**: Check if models are saved in the correct directory +3. **Authentication Errors**: Verify JWT token is valid and not expired +4. **Memory Issues**: Reduce batch size or dataset size for large datasets + +### Logging + +The framework includes comprehensive logging. Set log levels: + +```python +from src.utils.helpers import setup_logging +logger = setup_logging('DEBUG', 'logs/debug.log') +``` + +## License + +This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. -### ๐Ÿ” Additional Information +## Acknowledgments -- **Last Updated:** ![๐Ÿ•’ Last Commit](https://img.shields.io/github/last-commit/canstralian/CyberAttackDetection-Python) -- **Latest Release:** ![๐Ÿš€ Release](https://img.shields.io/github/v/release/canstralian/CyberAttackDetection-Python) -- **Open Issues:** ![๐Ÿž Issues](https://img.shields.io/github/issues/canstralian/CyberAttackDetection-Python) +- [scikit-learn](https://scikit-learn.org/) for machine learning algorithms +- [Flask](https://flask.palletsprojects.com/) for the REST API framework +- [Marshmallow](https://marshmallow.readthedocs.io/) for input validation +- [PyJWT](https://pyjwt.readthedocs.io/) for JWT authentication +- [pandas](https://pandas.pydata.org/) and [numpy](https://numpy.org/) for data processing --- -Thank you for using **CyberAttackDetection-Python**! If you encounter any issues, feel free to report them under the [Issues tab](https://github.com/canstralian/CyberAttackDetection-Python/issues). ๐Ÿ›ก๏ธ๐Ÿ’ป +Built with โค๏ธ for cybersecurity professionals and researchers. diff --git a/config/.env.example b/config/.env.example new file mode 100644 index 0000000..04838d5 --- /dev/null +++ b/config/.env.example @@ -0,0 +1,28 @@ +# Cybersecurity Detection Framework Configuration + +# API Configuration +SECRET_KEY=your-secret-key-here-change-in-production +FLASK_ENV=development +FLASK_DEBUG=False + +# Database Configuration (if needed in future) +DATABASE_URL=sqlite:///cyberattack_detection.db + +# Model Configuration +DEFAULT_MODEL_TYPE=random_forest +MODEL_CACHE_SIZE=5 +MODEL_DIRECTORY=models/ + +# Security Configuration +JWT_EXPIRATION_HOURS=1 +RATE_LIMIT_PER_HOUR=100 +RATE_LIMIT_AUTH_PER_MINUTE=20 + +# Data Configuration +MAX_FILE_SIZE_MB=16 +ALLOWED_EXTENSIONS=csv,json +DATA_DIRECTORY=data/ + +# Logging Configuration +LOG_LEVEL=INFO +LOG_FILE=logs/cyberattack_detection.log \ No newline at end of file diff --git a/config/config.py b/config/config.py new file mode 100644 index 0000000..43cd6a3 --- /dev/null +++ b/config/config.py @@ -0,0 +1,94 @@ +""" +Configuration module for cybersecurity detection framework. +""" + +import os +from datetime import timedelta +from typing import List + + +class Config: + """Base configuration class.""" + + # Application settings + SECRET_KEY = os.environ.get('SECRET_KEY', 'dev-secret-key-change-in-production') + DEBUG = False + TESTING = False + + # API settings + JWT_EXPIRATION_DELTA = timedelta(hours=int(os.environ.get('JWT_EXPIRATION_HOURS', 1))) + RATE_LIMIT_DEFAULT = f"{os.environ.get('RATE_LIMIT_PER_HOUR', 100)}/hour" + RATE_LIMIT_AUTH = f"{os.environ.get('RATE_LIMIT_AUTH_PER_MINUTE', 20)}/minute" + MAX_CONTENT_LENGTH = int(os.environ.get('MAX_FILE_SIZE_MB', 16)) * 1024 * 1024 + + # CORS settings + CORS_ORIGINS = ["http://localhost:3000", "http://127.0.0.1:3000"] + + # File paths + MODEL_DIRECTORY = os.environ.get('MODEL_DIRECTORY', 'models/') + DATA_DIRECTORY = os.environ.get('DATA_DIRECTORY', 'data/') + LOG_FILE = os.environ.get('LOG_FILE', 'logs/cyberattack_detection.log') + + # Model settings + DEFAULT_MODEL_TYPE = os.environ.get('DEFAULT_MODEL_TYPE', 'random_forest') + MODEL_CACHE_SIZE = int(os.environ.get('MODEL_CACHE_SIZE', 5)) + + # Security settings + ALLOWED_EXTENSIONS = {'csv', 'json'} + + @staticmethod + def init_app(app): + """Initialize application with config.""" + pass + + +class DevelopmentConfig(Config): + """Development configuration.""" + + DEBUG = True + LOG_LEVEL = 'DEBUG' + + +class ProductionConfig(Config): + """Production configuration.""" + + DEBUG = False + LOG_LEVEL = 'INFO' + + # Override with stronger security settings + CORS_ORIGINS = [] # Configure based on production needs + + @classmethod + def init_app(cls, app): + """Initialize production app.""" + Config.init_app(app) + + # Log to syslog in production + import logging + from logging.handlers import SysLogHandler + syslog_handler = SysLogHandler() + syslog_handler.setLevel(logging.INFO) + app.logger.addHandler(syslog_handler) + + +class TestingConfig(Config): + """Testing configuration.""" + + TESTING = True + DEBUG = True + LOG_LEVEL = 'DEBUG' + + # Use in-memory database for testing + DATABASE_URL = 'sqlite:///:memory:' + + # Disable rate limiting for tests + RATE_LIMIT_DEFAULT = "1000/hour" + RATE_LIMIT_AUTH = "100/minute" + + +config = { + 'development': DevelopmentConfig, + 'production': ProductionConfig, + 'testing': TestingConfig, + 'default': DevelopmentConfig +} \ No newline at end of file diff --git a/data/test_dataset.csv b/data/test_dataset.csv new file mode 100644 index 0000000..f152d30 --- /dev/null +++ b/data/test_dataset.csv @@ -0,0 +1,101 @@ +feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,protocol_type,service,label +2.446035030546515,0.7882314015369567,-0.9404137745679023,-1.753344273601822,4.459658518374963,0.8138639081236053,5.145029362588673,-3.0646371410063646,0.492395827849068,0.2346520051550863,udp,http,attack +0.3799473327299019,-1.0470749606832916,-0.6491913249036794,2.5139299432560893,0.09549737967482448,1.531688801240915,-0.3067000649220044,1.001620252253804,-3.518209729066801,0.48056210741074734,icmp,http,attack +-1.0555496955829409,-0.668407734966505,0.8865174125690543,2.183498866689609,-0.7586382955521542,-0.46545018240405345,-0.3150141603143387,-1.5709099149769432,1.058637956137902,-1.40515992236446,udp,ftp,normal +-0.12225321227664476,-0.4424361458824986,-0.8255848430174609,0.9765255038672649,-3.22202016113998,-0.1657230531466151,-3.0110611268069216,0.9715063957717444,0.8900744467756414,2.3097227554061868,icmp,ssh,normal +-4.77202522973986,-1.712936631956997,-0.1490366667295299,-2.7458937711934364,3.6397899913839704,-1.1254190077468889,4.487370834009874,-1.0248226493397838,-1.0148534542694319,-0.006653454631237347,udp,ssh,normal +0.26929568320199326,2.757147072496183,2.6435715833985505,2.6745462129469226,2.2477201422725925,0.18053084180543139,0.3013079301649735,-2.0242249704321327,1.4558042487379736,-0.20284621967015248,icmp,http,normal +-1.7953633053498477,-0.9250072314735335,0.9198275101476635,1.7364132128377225,0.9546298045425635,-2.4846115790858145,1.661766509527523,-2.611085510476737,0.9095088849535573,0.31386196347013423,udp,ssh,normal +-5.146072668679875,-1.254466630146458,0.1000954425796983,0.6540676669806567,0.8161392842014232,0.1838833519051575,1.378016994390978,-2.0642631423515754,1.5169877864124994,-0.0902276640562097,icmp,ssh,normal +-1.630250553113196,-1.2429061533305559,-0.683431790130123,0.23206966783590421,-1.0920520352893404,1.2231709505914572,0.2543395463002789,-1.710735539158689,2.097639787092401,-0.5098575774999748,udp,smtp,normal +0.6640987873968396,2.697771390144384,-0.7984390474170536,0.022002639290531256,3.013816950344355,-1.5240152417270303,0.6672305710103317,0.6335982054283361,-1.1112188575166537,1.3413119197001346,tcp,ftp,attack +0.07172598817857034,-0.39613986506053767,0.5553768168069081,0.010798098833008862,-2.456435065519328,-0.5989243965593357,-2.1915399911559517,1.1143150326998956,0.5153752700952515,-0.6228704964232837,udp,smtp,normal +1.697901059690928,0.9031953883843074,-0.1581014900701986,-1.4318901069909098,2.645567250220823,0.5102963895823233,2.8965677982715023,-1.7387840060785005,0.7270865270428809,0.8065932479136867,icmp,http,attack +0.22892445914495796,0.6111802795578946,-0.7825437052344801,0.41540916386044835,0.9804944580141247,0.14928656347247762,0.8824678818318511,-1.3490317294493421,0.9666208507629566,0.6742804687467753,icmp,ftp,normal +-3.016492683124596,-0.5423445057919424,-1.0705780796771673,1.9284219455927991,0.465268722867552,-1.2073958316463744,0.9732803277242783,-2.9899120410581363,2.350221062801086,0.6826232364571462,icmp,ftp,normal +-0.4072923931734753,1.3995360934478815,-0.3348522723761133,0.3279023189866148,1.6205394303764415,0.1798454187280591,0.537049522605652,-0.5017504325322538,0.36276957976353286,1.9750273536307812,udp,http,attack +-0.19775911520613354,0.9940785766178477,-0.6587688257813566,-0.6553249151429961,1.7809289942550521,-0.005336470836663279,1.301760361809799,-0.7654614079297859,0.579585218949456,-1.68955722607827,tcp,ssh,attack +0.3638778607085752,2.650451650124713,-1.4343649671588456,-0.5907268488630968,2.3046595715694655,0.1985583628962918,-0.05778424737304766,1.3952755690525924,-1.0169126930485652,0.5810220098659856,udp,ssh,attack +-1.3367139111761204,-0.6533021062421234,-1.5265518616977465,2.0394893198605235,-1.628816647392755,0.6329416337934651,-1.4633480478748686,-0.47038357773225187,0.73040777593909,-0.057904907337784896,icmp,http,normal +0.4416190493637253,2.985717429947782,-0.7726671422734709,1.0601345150608228,2.554180763947457,0.4853159965604532,-0.004810554439582226,0.18174895050031817,-0.37728345459139057,-0.08922117559367017,udp,smtp,normal +-2.0752506407322393,-1.8426686102966199,0.30555432037725844,1.0329819692821174,1.0304097778927135,0.2400227376850462,1.7148835763051091,-0.8158239228964083,-1.626097269291621,0.932996642650568,tcp,http,attack +-1.7027449898526734,-1.1167802269900076,-0.008278074550049355,1.132969463120281,0.3665073089290823,-0.840908798047855,0.5239767767051813,-0.20748255645642744,-1.2171222495613376,0.6947610308465396,udp,smtp,attack +-0.17682526390298658,0.956405653521065,-0.4691477201949346,-1.8007307898595941,1.8154948602006793,0.7175915797930499,1.4129361373667952,-0.14169786594749217,0.35952756615422843,0.5269374126317904,tcp,ftp,attack +-1.2679288600201015,-0.5112368561936267,2.3573763650327084,1.1292617350943335,-0.08185951119126814,-0.7036200629264463,0.03665000227566895,-0.6676198722496249,0.145609565006678,0.937838301566644,tcp,smtp,attack +-1.9469158708597172,2.6704141282927027,1.7821636185360787,1.4287254057619792,2.5712086118707815,0.046915434950258676,-0.24462478211663896,0.26092761047701246,-0.3254179780829525,0.5902746190617881,udp,ftp,attack +0.3573231983720332,0.07391077684005465,1.5704946352611389,0.36924385891765965,1.111023973327628,1.620533445165877,1.864312657153616,-2.6829188645533653,1.8722816018416164,-1.0037578915633707,udp,ftp,attack +1.6100636611072607,1.8522832296223148,0.6075668643313585,0.11368053210095619,3.3380502652759176,0.17903679654896512,2.3962196151648927,-1.5499231155120357,0.02949703512001356,-1.519179982986177,icmp,ftp,attack +-2.3977601601006313,-2.2302914641555054,-0.4046658233070462,1.606522507787822,0.5848905283864407,-0.1863723943630917,1.2644101815732531,-0.49583675409108413,-2.1482731450340236,-0.3629572789687074,tcp,http,attack +-2.2214616085893293,-0.31962478140764095,0.8166985177135699,1.002197593120677,0.6859597434024389,0.11960566888706947,1.2146822023887545,-2.5977835754008898,2.157101633473154,0.6186619597099322,udp,ssh,normal +-3.576345430362077,0.3631359987222096,0.1550366235955187,2.110328991540375,1.2649714779934134,-0.11484401948302342,0.9142169468808572,-2.873991442043078,2.394646743601487,2.569926839690334,udp,ssh,normal +-0.2392092546708744,0.5302166714243535,0.13498472081968857,-1.622061914505117,-1.3411157938739509,-0.26102725324679543,-1.5004785224888444,1.248841731933082,1.1043246358187844,0.4125703966441225,tcp,ssh,normal +-2.696161131959916,-0.9869917054702277,-1.7628669288710401,0.467934873002812,-0.2788245066990082,0.9294825073447697,0.7454126612204779,-2.1821271184979505,2.313604314026599,1.1441352026164182,tcp,http,normal +0.6424238290700345,0.18037459428068814,-0.3357150728865708,-1.4152195241460073,-2.223613658710538,-0.03095825063657335,-1.5041732403851493,0.1231828712029506,2.517947740144099,1.3460442074647172,udp,ssh,attack +1.0151910760459748,0.5509528770934655,-1.2884473797573248,-0.3423683406672854,1.2578393600269882,-0.4181430340092719,1.714018716567562,-2.024991212985968,1.603892857665288,0.10081742225405217,icmp,ssh,attack +-0.3409046569641945,-0.13061123951079234,-0.3521729835995906,-0.8392240679631999,0.4444968618063666,0.9871222001287102,1.2323912023016284,-1.5627347257463624,1.7729690012413162,0.24252042561251944,udp,ftp,attack +-1.5748771062103537,-0.3830255974089919,0.8841008365591795,2.2554363636980943,0.3062383294123369,-1.2573403004288446,-0.7577973003881842,1.095866869633655,-2.5664090274445384,-0.950683791405327,tcp,smtp,attack +-0.7537706480526911,0.35590327107885245,0.28984226416846787,1.1446840180439886,0.8620696681284957,0.20370179813884956,0.9615258348338983,-2.2244374209368685,1.737154003464802,-0.2548212683723646,tcp,smtp,normal +-1.6295043185413636,2.0421002666205217,-0.7044701034595299,0.9909496976562769,2.4481910346347138,0.029983634527157816,0.33865256046406256,-0.119392293757285,-0.22438853605744913,-0.7283151499478859,icmp,ssh,normal +1.327905321494339,1.3477271410986802,-1.7531070469382153,-1.5546628597757652,1.3617518070212722,1.3157665082821597,0.8950802155772717,-0.02157594263462217,0.46581055733046617,0.7563515692873813,icmp,ftp,attack +-2.1612483380377747,-0.15201556759398915,0.5277647609316608,0.4825762935571829,1.6477974230981176,-0.14987796676246168,1.9099843255708093,-2.2635289232318367,1.3169918989872795,0.8332238776207417,udp,smtp,normal +1.6044233901686964,1.3402212326170893,-1.2415012309491225,0.7286702656444164,3.96780654295332,1.7646869342627187,3.572888282337743,-2.8760903010866077,0.21814062047014404,-0.8163054090548456,udp,ftp,attack +0.6511891103416863,-2.6164520440258285,-0.8335589872792668,1.421874977220551,-1.5463133742445632,-0.3703233685972966,1.169855828228326,-2.9267946715072357,1.4510753611553957,-0.6895403210803661,udp,http,normal +-0.47075245321648396,-0.7220816651208304,-1.6393962676640563,2.280012931692392,0.18536733052558488,-1.1318161421326958,-0.5716813220337686,1.2401999955046759,-3.2353996946438706,0.4625485526475405,udp,ssh,attack +0.1905116995332402,-1.6700002004681949,-0.5449222349589087,-2.871321200005598,-2.0752273481654973,-1.7534694162506628,0.8434294534656794,-1.8922274915245405,3.831323280781395,0.9718226901727355,icmp,ssh,attack +-1.769375076830352,0.6542705498982615,0.1582559809533675,0.9189883296161538,1.5358415002194352,0.0006991175571747527,0.9301541700122702,-1.3942099420461684,0.842837374334668,-0.6775662127337071,tcp,ftp,normal +-2.359982058668808,-1.4498059857529766,0.7133995902054457,1.1141811397767498,-0.3617594636071895,0.5855034608585792,1.0975861278335826,-2.8672820264783194,2.3410234551603226,0.04701721312374508,tcp,smtp,normal +1.4140519217894043,1.7239045349938196,0.6722572844137722,-1.2692669111081054,2.4523583752554132,2.155239538433073,1.605053461769832,-0.30856437626844674,0.009295242164500284,-0.6228545031012834,icmp,ftp,attack +-2.4169534276866127,1.767057795737878,-1.5027650176306036,2.0169237123679946,2.23216581467413,-0.8184742170423912,0.5273317573874531,-1.5790662860441156,1.084540871994762,-0.26073516489181897,tcp,http,attack +2.6872007364276405,1.6042973007373802,1.6463972426799733,-0.618619037785407,2.5357196874772225,0.25275550753788667,2.170267499597884,-1.4035832067887524,0.4230511308484325,-1.007089364306261,tcp,http,attack +-1.954634888106384,0.17199756070984518,-0.3765526510222658,0.7151947167691501,1.5413613539176285,0.34361858953521734,0.819932606011428,-0.2950641738549086,-0.7085389996598499,-0.7052107363297812,udp,ftp,attack +-1.0720352553851649,-2.478007948356267,0.003414751178410458,1.5894177001506598,-0.23445693459721503,-0.22571364756933313,1.2055536776204343,-1.2076118405476577,-1.1762068663934326,0.8177526958760498,udp,http,attack +0.5579200883668298,-1.249062046897537,0.3431350510168737,2.2919250286196764,-0.9457604059068692,1.2056280779528372,-0.26229478380375704,-0.7443030472470422,-0.7259418982501984,1.3238748037872132,udp,ssh,attack +1.62797121603795,0.8741293036726576,0.5518636553095008,-2.1083013872840746,2.7468986492122736,-1.6295929859763272,3.0515672213097695,-1.4473805067615466,0.6287420356552158,2.0055740456399533,tcp,http,attack +-1.726067494935437,-0.09032813280396212,0.16469782743714056,1.3303653461168208,1.3171315543292739,-0.8401100785833383,0.007707267523933248,1.3654119828136868,-3.0072890933786542,-0.36253502509936236,udp,smtp,attack +-0.6326839257833132,-1.2546416979638844,-0.3003604880886286,2.1845067557704017,-1.5113535061026575,0.7936579419803536,-0.9675455093881183,-0.4365713341976043,-0.21702475634546814,-0.8047869166603249,tcp,ssh,normal +-1.0882281666149176,-0.5269142689782376,0.13384765424599987,1.634605036335191,-1.9548502457293122,0.8184847346385397,-1.6208344426532728,-0.5800520171191808,1.393148030098717,0.4466397315170858,udp,http,normal +-1.9302628659954126,0.046902942511178436,0.10473529863813051,1.4994824417565407,1.2151568859935904,-0.8505564586621955,-0.008670195140871417,0.8063341144376911,-2.183038342399834,-1.7485469849575264,tcp,ftp,attack +-0.909730693313743,1.727073959715482,1.388582269040288,0.5529967234643703,3.10910462755486,-1.1071044546000215,1.7722348472877427,-1.382855272485153,0.37191595642433717,-2.6593513570493728,udp,http,normal +1.4029277767189676,1.815999036336924,0.5739669592016755,0.06765921579476175,4.473752678851211,-0.8837777888642685,3.5079704337119137,-2.0123476020299935,-0.40362776916181997,-1.753573717908687,tcp,http,attack +-0.5217374110728571,0.19277884709062265,0.532595109370551,-1.0086641714363016,-1.963141015653855,0.49973762011211026,-2.2642157450374203,1.9103408170837244,0.3642165051221631,-1.0877809352617618,tcp,ftp,normal +0.5935983800957327,-1.1297769411930276,0.26658054669493636,2.3329922526782783,-1.8711718569117406,0.2136304244285101,-0.987897336141532,-1.008324336630194,0.4736055889429407,-0.40137975733222647,udp,ftp,normal +0.9818331086652545,0.4987706952536981,0.28303220539521634,0.7092535795628199,-3.902086901519922,0.8677650241708884,-4.735288259320994,3.0644871195918504,-0.19632557474121648,-1.3494635714255234,icmp,http,normal +0.0957756895429831,-0.28423078773459154,-0.22624521900867614,-1.0958932842692732,-2.254580641364886,-1.3193478891374055,-1.8524676124926136,1.2523946598275544,0.8185188709490152,1.0645686780266999,tcp,smtp,normal +-1.6401865385512742,-0.5547050551123112,-1.7585989999941474,0.3305102399815646,-1.2953897795325535,-1.681820325133229,-0.7962474919184902,-0.5774765235759334,1.5665186169936476,-0.8726306273710107,udp,http,normal +0.969425310399492,0.61660731166243,-0.937355682223321,-1.5032206925461247,2.6156922611117466,-0.6572333181174939,3.027309620187378,-1.891858298480262,0.8925964963699855,-1.6697791720684,tcp,http,attack +0.7297178118521155,-2.1305720148687897,0.4182525547860004,4.026606603359985,-2.4274795514543133,0.49495551756152545,-1.8924579017856793,0.5667185336482521,-2.613505771819312,0.4004120292384369,udp,ssh,attack +-3.1437363345605998,-1.8211789568875552,-1.050911957828187,2.6333994679086734,-0.815158778021211,-1.006115893083496,0.6542544518282689,-3.3962339263710795,2.46631618070662,-1.9336783011075602,icmp,http,normal +-1.423713991373062,-1.092620679176252,1.0622098509137117,2.2232605140088193,-1.3687950859679032,1.286268613404601,-0.6878487533055844,-1.3883400569949393,1.0796836878225025,-0.4537443085364418,icmp,smtp,attack +0.8058014889630276,1.2556112136919808,-1.2406400776875708,1.2571025860047245,1.7390142162880697,-2.078376654918943,1.3863511448691264,-2.4161354000767212,1.5375796420454433,-1.1775283652749364,icmp,ssh,normal +2.052837106369524,-1.1312364389409817,-0.30455363848725747,-1.1339225277853846,2.656914427711499,2.0813146567651417,5.5787090566957716,-5.518044239494748,2.934273425753984,0.09532080694771912,tcp,http,attack +-0.5977292373603689,0.46208265603957865,-0.1303451459872082,0.26103232189638004,2.224458291436808,0.09874577763733462,2.1720524817711846,-2.001795777444905,0.7676283510968686,0.5004775194206322,udp,ftp,normal +0.9213058549839399,3.200886364562553,-1.8449052402960937,1.984403426083632,2.572592588583328,1.0804598690053497,0.10678328600651998,-0.8523821594458613,0.403789154806133,-1.0171921705453906,icmp,ssh,normal +0.3652925703925687,-0.1161681904726235,0.04795405028310565,1.4609856198944633,1.305868492351851,0.5380223547968438,0.6866161361951972,0.26647943762141374,-2.378873382244377,1.2089819441064658,udp,smtp,normal +-2.350830349917292,0.7245670402259833,0.8044391101108651,1.0629499627586636,2.521678985186988,-1.3191450452019449,1.9517091106115758,-2.3512590205435058,1.2091243706565278,-0.15721389336445102,icmp,smtp,normal +-0.9300147890516832,0.16293565441355928,-0.7991105708252652,-0.5769556920054904,2.1725896614956275,0.5166693086481229,2.2324214840037464,-1.3547304092841885,0.3003214533666635,-0.1645356580496962,udp,ftp,normal +-2.3179601683458486,-1.3646526363939109,-1.0424848742271386,2.127961160094677,0.07981729983851894,0.501033343517345,1.1853306017442442,-2.9762308503683825,1.643607792770292,1.2779268292381651,icmp,ssh,normal +-0.49491796655843456,-1.426078041219425,0.02464209842858423,0.6080218066849914,0.8234001931138106,-1.4893930596493061,1.8213738871172518,-1.3336955639548427,-0.6770042594670485,0.08995614812678975,tcp,ssh,normal +-1.416631072766702,2.3445394002102145,0.44211666531347293,2.567283004979708,1.70091392503033,0.2168610369829383,-0.38032749240209784,-1.2846106186131305,1.2131896378425417,0.2502439325736485,tcp,ssh,normal +0.222832637962322,1.308911675453657,-0.35068152787878737,-1.7758454086656639,-0.0793847853937693,-1.0532587206449961,-0.8109239734267952,1.1761312178062275,0.7364667576530325,-0.229216051553592,tcp,smtp,attack +-3.2760118871852897,-0.6616802380035336,-0.3618909712330037,0.4287762615822318,2.016310360018961,-1.326539016109759,2.6178974486675166,-2.8617775242482257,1.5272074827072672,-0.06958982291301025,icmp,ftp,normal +-0.7103981945742239,-1.7494692131393714,1.1290694375351427,2.600148845457415,-1.8971208787209952,-0.21327237596579143,-1.1816828490739442,-0.2427791542610506,-0.7325370862102922,-0.08041577370899264,tcp,ssh,normal +-1.1916395079505993,-1.1820491916224065,0.17554582989050802,1.063039799747586,0.6532497008235292,-1.4414942480636153,1.1244067249876095,-0.8239411398468927,-0.9263399040825036,0.40851234837742617,tcp,smtp,attack +-1.1319761786968676,0.534593552991031,-1.3414451577996975,0.7507169387660726,1.7355042508624812,-1.139754428156706,1.535175984918303,-2.0111392369612413,1.176698872686163,1.8868281121242367,udp,ftp,normal +-1.8680461695758477,0.4481048907845493,0.5434991027893419,0.4162501983556405,1.7141541812971937,0.5838003761562907,0.8208031413681565,-0.13013474692100965,-0.6734668751928957,1.083609653905471,tcp,smtp,attack +-2.6796291276944215,-0.3623681269380986,0.4056175931403788,-0.3593773757785892,3.7176226265143115,0.42168492391225987,3.0068919230565627,-0.06460022628393976,-2.667208019970844,0.7150070138011407,udp,ssh,attack +0.6158451276127594,0.06404234784218188,-0.16811489562529364,-1.0983089575245746,-2.2891867869422167,-0.49596904632647065,-2.1716981909594395,1.637506529472067,0.5923520604611137,-0.3733992683765694,icmp,smtp,normal +-1.2478045608210466,0.13503766592817334,-0.13340085789154404,1.1159111274293387,1.1386039050407832,-0.5880868263145528,1.486722833069141,-2.7971888537919525,2.0883148481029616,1.070996467205255,icmp,smtp,normal +-1.8565691201585455,-0.47240735030531134,0.8574459317174058,1.0857667478699795,0.5848610616558808,0.5024166321870274,0.2956189899244541,-0.1752840897200458,-0.7951520774228367,1.0085208247785817,udp,ssh,attack +-1.2372821489098667,0.19455697555404772,-3.480808828760084,0.20080062169970936,2.6668183961570073,-1.0608267958340962,1.8629186839765197,-0.07027914082409566,-1.8757512618516645,-0.8702232421435117,udp,http,attack +-0.8192588597304246,-0.7675322356502449,-0.6505848415568568,1.34290884281662,-3.0480600617210736,-1.9296845963141573,-2.581526738496681,0.2548566457717975,1.2602430093252197,1.0435459092865536,icmp,ssh,normal +1.1326499144745823,0.22380861557528042,-1.5579523176766679,-1.3841910063450813,-0.5956891557889185,-0.04726465262900927,0.3252953766263734,-1.086550964222601,2.35456692713069,-0.40038136018331794,tcp,ssh,attack +-0.6388607470425571,-2.872666504937722,1.167102914383636,1.407429474465963,-3.8428274597311414,-0.6638969135466657,-1.8672181803256886,0.015223092167168804,0.23570145421045297,1.6211085577877153,icmp,http,attack +-0.8267468119211125,1.9915747151804841,0.8522464890302681,1.2320905196226029,2.126677732113498,-0.5828006014129976,0.3279783034346414,-0.5737288122647728,0.2036974450156268,-2.3207373739781305,udp,http,normal +-1.2149461945440996,-1.6657306613963563,0.795906504847474,0.8113493622089705,0.7147073879975814,-1.1735702070121252,1.5915752810061734,-0.9871629192971766,-1.1063391733175116,1.0485309385076005,icmp,ftp,attack +2.093237385908343,0.7708747109638758,-2.3697214094844474,-0.4121772940686629,3.0799858070124335,0.38361395381677077,3.6395966998627385,-3.082364257464201,1.151303255527203,-0.34801308629683886,tcp,ftp,attack +1.9861558696376607,0.3817892780422074,1.246694736529587,0.025696938887743714,-0.4453361913707514,-1.4130243453417775,0.45373085005460145,-2.0064613965034375,2.499803771255707,-2.1698776870408585,udp,ssh,attack +-0.8939061135230677,1.4718658460843468,-1.350248411942825,-0.7394375142919738,3.3349100258599185,0.7533042690776492,1.9600662396586999,-0.1315039884303197,-0.8588768551760452,1.4951337455174116,udp,smtp,attack +-1.9358271425633762,-0.11851970101590759,0.07396296407356806,1.1290258843028023,-1.0684077011423678,-1.569960337657343,-0.7681195141645876,-1.4787223664237616,2.478217386162483,-0.5175336509082241,udp,http,normal +-0.8051240043890072,-0.171788457190987,-1.809958900980706,0.9891932272994489,1.0178072621596284,0.174639269367828,0.5836683720581672,-0.05462956023272619,-1.2696750671265455,0.9650149610632297,udp,ssh,attack +-3.8612084683381975,0.596692236561742,0.6281400134806694,1.030284118231267,2.385445469482641,-0.2298795714463868,1.5309629243350167,-1.8510559939680775,1.0539092418548264,1.2701255465757155,tcp,http,attack +2.049214838152612,1.4278396280383134,2.5386911017374865,0.8199333655388336,1.0404176533081861,0.7810732676164527,0.6287527530967024,-1.487170629966769,1.1411119826448135,0.5239526671077938,icmp,ftp,normal diff --git a/models/random_forest_model.pkl b/models/random_forest_model.pkl new file mode 100644 index 0000000..7c446be Binary files /dev/null and b/models/random_forest_model.pkl differ diff --git a/models/random_forest_model_metadata.json b/models/random_forest_model_metadata.json new file mode 100644 index 0000000..638b321 --- /dev/null +++ b/models/random_forest_model_metadata.json @@ -0,0 +1,117 @@ +{ + "model_type": "random_forest", + "data_file": "data/test_dataset.csv", + "data_shape": [ + 100, + 13 + ], + "test_size": 0.2, + "hyperparameter_tuning": false, + "validation_results": { + "is_valid": true, + "errors": [], + "warnings": [], + "shape": [ + 100, + 13 + ], + "columns": [ + "feature_00", + "feature_01", + "feature_02", + "feature_03", + "feature_04", + "feature_05", + "feature_06", + "feature_07", + "feature_08", + "feature_09", + "protocol_type", + "service", + "label" + ], + "dtypes": { + "feature_00": "float64", + "feature_01": "float64", + "feature_02": "float64", + "feature_03": "float64", + "feature_04": "float64", + "feature_05": "float64", + "feature_06": "float64", + "feature_07": "float64", + "feature_08": "float64", + "feature_09": "float64", + "protocol_type": "object", + "service": "object", + "label": "object" + }, + "missing_values": { + "feature_00": 0, + "feature_01": 0, + "feature_02": 0, + "feature_03": 0, + "feature_04": 0, + "feature_05": 0, + "feature_06": 0, + "feature_07": 0, + "feature_08": 0, + "feature_09": 0, + "protocol_type": 0, + "service": 0, + "label": 0 + }, + "memory_usage": "24107" + }, + "evaluation_results": { + "accuracy": 0.75, + "precision": 0.7525252525252526, + "recall": 0.75, + "f1_score": 0.7493734335839599, + "confusion_matrix": [ + [ + 8, + 2 + ], + [ + 3, + 7 + ] + ], + "classification_report": " precision recall f1-score support\n\n attack 0.73 0.80 0.76 10\n normal 0.78 0.70 0.74 10\n\n accuracy 0.75 20\n macro avg 0.75 0.75 0.75 20\nweighted avg 0.75 0.75 0.75 20\n", + "auc_score": 0.8, + "feature_importance": { + "feature_08": 0.18016042076879749, + "feature_03": 0.15341555378567834, + "feature_00": 0.1389365363834253, + "feature_07": 0.096349138713875, + "feature_04": 0.07728955832067869, + "feature_01": 0.0756043872971028, + "feature_05": 0.07052390586139434, + "feature_09": 0.06454198649319527, + "feature_06": 0.060302368574283935, + "feature_02": 0.04364081448870484, + "protocol_type_tcp": 0.0102319660889518, + "service_ssh": 0.010011947871145782, + "protocol_type_udp": 0.008241555991509638, + "service_http": 0.0064549694223867615, + "service_smtp": 0.004294889938870055 + } + }, + "model_path": "models/random_forest_model.pkl", + "training_results": { + "model_type": "random_forest", + "training_time": 0.09336, + "training_samples": 80, + "features_count": 15, + "cv_scores": [ + 0.75, + 0.75, + 0.8125, + 0.75, + 0.75 + ], + "cv_mean": 0.7625, + "cv_std": 0.024999999999999998, + "timestamp": "2025-09-26T13:50:01.470802" + } +} \ No newline at end of file diff --git a/notebooks/model_training_demo.ipynb b/notebooks/model_training_demo.ipynb new file mode 100644 index 0000000..50ae7cd --- /dev/null +++ b/notebooks/model_training_demo.ipynb @@ -0,0 +1,466 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cybersecurity Attack Detection - Model Training\n", + "\n", + "This notebook demonstrates how to train and evaluate machine learning models for cybersecurity attack detection using our framework." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup and Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import confusion_matrix, classification_report\n", + "\n", + "# Import our framework modules\n", + "import sys\n", + "import os\n", + "sys.path.append(os.path.join(os.getcwd(), '..'))\n", + "\n", + "from src.core.preprocessing import DataPreprocessor\n", + "from src.models.detector import CyberAttackDetector, ModelComparer, ModelRegistry\n", + "from src.utils.helpers import create_sample_dataset, calculate_dataset_statistics, PerformanceTimer\n", + "\n", + "# Configure plotting\n", + "plt.style.use('seaborn-v0_8')\n", + "plt.rcParams['figure.figsize'] = (12, 8)\n", + "\n", + "print(\"Framework modules imported successfully!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Loading and Exploration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create sample dataset for demonstration\n", + "print(\"Creating sample cybersecurity dataset...\")\n", + "dataset = create_sample_dataset(n_samples=2000, n_features=15, n_classes=2)\n", + "\n", + "print(f\"Dataset shape: {dataset.shape}\")\n", + "print(f\"\\nColumns: {list(dataset.columns)}\")\n", + "print(f\"\\nLabel distribution:\")\n", + "print(dataset['label'].value_counts())\n", + "\n", + "dataset.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Calculate comprehensive dataset statistics\n", + "stats = calculate_dataset_statistics(dataset)\n", + "\n", + "print(\"Dataset Statistics:\")\n", + "print(f\"Shape: {stats['shape']}\")\n", + "print(f\"Memory usage: {stats['memory_usage_mb']:.2f} MB\")\n", + "print(f\"Missing values: {stats['missing_values']}\")\n", + "print(f\"Duplicate rows: {stats['duplicate_rows']}\")\n", + "print(f\"\\nData types distribution:\")\n", + "for dtype, count in stats['dtypes'].items():\n", + " print(f\" {dtype}: {count} columns\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize label distribution\n", + "plt.figure(figsize=(10, 6))\n", + "\n", + "plt.subplot(1, 2, 1)\n", + "dataset['label'].value_counts().plot(kind='bar')\n", + "plt.title('Label Distribution')\n", + "plt.xlabel('Class')\n", + "plt.ylabel('Count')\n", + "\n", + "plt.subplot(1, 2, 2)\n", + "dataset['protocol_type'].value_counts().plot(kind='pie', autopct='%1.1f%%')\n", + "plt.title('Protocol Type Distribution')\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Preprocessing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize preprocessor and run full pipeline\n", + "preprocessor = DataPreprocessor()\n", + "\n", + "with PerformanceTimer(\"Data preprocessing\"):\n", + " X, y, validation_results = preprocessor.full_preprocessing_pipeline(dataset)\n", + "\n", + "print(\"Preprocessing completed!\")\n", + "print(f\"Features shape: {X.shape}\")\n", + "print(f\"Labels shape: {y.shape}\")\n", + "print(f\"\\nValidation results: {validation_results['is_valid']}\")\n", + "\n", + "if validation_results['warnings']:\n", + " print(f\"Warnings: {validation_results['warnings']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Split data into training and testing sets\n", + "X_train, X_test, y_train, y_test = train_test_split(\n", + " X, y, test_size=0.2, random_state=42, stratify=y\n", + ")\n", + "\n", + "# Scale features\n", + "X_train_scaled, X_test_scaled = preprocessor.scale_features(X_train, X_test)\n", + "\n", + "print(f\"Training set shape: {X_train_scaled.shape}\")\n", + "print(f\"Test set shape: {X_test_scaled.shape}\")\n", + "print(f\"Training labels distribution:\")\n", + "print(pd.Series(y_train).value_counts())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Training and Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show available models\n", + "available_models = ModelRegistry.get_available_models()\n", + "print(\"Available models:\")\n", + "for model_name in available_models.keys():\n", + " print(f\" - {model_name}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Train a Random Forest model\n", + "print(\"Training Random Forest model...\")\n", + "rf_detector = CyberAttackDetector('random_forest')\n", + "\n", + "with PerformanceTimer(\"Random Forest training\"):\n", + " training_results = rf_detector.train(\n", + " X_train_scaled, y_train, feature_names=X.columns.tolist()\n", + " )\n", + "\n", + "print(f\"Training completed!\")\n", + "print(f\"CV Accuracy: {training_results['cv_mean']:.4f} (+/- {training_results['cv_std']*2:.4f})\")\n", + "print(f\"Training time: {training_results['training_time']:.2f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Evaluate the trained model\n", + "evaluation_results = rf_detector.evaluate(X_test_scaled, y_test)\n", + "\n", + "print(\"Evaluation Results:\")\n", + "print(f\"Accuracy: {evaluation_results['accuracy']:.4f}\")\n", + "print(f\"Precision: {evaluation_results['precision']:.4f}\")\n", + "print(f\"Recall: {evaluation_results['recall']:.4f}\")\n", + "print(f\"F1 Score: {evaluation_results['f1_score']:.4f}\")\n", + "\n", + "if 'auc_score' in evaluation_results:\n", + " print(f\"AUC Score: {evaluation_results['auc_score']:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize confusion matrix\n", + "plt.figure(figsize=(8, 6))\n", + "cm = np.array(evaluation_results['confusion_matrix'])\n", + "sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')\n", + "plt.title('Confusion Matrix - Random Forest')\n", + "plt.xlabel('Predicted')\n", + "plt.ylabel('Actual')\n", + "plt.show()\n", + "\n", + "print(\"\\nClassification Report:\")\n", + "print(evaluation_results['classification_report'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize feature importance\n", + "if 'feature_importance' in evaluation_results:\n", + " feature_importance = pd.Series(evaluation_results['feature_importance'])\n", + " \n", + " plt.figure(figsize=(12, 8))\n", + " feature_importance.head(15).plot(kind='barh')\n", + " plt.title('Top 15 Feature Importances - Random Forest')\n", + " plt.xlabel('Importance')\n", + " plt.tight_layout()\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Comparison" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compare multiple models\n", + "print(\"Comparing multiple models...\")\n", + "comparer = ModelComparer()\n", + "\n", + "# Add models to compare\n", + "comparer.add_model('Random Forest', 'random_forest')\n", + "comparer.add_model('Logistic Regression', 'logistic_regression')\n", + "comparer.add_model('Neural Network', 'neural_network')\n", + "\n", + "# Run comparison\n", + "with PerformanceTimer(\"Model comparison\"):\n", + " comparison_results = comparer.compare_models(\n", + " X_train_scaled, y_train, X_test_scaled, y_test, X.columns.tolist()\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Display comparison summary\n", + "summary = comparison_results['summary']\n", + "\n", + "print(\"Model Comparison Summary:\")\n", + "print(f\"Best Accuracy: {summary['best_accuracy']['model']} ({summary['best_accuracy']['score']:.4f})\")\n", + "print(f\"Best Precision: {summary['best_precision']['model']} ({summary['best_precision']['score']:.4f})\")\n", + "print(f\"Best Recall: {summary['best_recall']['model']} ({summary['best_recall']['score']:.4f})\")\n", + "print(f\"Best F1 Score: {summary['best_f1']['model']} ({summary['best_f1']['score']:.4f})\")\n", + "print(f\"Fastest Training: {summary['fastest_training']['model']} ({summary['fastest_training']['time']:.2f}s)\")\n", + "\n", + "print(\"\\nAccuracy Ranking:\")\n", + "for i, (model, accuracy) in enumerate(summary['accuracy_ranking'], 1):\n", + " print(f\"{i}. {model}: {accuracy:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize model comparison\n", + "detailed_results = comparison_results['detailed_results']\n", + "\n", + "# Extract metrics for visualization\n", + "models = list(detailed_results.keys())\n", + "metrics = ['accuracy', 'precision', 'recall', 'f1_score']\n", + "\n", + "comparison_data = pd.DataFrame(index=models, columns=metrics)\n", + "for model in models:\n", + " eval_results = detailed_results[model]['evaluation']\n", + " for metric in metrics:\n", + " comparison_data.loc[model, metric] = eval_results[metric]\n", + "\n", + "comparison_data = comparison_data.astype(float)\n", + "\n", + "# Plot comparison\n", + "plt.figure(figsize=(12, 8))\n", + "comparison_data.plot(kind='bar', rot=45)\n", + "plt.title('Model Performance Comparison')\n", + "plt.xlabel('Models')\n", + "plt.ylabel('Score')\n", + "plt.legend(title='Metrics')\n", + "plt.ylim(0, 1)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hyperparameter Tuning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Demonstrate hyperparameter tuning\n", + "print(\"Performing hyperparameter tuning for Random Forest...\")\n", + "\n", + "tuned_detector = CyberAttackDetector('random_forest')\n", + "\n", + "with PerformanceTimer(\"Hyperparameter tuning\"):\n", + " tuning_results = tuned_detector.hyperparameter_tuning(\n", + " X_train_scaled, y_train, cv_folds=3 # Reduced for demo\n", + " )\n", + "\n", + "print(f\"Tuning completed!\")\n", + "print(f\"Best parameters: {tuning_results['best_params']}\")\n", + "print(f\"Best CV score: {tuning_results['best_score']:.4f}\")\n", + "print(f\"Combinations tested: {tuning_results['n_combinations']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Evaluate tuned model\n", + "tuned_evaluation = tuned_detector.evaluate(X_test_scaled, y_test)\n", + "\n", + "print(\"Tuned Model Evaluation:\")\n", + "print(f\"Accuracy: {tuned_evaluation['accuracy']:.4f}\")\n", + "print(f\"Precision: {tuned_evaluation['precision']:.4f}\")\n", + "print(f\"Recall: {tuned_evaluation['recall']:.4f}\")\n", + "print(f\"F1 Score: {tuned_evaluation['f1_score']:.4f}\")\n", + "\n", + "# Compare with default model\n", + "print(f\"\\nImprovement over default model:\")\n", + "print(f\"Accuracy: {tuned_evaluation['accuracy'] - evaluation_results['accuracy']:.4f}\")\n", + "print(f\"F1 Score: {tuned_evaluation['f1_score'] - evaluation_results['f1_score']:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Saving and Loading" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save the best model\n", + "import os\n", + "os.makedirs('../models', exist_ok=True)\n", + "\n", + "model_path = '../models/best_rf_model.pkl'\n", + "tuned_detector.save_model(model_path)\n", + "print(f\"Model saved to {model_path}\")\n", + "\n", + "# Demonstrate loading\n", + "loaded_detector = CyberAttackDetector('random_forest')\n", + "loaded_detector.load_model(model_path)\n", + "print(\"Model loaded successfully!\")\n", + "\n", + "# Verify loaded model works\n", + "loaded_predictions = loaded_detector.predict(X_test_scaled)\n", + "original_predictions = tuned_detector.predict(X_test_scaled)\n", + "\n", + "print(f\"Predictions match: {np.array_equal(loaded_predictions, original_predictions)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "This notebook demonstrated:\n", + "\n", + "1. **Data Loading and Exploration**: How to load and explore cybersecurity datasets\n", + "2. **Data Preprocessing**: Complete preprocessing pipeline with validation\n", + "3. **Model Training**: Training different types of models\n", + "4. **Model Evaluation**: Comprehensive evaluation with multiple metrics\n", + "5. **Model Comparison**: Comparing multiple models side-by-side\n", + "6. **Hyperparameter Tuning**: Optimizing model performance\n", + "7. **Model Persistence**: Saving and loading trained models\n", + "\n", + "The framework provides a robust, modular approach to cybersecurity attack detection with built-in best practices for data preprocessing, model management, and evaluation." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ad1c96f..d48b307 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,29 @@ +# Core ML dependencies pandas==2.2.3 numpy==2.2.1 -scikit-learn==1.6.0 \ No newline at end of file +scikit-learn==1.6.0 + +# Web framework and API +flask==3.0.0 +flask-cors==4.0.0 +flask-limiter==3.5.0 + +# Input validation and serialization +marshmallow==3.20.1 + +# Authentication +PyJWT==2.8.0 + +# Development and testing +pytest==7.4.3 +pytest-cov==4.1.0 +flake8==6.1.0 +black==23.11.0 + +# Jupyter notebooks +jupyter==1.0.0 +ipykernel==6.26.0 + +# Additional utilities +python-dotenv==1.0.0 +click==8.1.7 \ No newline at end of file diff --git a/scripts/evaluate_model.py b/scripts/evaluate_model.py new file mode 100644 index 0000000..110c318 --- /dev/null +++ b/scripts/evaluate_model.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +""" +Evaluation script for cybersecurity attack detection models. + +This script provides a command-line interface for evaluating +trained models on new datasets. +""" + +import argparse +import sys +import os +import json +from pathlib import Path + +# Add src to Python path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from src.core.preprocessing import DataPreprocessor +from src.models.detector import CyberAttackDetector +from src.utils.helpers import setup_logging, PerformanceTimer +import pandas as pd +import numpy as np +from sklearn.metrics import confusion_matrix +import matplotlib.pyplot as plt +import seaborn as sns + + +def main(): + """Main evaluation function.""" + parser = argparse.ArgumentParser( + description='Evaluate cybersecurity attack detection models' + ) + + # Required arguments + parser.add_argument( + '--model', + required=True, + help='Path to the trained model file (.pkl)' + ) + + parser.add_argument( + '--data', + required=True, + help='Path to the test data CSV file' + ) + + # Optional arguments + parser.add_argument( + '--output-dir', + default='results', + help='Directory to save evaluation results (default: results)' + ) + + parser.add_argument( + '--save-predictions', + action='store_true', + help='Save predictions to CSV file' + ) + + parser.add_argument( + '--plot-confusion-matrix', + action='store_true', + help='Generate and save confusion matrix plot' + ) + + parser.add_argument( + '--log-level', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], + default='INFO', + help='Logging level (default: INFO)' + ) + + parser.add_argument( + '--log-file', + help='Log file path (default: console only)' + ) + + args = parser.parse_args() + + # Setup logging + logger = setup_logging(args.log_level, args.log_file) + + try: + # Validate input files + if not os.path.exists(args.model): + logger.error(f"Model file not found: {args.model}") + sys.exit(1) + + if not os.path.exists(args.data): + logger.error(f"Data file not found: {args.data}") + sys.exit(1) + + logger.info(f"Evaluating model: {args.model}") + logger.info(f"Test data: {args.data}") + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Load model + logger.info("Loading model...") + # Extract model type from filename (assuming format: modeltype_model.pkl) + model_filename = Path(args.model).stem + if '_model' in model_filename: + model_type = model_filename.replace('_model', '') + else: + model_type = 'random_forest' # Default fallback + + detector = CyberAttackDetector(model_type) + detector.load_model(args.model) + logger.info(f"Model loaded successfully (type: {model_type})") + + # Load and preprocess test data + logger.info("Loading and preprocessing test data...") + preprocessor = DataPreprocessor() + + with PerformanceTimer("Data preprocessing", logger): + raw_data = pd.read_csv(args.data) + logger.info(f"Loaded test data with shape: {raw_data.shape}") + + X, y, validation_results = preprocessor.full_preprocessing_pipeline(raw_data) + + if not validation_results['is_valid']: + logger.error("Data validation failed:") + for error in validation_results['errors']: + logger.error(f" - {error}") + sys.exit(1) + + # Scale features + X_scaled, _ = preprocessor.scale_features(X) + + # Make predictions + logger.info("Making predictions...") + with PerformanceTimer("Prediction", logger): + predictions = detector.predict(X_scaled) + probabilities = detector.predict_proba(X_scaled) + + logger.info(f"Generated {len(predictions)} predictions") + + # Evaluate model + logger.info("Evaluating model performance...") + with PerformanceTimer("Model evaluation", logger): + evaluation_results = detector.evaluate(X_scaled, y) + + # Display results + logger.info("Evaluation Results:") + logger.info(f" Accuracy: {evaluation_results['accuracy']:.4f}") + logger.info(f" Precision: {evaluation_results['precision']:.4f}") + logger.info(f" Recall: {evaluation_results['recall']:.4f}") + logger.info(f" F1 Score: {evaluation_results['f1_score']:.4f}") + + if 'auc_score' in evaluation_results: + logger.info(f" AUC Score: {evaluation_results['auc_score']:.4f}") + + # Create detailed results + results_summary = { + 'model_path': args.model, + 'model_type': model_type, + 'test_data_path': args.data, + 'test_data_shape': raw_data.shape, + 'validation_results': validation_results, + 'evaluation_results': evaluation_results, + 'predictions_count': len(predictions) + } + + # Save results + results_filename = f"evaluation_results_{Path(args.model).stem}_{Path(args.data).stem}.json" + results_path = os.path.join(args.output_dir, results_filename) + + with open(results_path, 'w') as f: + json.dump(results_summary, f, indent=2, default=str) + + logger.info(f"Evaluation results saved to {results_path}") + + # Save predictions if requested + if args.save_predictions: + predictions_df = pd.DataFrame({ + 'true_label': y, + 'predicted_label': predictions + }) + + if probabilities is not None: + # Add probability columns + unique_classes = np.unique(y) + for i, class_name in enumerate(unique_classes): + predictions_df[f'prob_{class_name}'] = probabilities[:, i] + + predictions_filename = f"predictions_{Path(args.model).stem}_{Path(args.data).stem}.csv" + predictions_path = os.path.join(args.output_dir, predictions_filename) + + predictions_df.to_csv(predictions_path, index=False) + logger.info(f"Predictions saved to {predictions_path}") + + # Generate confusion matrix plot if requested + if args.plot_confusion_matrix: + logger.info("Generating confusion matrix plot...") + + plt.figure(figsize=(8, 6)) + cm = confusion_matrix(y, predictions) + + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=np.unique(y), yticklabels=np.unique(y)) + plt.title(f'Confusion Matrix - {model_type.title()}') + plt.xlabel('Predicted') + plt.ylabel('Actual') + + plot_filename = f"confusion_matrix_{Path(args.model).stem}_{Path(args.data).stem}.png" + plot_path = os.path.join(args.output_dir, plot_filename) + + plt.savefig(plot_path, dpi=300, bbox_inches='tight') + plt.close() + + logger.info(f"Confusion matrix plot saved to {plot_path}") + + # Print classification report + print("\nDetailed Classification Report:") + print(evaluation_results['classification_report']) + + logger.info("Evaluation completed successfully!") + + except Exception as e: + logger.error(f"Evaluation failed with error: {str(e)}") + import traceback + logger.debug(traceback.format_exc()) + sys.exit(1) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/scripts/generate_data.py b/scripts/generate_data.py new file mode 100755 index 0000000..7b5faad --- /dev/null +++ b/scripts/generate_data.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 +""" +Data generation script for cybersecurity attack detection. + +This script generates sample cybersecurity datasets for training and testing. +""" + +import argparse +import sys +import os + +# Add src to Python path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from src.utils.helpers import create_sample_dataset, setup_logging, ensure_directory +import pandas as pd + + +def main(): + """Main data generation function.""" + parser = argparse.ArgumentParser( + description='Generate sample cybersecurity datasets' + ) + + # Dataset parameters + parser.add_argument( + '--samples', + type=int, + default=1000, + help='Number of samples to generate (default: 1000)' + ) + + parser.add_argument( + '--features', + type=int, + default=20, + help='Number of features to generate (default: 20)' + ) + + parser.add_argument( + '--classes', + type=int, + default=2, + help='Number of classes (attack types) (default: 2)' + ) + + parser.add_argument( + '--noise', + type=float, + default=0.1, + help='Amount of noise to add (0.0-1.0) (default: 0.1)' + ) + + parser.add_argument( + '--random-state', + type=int, + default=42, + help='Random seed for reproducibility (default: 42)' + ) + + # Output parameters + parser.add_argument( + '--output-dir', + default='data', + help='Directory to save the generated data (default: data)' + ) + + parser.add_argument( + '--filename', + help='Output filename (default: auto-generated)' + ) + + parser.add_argument( + '--split', + action='store_true', + help='Generate separate train/test files' + ) + + parser.add_argument( + '--test-size', + type=float, + default=0.2, + help='Test set proportion when splitting (default: 0.2)' + ) + + # Logging parameters + parser.add_argument( + '--log-level', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], + default='INFO', + help='Logging level (default: INFO)' + ) + + args = parser.parse_args() + + # Setup logging + logger = setup_logging(args.log_level) + + try: + logger.info("Generating sample cybersecurity dataset...") + logger.info(f"Samples: {args.samples}") + logger.info(f"Features: {args.features}") + logger.info(f"Classes: {args.classes}") + logger.info(f"Noise level: {args.noise}") + + # Generate dataset + dataset = create_sample_dataset( + n_samples=args.samples, + n_features=args.features, + n_classes=args.classes, + noise=args.noise, + random_state=args.random_state + ) + + logger.info(f"Generated dataset with shape: {dataset.shape}") + + # Create output directory + ensure_directory(args.output_dir) + + # Determine filename + if args.filename: + base_filename = args.filename + else: + base_filename = f"cyber_dataset_{args.samples}s_{args.features}f_{args.classes}c" + + if args.split: + # Split into train and test sets + from sklearn.model_selection import train_test_split + + X = dataset.drop('label', axis=1) + y = dataset['label'] + + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=args.test_size, random_state=args.random_state, stratify=y + ) + + train_dataset = pd.concat([X_train, y_train], axis=1) + test_dataset = pd.concat([X_test, y_test], axis=1) + + # Save train set + train_filename = f"{base_filename}_train.csv" + train_path = os.path.join(args.output_dir, train_filename) + train_dataset.to_csv(train_path, index=False) + logger.info(f"Training data saved to {train_path} (shape: {train_dataset.shape})") + + # Save test set + test_filename = f"{base_filename}_test.csv" + test_path = os.path.join(args.output_dir, test_filename) + test_dataset.to_csv(test_path, index=False) + logger.info(f"Test data saved to {test_path} (shape: {test_dataset.shape})") + + else: + # Save complete dataset + if not base_filename.endswith('.csv'): + base_filename += '.csv' + + output_path = os.path.join(args.output_dir, base_filename) + dataset.to_csv(output_path, index=False) + logger.info(f"Dataset saved to {output_path}") + + # Display dataset statistics + logger.info("\nDataset Statistics:") + logger.info(f" Total samples: {len(dataset)}") + logger.info(f" Features: {len(dataset.columns) - 1}") # Exclude label column + logger.info(f" Memory usage: {dataset.memory_usage(deep=True).sum() / 1024:.1f} KB") + + logger.info("\nLabel Distribution:") + label_counts = dataset['label'].value_counts() + for label, count in label_counts.items(): + percentage = (count / len(dataset)) * 100 + logger.info(f" {label}: {count} ({percentage:.1f}%)") + + logger.info("\nFeature Statistics:") + numeric_features = dataset.select_dtypes(include=['float64', 'int64']) + logger.info(f" Numeric features: {len(numeric_features.columns)}") + + categorical_features = dataset.select_dtypes(include=['object']) + categorical_features = categorical_features.drop('label', axis=1, errors='ignore') + logger.info(f" Categorical features: {len(categorical_features.columns)}") + + logger.info("Data generation completed successfully!") + + except Exception as e: + logger.error(f"Data generation failed with error: {str(e)}") + import traceback + logger.debug(traceback.format_exc()) + sys.exit(1) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/scripts/test_api.py b/scripts/test_api.py new file mode 100755 index 0000000..3603243 --- /dev/null +++ b/scripts/test_api.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +""" +Simple API client for testing the cybersecurity detection API. +""" + +import requests +import json +import sys +import numpy as np + +def test_api(): + """Test the API endpoints.""" + base_url = "http://localhost:5000" + + print("Testing Cybersecurity Detection API...") + + # 1. Test health check + print("\n1. Testing health check...") + try: + response = requests.get(f"{base_url}/api/health", timeout=5) + if response.status_code == 200: + print("โœ“ Health check passed") + print(f" Response: {response.json()}") + else: + print(f"โœ— Health check failed: {response.status_code}") + return False + except requests.exceptions.ConnectionError: + print("โœ— Cannot connect to API. Is the server running?") + return False + except Exception as e: + print(f"โœ— Health check error: {e}") + return False + + # 2. Get authentication token + print("\n2. Getting authentication token...") + try: + auth_data = {"username": "test_user"} + response = requests.post(f"{base_url}/api/auth/token", json=auth_data, timeout=5) + if response.status_code == 200: + token = response.json()['token'] + print("โœ“ Token obtained successfully") + else: + print(f"โœ— Token request failed: {response.status_code}") + print(f" Response: {response.text}") + return False + except Exception as e: + print(f"โœ— Token request error: {e}") + return False + + headers = {"Authorization": f"Bearer {token}"} + + # 3. Test available models + print("\n3. Testing available models endpoint...") + try: + response = requests.get(f"{base_url}/api/models/available", timeout=5) + if response.status_code == 200: + models = response.json()['models'] + print("โœ“ Available models retrieved") + print(f" Available models: {list(models.keys())}") + else: + print(f"โœ— Available models failed: {response.status_code}") + except Exception as e: + print(f"โœ— Available models error: {e}") + + # 4. Test detection (with sample data) + print("\n4. Testing detection endpoint...") + try: + # Generate some sample data for detection + sample_data = np.random.randn(5, 15).tolist() # 5 samples, 15 features + + detection_data = { + "data": sample_data, + "model_type": "random_forest" + } + + response = requests.post( + f"{base_url}/api/detect", + json=detection_data, + headers=headers, + timeout=10 + ) + + if response.status_code == 200: + result = response.json() + print("โœ“ Detection successful") + print(f" Predictions: {result.get('predictions', [])}") + print(f" Model used: {result.get('model_type')}") + elif response.status_code == 404: + print("! Model not found - this is expected if no model is trained yet") + print(" Train a model first using: python scripts/train_model.py") + else: + print(f"โœ— Detection failed: {response.status_code}") + print(f" Response: {response.text}") + except Exception as e: + print(f"โœ— Detection error: {e}") + + print("\n5. API test completed!") + return True + +if __name__ == '__main__': + success = test_api() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/scripts/train_model.py b/scripts/train_model.py new file mode 100755 index 0000000..0807c0d --- /dev/null +++ b/scripts/train_model.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +""" +Training script for cybersecurity attack detection models. + +This script provides a command-line interface for training models +using the cybersecurity detection framework. +""" + +import argparse +import sys +import os +import json +from pathlib import Path + +# Add src to Python path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from src.core.preprocessing import DataPreprocessor +from src.models.detector import CyberAttackDetector, ModelRegistry +from src.utils.helpers import setup_logging, PerformanceTimer, ensure_directory +from sklearn.model_selection import train_test_split +import pandas as pd + + +def main(): + """Main training function.""" + parser = argparse.ArgumentParser( + description='Train cybersecurity attack detection models' + ) + + # Data arguments + parser.add_argument( + '--data', + required=True, + help='Path to the training data CSV file' + ) + + parser.add_argument( + '--test-size', + type=float, + default=0.2, + help='Proportion of data to use for testing (default: 0.2)' + ) + + # Model arguments + parser.add_argument( + '--model', + choices=list(ModelRegistry.get_available_models().keys()), + default='random_forest', + help='Type of model to train (default: random_forest)' + ) + + parser.add_argument( + '--hyperparameter-tuning', + action='store_true', + help='Perform hyperparameter tuning' + ) + + parser.add_argument( + '--cv-folds', + type=int, + default=5, + help='Number of cross-validation folds for hyperparameter tuning (default: 5)' + ) + + # Output arguments + parser.add_argument( + '--output-dir', + default='models', + help='Directory to save the trained model (default: models)' + ) + + parser.add_argument( + '--model-name', + help='Name for the saved model (default: auto-generated)' + ) + + # Logging arguments + parser.add_argument( + '--log-level', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], + default='INFO', + help='Logging level (default: INFO)' + ) + + parser.add_argument( + '--log-file', + help='Log file path (default: console only)' + ) + + args = parser.parse_args() + + # Setup logging + logger = setup_logging(args.log_level, args.log_file) + + try: + # Validate input data file + if not os.path.exists(args.data): + logger.error(f"Data file not found: {args.data}") + sys.exit(1) + + logger.info(f"Starting training with data: {args.data}") + logger.info(f"Model type: {args.model}") + logger.info(f"Test size: {args.test_size}") + logger.info(f"Hyperparameter tuning: {args.hyperparameter_tuning}") + + # Load and preprocess data + logger.info("Loading and preprocessing data...") + preprocessor = DataPreprocessor() + + with PerformanceTimer("Data preprocessing", logger): + # Load data + raw_data = pd.read_csv(args.data) + logger.info(f"Loaded data with shape: {raw_data.shape}") + + # Preprocess data + X, y, validation_results = preprocessor.full_preprocessing_pipeline(raw_data) + + if not validation_results['is_valid']: + logger.error("Data validation failed:") + for error in validation_results['errors']: + logger.error(f" - {error}") + sys.exit(1) + + if validation_results['warnings']: + for warning in validation_results['warnings']: + logger.warning(f" - {warning}") + + # Split data + logger.info("Splitting data into train/test sets...") + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=args.test_size, random_state=42, stratify=y + ) + + logger.info(f"Training set shape: {X_train.shape}") + logger.info(f"Test set shape: {X_test.shape}") + + # Scale features + X_train_scaled, X_test_scaled = preprocessor.scale_features(X_train, X_test) + + # Initialize model + logger.info(f"Initializing {args.model} model...") + detector = CyberAttackDetector(args.model) + + # Training + if args.hyperparameter_tuning: + logger.info("Performing hyperparameter tuning...") + with PerformanceTimer("Hyperparameter tuning", logger): + tuning_results = detector.hyperparameter_tuning( + X_train_scaled, y_train, cv_folds=args.cv_folds + ) + + logger.info(f"Best parameters: {tuning_results['best_params']}") + logger.info(f"Best CV score: {tuning_results['best_score']:.4f}") + + else: + logger.info("Training model with default parameters...") + with PerformanceTimer("Model training", logger): + training_results = detector.train( + X_train_scaled, y_train, feature_names=X.columns.tolist() + ) + + logger.info(f"Training completed - CV accuracy: {training_results['cv_mean']:.4f}") + + # Evaluation + logger.info("Evaluating model on test set...") + with PerformanceTimer("Model evaluation", logger): + evaluation_results = detector.evaluate(X_test_scaled, y_test) + + logger.info("Evaluation Results:") + logger.info(f" Accuracy: {evaluation_results['accuracy']:.4f}") + logger.info(f" Precision: {evaluation_results['precision']:.4f}") + logger.info(f" Recall: {evaluation_results['recall']:.4f}") + logger.info(f" F1 Score: {evaluation_results['f1_score']:.4f}") + + if 'auc_score' in evaluation_results: + logger.info(f" AUC Score: {evaluation_results['auc_score']:.4f}") + + # Save model + ensure_directory(args.output_dir) + + if args.model_name: + model_filename = f"{args.model_name}.pkl" + else: + model_filename = f"{args.model}_model.pkl" + + model_path = os.path.join(args.output_dir, model_filename) + + logger.info(f"Saving model to {model_path}...") + detector.save_model(model_path) + + # Save training metadata + metadata_path = os.path.join(args.output_dir, f"{Path(model_filename).stem}_metadata.json") + + metadata = { + 'model_type': args.model, + 'data_file': args.data, + 'data_shape': raw_data.shape, + 'test_size': args.test_size, + 'hyperparameter_tuning': args.hyperparameter_tuning, + 'validation_results': validation_results, + 'evaluation_results': evaluation_results, + 'model_path': model_path + } + + if args.hyperparameter_tuning: + metadata['tuning_results'] = tuning_results + else: + metadata['training_results'] = training_results + + with open(metadata_path, 'w') as f: + json.dump(metadata, f, indent=2, default=str) + + logger.info(f"Training metadata saved to {metadata_path}") + logger.info("Training completed successfully!") + + except Exception as e: + logger.error(f"Training failed with error: {str(e)}") + import traceback + logger.debug(traceback.format_exc()) + sys.exit(1) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..031a0c1 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,8 @@ +""" +Cybersecurity Attack Detection Framework + +A modular, secure framework for detecting cyber attacks using machine learning. +""" + +__version__ = "1.0.0" +__author__ = "Cybersecurity Detection Team" \ No newline at end of file diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000..7867bf7 --- /dev/null +++ b/src/api/__init__.py @@ -0,0 +1 @@ +"""API module for cybersecurity detection framework.""" \ No newline at end of file diff --git a/src/api/app.py b/src/api/app.py new file mode 100644 index 0000000..ac164a9 --- /dev/null +++ b/src/api/app.py @@ -0,0 +1,475 @@ +""" +Flask API for cybersecurity attack detection framework. + +This module provides a secure REST API for the cybersecurity detection system +with JWT authentication, rate limiting, and input validation. +""" + +import os +import jwt +import logging +from datetime import datetime, timedelta +from functools import wraps +from typing import Dict, Any, Optional + +from flask import Flask, request, jsonify, current_app +from flask_limiter import Limiter +from flask_limiter.util import get_remote_address +from flask_cors import CORS +from marshmallow import Schema, fields, ValidationError +import numpy as np +import pandas as pd + +from ..core.preprocessing import DataPreprocessor +from ..models.detector import CyberAttackDetector, ModelRegistry + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class DetectionRequestSchema(Schema): + """Schema for detection API requests.""" + data = fields.List(fields.List(fields.Float()), required=True) + model_type = fields.Str(missing='random_forest', validate=lambda x: x in ModelRegistry.get_available_models()) + + +class TrainingRequestSchema(Schema): + """Schema for model training API requests.""" + file_path = fields.Str(required=True) + model_type = fields.Str(missing='random_forest', validate=lambda x: x in ModelRegistry.get_available_models()) + test_size = fields.Float(missing=0.2, validate=lambda x: 0.1 <= x <= 0.5) + hyperparameter_tuning = fields.Bool(missing=False) + + +class APIConfig: + """Configuration class for the API.""" + SECRET_KEY = os.environ.get('SECRET_KEY', 'dev-secret-key-change-in-production') + JWT_EXPIRATION_DELTA = timedelta(hours=1) + RATE_LIMIT_DEFAULT = "100/hour" + RATE_LIMIT_AUTH = "20/minute" + CORS_ORIGINS = ["http://localhost:3000", "http://127.0.0.1:3000"] + MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size + + +def create_app(config_class=APIConfig) -> Flask: + """ + Create and configure the Flask application. + + Args: + config_class: Configuration class to use + + Returns: + Configured Flask application + """ + app = Flask(__name__) + app.config.from_object(config_class) + + # Initialize extensions + CORS(app, origins=app.config['CORS_ORIGINS']) + + limiter = Limiter( + app, + key_func=get_remote_address, + default_limits=[app.config['RATE_LIMIT_DEFAULT']] + ) + + # Store limiter in app context for use in decorators + app.limiter = limiter + + # Initialize global components + app.preprocessor = DataPreprocessor() + app.models = {} # Cache for loaded models + + # Register error handlers + register_error_handlers(app) + + # Register routes + register_routes(app) + + # Security headers middleware + @app.after_request + def after_request(response): + response.headers['X-Content-Type-Options'] = 'nosniff' + response.headers['X-Frame-Options'] = 'DENY' + response.headers['X-XSS-Protection'] = '1; mode=block' + response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains' + return response + + logger.info("Flask application created and configured") + return app + + +def register_error_handlers(app: Flask) -> None: + """Register error handlers for the application.""" + + @app.errorhandler(400) + def bad_request(error): + return jsonify({ + 'success': False, + 'error': 'Bad Request', + 'message': 'Invalid request format or parameters' + }), 400 + + @app.errorhandler(401) + def unauthorized(error): + return jsonify({ + 'success': False, + 'error': 'Unauthorized', + 'message': 'Valid authentication token required' + }), 401 + + @app.errorhandler(403) + def forbidden(error): + return jsonify({ + 'success': False, + 'error': 'Forbidden', + 'message': 'Insufficient permissions' + }), 403 + + @app.errorhandler(429) + def ratelimit_handler(error): + return jsonify({ + 'success': False, + 'error': 'Too Many Requests', + 'message': 'Rate limit exceeded. Please try again later.' + }), 429 + + @app.errorhandler(500) + def internal_error(error): + logger.error(f"Internal server error: {str(error)}") + return jsonify({ + 'success': False, + 'error': 'Internal Server Error', + 'message': 'An unexpected error occurred' + }), 500 + + +def require_auth(f): + """Decorator to require JWT authentication.""" + @wraps(f) + def decorated_function(*args, **kwargs): + token = request.headers.get('Authorization') + + if not token: + return jsonify({ + 'success': False, + 'error': 'Missing token', + 'message': 'Authorization header is required' + }), 401 + + try: + # Remove 'Bearer ' prefix if present + if token.startswith('Bearer '): + token = token[7:] + + payload = jwt.decode( + token, + current_app.config['SECRET_KEY'], + algorithms=['HS256'] + ) + + # Check if token is expired + if datetime.utcnow() > datetime.fromtimestamp(payload['exp']): + return jsonify({ + 'success': False, + 'error': 'Token expired', + 'message': 'Please obtain a new token' + }), 401 + + except jwt.InvalidTokenError as e: + return jsonify({ + 'success': False, + 'error': 'Invalid token', + 'message': str(e) + }), 401 + + return f(*args, **kwargs) + + return decorated_function + + +def register_routes(app: Flask) -> None: + """Register API routes.""" + + @app.route('/api/health', methods=['GET']) + def health_check(): + """Health check endpoint.""" + return jsonify({ + 'success': True, + 'message': 'Cybersecurity Detection API is running', + 'timestamp': datetime.utcnow().isoformat(), + 'version': '1.0.0' + }) + + @app.route('/api/auth/token', methods=['POST']) + @app.limiter.limit(app.config['RATE_LIMIT_AUTH']) + def get_token(): + """Generate JWT token (simplified - in production, validate credentials).""" + try: + # In production, validate username/password here + username = request.json.get('username', 'user') + + payload = { + 'username': username, + 'exp': datetime.utcnow() + app.config['JWT_EXPIRATION_DELTA'], + 'iat': datetime.utcnow() + } + + token = jwt.encode(payload, app.config['SECRET_KEY'], algorithm='HS256') + + return jsonify({ + 'success': True, + 'token': token, + 'expires_in': app.config['JWT_EXPIRATION_DELTA'].total_seconds() + }) + + except Exception as e: + logger.error(f"Token generation error: {str(e)}") + return jsonify({ + 'success': False, + 'error': 'Token generation failed', + 'message': str(e) + }), 500 + + @app.route('/api/models/available', methods=['GET']) + def get_available_models(): + """Get list of available models.""" + try: + available_models = ModelRegistry.get_available_models() + model_info = {} + + for name, config in available_models.items(): + model_info[name] = { + 'name': name, + 'class': config['class'].__name__, + 'default_params': config['default_params'], + 'tunable_params': list(config['param_grid'].keys()) + } + + return jsonify({ + 'success': True, + 'models': model_info + }) + + except Exception as e: + logger.error(f"Error getting available models: {str(e)}") + return jsonify({ + 'success': False, + 'error': 'Failed to get available models', + 'message': str(e) + }), 500 + + @app.route('/api/detect', methods=['POST']) + @require_auth + def detect_attacks(): + """Detect cyber attacks using trained model.""" + try: + # Validate input + schema = DetectionRequestSchema() + try: + data = schema.load(request.json) + except ValidationError as err: + return jsonify({ + 'success': False, + 'error': 'Validation error', + 'message': err.messages + }), 400 + + # Get or load model + model_type = data['model_type'] + model_key = f"detector_{model_type}" + + if model_key not in app.models: + # Try to load pre-trained model + model_path = f"models/{model_type}_model.pkl" + if os.path.exists(model_path): + detector = CyberAttackDetector(model_type) + detector.load_model(model_path) + app.models[model_key] = detector + else: + return jsonify({ + 'success': False, + 'error': 'Model not found', + 'message': f'No trained model available for {model_type}' + }), 404 + + # Make predictions + detector = app.models[model_key] + input_data = np.array(data['data']) + + predictions = detector.predict(input_data) + probabilities = detector.predict_proba(input_data) + + response_data = { + 'success': True, + 'predictions': predictions.tolist(), + 'model_type': model_type, + 'num_samples': len(predictions) + } + + if probabilities is not None: + response_data['probabilities'] = probabilities.tolist() + + return jsonify(response_data) + + except Exception as e: + logger.error(f"Detection error: {str(e)}") + return jsonify({ + 'success': False, + 'error': 'Detection failed', + 'message': str(e) + }), 500 + + @app.route('/api/train', methods=['POST']) + @require_auth + def train_model(): + """Train a new model.""" + try: + # Validate input + schema = TrainingRequestSchema() + try: + data = schema.load(request.json) + except ValidationError as err: + return jsonify({ + 'success': False, + 'error': 'Validation error', + 'message': err.messages + }), 400 + + # Load and preprocess data + file_path = data['file_path'] + if not os.path.exists(file_path): + return jsonify({ + 'success': False, + 'error': 'File not found', + 'message': f'Data file not found: {file_path}' + }), 404 + + X, y, validation_results = app.preprocessor.full_preprocessing_pipeline( + pd.read_csv(file_path) + ) + + if not validation_results['is_valid']: + return jsonify({ + 'success': False, + 'error': 'Data validation failed', + 'message': validation_results['errors'] + }), 400 + + # Split data + from sklearn.model_selection import train_test_split + test_size = data['test_size'] + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=test_size, random_state=42 + ) + + # Scale features + X_train_scaled, X_test_scaled = app.preprocessor.scale_features(X_train, X_test) + + # Initialize and train model + model_type = data['model_type'] + detector = CyberAttackDetector(model_type) + + # Perform hyperparameter tuning if requested + if data['hyperparameter_tuning']: + tuning_results = detector.hyperparameter_tuning(X_train_scaled, y_train) + else: + training_results = detector.train(X_train_scaled, y_train, X.columns.tolist()) + + # Evaluate model + evaluation_results = detector.evaluate(X_test_scaled, y_test) + + # Save model + model_path = f"models/{model_type}_model.pkl" + os.makedirs('models', exist_ok=True) + detector.save_model(model_path) + + # Cache model + app.models[f"detector_{model_type}"] = detector + + response_data = { + 'success': True, + 'model_type': model_type, + 'model_path': model_path, + 'evaluation': evaluation_results, + 'validation': validation_results + } + + if data['hyperparameter_tuning']: + response_data['hyperparameter_tuning'] = tuning_results + else: + response_data['training'] = training_results + + return jsonify(response_data) + + except Exception as e: + logger.error(f"Training error: {str(e)}") + return jsonify({ + 'success': False, + 'error': 'Training failed', + 'message': str(e) + }), 500 + + @app.route('/api/models//evaluate', methods=['POST']) + @require_auth + def evaluate_model(model_type: str): + """Evaluate a trained model on test data.""" + try: + if model_type not in ModelRegistry.get_available_models(): + return jsonify({ + 'success': False, + 'error': 'Invalid model type', + 'message': f'Model type {model_type} not supported' + }), 400 + + # Get test data file path from request + file_path = request.json.get('file_path') + if not file_path or not os.path.exists(file_path): + return jsonify({ + 'success': False, + 'error': 'File not found', + 'message': 'Test data file not found' + }), 404 + + # Load model + model_key = f"detector_{model_type}" + if model_key not in app.models: + model_path = f"models/{model_type}_model.pkl" + if not os.path.exists(model_path): + return jsonify({ + 'success': False, + 'error': 'Model not found', + 'message': f'No trained model available for {model_type}' + }), 404 + + detector = CyberAttackDetector(model_type) + detector.load_model(model_path) + app.models[model_key] = detector + + # Preprocess test data + X, y, _ = app.preprocessor.full_preprocessing_pipeline(pd.read_csv(file_path)) + X_scaled, _ = app.preprocessor.scale_features(X) + + # Evaluate model + detector = app.models[model_key] + evaluation_results = detector.evaluate(X_scaled, y) + + return jsonify({ + 'success': True, + 'model_type': model_type, + 'evaluation': evaluation_results + }) + + except Exception as e: + logger.error(f"Evaluation error: {str(e)}") + return jsonify({ + 'success': False, + 'error': 'Evaluation failed', + 'message': str(e) + }), 500 + + +if __name__ == '__main__': + app = create_app() + app.run(debug=False, host='0.0.0.0', port=5000) \ No newline at end of file diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..db85cfa --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1 @@ +"""Core detection and preprocessing modules.""" \ No newline at end of file diff --git a/src/core/preprocessing.py b/src/core/preprocessing.py new file mode 100644 index 0000000..66bb469 --- /dev/null +++ b/src/core/preprocessing.py @@ -0,0 +1,246 @@ +""" +Data preprocessing module for cybersecurity detection. + +This module provides functions for loading, cleaning, and preprocessing +cybersecurity datasets for machine learning model training and inference. +""" + +import pandas as pd +import numpy as np +from sklearn.preprocessing import StandardScaler, LabelEncoder +from typing import Tuple, Optional, Dict, Any +import logging + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class DataPreprocessor: + """Class for preprocessing cybersecurity datasets.""" + + def __init__(self): + """Initialize the preprocessor with default settings.""" + self.scaler = StandardScaler() + self.label_encoder = LabelEncoder() + self.feature_columns = None + self.is_fitted = False + + def load_data(self, file_path: str) -> Optional[pd.DataFrame]: + """ + Load data from a CSV file with error handling. + + Args: + file_path: Path to the CSV file + + Returns: + DataFrame or None if loading fails + """ + try: + data = pd.read_csv(file_path) + logger.info(f"Successfully loaded data with shape: {data.shape}") + return data + except FileNotFoundError: + logger.error(f"File not found: {file_path}") + return None + except pd.errors.EmptyDataError: + logger.error("File is empty") + return None + except Exception as e: + logger.error(f"Error loading data: {str(e)}") + return None + + def validate_data(self, data: pd.DataFrame) -> Dict[str, Any]: + """ + Validate input data and provide summary statistics. + + Args: + data: Input DataFrame + + Returns: + Dictionary with validation results + """ + validation_results = { + 'is_valid': True, + 'errors': [], + 'warnings': [], + 'shape': data.shape, + 'columns': list(data.columns), + 'dtypes': data.dtypes.to_dict(), + 'missing_values': data.isnull().sum().to_dict(), + 'memory_usage': data.memory_usage(deep=True).sum() + } + + # Check for completely empty dataset + if data.empty: + validation_results['is_valid'] = False + validation_results['errors'].append("Dataset is empty") + + # Check for required columns (assuming 'label' is required) + if 'label' not in data.columns: + validation_results['warnings'].append( + "No 'label' column found - assuming unsupervised learning" + ) + + # Check for excessive missing values + missing_percentage = (data.isnull().sum() / len(data)) * 100 + high_missing_cols = missing_percentage[missing_percentage > 50].index.tolist() + if high_missing_cols: + validation_results['warnings'].append( + f"Columns with >50% missing values: {high_missing_cols}" + ) + + return validation_results + + def clean_data(self, data: pd.DataFrame) -> pd.DataFrame: + """ + Clean the dataset by handling missing values and outliers. + + Args: + data: Input DataFrame + + Returns: + Cleaned DataFrame + """ + cleaned_data = data.copy() + + # Handle missing values + # For numeric columns, fill with median + numeric_cols = cleaned_data.select_dtypes(include=[np.number]).columns + for col in numeric_cols: + if cleaned_data[col].isnull().any(): + median_value = cleaned_data[col].median() + cleaned_data.loc[:, col] = cleaned_data[col].fillna(median_value) + logger.info(f"Filled missing values in {col} with median: {median_value}") + + # For categorical columns, fill with mode + categorical_cols = cleaned_data.select_dtypes(include=['object']).columns + for col in categorical_cols: + if cleaned_data[col].isnull().any(): + mode_value = cleaned_data[col].mode().iloc[0] if not cleaned_data[col].mode().empty else 'unknown' + cleaned_data.loc[:, col] = cleaned_data[col].fillna(mode_value) + logger.info(f"Filled missing values in {col} with mode: {mode_value}") + + return cleaned_data + + def encode_categorical_features(self, data: pd.DataFrame) -> pd.DataFrame: + """ + Encode categorical features using one-hot encoding. + + Args: + data: Input DataFrame + + Returns: + DataFrame with encoded categorical features + """ + encoded_data = data.copy() + + # Get categorical columns (excluding the label column if present) + categorical_cols = encoded_data.select_dtypes(include=['object']).columns + categorical_cols = [col for col in categorical_cols if col != 'label'] + + if categorical_cols: + # Use pandas get_dummies for one-hot encoding + encoded_data = pd.get_dummies( + encoded_data, + columns=categorical_cols, + prefix=categorical_cols, + drop_first=True + ) + logger.info(f"One-hot encoded columns: {categorical_cols}") + + return encoded_data + + def prepare_features_labels(self, data: pd.DataFrame) -> Tuple[pd.DataFrame, Optional[pd.Series]]: + """ + Separate features and labels from the dataset. + + Args: + data: Input DataFrame + + Returns: + Tuple of (features DataFrame, labels Series or None) + """ + if 'label' in data.columns: + X = data.drop('label', axis=1) + y = data['label'] + logger.info(f"Separated features (shape: {X.shape}) and labels (unique: {y.nunique()})") + return X, y + else: + logger.warning("No label column found - returning all columns as features") + return data, None + + def scale_features(self, X_train: pd.DataFrame, X_test: Optional[pd.DataFrame] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Scale features using StandardScaler. + + Args: + X_train: Training features + X_test: Testing features (optional) + + Returns: + Tuple of scaled training and testing features + """ + # Fit scaler on training data + X_train_scaled = self.scaler.fit_transform(X_train) + self.feature_columns = X_train.columns.tolist() + self.is_fitted = True + + logger.info("Features scaled using StandardScaler") + + if X_test is not None: + X_test_scaled = self.scaler.transform(X_test) + return X_train_scaled, X_test_scaled + else: + return X_train_scaled, None + + def full_preprocessing_pipeline(self, data: pd.DataFrame) -> Tuple[pd.DataFrame, Optional[pd.Series], Dict[str, Any]]: + """ + Complete preprocessing pipeline. + + Args: + data: Input raw DataFrame + + Returns: + Tuple of (processed features, labels, validation results) + """ + # Validate data + validation_results = self.validate_data(data) + + if not validation_results['is_valid']: + logger.error("Data validation failed") + return None, None, validation_results + + # Clean data + cleaned_data = self.clean_data(data) + + # Encode categorical features + encoded_data = self.encode_categorical_features(cleaned_data) + + # Separate features and labels + X, y = self.prepare_features_labels(encoded_data) + + logger.info("Full preprocessing pipeline completed successfully") + + return X, y, validation_results + + +def load_and_preprocess(file_path: str) -> Tuple[Optional[pd.DataFrame], Optional[pd.Series], Dict[str, Any]]: + """ + Convenience function to load and preprocess data in one step. + + Args: + file_path: Path to the CSV file + + Returns: + Tuple of (features, labels, validation results) + """ + preprocessor = DataPreprocessor() + + # Load data + data = preprocessor.load_data(file_path) + if data is None: + return None, None, {'is_valid': False, 'errors': ['Failed to load data']} + + # Run full preprocessing pipeline + return preprocessor.full_preprocessing_pipeline(data) \ No newline at end of file diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..8536b7e --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1 @@ +"""Machine learning models for cybersecurity detection.""" \ No newline at end of file diff --git a/src/models/detector.py b/src/models/detector.py new file mode 100644 index 0000000..c4061e7 --- /dev/null +++ b/src/models/detector.py @@ -0,0 +1,473 @@ +""" +Machine learning models for cybersecurity attack detection. + +This module provides classes and functions for training, evaluating, +and managing machine learning models for cybersecurity threat detection. +""" + +import pickle +import json +import os +from typing import Dict, Any, Tuple, Optional, List +from datetime import datetime +import logging + +import numpy as np +import pandas as pd +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import LogisticRegression +from sklearn.svm import SVC +from sklearn.neural_network import MLPClassifier +from sklearn.model_selection import cross_val_score, GridSearchCV +from sklearn.metrics import ( + accuracy_score, + precision_score, + recall_score, + f1_score, + confusion_matrix, + classification_report, + roc_auc_score +) + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class ModelRegistry: + """Registry for available machine learning models.""" + + @staticmethod + def get_available_models() -> Dict[str, Dict[str, Any]]: + """ + Get dictionary of available models with their configurations. + + Returns: + Dictionary mapping model names to their configurations + """ + return { + 'random_forest': { + 'class': RandomForestClassifier, + 'default_params': { + 'n_estimators': 100, + 'random_state': 42, + 'max_depth': 10, + 'min_samples_split': 5 + }, + 'param_grid': { + 'n_estimators': [50, 100, 200], + 'max_depth': [5, 10, 15, None], + 'min_samples_split': [2, 5, 10] + } + }, + 'logistic_regression': { + 'class': LogisticRegression, + 'default_params': { + 'random_state': 42, + 'max_iter': 1000, + 'C': 1.0 + }, + 'param_grid': { + 'C': [0.1, 1.0, 10.0], + 'penalty': ['l1', 'l2'], + 'solver': ['liblinear'] + } + }, + 'svm': { + 'class': SVC, + 'default_params': { + 'random_state': 42, + 'probability': True, + 'C': 1.0, + 'kernel': 'rbf' + }, + 'param_grid': { + 'C': [0.1, 1.0, 10.0], + 'kernel': ['rbf', 'linear'], + 'gamma': ['scale', 'auto'] + } + }, + 'neural_network': { + 'class': MLPClassifier, + 'default_params': { + 'random_state': 42, + 'max_iter': 1000, + 'hidden_layer_sizes': (100,), + 'alpha': 0.0001 + }, + 'param_grid': { + 'hidden_layer_sizes': [(50,), (100,), (100, 50)], + 'alpha': [0.0001, 0.001, 0.01], + 'learning_rate': ['constant', 'adaptive'] + } + } + } + + +class CyberAttackDetector: + """Main class for cybersecurity attack detection models.""" + + def __init__(self, model_type: str = 'random_forest'): + """ + Initialize the detector with a specific model type. + + Args: + model_type: Type of model to use ('random_forest', 'logistic_regression', 'svm', 'neural_network') + """ + self.model_type = model_type + self.model = None + self.model_config = None + self.training_history = {} + self.feature_names = None + self.is_trained = False + + # Get model configuration + available_models = ModelRegistry.get_available_models() + if model_type not in available_models: + raise ValueError(f"Model type '{model_type}' not supported. " + f"Available models: {list(available_models.keys())}") + + self.model_config = available_models[model_type] + self.model = self.model_config['class'](**self.model_config['default_params']) + + logger.info(f"Initialized {model_type} detector") + + def train(self, X_train: np.ndarray, y_train: np.ndarray, + feature_names: Optional[List[str]] = None) -> Dict[str, Any]: + """ + Train the model on the provided dataset. + + Args: + X_train: Training features + y_train: Training labels + feature_names: Names of features (optional) + + Returns: + Dictionary with training results + """ + logger.info(f"Training {self.model_type} model...") + + # Store feature names + self.feature_names = feature_names or [f"feature_{i}" for i in range(X_train.shape[1])] + + # Train the model + start_time = datetime.now() + self.model.fit(X_train, y_train) + training_time = (datetime.now() - start_time).total_seconds() + + # Perform cross-validation + cv_scores = cross_val_score(self.model, X_train, y_train, cv=5, scoring='accuracy') + + # Store training history + self.training_history = { + 'model_type': self.model_type, + 'training_time': training_time, + 'training_samples': X_train.shape[0], + 'features_count': X_train.shape[1], + 'cv_scores': cv_scores.tolist(), + 'cv_mean': cv_scores.mean(), + 'cv_std': cv_scores.std(), + 'timestamp': datetime.now().isoformat() + } + + self.is_trained = True + + logger.info(f"Training completed in {training_time:.2f} seconds") + logger.info(f"Cross-validation accuracy: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})") + + return self.training_history + + def predict(self, X: np.ndarray) -> np.ndarray: + """ + Make predictions on new data. + + Args: + X: Features to predict + + Returns: + Predicted labels + """ + if not self.is_trained: + raise ValueError("Model must be trained before making predictions") + + return self.model.predict(X) + + def predict_proba(self, X: np.ndarray) -> Optional[np.ndarray]: + """ + Get prediction probabilities if supported by the model. + + Args: + X: Features to predict + + Returns: + Prediction probabilities or None if not supported + """ + if not self.is_trained: + raise ValueError("Model must be trained before making predictions") + + if hasattr(self.model, 'predict_proba'): + return self.model.predict_proba(X) + else: + logger.warning(f"Model {self.model_type} does not support probability predictions") + return None + + def evaluate(self, X_test: np.ndarray, y_test: np.ndarray) -> Dict[str, Any]: + """ + Evaluate the model performance on test data. + + Args: + X_test: Test features + y_test: Test labels + + Returns: + Dictionary with evaluation metrics + """ + if not self.is_trained: + raise ValueError("Model must be trained before evaluation") + + # Make predictions + y_pred = self.predict(X_test) + y_proba = self.predict_proba(X_test) + + # Calculate metrics + metrics = { + 'accuracy': accuracy_score(y_test, y_pred), + 'precision': precision_score(y_test, y_pred, average='weighted', zero_division=0), + 'recall': recall_score(y_test, y_pred, average='weighted', zero_division=0), + 'f1_score': f1_score(y_test, y_pred, average='weighted', zero_division=0), + 'confusion_matrix': confusion_matrix(y_test, y_pred).tolist(), + 'classification_report': classification_report(y_test, y_pred, zero_division=0) + } + + # Add AUC score if probabilities are available + if y_proba is not None and len(np.unique(y_test)) == 2: + metrics['auc_score'] = roc_auc_score(y_test, y_proba[:, 1]) + + # Get feature importance if available + if hasattr(self.model, 'feature_importances_'): + feature_importance = pd.Series( + self.model.feature_importances_, + index=self.feature_names + ).sort_values(ascending=False) + metrics['feature_importance'] = feature_importance.to_dict() + + logger.info(f"Model evaluation completed - Accuracy: {metrics['accuracy']:.4f}") + + return metrics + + def hyperparameter_tuning(self, X_train: np.ndarray, y_train: np.ndarray, + cv_folds: int = 5) -> Dict[str, Any]: + """ + Perform hyperparameter tuning using GridSearchCV. + + Args: + X_train: Training features + y_train: Training labels + cv_folds: Number of cross-validation folds + + Returns: + Dictionary with tuning results + """ + logger.info(f"Starting hyperparameter tuning for {self.model_type}...") + + # Get parameter grid + param_grid = self.model_config['param_grid'] + + # Initialize fresh model + base_model = self.model_config['class']() + + # Perform grid search + grid_search = GridSearchCV( + base_model, + param_grid, + cv=cv_folds, + scoring='accuracy', + n_jobs=-1, + verbose=1 + ) + + start_time = datetime.now() + grid_search.fit(X_train, y_train) + tuning_time = (datetime.now() - start_time).total_seconds() + + # Update model with best parameters + self.model = grid_search.best_estimator_ + self.is_trained = True + + # Store tuning results + tuning_results = { + 'best_params': grid_search.best_params_, + 'best_score': grid_search.best_score_, + 'tuning_time': tuning_time, + 'n_combinations': len(grid_search.cv_results_['params']), + 'timestamp': datetime.now().isoformat() + } + + logger.info(f"Hyperparameter tuning completed in {tuning_time:.2f} seconds") + logger.info(f"Best parameters: {grid_search.best_params_}") + logger.info(f"Best cross-validation score: {grid_search.best_score_:.4f}") + + return tuning_results + + def save_model(self, file_path: str) -> None: + """ + Save the trained model to disk. + + Args: + file_path: Path to save the model + """ + if not self.is_trained: + raise ValueError("Model must be trained before saving") + + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + # Save model and metadata + model_data = { + 'model': self.model, + 'model_type': self.model_type, + 'feature_names': self.feature_names, + 'training_history': self.training_history + } + + with open(file_path, 'wb') as f: + pickle.dump(model_data, f) + + logger.info(f"Model saved to {file_path}") + + def load_model(self, file_path: str) -> None: + """ + Load a trained model from disk. + + Args: + file_path: Path to the saved model + """ + try: + with open(file_path, 'rb') as f: + model_data = pickle.load(f) + + self.model = model_data['model'] + self.model_type = model_data['model_type'] + self.feature_names = model_data['feature_names'] + self.training_history = model_data['training_history'] + self.is_trained = True + + logger.info(f"Model loaded from {file_path}") + + except FileNotFoundError: + logger.error(f"Model file not found: {file_path}") + raise + except Exception as e: + logger.error(f"Error loading model: {str(e)}") + raise + + +class ModelComparer: + """Class for comparing multiple models.""" + + def __init__(self): + """Initialize the model comparer.""" + self.models = {} + self.comparison_results = {} + + def add_model(self, name: str, model_type: str) -> None: + """ + Add a model to the comparison. + + Args: + name: Unique name for the model + model_type: Type of model ('random_forest', 'logistic_regression', etc.) + """ + self.models[name] = CyberAttackDetector(model_type) + + def compare_models(self, X_train: np.ndarray, y_train: np.ndarray, + X_test: np.ndarray, y_test: np.ndarray, + feature_names: Optional[List[str]] = None) -> Dict[str, Any]: + """ + Compare all added models on the same dataset. + + Args: + X_train: Training features + y_train: Training labels + X_test: Test features + y_test: Test labels + feature_names: Names of features + + Returns: + Dictionary with comparison results + """ + logger.info(f"Comparing {len(self.models)} models...") + + self.comparison_results = {} + + for name, model in self.models.items(): + logger.info(f"Training and evaluating {name}...") + + # Train model + training_results = model.train(X_train, y_train, feature_names) + + # Evaluate model + evaluation_results = model.evaluate(X_test, y_test) + + # Store results + self.comparison_results[name] = { + 'training': training_results, + 'evaluation': evaluation_results + } + + # Create comparison summary + summary = self._create_comparison_summary() + + logger.info("Model comparison completed") + + return { + 'detailed_results': self.comparison_results, + 'summary': summary + } + + def _create_comparison_summary(self) -> Dict[str, Any]: + """Create a summary of the model comparison.""" + if not self.comparison_results: + return {} + + summary = { + 'best_accuracy': {'model': '', 'score': 0}, + 'best_precision': {'model': '', 'score': 0}, + 'best_recall': {'model': '', 'score': 0}, + 'best_f1': {'model': '', 'score': 0}, + 'fastest_training': {'model': '', 'time': float('inf')}, + 'accuracy_ranking': [] + } + + for name, results in self.comparison_results.items(): + eval_results = results['evaluation'] + train_results = results['training'] + + # Check best accuracy + if eval_results['accuracy'] > summary['best_accuracy']['score']: + summary['best_accuracy'] = {'model': name, 'score': eval_results['accuracy']} + + # Check best precision + if eval_results['precision'] > summary['best_precision']['score']: + summary['best_precision'] = {'model': name, 'score': eval_results['precision']} + + # Check best recall + if eval_results['recall'] > summary['best_recall']['score']: + summary['best_recall'] = {'model': name, 'score': eval_results['recall']} + + # Check best F1 score + if eval_results['f1_score'] > summary['best_f1']['score']: + summary['best_f1'] = {'model': name, 'score': eval_results['f1_score']} + + # Check fastest training + if train_results['training_time'] < summary['fastest_training']['time']: + summary['fastest_training'] = {'model': name, 'time': train_results['training_time']} + + # Create accuracy ranking + summary['accuracy_ranking'] = sorted( + [(name, results['evaluation']['accuracy']) for name, results in self.comparison_results.items()], + key=lambda x: x[1], + reverse=True + ) + + return summary \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..b8795e3 --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1 @@ +"""Utility functions and helpers.""" \ No newline at end of file diff --git a/src/utils/helpers.py b/src/utils/helpers.py new file mode 100644 index 0000000..8e2ebee --- /dev/null +++ b/src/utils/helpers.py @@ -0,0 +1,301 @@ +""" +Utility functions for the cybersecurity detection framework. +""" + +import os +import json +import logging +from typing import Dict, Any, Optional, Union +from pathlib import Path +import pandas as pd +import numpy as np + + +def setup_logging(log_level: str = 'INFO', log_file: Optional[str] = None) -> logging.Logger: + """ + Set up logging configuration. + + Args: + log_level: Logging level ('DEBUG', 'INFO', 'WARNING', 'ERROR') + log_file: Optional log file path + + Returns: + Configured logger instance + """ + # Create logs directory if it doesn't exist + if log_file: + log_dir = Path(log_file).parent + log_dir.mkdir(parents=True, exist_ok=True) + + # Configure logging + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(), + logging.FileHandler(log_file) if log_file else logging.NullHandler() + ] + ) + + return logging.getLogger(__name__) + + +def ensure_directory(directory_path: str) -> None: + """ + Ensure a directory exists, creating it if necessary. + + Args: + directory_path: Path to the directory + """ + Path(directory_path).mkdir(parents=True, exist_ok=True) + + +def load_json_config(config_path: str) -> Dict[str, Any]: + """ + Load configuration from a JSON file. + + Args: + config_path: Path to the JSON configuration file + + Returns: + Dictionary with configuration values + + Raises: + FileNotFoundError: If the configuration file doesn't exist + json.JSONDecodeError: If the file contains invalid JSON + """ + try: + with open(config_path, 'r') as f: + return json.load(f) + except FileNotFoundError: + raise FileNotFoundError(f"Configuration file not found: {config_path}") + except json.JSONDecodeError as e: + raise json.JSONDecodeError(f"Invalid JSON in configuration file: {config_path}", e.doc, e.pos) + + +def save_json_config(config: Dict[str, Any], config_path: str) -> None: + """ + Save configuration to a JSON file. + + Args: + config: Dictionary with configuration values + config_path: Path to save the configuration file + """ + ensure_directory(os.path.dirname(config_path)) + + with open(config_path, 'w') as f: + json.dump(config, f, indent=2, default=str) + + +def validate_file_extension(file_path: str, allowed_extensions: set) -> bool: + """ + Validate if a file has an allowed extension. + + Args: + file_path: Path to the file + allowed_extensions: Set of allowed file extensions (without dots) + + Returns: + True if the extension is allowed, False otherwise + """ + file_extension = Path(file_path).suffix.lower().lstrip('.') + return file_extension in allowed_extensions + + +def get_file_size_mb(file_path: str) -> float: + """ + Get the size of a file in megabytes. + + Args: + file_path: Path to the file + + Returns: + File size in megabytes + """ + return os.path.getsize(file_path) / (1024 * 1024) + + +def sanitize_filename(filename: str) -> str: + """ + Sanitize a filename to remove potentially dangerous characters. + + Args: + filename: Original filename + + Returns: + Sanitized filename + """ + # Remove path separators and other potentially dangerous characters + dangerous_chars = ['/', '\\', '..', '<', '>', ':', '"', '|', '?', '*'] + sanitized = filename + + for char in dangerous_chars: + sanitized = sanitized.replace(char, '_') + + # Remove leading/trailing whitespace and dots + sanitized = sanitized.strip().strip('.') + + # Ensure filename is not empty + if not sanitized: + sanitized = 'unnamed_file' + + return sanitized + + +def convert_numpy_types(obj: Any) -> Any: + """ + Convert numpy types to Python native types for JSON serialization. + + Args: + obj: Object that may contain numpy types + + Returns: + Object with numpy types converted to Python types + """ + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, dict): + return {key: convert_numpy_types(value) for key, value in obj.items()} + elif isinstance(obj, list): + return [convert_numpy_types(item) for item in obj] + else: + return obj + + +def format_memory_usage(bytes_used: int) -> str: + """ + Format memory usage in human-readable form. + + Args: + bytes_used: Memory usage in bytes + + Returns: + Formatted memory usage string + """ + for unit in ['B', 'KB', 'MB', 'GB']: + if bytes_used < 1024.0: + return f"{bytes_used:.1f} {unit}" + bytes_used /= 1024.0 + return f"{bytes_used:.1f} TB" + + +def create_sample_dataset(n_samples: int = 1000, n_features: int = 20, + n_classes: int = 2, noise: float = 0.1, + random_state: int = 42) -> pd.DataFrame: + """ + Create a sample cybersecurity dataset for testing. + + Args: + n_samples: Number of samples to generate + n_features: Number of features + n_classes: Number of classes (attack types) + noise: Amount of noise to add + random_state: Random seed for reproducibility + + Returns: + Generated DataFrame with features and labels + """ + from sklearn.datasets import make_classification + + X, y = make_classification( + n_samples=n_samples, + n_features=n_features, + n_classes=n_classes, + n_redundant=max(0, n_features // 4), + n_informative=max(2, n_features // 2), # Ensure at least 2 informative features + random_state=random_state, + flip_y=noise + ) + + # Create feature names + feature_names = [f'feature_{i:02d}' for i in range(n_features)] + + # Create DataFrame + df = pd.DataFrame(X, columns=feature_names) + + # Add some categorical features for realism + df['protocol_type'] = np.random.choice(['tcp', 'udp', 'icmp'], size=n_samples) + df['service'] = np.random.choice(['http', 'ftp', 'ssh', 'smtp'], size=n_samples) + + # Add labels + label_mapping = {0: 'normal', 1: 'attack'} if n_classes == 2 else {i: f'class_{i}' for i in range(n_classes)} + df['label'] = [label_mapping[label] for label in y] + + return df + + +def calculate_dataset_statistics(df: pd.DataFrame) -> Dict[str, Any]: + """ + Calculate comprehensive statistics for a dataset. + + Args: + df: Input DataFrame + + Returns: + Dictionary with dataset statistics + """ + stats = { + 'shape': df.shape, + 'memory_usage_mb': df.memory_usage(deep=True).sum() / (1024 * 1024), + 'dtypes': df.dtypes.value_counts().to_dict(), + 'missing_values': df.isnull().sum().sum(), + 'duplicate_rows': df.duplicated().sum(), + } + + # Numeric statistics + numeric_cols = df.select_dtypes(include=[np.number]).columns + if len(numeric_cols) > 0: + numeric_stats = df[numeric_cols].describe() + stats['numeric_summary'] = numeric_stats.to_dict() + + # Categorical statistics + categorical_cols = df.select_dtypes(include=['object']).columns + if len(categorical_cols) > 0: + categorical_stats = {} + for col in categorical_cols: + categorical_stats[col] = { + 'unique_values': df[col].nunique(), + 'top_values': df[col].value_counts().head(5).to_dict() + } + stats['categorical_summary'] = categorical_stats + + return stats + + +class PerformanceTimer: + """Context manager for timing operations.""" + + def __init__(self, operation_name: str, logger: Optional[logging.Logger] = None): + """ + Initialize the timer. + + Args: + operation_name: Name of the operation being timed + logger: Optional logger to use for output + """ + self.operation_name = operation_name + self.logger = logger or logging.getLogger(__name__) + self.start_time = None + + def __enter__(self): + """Start timing.""" + import time + self.start_time = time.time() + self.logger.info(f"Starting {self.operation_name}...") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """End timing and log results.""" + import time + elapsed_time = time.time() - self.start_time + + if exc_type is None: + self.logger.info(f"Completed {self.operation_name} in {elapsed_time:.2f} seconds") + else: + self.logger.error(f"Failed {self.operation_name} after {elapsed_time:.2f} seconds") + + return False # Don't suppress exceptions \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..2d80752 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite for the cybersecurity detection framework.""" \ No newline at end of file diff --git a/tests/test_framework.py b/tests/test_framework.py new file mode 100644 index 0000000..d4e0a28 --- /dev/null +++ b/tests/test_framework.py @@ -0,0 +1,354 @@ +""" +Test suite for the cybersecurity detection framework. +""" + +import pytest +import pandas as pd +import numpy as np +from unittest.mock import patch, MagicMock +import tempfile +import os + +from src.core.preprocessing import DataPreprocessor, load_and_preprocess +from src.models.detector import CyberAttackDetector, ModelRegistry, ModelComparer +from src.utils.helpers import ( + ensure_directory, + validate_file_extension, + sanitize_filename, + create_sample_dataset, + calculate_dataset_statistics +) + + +class TestDataPreprocessor: + """Test cases for the DataPreprocessor class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.preprocessor = DataPreprocessor() + self.sample_data = create_sample_dataset(n_samples=100, n_features=5) + + def test_initialization(self): + """Test DataPreprocessor initialization.""" + assert self.preprocessor.scaler is not None + assert self.preprocessor.label_encoder is not None + assert self.preprocessor.is_fitted is False + + def test_validate_data_valid(self): + """Test data validation with valid data.""" + validation_results = self.preprocessor.validate_data(self.sample_data) + + assert validation_results['is_valid'] is True + assert len(validation_results['errors']) == 0 + assert validation_results['shape'] == self.sample_data.shape + assert 'label' in validation_results['columns'] + + def test_validate_data_empty(self): + """Test data validation with empty data.""" + empty_data = pd.DataFrame() + validation_results = self.preprocessor.validate_data(empty_data) + + assert validation_results['is_valid'] is False + assert 'Dataset is empty' in validation_results['errors'] + + def test_clean_data(self): + """Test data cleaning functionality.""" + # Create data with missing values + data_with_missing = self.sample_data.copy() + data_with_missing.loc[0:4, 'feature_01'] = np.nan + data_with_missing.loc[0:2, 'protocol_type'] = np.nan + + cleaned_data = self.preprocessor.clean_data(data_with_missing) + + assert cleaned_data.isnull().sum().sum() == 0 # No missing values + + def test_encode_categorical_features(self): + """Test categorical feature encoding.""" + encoded_data = self.preprocessor.encode_categorical_features(self.sample_data) + + # Check that categorical columns are encoded + assert 'protocol_type' not in encoded_data.columns + assert any('protocol_type_' in col for col in encoded_data.columns) + + def test_prepare_features_labels(self): + """Test feature and label separation.""" + X, y = self.preprocessor.prepare_features_labels(self.sample_data) + + assert 'label' not in X.columns + assert y is not None + assert len(X) == len(y) + + def test_scale_features(self): + """Test feature scaling.""" + X, y = self.preprocessor.prepare_features_labels(self.sample_data) + X_numeric = X.select_dtypes(include=[np.number]) + + X_scaled, _ = self.preprocessor.scale_features(X_numeric) + + assert self.preprocessor.is_fitted is True + assert X_scaled.shape == X_numeric.shape + # Check that scaled data has zero mean and unit variance (approximately) + assert np.allclose(np.mean(X_scaled, axis=0), 0, atol=1e-7) + assert np.allclose(np.std(X_scaled, axis=0), 1, atol=1e-7) + + def test_full_preprocessing_pipeline(self): + """Test the complete preprocessing pipeline.""" + X, y, validation_results = self.preprocessor.full_preprocessing_pipeline(self.sample_data) + + assert validation_results['is_valid'] is True + assert X is not None + assert y is not None + assert len(X) == len(y) + assert 'label' not in X.columns + + +class TestCyberAttackDetector: + """Test cases for the CyberAttackDetector class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.detector = CyberAttackDetector('random_forest') + self.sample_data = create_sample_dataset(n_samples=100, n_features=10) + + # Prepare training data + preprocessor = DataPreprocessor() + X, y, _ = preprocessor.full_preprocessing_pipeline(self.sample_data) + X_scaled, _ = preprocessor.scale_features(X.select_dtypes(include=[np.number])) + + self.X_train = X_scaled + self.y_train = y + + def test_initialization(self): + """Test detector initialization.""" + assert self.detector.model_type == 'random_forest' + assert self.detector.model is not None + assert self.detector.is_trained is False + + def test_initialization_invalid_model(self): + """Test detector initialization with invalid model type.""" + with pytest.raises(ValueError): + CyberAttackDetector('invalid_model') + + def test_train(self): + """Test model training.""" + training_results = self.detector.train(self.X_train, self.y_train) + + assert self.detector.is_trained is True + assert 'training_time' in training_results + assert 'cv_scores' in training_results + assert training_results['training_samples'] == len(self.X_train) + + def test_predict_untrained(self): + """Test prediction with untrained model.""" + with pytest.raises(ValueError): + self.detector.predict(self.X_train) + + def test_predict_trained(self): + """Test prediction with trained model.""" + self.detector.train(self.X_train, self.y_train) + predictions = self.detector.predict(self.X_train) + + assert len(predictions) == len(self.X_train) + assert all(pred in self.y_train.unique() for pred in predictions) + + def test_evaluate(self): + """Test model evaluation.""" + self.detector.train(self.X_train, self.y_train) + evaluation_results = self.detector.evaluate(self.X_train, self.y_train) + + assert 'accuracy' in evaluation_results + assert 'precision' in evaluation_results + assert 'recall' in evaluation_results + assert 'f1_score' in evaluation_results + assert 0 <= evaluation_results['accuracy'] <= 1 + + def test_save_load_model(self): + """Test model saving and loading.""" + self.detector.train(self.X_train, self.y_train) + + with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as tmp_file: + model_path = tmp_file.name + + try: + # Save model + self.detector.save_model(model_path) + assert os.path.exists(model_path) + + # Create new detector and load model + new_detector = CyberAttackDetector('random_forest') + new_detector.load_model(model_path) + + assert new_detector.is_trained is True + assert new_detector.model_type == self.detector.model_type + + # Test predictions are consistent + original_pred = self.detector.predict(self.X_train) + loaded_pred = new_detector.predict(self.X_train) + assert np.array_equal(original_pred, loaded_pred) + + finally: + if os.path.exists(model_path): + os.unlink(model_path) + + +class TestModelRegistry: + """Test cases for the ModelRegistry class.""" + + def test_get_available_models(self): + """Test getting available models.""" + models = ModelRegistry.get_available_models() + + assert isinstance(models, dict) + assert len(models) > 0 + assert 'random_forest' in models + assert 'logistic_regression' in models + + for model_name, config in models.items(): + assert 'class' in config + assert 'default_params' in config + assert 'param_grid' in config + + +class TestModelComparer: + """Test cases for the ModelComparer class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.comparer = ModelComparer() + self.sample_data = create_sample_dataset(n_samples=50, n_features=5) # Small dataset for faster tests + + # Prepare data + preprocessor = DataPreprocessor() + X, y, _ = preprocessor.full_preprocessing_pipeline(self.sample_data) + X_scaled, _ = preprocessor.scale_features(X.select_dtypes(include=[np.number])) + + from sklearn.model_selection import train_test_split + self.X_train, self.X_test, self.y_train, self.y_test = train_test_split( + X_scaled, y, test_size=0.3, random_state=42 + ) + + def test_add_model(self): + """Test adding models to comparer.""" + self.comparer.add_model('rf', 'random_forest') + self.comparer.add_model('lr', 'logistic_regression') + + assert len(self.comparer.models) == 2 + assert 'rf' in self.comparer.models + assert 'lr' in self.comparer.models + + def test_compare_models(self): + """Test model comparison.""" + self.comparer.add_model('rf', 'random_forest') + self.comparer.add_model('lr', 'logistic_regression') + + comparison_results = self.comparer.compare_models( + self.X_train, self.y_train, self.X_test, self.y_test + ) + + assert 'detailed_results' in comparison_results + assert 'summary' in comparison_results + + detailed_results = comparison_results['detailed_results'] + assert 'rf' in detailed_results + assert 'lr' in detailed_results + + for model_name, results in detailed_results.items(): + assert 'training' in results + assert 'evaluation' in results + + +class TestUtilities: + """Test cases for utility functions.""" + + def test_ensure_directory(self): + """Test directory creation.""" + with tempfile.TemporaryDirectory() as tmp_dir: + test_path = os.path.join(tmp_dir, 'test', 'nested', 'directory') + ensure_directory(test_path) + assert os.path.exists(test_path) + + def test_validate_file_extension(self): + """Test file extension validation.""" + allowed_extensions = {'csv', 'json', 'txt'} + + assert validate_file_extension('data.csv', allowed_extensions) is True + assert validate_file_extension('config.json', allowed_extensions) is True + assert validate_file_extension('readme.txt', allowed_extensions) is True + assert validate_file_extension('image.png', allowed_extensions) is False + assert validate_file_extension('data.CSV', allowed_extensions) is True # Case insensitive + + def test_sanitize_filename(self): + """Test filename sanitization.""" + dangerous_filename = '../../../etc/passwd' + sanitized = sanitize_filename(dangerous_filename) + + assert '/' not in sanitized + assert '..' not in sanitized + assert len(sanitized) > 0 + + def test_create_sample_dataset(self): + """Test sample dataset creation.""" + dataset = create_sample_dataset(n_samples=100, n_features=5) + + assert len(dataset) == 100 + assert 'label' in dataset.columns + assert 'protocol_type' in dataset.columns + assert 'service' in dataset.columns + + # Check feature columns + feature_cols = [col for col in dataset.columns if col.startswith('feature_')] + assert len(feature_cols) == 5 + + def test_calculate_dataset_statistics(self): + """Test dataset statistics calculation.""" + dataset = create_sample_dataset(n_samples=50, n_features=5) + stats = calculate_dataset_statistics(dataset) + + assert 'shape' in stats + assert 'memory_usage_mb' in stats + assert 'dtypes' in stats + assert 'missing_values' in stats + assert 'duplicate_rows' in stats + + assert stats['shape'] == dataset.shape + assert stats['missing_values'] == 0 # Sample dataset has no missing values + + +@pytest.fixture +def sample_csv_file(): + """Create a temporary CSV file for testing.""" + dataset = create_sample_dataset(n_samples=50, n_features=3) + + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp_file: + dataset.to_csv(tmp_file.name, index=False) + yield tmp_file.name + + os.unlink(tmp_file.name) + + +class TestIntegration: + """Integration tests for the complete framework.""" + + def test_end_to_end_pipeline(self, sample_csv_file): + """Test complete end-to-end pipeline.""" + # Load and preprocess data + X, y, validation_results = load_and_preprocess(sample_csv_file) + + assert validation_results['is_valid'] is True + assert X is not None + assert y is not None + + # Train model + detector = CyberAttackDetector('random_forest') + X_scaled, _ = DataPreprocessor().scale_features(X.select_dtypes(include=[np.number])) + + training_results = detector.train(X_scaled, y) + assert detector.is_trained is True + + # Make predictions + predictions = detector.predict(X_scaled) + assert len(predictions) == len(X_scaled) + + # Evaluate model + evaluation_results = detector.evaluate(X_scaled, y) + assert 'accuracy' in evaluation_results \ No newline at end of file