{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "18676180", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import ipywidgets as widgets\n", "from tqdm import tqdm\n", "import random\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 254, "id": "d69c70f2", "metadata": {}, "outputs": [], "source": [ "class MCTSNode:\n", " def __init__(self, state, parent_node):\n", " self.state = state\n", " self.parent_node = parent_node\n", " self.total_visits = 0\n", " self.total_score = 0\n", " self.children_nodes = []\n", " self.player = self.check_player(state)\n", " self.terminate_state = False\n", " self.all_children_nodes = False\n", "\n", " def check_player(self, state):\n", " if np.sum(state==1) > np.sum(state==2):\n", " return 2\n", " else:\n", " return 1\n", "\n", "class MCTS:\n", " def __init__(self, exploration_constant = 2):\n", " self.exploration_constant = exploration_constant\n", "\n", " def is_terminal(self, board):\n", " return not np.any(board == 0)\n", "\n", " def is_win(self, state, player):\n", " col_win = (np.sum(state == player, axis=0) == 3).any()\n", " row_win = (np.sum(state == player, axis=1) == 3).any()\n", " diagonal_win = np.trace(state == player) == 3\n", " opposite_diagonal = np.trace(np.fliplr(state) == player) == 3\n", " return col_win or row_win or diagonal_win or opposite_diagonal\n", "\n", " def select(self, curr_node, should_explore=True):\n", " while not is_terminal(curr_node.state) and not (self.is_win(curr_node.state, 1) or self.is_win(curr_node.state, 2)):\n", " if curr_node.all_children_nodes:\n", " highest_value = -float(\"inf\")\n", " chosen_child = None\n", "\n", " # loop all children nodes and take the best one according to heuristic\n", " for child in curr_node.children_nodes:\n", " # compute UCB1 score\n", " child_val = (child.total_score/child.total_visits) + should_explore*self.exploration_constant*np.sqrt(np.log(curr_node.total_visits)/child.total_visits)\n", "\n", " # if it has highest value then store it as the chosen child from this step\n", " if child_val > highest_value:\n", " highest_value = child_val\n", " chosen_child = child\n", "\n", " # choose highest value move\n", " return chosen_child\n", "\n", " else:\n", " # if not all children nodes accessible then expand the node first\n", " return self.expand(curr_node)\n", "\n", " print(\"should never come here\")\n", "\n", " def expand(self, curr_node):\n", " states = self.generate_next_states(curr_node)\n", "\n", " for state in states:\n", " # unroll children states, and ensure we do not expand to a state we have \n", " # already expanded to in a previous iteration\n", " if str(state) not in [str(b.state) for b in curr_node.children_nodes]:\n", " child_node = MCTSNode(state, curr_node)\n", " curr_node.children_nodes.append(child_node)\n", " \n", " # if the num children nodes equal the amount of possible next states\n", " # we have explored all child nodes for this state\n", " if len(states) == len(curr_node.children_nodes):\n", " curr_node.all_children_nodes = True\n", "\n", " return child_node\n", "\n", "\n", " def simulate(self, curr_node, computer_playing):\n", " opponent = 1 if computer_playing == 2 else 1\n", " \n", " while not is_terminal(curr_node.state) and not (self.is_win(curr_node.state, 1) or self.is_win(curr_node.state, 2)):\n", " next_states = self.generate_next_states(curr_node)\n", " curr_node = MCTSNode(next_states[random.randint(0, len(next_states) - 1)], curr_node)\n", " \n", " if self.is_win(curr_node.state, player=computer_playing):\n", " return 1\n", " elif self.is_win(curr_node.state, player=opponent):\n", " return -1\n", " else:\n", " return 0\n", "\n", " \n", " def backpropagate(self, node, score):\n", " while node:\n", " node.total_visits += 1\n", " node.total_score += score\n", " node = node.parent_node\n", " \n", " def generate_next_states(self, curr_node):\n", " player = curr_node.player\n", " curr_state = curr_node.state\n", " next_states = []\n", " for i in range(3):\n", " for j in range(3):\n", " if curr_state[i,j] == 0:\n", " to_append = np.copy(curr_state)\n", " to_append[i,j] = player\n", " next_states.append(to_append)\n", " return next_states\n", "\n", "\n", " def get_move(self, root, num_iterations=1000):\n", " for it in range(num_iterations):\n", " curr_node = self.select(root)\n", " obtained_value = self.simulate(curr_node, root.player)\n", " self.backpropagate(curr_node, obtained_value)\n", " \n", " chosen_move = self.select(root, should_explore=False)\n", " return chosen_move" ] }, { "cell_type": "code", "execution_count": 263, "id": "36e39228", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Row and column to place with ,1,1\n", "[[0. 0. 0.]\n", " [0. 1. 0.]\n", " [0. 0. 2.]]\n", "Row and column to place with ,0,0\n", "[[1. 0. 0.]\n", " [0. 1. 0.]\n", " [2. 0. 2.]]\n", "Row and column to place with ,2,1\n", "[[1. 2. 0.]\n", " [0. 1. 0.]\n", " [2. 1. 2.]]\n", "Row and column to place with ,1,2\n", "[[1. 2. 0.]\n", " [2. 1. 1.]\n", " [2. 1. 2.]]\n", "Row and column to place with ,0,2\n", "should never come here\n" ] }, { "ename": "AttributeError", "evalue": "'NoneType' object has no attribute 'state'", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15720/2518229713.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[0mnext_node\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mMCTSNode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mroot\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 10\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 11\u001b[1;33m \u001b[0mroot\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmc\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_move\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_node\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 12\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15720/416212796.py\u001b[0m in \u001b[0;36mget_move\u001b[1;34m(self, root, num_iterations)\u001b[0m\n\u001b[0;32m 110\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mit\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnum_iterations\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 111\u001b[0m \u001b[0mcurr_node\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mselect\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 112\u001b[1;33m \u001b[0mobtained_value\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msimulate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mroot\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mplayer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 113\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackpropagate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mobtained_value\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 114\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15720/416212796.py\u001b[0m in \u001b[0;36msimulate\u001b[1;34m(self, curr_node, computer_playing)\u001b[0m\n\u001b[0;32m 76\u001b[0m \u001b[0mopponent\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m1\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mcomputer_playing\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m2\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 77\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 78\u001b[1;33m \u001b[1;32mwhile\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mis_terminal\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_win\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_win\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 79\u001b[0m \u001b[0mnext_states\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate_next_states\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 80\u001b[0m \u001b[0mcurr_node\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mMCTSNode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_states\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mrandom\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_states\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m-\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcurr_node\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'state'" ] } ], "source": [ "a = np.zeros((3,3))\n", "root = MCTSNode(a, None)\n", "mc = MCTS()\n", "\n", "for i in range(9):\n", " row_col = input(\"Row and column to place with ,\").split(\",\")\n", " state = np.copy(root.state)\n", " state[int(row_col[0]), int(row_col[1])] = 1\n", " next_node = MCTSNode(state, root)\n", " \n", " root = mc.get_move(next_node)\n", " print(root.state)\n", "\n", "print(\"Final: {root.state}\")\n", " " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 5 }