diff --git a/README.md b/README.md
index d4a4e49..00fc92b 100644
--- a/README.md
+++ b/README.md
@@ -28,13 +28,15 @@ To complete these labs, you will need:
Add the following string before each each Jupyter files (incluidng path)
```
-https://colab.research.google.com/github.com/frankwxu/AI4DigitalForensics/blob/main/
+https://colab.research.google.com/github/frankwxu/AI4DigitalForensics/blob/main/
```
- Lab 1: [Hate speed detection](https://colab.research.google.com/github/frankwxu/AI4DigitalForensics/blob/main/lab01_Hate_speech_detection/social_media_threat_detection.ipynb)
- Lab 2: [Gun detection](https://colab.research.google.com/github/frankwxu/AI4DigitalForensics/blob/main/lab02_Gun_detection_fasterRCNN/gun_detection_fasterRCNN.ipynb)
+- Lab 10: [Reinforcement Learning](https://colab.research.google.com/github/frankwxu/AI4DigitalForensics/blob/main/lab10_Reinforcement_Learning/dqn_lunar_lander_demo.ipynb)
+
## Contributing
We welcome contributions from students, faculty, and researchers! Feel free to:
diff --git a/lab10_Reinforcement_Learning/dqn_lunar_lander_demo.ipynb b/lab10_Reinforcement_Learning/dqn_lunar_lander_demo.ipynb
index 8423c68..bc358cb 100644
--- a/lab10_Reinforcement_Learning/dqn_lunar_lander_demo.ipynb
+++ b/lab10_Reinforcement_Learning/dqn_lunar_lander_demo.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -11,7 +11,11 @@
"# pip install swig\n",
"# pip install box2d\n",
"# pip install gymnasium\n",
- "# pip install requests"
+ "# pip install requests\n",
+ "# pip install gymnasium[other]\n",
+ "\n",
+ "# Step 1: Install Dependencies\n",
+ "!pip install torch torchvision matplotlib opencv-python pycocotools py7zr requests pygame swig box2d gymnasium gymnasium[other] imageio"
]
},
{
@@ -1029,7 +1033,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 32,
"metadata": {},
"outputs": [
{
@@ -1046,7 +1050,7 @@
"Loaded model from models_xu/dqn_model_episode_1600.pth\n",
"\n",
"Starting demo...\n",
- "Score obtained: 226.3022059875887\n"
+ "Score obtained: 216.69617981903787\n"
]
}
],
@@ -1069,6 +1073,124 @@
" score = demo_dqn_model(model_path, render=True, delay=0.000)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Episode finished. Total reward: 146.24241059984874\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "\n",
+ "def show_video(env_name):\n",
+ " \"\"\"\n",
+ " Display a saved video from the 'video' folder in Colab.\n",
+ " \n",
+ " Parameters\n",
+ " ----------\n",
+ " env_name: str\n",
+ " Name of the environment (used to find the video file).\n",
+ " \"\"\"\n",
+ " mp4list = glob.glob(f'video/{env_name}-episode-*.mp4')\n",
+ " if len(mp4list) > 0:\n",
+ " mp4 = mp4list[0] # Take the first matching file\n",
+ " video = io.open(mp4, 'r+b').read()\n",
+ " encoded = base64.b64encode(video)\n",
+ " display(HTML(data=''''''.format(encoded.decode('ascii'))))\n",
+ " else:\n",
+ " print(\"Could not find video\")\n",
+ "\n",
+ "def show_video_of_model(agent, env_name, model_path):\n",
+ " \"\"\"\n",
+ " Play and record a video of an episode using the given agent and trained model in the specified environment.\n",
+ " \n",
+ " Parameters\n",
+ " ----------\n",
+ " agent: DQNAgent\n",
+ " The agent instance to use (untrained initially).\n",
+ " env_name: str\n",
+ " Name of the environment (e.g., 'LunarLander-v3').\n",
+ " model_path: str\n",
+ " Path to the saved trained model file (e.g., 'models_xu/dqn_model_episode_1600.pth').\n",
+ " \"\"\"\n",
+ " # Create environment with rgb_array rendering\n",
+ " env = gym.make(env_name, render_mode=\"rgb_array\")\n",
+ " \n",
+ " # Wrap with RecordVideo to save the episode\n",
+ " env = RecordVideo(env, video_folder=\"video\", name_prefix=env_name, episode_trigger=lambda x: True)\n",
+ "\n",
+ " # Load the trained model into the agent\n",
+ " try:\n",
+ " agent.qnetwork_local.load_state_dict(torch.load(model_path))\n",
+ " agent.qnetwork_local.eval()\n",
+ " except AttributeError:\n",
+ " # Fallback to q_network if qnetwork_local doesn't exist\n",
+ " agent.q_network.load_state_dict(torch.load(model_path))\n",
+ " agent.q_network.eval()\n",
+ "\n",
+ " # Play one episode\n",
+ " state, _ = env.reset()\n",
+ " total_reward = 0\n",
+ " done = False\n",
+ " while not done:\n",
+ " action = agent.act(state, eps=0) # Greedy policy\n",
+ " state, reward, terminated, truncated, _ = env.step(action)\n",
+ " total_reward += reward\n",
+ " done = terminated or truncated\n",
+ "\n",
+ " print(f\"Episode finished. Total reward: {total_reward}\")\n",
+ " env.close()\n",
+ "\n",
+ " # Display the video\n",
+ " show_video(env_name)\n",
+ "\n",
+ "# Example usage in Colab\n",
+ "if __name__ == \"__main__\":\n",
+ " # Initialize agent\n",
+ " state_size = 8 # LunarLander-v3 state size\n",
+ " action_size = 4 # LunarLander-v3 action size\n",
+ " agent = DQNAgent(state_size=state_size, action_size=action_size)\n",
+ "\n",
+ " # Create video directory if it doesn't exist\n",
+ " import os\n",
+ " os.makedirs(\"video\", exist_ok=True)\n",
+ "\n",
+ " # Specify the trained model path\n",
+ " model_path = 'models_xu/dqn_model_episode_1700.pth' # Pass your desired model here\n",
+ "\n",
+ " # Run the demo with the specified model\n",
+ " show_video_of_model(agent, 'LunarLander-v3', model_path)\n",
+ "\n",
+ " # Optional: Download the video\n",
+ " # from google.colab import files\n",
+ " # video_file = glob.glob('video/LunarLander-v3-rl-video-episode-*.mp4')[0] # Match the generated file\n",
+ " # files.download(video_file)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,