mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
added_mcts_and_metrics
This commit is contained in:
228
ML/algorithms/MCTS/Monte Carlo Tree Search - TicTacToe.ipynb
Normal file
228
ML/algorithms/MCTS/Monte Carlo Tree Search - TicTacToe.ipynb
Normal file
@@ -0,0 +1,228 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "18676180",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import ipywidgets as widgets\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import random\n",
|
||||
"import matplotlib.pyplot as plt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 254,
|
||||
"id": "d69c70f2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MCTSNode:\n",
|
||||
" def __init__(self, state, parent_node):\n",
|
||||
" self.state = state\n",
|
||||
" self.parent_node = parent_node\n",
|
||||
" self.total_visits = 0\n",
|
||||
" self.total_score = 0\n",
|
||||
" self.children_nodes = []\n",
|
||||
" self.player = self.check_player(state)\n",
|
||||
" self.terminate_state = False\n",
|
||||
" self.all_children_nodes = False\n",
|
||||
"\n",
|
||||
" def check_player(self, state):\n",
|
||||
" if np.sum(state==1) > np.sum(state==2):\n",
|
||||
" return 2\n",
|
||||
" else:\n",
|
||||
" return 1\n",
|
||||
"\n",
|
||||
"class MCTS:\n",
|
||||
" def __init__(self, exploration_constant = 2):\n",
|
||||
" self.exploration_constant = exploration_constant\n",
|
||||
"\n",
|
||||
" def is_terminal(self, board):\n",
|
||||
" return not np.any(board == 0)\n",
|
||||
"\n",
|
||||
" def is_win(self, state, player):\n",
|
||||
" col_win = (np.sum(state == player, axis=0) == 3).any()\n",
|
||||
" row_win = (np.sum(state == player, axis=1) == 3).any()\n",
|
||||
" diagonal_win = np.trace(state == player) == 3\n",
|
||||
" opposite_diagonal = np.trace(np.fliplr(state) == player) == 3\n",
|
||||
" return col_win or row_win or diagonal_win or opposite_diagonal\n",
|
||||
"\n",
|
||||
" def select(self, curr_node, should_explore=True):\n",
|
||||
" while not is_terminal(curr_node.state) and not (self.is_win(curr_node.state, 1) or self.is_win(curr_node.state, 2)):\n",
|
||||
" if curr_node.all_children_nodes:\n",
|
||||
" highest_value = -float(\"inf\")\n",
|
||||
" chosen_child = None\n",
|
||||
"\n",
|
||||
" # loop all children nodes and take the best one according to heuristic\n",
|
||||
" for child in curr_node.children_nodes:\n",
|
||||
" # compute UCB1 score\n",
|
||||
" child_val = (child.total_score/child.total_visits) + should_explore*self.exploration_constant*np.sqrt(np.log(curr_node.total_visits)/child.total_visits)\n",
|
||||
"\n",
|
||||
" # if it has highest value then store it as the chosen child from this step\n",
|
||||
" if child_val > highest_value:\n",
|
||||
" highest_value = child_val\n",
|
||||
" chosen_child = child\n",
|
||||
"\n",
|
||||
" # choose highest value move\n",
|
||||
" return chosen_child\n",
|
||||
"\n",
|
||||
" else:\n",
|
||||
" # if not all children nodes accessible then expand the node first\n",
|
||||
" return self.expand(curr_node)\n",
|
||||
"\n",
|
||||
" print(\"should never come here\")\n",
|
||||
"\n",
|
||||
" def expand(self, curr_node):\n",
|
||||
" states = self.generate_next_states(curr_node)\n",
|
||||
"\n",
|
||||
" for state in states:\n",
|
||||
" # unroll children states, and ensure we do not expand to a state we have \n",
|
||||
" # already expanded to in a previous iteration\n",
|
||||
" if str(state) not in [str(b.state) for b in curr_node.children_nodes]:\n",
|
||||
" child_node = MCTSNode(state, curr_node)\n",
|
||||
" curr_node.children_nodes.append(child_node)\n",
|
||||
" \n",
|
||||
" # if the num children nodes equal the amount of possible next states\n",
|
||||
" # we have explored all child nodes for this state\n",
|
||||
" if len(states) == len(curr_node.children_nodes):\n",
|
||||
" curr_node.all_children_nodes = True\n",
|
||||
"\n",
|
||||
" return child_node\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def simulate(self, curr_node, computer_playing):\n",
|
||||
" opponent = 1 if computer_playing == 2 else 1\n",
|
||||
" \n",
|
||||
" while not is_terminal(curr_node.state) and not (self.is_win(curr_node.state, 1) or self.is_win(curr_node.state, 2)):\n",
|
||||
" next_states = self.generate_next_states(curr_node)\n",
|
||||
" curr_node = MCTSNode(next_states[random.randint(0, len(next_states) - 1)], curr_node)\n",
|
||||
" \n",
|
||||
" if self.is_win(curr_node.state, player=computer_playing):\n",
|
||||
" return 1\n",
|
||||
" elif self.is_win(curr_node.state, player=opponent):\n",
|
||||
" return -1\n",
|
||||
" else:\n",
|
||||
" return 0\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" def backpropagate(self, node, score):\n",
|
||||
" while node:\n",
|
||||
" node.total_visits += 1\n",
|
||||
" node.total_score += score\n",
|
||||
" node = node.parent_node\n",
|
||||
" \n",
|
||||
" def generate_next_states(self, curr_node):\n",
|
||||
" player = curr_node.player\n",
|
||||
" curr_state = curr_node.state\n",
|
||||
" next_states = []\n",
|
||||
" for i in range(3):\n",
|
||||
" for j in range(3):\n",
|
||||
" if curr_state[i,j] == 0:\n",
|
||||
" to_append = np.copy(curr_state)\n",
|
||||
" to_append[i,j] = player\n",
|
||||
" next_states.append(to_append)\n",
|
||||
" return next_states\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def get_move(self, root, num_iterations=1000):\n",
|
||||
" for it in range(num_iterations):\n",
|
||||
" curr_node = self.select(root)\n",
|
||||
" obtained_value = self.simulate(curr_node, root.player)\n",
|
||||
" self.backpropagate(curr_node, obtained_value)\n",
|
||||
" \n",
|
||||
" chosen_move = self.select(root, should_explore=False)\n",
|
||||
" return chosen_move"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 263,
|
||||
"id": "36e39228",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Row and column to place with ,1,1\n",
|
||||
"[[0. 0. 0.]\n",
|
||||
" [0. 1. 0.]\n",
|
||||
" [0. 0. 2.]]\n",
|
||||
"Row and column to place with ,0,0\n",
|
||||
"[[1. 0. 0.]\n",
|
||||
" [0. 1. 0.]\n",
|
||||
" [2. 0. 2.]]\n",
|
||||
"Row and column to place with ,2,1\n",
|
||||
"[[1. 2. 0.]\n",
|
||||
" [0. 1. 0.]\n",
|
||||
" [2. 1. 2.]]\n",
|
||||
"Row and column to place with ,1,2\n",
|
||||
"[[1. 2. 0.]\n",
|
||||
" [2. 1. 1.]\n",
|
||||
" [2. 1. 2.]]\n",
|
||||
"Row and column to place with ,0,2\n",
|
||||
"should never come here\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "AttributeError",
|
||||
"evalue": "'NoneType' object has no attribute 'state'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15720/2518229713.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[0mnext_node\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mMCTSNode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mroot\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 10\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 11\u001b[1;33m \u001b[0mroot\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmc\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_move\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_node\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 12\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15720/416212796.py\u001b[0m in \u001b[0;36mget_move\u001b[1;34m(self, root, num_iterations)\u001b[0m\n\u001b[0;32m 110\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mit\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnum_iterations\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 111\u001b[0m \u001b[0mcurr_node\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mselect\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 112\u001b[1;33m \u001b[0mobtained_value\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msimulate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mroot\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mplayer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 113\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackpropagate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mobtained_value\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 114\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15720/416212796.py\u001b[0m in \u001b[0;36msimulate\u001b[1;34m(self, curr_node, computer_playing)\u001b[0m\n\u001b[0;32m 76\u001b[0m \u001b[0mopponent\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m1\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mcomputer_playing\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m2\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 77\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 78\u001b[1;33m \u001b[1;32mwhile\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mis_terminal\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_win\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_win\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 79\u001b[0m \u001b[0mnext_states\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate_next_states\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 80\u001b[0m \u001b[0mcurr_node\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mMCTSNode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_states\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mrandom\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_states\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m-\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcurr_node\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'state'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a = np.zeros((3,3))\n",
|
||||
"root = MCTSNode(a, None)\n",
|
||||
"mc = MCTS()\n",
|
||||
"\n",
|
||||
"for i in range(9):\n",
|
||||
" row_col = input(\"Row and column to place with ,\").split(\",\")\n",
|
||||
" state = np.copy(root.state)\n",
|
||||
" state[int(row_col[0]), int(row_col[1])] = 1\n",
|
||||
" next_node = MCTSNode(state, root)\n",
|
||||
" \n",
|
||||
" root = mc.get_move(next_node)\n",
|
||||
" print(root.state)\n",
|
||||
"\n",
|
||||
"print(\"Final: {root.state}\")\n",
|
||||
" "
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
100
ML/ml_metrics/data.txt
Normal file
100
ML/ml_metrics/data.txt
Normal file
@@ -0,0 +1,100 @@
|
||||
0 0.827142151760153
|
||||
0 0.6044595910412887
|
||||
0 0.7916340858282026
|
||||
0 0.16080518180592987
|
||||
0 0.611222921705038
|
||||
0 0.2555087295500818
|
||||
0 0.5681507664364468
|
||||
0 0.05990570219972058
|
||||
0 0.6644434078306367
|
||||
0 0.11293577405861703
|
||||
0 0.06152372321587048
|
||||
0 0.35250697207600584
|
||||
0 0.3226701829081975
|
||||
0 0.43339115381458776
|
||||
0 0.2280744262436838
|
||||
0 0.7219848389339433
|
||||
0 0.23527698971402375
|
||||
0 0.2850245335200196
|
||||
0 0.4107047877448165
|
||||
0 0.2008356196164621
|
||||
0 0.3711921802697385
|
||||
0 0.4234822657253734
|
||||
0 0.4876482027124213
|
||||
0 0.4234822657253734
|
||||
0 0.5750985220664769
|
||||
0 0.6734047730095499
|
||||
0 0.7355892648444824
|
||||
0 0.7137899092959652
|
||||
0 0.3873972469024071
|
||||
0 0.24042033264833723
|
||||
0 0.1663411647259707
|
||||
0 0.1663411647259707
|
||||
0 0.2850245335200196
|
||||
0 0.3683741846950643
|
||||
0 0.17375784896208155
|
||||
0 0.43636290738886574
|
||||
0 0.7219848389339433
|
||||
0 0.46745878087292836
|
||||
0 0.23527698971402375
|
||||
0 0.17202866439941822
|
||||
0 0.17786913865061538
|
||||
0 0.44335359557308707
|
||||
0 0.2768503833164947
|
||||
0 0.06891755391553003
|
||||
0 0.21414010746535972
|
||||
0 0.27120595352357546
|
||||
0 0.26328216986315905
|
||||
0 0.48056205121673834
|
||||
0 0.08848560476699129
|
||||
0 0.2555087295500818
|
||||
1 0.5681507664364468
|
||||
1 0.2850245335200196
|
||||
1 0.842216416418616
|
||||
1 0.5280820469827786
|
||||
1 0.6302728469340095
|
||||
1 0.9325162813331325
|
||||
1 0.062225621463076315
|
||||
1 0.8823445035377085
|
||||
1 0.670739773835188
|
||||
1 0.891663414209465
|
||||
1 0.6489254823470298
|
||||
1 0.5552119758821265
|
||||
1 0.7510275470993321
|
||||
1 0.23310831157247616
|
||||
1 0.2933421288888426
|
||||
1 0.6044595910412887
|
||||
1 0.6302728469340095
|
||||
1 0.9585115007613662
|
||||
1 0.9342800686704079
|
||||
1 0.3226701829081975
|
||||
1 0.7982301827889998
|
||||
1 0.22102862644325694
|
||||
1 0.9390780973389883
|
||||
1 0.5078780077620866
|
||||
1 0.7379344573081708
|
||||
1 0.8750078631067137
|
||||
1 0.4704701704107932
|
||||
1 0.44335359557308707
|
||||
1 0.5651814720676593
|
||||
1 0.8658845001112441
|
||||
1 0.897024614730928
|
||||
1 0.9712637967845552
|
||||
1 0.5651814720676593
|
||||
1 0.517987379389242
|
||||
1 0.40385540386469254
|
||||
1 0.9435470013187671
|
||||
1 0.5780506539476005
|
||||
1 0.594744923406366
|
||||
1 0.3970432858350056
|
||||
1 0.7916340858282026
|
||||
1 0.7219848389339433
|
||||
1 0.7916340858282026
|
||||
1 0.2850245335200196
|
||||
1 0.7658513560779588
|
||||
1 0.7379344573081708
|
||||
1 0.7137899092959652
|
||||
1 0.4876482027124213
|
||||
1 0.6302728469340095
|
||||
1 0.5310944974701136
|
||||
1 0.35250697207600584
|
||||
240
ML/ml_metrics/metrics.py
Normal file
240
ML/ml_metrics/metrics.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import numpy as np
|
||||
from scipy.integrate import simpson
|
||||
import matplotlib.pyplot as plt
|
||||
import warnings
|
||||
|
||||
|
||||
def true_positives(y_true, y_pred):
|
||||
tp = 0
|
||||
for label, pred in zip(y_true, y_pred):
|
||||
if pred == 1 and label == 1:
|
||||
tp += 1
|
||||
return tp
|
||||
|
||||
|
||||
def true_negatives(y_true, y_pred):
|
||||
tn = 0
|
||||
for label, pred in zip(y_true, y_pred):
|
||||
if pred == 0 and label == 0:
|
||||
tn += 1
|
||||
return tn
|
||||
|
||||
|
||||
def false_positives(y_true, y_pred):
|
||||
fp = 0
|
||||
for label, pred in zip(y_true, y_pred):
|
||||
if pred == 1 and label == 0:
|
||||
fp += 1
|
||||
return fp
|
||||
|
||||
|
||||
def false_negatives(y_true, y_pred):
|
||||
fn = 0
|
||||
for label, pred in zip(y_true, y_pred):
|
||||
if pred == 0 and label == 1:
|
||||
fn += 1
|
||||
return fn
|
||||
|
||||
|
||||
def binary_accuracy(y_true, y_pred):
|
||||
tp = true_positives(y_true, y_pred)
|
||||
tn = true_negatives(y_true, y_pred)
|
||||
fp = false_positives(y_true, y_pred)
|
||||
fn = false_negatives(y_true, y_pred)
|
||||
return (tp + tn) / (tp + tn + fp + fn)
|
||||
|
||||
|
||||
def precision(y_true, y_pred):
|
||||
"""
|
||||
Fraction of True Positive Elements divided by total number of positive predicted units
|
||||
How I view it: Assuming we say someone has cancer: how often are we correct?
|
||||
It tells us how much we can trust the model when it predicts an individual as positive.
|
||||
"""
|
||||
tp = true_negatives(y_true, y_pred)
|
||||
fp = false_positives(y_true, y_pred)
|
||||
return tp / (tp + fp)
|
||||
|
||||
|
||||
def recall(y_true, y_pred):
|
||||
"""
|
||||
Recall meaasure the model's predictive accuracy for the positive class.
|
||||
How I view it, out of all the people that has cancer: how often are
|
||||
we able to detect it?
|
||||
"""
|
||||
tp = true_negatives(y_true, y_pred)
|
||||
fn = false_negatives(y_true, y_pred)
|
||||
return tp / (tp + fn)
|
||||
|
||||
|
||||
def multiclass_accuracy(y_true, y_pred):
|
||||
correct = 0
|
||||
total = len(y_true)
|
||||
for label, pred in zip(y_true, y_pred):
|
||||
correct += label == pred
|
||||
return correct/total
|
||||
|
||||
|
||||
def confusion_matrix(y_true, y_pred):
|
||||
y_true = np.array(y_true)
|
||||
y_pred = np.array(y_pred)
|
||||
assert y_true.shape == y_pred.shape
|
||||
unique_classes = np.unique(np.concatenate([y_true, y_pred], axis=0)).shape[0]
|
||||
cm = np.zeros((unique_classes, unique_classes), dtype=np.int64)
|
||||
|
||||
for label, pred in zip(y_true, y_pred):
|
||||
cm[label, pred] += 1
|
||||
|
||||
return cm
|
||||
|
||||
|
||||
def accuracy_cm(cm):
|
||||
return np.trace(cm)/np.sum(cm)
|
||||
|
||||
|
||||
def balanced_accuracy_cm(cm):
|
||||
correctly_classified = np.diagonal(cm)
|
||||
rows_sum = np.sum(cm, axis=1)
|
||||
indices = np.nonzero(rows_sum)[0]
|
||||
if rows_sum.shape[0] != indices.shape[0]:
|
||||
warnings.warn("y_pred contains classes not in y_true")
|
||||
accuracy_per_class = correctly_classified[indices]/(rows_sum[indices])
|
||||
return np.sum(accuracy_per_class)/accuracy_per_class.shape[0]
|
||||
|
||||
|
||||
def precision_cm(cm, average="specific", class_label=1, eps=1e-12):
|
||||
tp = np.diagonal(cm)
|
||||
fp = np.sum(cm, axis=0) - tp
|
||||
#precisions = np.diagonal(cm)/np.maximum(np.sum(cm, axis=0), 1e-12)
|
||||
|
||||
if average == "none":
|
||||
return tp/(tp+fp+eps)
|
||||
|
||||
if average == "specific":
|
||||
precisions = tp / (tp + fp + eps)
|
||||
return precisions[class_label]
|
||||
|
||||
if average == "micro":
|
||||
# all samples equally contribute to the average,
|
||||
# hence there is a distinction between highly
|
||||
# and poorly populated classes
|
||||
return np.sum(tp) / (np.sum(tp) + np.sum(fp) + eps)
|
||||
|
||||
if average == "macro":
|
||||
# all classes equally contribute to the average,
|
||||
# no distinction between highly and poorly populated classes.
|
||||
precisions = tp / (tp + fp + eps)
|
||||
return np.sum(precisions)/precisions.shape[0]
|
||||
|
||||
if average == "weighted":
|
||||
pass
|
||||
|
||||
|
||||
def recall_cm(cm, average="specific", class_label=1, eps=1e-12):
|
||||
tp = np.diagonal(cm)
|
||||
fn = np.sum(cm, axis=1) - tp
|
||||
|
||||
if average == "none":
|
||||
return tp / (tp + fn + eps)
|
||||
|
||||
if average == "specific":
|
||||
recalls = tp / (tp + fn + eps)
|
||||
return recalls[class_label]
|
||||
|
||||
if average == "micro":
|
||||
return np.sum(tp) / (np.sum(tp) + np.sum(fn))
|
||||
|
||||
if average == "macro":
|
||||
recalls = tp / (tp + fn + eps)
|
||||
return np.sum(recalls)/recalls.shape[0]
|
||||
|
||||
if average == "weighted":
|
||||
pass
|
||||
|
||||
|
||||
def f1score_cm(cm, average="specific", class_label=1):
|
||||
precision = precision_cm(cm, average, class_label)
|
||||
recall = recall_cm(cm, average, class_label)
|
||||
return 2 * (precision*recall)/(precision+recall)
|
||||
|
||||
# true positive rate <-> sensitivity <-> recall
|
||||
# true negative rate <-> specificity <-> recall for neg. class
|
||||
# ROC curve
|
||||
# AUC from ROC
|
||||
# Precision-Recall Curve
|
||||
# Log Loss
|
||||
# Mattheus Correlation
|
||||
# Cohen Kappa score
|
||||
# --> REGRESSION METRICS
|
||||
|
||||
|
||||
def roc_curve(y_true, y_preds, plot_graph=True, calculate_AUC=True, threshold_step=0.01):
|
||||
TPR, FPR = [], []
|
||||
|
||||
for threshold in np.arange(np.min(y_preds), np.max(y_preds), threshold_step):
|
||||
predictions = (y_preds > threshold) * 1
|
||||
cm = confusion_matrix(y_true, predictions)
|
||||
recalls = recall_cm(cm, average="none")
|
||||
# note TPR == sensitivity == recall
|
||||
tpr = recalls[1]
|
||||
# note tnr == specificity (which is same as recall for the negative class)
|
||||
tnr = recalls[0]
|
||||
TPR.append(tpr)
|
||||
FPR.append(1-tnr)
|
||||
|
||||
if plot_graph:
|
||||
plt.plot(FPR, TPR)
|
||||
plt.xlabel("False Positive Rate")
|
||||
plt.ylabel("True Positive Rate")
|
||||
plt.title("ROC curve")
|
||||
plt.show()
|
||||
|
||||
if calculate_AUC:
|
||||
print(np.abs(np.trapz(TPR, FPR)))
|
||||
|
||||
|
||||
def precision_recall_curve(y_true, y_preds, plot_graph=True, calculate_AUC=True, threshold_step=0.01):
|
||||
recalls, precisions = [], []
|
||||
|
||||
for threshold in np.arange(np.min(y_preds), np.max(y_preds), threshold_step):
|
||||
predictions = (y_preds > threshold) * 1
|
||||
cm = confusion_matrix(y_true, predictions)
|
||||
recall = recall_cm(cm, average="specific", class_label=1)
|
||||
precision = precision_cm(cm, average="specific", class_label=1)
|
||||
recalls.append(recall)
|
||||
precisions.append(precision)
|
||||
|
||||
recalls.append(0)
|
||||
precisions.append(1)
|
||||
|
||||
if plot_graph:
|
||||
plt.plot(recalls, precisions)
|
||||
plt.xlabel("Recall")
|
||||
plt.ylabel("Precision")
|
||||
plt.title("Precision-Recall curve")
|
||||
plt.show()
|
||||
|
||||
if calculate_AUC:
|
||||
print(np.abs(np.trapz(precisions, recalls)))
|
||||
|
||||
|
||||
y = []
|
||||
probs = []
|
||||
with open("data.txt") as f:
|
||||
for line in f.readlines():
|
||||
label, pred = line.split()
|
||||
label = int(label)
|
||||
pred = float(pred)
|
||||
y.append(label)
|
||||
probs.append(pred)
|
||||
|
||||
precision_recall_curve(y, probs, threshold_step=0.001)
|
||||
#from sklearn.metrics import precision_recall_curve
|
||||
#precisions, recalls, _ = precision_recall_curve(y, probs)
|
||||
#plt.plot(recalls, precisions)
|
||||
#plt.xlabel("Recall")
|
||||
#plt.ylabel("Precision")
|
||||
#plt.title("Precision-Recall curve")
|
||||
#plt.show()
|
||||
#print(np.abs(np.trapz(precisions, recalls)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user