diff --git a/README.md b/README.md index d4a4e49..00fc92b 100644 --- a/README.md +++ b/README.md @@ -28,13 +28,15 @@ To complete these labs, you will need: Add the following string before each each Jupyter files (incluidng path) ``` -https://colab.research.google.com/github.com/frankwxu/AI4DigitalForensics/blob/main/ +https://colab.research.google.com/github/frankwxu/AI4DigitalForensics/blob/main/ ``` - Lab 1: [Hate speed detection](https://colab.research.google.com/github/frankwxu/AI4DigitalForensics/blob/main/lab01_Hate_speech_detection/social_media_threat_detection.ipynb) - Lab 2: [Gun detection](https://colab.research.google.com/github/frankwxu/AI4DigitalForensics/blob/main/lab02_Gun_detection_fasterRCNN/gun_detection_fasterRCNN.ipynb) +- Lab 10: [Reinforcement Learning](https://colab.research.google.com/github/frankwxu/AI4DigitalForensics/blob/main/lab10_Reinforcement_Learning/dqn_lunar_lander_demo.ipynb) + ## Contributing We welcome contributions from students, faculty, and researchers! Feel free to: diff --git a/lab10_Reinforcement_Learning/dqn_lunar_lander_demo.ipynb b/lab10_Reinforcement_Learning/dqn_lunar_lander_demo.ipynb index 8423c68..bc358cb 100644 --- a/lab10_Reinforcement_Learning/dqn_lunar_lander_demo.ipynb +++ b/lab10_Reinforcement_Learning/dqn_lunar_lander_demo.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -11,7 +11,11 @@ "# pip install swig\n", "# pip install box2d\n", "# pip install gymnasium\n", - "# pip install requests" + "# pip install requests\n", + "# pip install gymnasium[other]\n", + "\n", + "# Step 1: Install Dependencies\n", + "!pip install torch torchvision matplotlib opencv-python pycocotools py7zr requests pygame swig box2d gymnasium gymnasium[other] imageio" ] }, { @@ -1029,7 +1033,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -1046,7 +1050,7 @@ "Loaded model from models_xu/dqn_model_episode_1600.pth\n", "\n", "Starting demo...\n", - "Score obtained: 226.3022059875887\n" + "Score obtained: 216.69617981903787\n" ] } ], @@ -1069,6 +1073,124 @@ " score = demo_dqn_model(model_path, render=True, delay=0.000)" ] }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Episode finished. Total reward: 146.24241059984874\n" + ] + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "def show_video(env_name):\n", + " \"\"\"\n", + " Display a saved video from the 'video' folder in Colab.\n", + " \n", + " Parameters\n", + " ----------\n", + " env_name: str\n", + " Name of the environment (used to find the video file).\n", + " \"\"\"\n", + " mp4list = glob.glob(f'video/{env_name}-episode-*.mp4')\n", + " if len(mp4list) > 0:\n", + " mp4 = mp4list[0] # Take the first matching file\n", + " video = io.open(mp4, 'r+b').read()\n", + " encoded = base64.b64encode(video)\n", + " display(HTML(data=''''''.format(encoded.decode('ascii'))))\n", + " else:\n", + " print(\"Could not find video\")\n", + "\n", + "def show_video_of_model(agent, env_name, model_path):\n", + " \"\"\"\n", + " Play and record a video of an episode using the given agent and trained model in the specified environment.\n", + " \n", + " Parameters\n", + " ----------\n", + " agent: DQNAgent\n", + " The agent instance to use (untrained initially).\n", + " env_name: str\n", + " Name of the environment (e.g., 'LunarLander-v3').\n", + " model_path: str\n", + " Path to the saved trained model file (e.g., 'models_xu/dqn_model_episode_1600.pth').\n", + " \"\"\"\n", + " # Create environment with rgb_array rendering\n", + " env = gym.make(env_name, render_mode=\"rgb_array\")\n", + " \n", + " # Wrap with RecordVideo to save the episode\n", + " env = RecordVideo(env, video_folder=\"video\", name_prefix=env_name, episode_trigger=lambda x: True)\n", + "\n", + " # Load the trained model into the agent\n", + " try:\n", + " agent.qnetwork_local.load_state_dict(torch.load(model_path))\n", + " agent.qnetwork_local.eval()\n", + " except AttributeError:\n", + " # Fallback to q_network if qnetwork_local doesn't exist\n", + " agent.q_network.load_state_dict(torch.load(model_path))\n", + " agent.q_network.eval()\n", + "\n", + " # Play one episode\n", + " state, _ = env.reset()\n", + " total_reward = 0\n", + " done = False\n", + " while not done:\n", + " action = agent.act(state, eps=0) # Greedy policy\n", + " state, reward, terminated, truncated, _ = env.step(action)\n", + " total_reward += reward\n", + " done = terminated or truncated\n", + "\n", + " print(f\"Episode finished. Total reward: {total_reward}\")\n", + " env.close()\n", + "\n", + " # Display the video\n", + " show_video(env_name)\n", + "\n", + "# Example usage in Colab\n", + "if __name__ == \"__main__\":\n", + " # Initialize agent\n", + " state_size = 8 # LunarLander-v3 state size\n", + " action_size = 4 # LunarLander-v3 action size\n", + " agent = DQNAgent(state_size=state_size, action_size=action_size)\n", + "\n", + " # Create video directory if it doesn't exist\n", + " import os\n", + " os.makedirs(\"video\", exist_ok=True)\n", + "\n", + " # Specify the trained model path\n", + " model_path = 'models_xu/dqn_model_episode_1700.pth' # Pass your desired model here\n", + "\n", + " # Run the demo with the specified model\n", + " show_video_of_model(agent, 'LunarLander-v3', model_path)\n", + "\n", + " # Optional: Download the video\n", + " # from google.colab import files\n", + " # video_file = glob.glob('video/LunarLander-v3-rl-video-episode-*.mp4')[0] # Match the generated file\n", + " # files.download(video_file)" + ] + }, { "cell_type": "code", "execution_count": null,