A Python implementation of the classic Snake game powered by Deep Q-Learning Network (DQN). This project combines PyGame for visualization with PyTorch for deep reinforcement learning, creating an AI agent that learns to play Snake through experience.
SnakeDQN.mp4
- Classic Snake game implementation using PyGame
- Deep Q-Learning Network (DQN) with PyTorch
- Experience replay for stable training
- Model save/load functionality
- Real-time visualization of training
- Configurable hyperparameters
- Support for both CPU and CUDA training
pygame
numpy
torch
- Clone the repository:
git clone https://github.com/Abhigyan126/Snake-DQN
cd Snake-DQN- Install dependencies:
pip install pygame numpy torchRun the main script to start training:
python train.pyYou'll be presented with three options:
- Start new training
- Continue training existing model
- Exit
The training process will display:
- Real-time game visualization
- Episode progress
- Total rewards
- Current exploration rate (epsilon)
Training pipeline
flowchart LR
subgraph Game["Snake Game Environment"]
State["State\n(12 inputs)"] --> Action
Action --> Reward
Action --> NextState["Next State"]
end
subgraph Agent["DQN Agent"]
Policy["Policy Network\n(5 layers)"]
Target["Target Network\n(5 layers)"]
Memory["Replay Buffer\n(10000 capacity)"]
Policy --> |"Select Action"| Action
State --> Policy
Memory --> |"Batch (64)"| Policy
Policy --> |"Update"| Target
end
State --> Memory
Action --> Memory
Reward --> Memory
NextState --> Memory
- Input layer: 12 nodes (state space)
- Hidden layers: 64 → 64 → 128 → 64 nodes
- Output layer: 4 nodes (action space)
- Activation function: Leaky ReLU
The game state consists of 12 binary values:
- Danger detection (4 directions)
- Current direction (4 possibilities)
- Food location relative to snake (4 directions)
- Replay buffer size: 10,000
- Batch size: 64
- Learning rate: 0.001
- Discount factor (gamma): 0.99
- Initial epsilon: 1.0
- Minimum epsilon: 0.01
- Epsilon decay: 0.995
snake-ai-dqn/
├── train.py # Main game and training logic
├── test.py # Test scrip to run the model
├── README.md # Project documentation
└── snake_dqn_model.pth # Saved model checkpoints