This project implements a Snake game with an AI agent that learns to play the game using reinforcement learning. The AI is trained using a Deep Q-Network (DQN) approach.
├── game.py
├── model.pth
├── requirements.txt
├── training_stats.txt
└── README.md
game.py: Contains the main code for the Snake game and the AI agent.model.pth: The saved model weights for the AI agent.training_stats.txt: Contains training statistics such as the number of games played and the highest score achieved.requirements.txt: Requirements fileREADME.md: This file.
- Python 3.x
- Pygame
- NumPy
- PyTorch
- Matplotlib
- Clone the repository:
git clone https://github.com/pranavms13/snake-dqn
cd snake-nn- Install the required packages:
pip install -r requirements.txtTo train the AI agent, run:
python game.pyThe training process will start, and the AI agent will learn to play the Snake game. The training statistics and model weights will be saved periodically.
-
Direction: Enum class for the direction of the snake. -
Point: Class representing a point on the game board. -
SnakeGame: Class representing the Snake game.__init__(self, w=640, h=480): Initializes the game.reset(self): Resets the game.play_step(self, action): Executes a game step based on the action.is_collision(self, pt=None): Checks for collisions._update_ui(self): Updates the game UI._move(self, action): Moves the snake based on the action._place_fruit(self): Places a fruit on the game board.
-
Agent: Class representing the AI agent.__init__(self): Initializes the agent.get_state(self, game): Gets the current state of the game.remember(self, state, action, reward, next_state, game_over): Stores the experience in memory.train_long_memory(self): Trains the agent using long-term memory.train_short_memory(self, state, action, reward, next_state, game_over): Trains the agent using short-term memory.get_action(self, state): Gets the action to be taken by the agent.
-
Linear_QNet: Class representing the neural network for the Q-learning agent.__init__(self, input_size, hidden_size, output_size): Initializes the neural network.forward(self, x): Forward pass of the neural network.save(self, file_name='model.pth'): Saves the model weights.load(self, file_name='model.pth'): Loads the model weights.
-
QTrainer: Class for training the Q-learning agent.__init__(self, model, lr, gamma): Initializes the trainer.train_step(self, state, action, reward, next_state, game_over): Performs a training step.
-
train(): Main function to train the AI agent.
This project is licensed under the MIT License.
Feel free to modify the README as needed to better fit your project.