From 088bdb63e9ccba93da9f14289cb4e66f67539afb Mon Sep 17 00:00:00 2001 From: dino Date: Thu, 29 Sep 2022 11:18:12 +0200 Subject: [PATCH] added_mcts_and_metrics --- .../Monte Carlo Tree Search - TicTacToe.ipynb | 228 +++++++++++++++++ ML/ml_metrics/data.txt | 100 ++++++++ ML/ml_metrics/metrics.py | 240 ++++++++++++++++++ 3 files changed, 568 insertions(+) create mode 100644 ML/algorithms/MCTS/Monte Carlo Tree Search - TicTacToe.ipynb create mode 100644 ML/ml_metrics/data.txt create mode 100644 ML/ml_metrics/metrics.py diff --git a/ML/algorithms/MCTS/Monte Carlo Tree Search - TicTacToe.ipynb b/ML/algorithms/MCTS/Monte Carlo Tree Search - TicTacToe.ipynb new file mode 100644 index 0000000..9f56e15 --- /dev/null +++ b/ML/algorithms/MCTS/Monte Carlo Tree Search - TicTacToe.ipynb @@ -0,0 +1,228 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "18676180", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import ipywidgets as widgets\n", + "from tqdm import tqdm\n", + "import random\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 254, + "id": "d69c70f2", + "metadata": {}, + "outputs": [], + "source": [ + "class MCTSNode:\n", + " def __init__(self, state, parent_node):\n", + " self.state = state\n", + " self.parent_node = parent_node\n", + " self.total_visits = 0\n", + " self.total_score = 0\n", + " self.children_nodes = []\n", + " self.player = self.check_player(state)\n", + " self.terminate_state = False\n", + " self.all_children_nodes = False\n", + "\n", + " def check_player(self, state):\n", + " if np.sum(state==1) > np.sum(state==2):\n", + " return 2\n", + " else:\n", + " return 1\n", + "\n", + "class MCTS:\n", + " def __init__(self, exploration_constant = 2):\n", + " self.exploration_constant = exploration_constant\n", + "\n", + " def is_terminal(self, board):\n", + " return not np.any(board == 0)\n", + "\n", + " def is_win(self, state, player):\n", + " col_win = (np.sum(state == player, axis=0) == 3).any()\n", + " row_win = (np.sum(state == player, axis=1) == 3).any()\n", + " diagonal_win = np.trace(state == player) == 3\n", + " opposite_diagonal = np.trace(np.fliplr(state) == player) == 3\n", + " return col_win or row_win or diagonal_win or opposite_diagonal\n", + "\n", + " def select(self, curr_node, should_explore=True):\n", + " while not is_terminal(curr_node.state) and not (self.is_win(curr_node.state, 1) or self.is_win(curr_node.state, 2)):\n", + " if curr_node.all_children_nodes:\n", + " highest_value = -float(\"inf\")\n", + " chosen_child = None\n", + "\n", + " # loop all children nodes and take the best one according to heuristic\n", + " for child in curr_node.children_nodes:\n", + " # compute UCB1 score\n", + " child_val = (child.total_score/child.total_visits) + should_explore*self.exploration_constant*np.sqrt(np.log(curr_node.total_visits)/child.total_visits)\n", + "\n", + " # if it has highest value then store it as the chosen child from this step\n", + " if child_val > highest_value:\n", + " highest_value = child_val\n", + " chosen_child = child\n", + "\n", + " # choose highest value move\n", + " return chosen_child\n", + "\n", + " else:\n", + " # if not all children nodes accessible then expand the node first\n", + " return self.expand(curr_node)\n", + "\n", + " print(\"should never come here\")\n", + "\n", + " def expand(self, curr_node):\n", + " states = self.generate_next_states(curr_node)\n", + "\n", + " for state in states:\n", + " # unroll children states, and ensure we do not expand to a state we have \n", + " # already expanded to in a previous iteration\n", + " if str(state) not in [str(b.state) for b in curr_node.children_nodes]:\n", + " child_node = MCTSNode(state, curr_node)\n", + " curr_node.children_nodes.append(child_node)\n", + " \n", + " # if the num children nodes equal the amount of possible next states\n", + " # we have explored all child nodes for this state\n", + " if len(states) == len(curr_node.children_nodes):\n", + " curr_node.all_children_nodes = True\n", + "\n", + " return child_node\n", + "\n", + "\n", + " def simulate(self, curr_node, computer_playing):\n", + " opponent = 1 if computer_playing == 2 else 1\n", + " \n", + " while not is_terminal(curr_node.state) and not (self.is_win(curr_node.state, 1) or self.is_win(curr_node.state, 2)):\n", + " next_states = self.generate_next_states(curr_node)\n", + " curr_node = MCTSNode(next_states[random.randint(0, len(next_states) - 1)], curr_node)\n", + " \n", + " if self.is_win(curr_node.state, player=computer_playing):\n", + " return 1\n", + " elif self.is_win(curr_node.state, player=opponent):\n", + " return -1\n", + " else:\n", + " return 0\n", + "\n", + " \n", + " def backpropagate(self, node, score):\n", + " while node:\n", + " node.total_visits += 1\n", + " node.total_score += score\n", + " node = node.parent_node\n", + " \n", + " def generate_next_states(self, curr_node):\n", + " player = curr_node.player\n", + " curr_state = curr_node.state\n", + " next_states = []\n", + " for i in range(3):\n", + " for j in range(3):\n", + " if curr_state[i,j] == 0:\n", + " to_append = np.copy(curr_state)\n", + " to_append[i,j] = player\n", + " next_states.append(to_append)\n", + " return next_states\n", + "\n", + "\n", + " def get_move(self, root, num_iterations=1000):\n", + " for it in range(num_iterations):\n", + " curr_node = self.select(root)\n", + " obtained_value = self.simulate(curr_node, root.player)\n", + " self.backpropagate(curr_node, obtained_value)\n", + " \n", + " chosen_move = self.select(root, should_explore=False)\n", + " return chosen_move" + ] + }, + { + "cell_type": "code", + "execution_count": 263, + "id": "36e39228", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Row and column to place with ,1,1\n", + "[[0. 0. 0.]\n", + " [0. 1. 0.]\n", + " [0. 0. 2.]]\n", + "Row and column to place with ,0,0\n", + "[[1. 0. 0.]\n", + " [0. 1. 0.]\n", + " [2. 0. 2.]]\n", + "Row and column to place with ,2,1\n", + "[[1. 2. 0.]\n", + " [0. 1. 0.]\n", + " [2. 1. 2.]]\n", + "Row and column to place with ,1,2\n", + "[[1. 2. 0.]\n", + " [2. 1. 1.]\n", + " [2. 1. 2.]]\n", + "Row and column to place with ,0,2\n", + "should never come here\n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'NoneType' object has no attribute 'state'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15720/2518229713.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[0mnext_node\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mMCTSNode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mroot\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 10\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 11\u001b[1;33m \u001b[0mroot\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmc\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_move\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_node\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 12\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15720/416212796.py\u001b[0m in \u001b[0;36mget_move\u001b[1;34m(self, root, num_iterations)\u001b[0m\n\u001b[0;32m 110\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mit\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnum_iterations\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 111\u001b[0m \u001b[0mcurr_node\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mselect\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 112\u001b[1;33m \u001b[0mobtained_value\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msimulate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mroot\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mplayer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 113\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackpropagate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mobtained_value\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 114\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_15720/416212796.py\u001b[0m in \u001b[0;36msimulate\u001b[1;34m(self, curr_node, computer_playing)\u001b[0m\n\u001b[0;32m 76\u001b[0m \u001b[0mopponent\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m1\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mcomputer_playing\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m2\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 77\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 78\u001b[1;33m \u001b[1;32mwhile\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mis_terminal\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_win\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_win\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 79\u001b[0m \u001b[0mnext_states\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate_next_states\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcurr_node\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 80\u001b[0m \u001b[0mcurr_node\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mMCTSNode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_states\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mrandom\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_states\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m-\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcurr_node\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'state'" + ] + } + ], + "source": [ + "a = np.zeros((3,3))\n", + "root = MCTSNode(a, None)\n", + "mc = MCTS()\n", + "\n", + "for i in range(9):\n", + " row_col = input(\"Row and column to place with ,\").split(\",\")\n", + " state = np.copy(root.state)\n", + " state[int(row_col[0]), int(row_col[1])] = 1\n", + " next_node = MCTSNode(state, root)\n", + " \n", + " root = mc.get_move(next_node)\n", + " print(root.state)\n", + "\n", + "print(\"Final: {root.state}\")\n", + " " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ML/ml_metrics/data.txt b/ML/ml_metrics/data.txt new file mode 100644 index 0000000..076203e --- /dev/null +++ b/ML/ml_metrics/data.txt @@ -0,0 +1,100 @@ +0 0.827142151760153 +0 0.6044595910412887 +0 0.7916340858282026 +0 0.16080518180592987 +0 0.611222921705038 +0 0.2555087295500818 +0 0.5681507664364468 +0 0.05990570219972058 +0 0.6644434078306367 +0 0.11293577405861703 +0 0.06152372321587048 +0 0.35250697207600584 +0 0.3226701829081975 +0 0.43339115381458776 +0 0.2280744262436838 +0 0.7219848389339433 +0 0.23527698971402375 +0 0.2850245335200196 +0 0.4107047877448165 +0 0.2008356196164621 +0 0.3711921802697385 +0 0.4234822657253734 +0 0.4876482027124213 +0 0.4234822657253734 +0 0.5750985220664769 +0 0.6734047730095499 +0 0.7355892648444824 +0 0.7137899092959652 +0 0.3873972469024071 +0 0.24042033264833723 +0 0.1663411647259707 +0 0.1663411647259707 +0 0.2850245335200196 +0 0.3683741846950643 +0 0.17375784896208155 +0 0.43636290738886574 +0 0.7219848389339433 +0 0.46745878087292836 +0 0.23527698971402375 +0 0.17202866439941822 +0 0.17786913865061538 +0 0.44335359557308707 +0 0.2768503833164947 +0 0.06891755391553003 +0 0.21414010746535972 +0 0.27120595352357546 +0 0.26328216986315905 +0 0.48056205121673834 +0 0.08848560476699129 +0 0.2555087295500818 +1 0.5681507664364468 +1 0.2850245335200196 +1 0.842216416418616 +1 0.5280820469827786 +1 0.6302728469340095 +1 0.9325162813331325 +1 0.062225621463076315 +1 0.8823445035377085 +1 0.670739773835188 +1 0.891663414209465 +1 0.6489254823470298 +1 0.5552119758821265 +1 0.7510275470993321 +1 0.23310831157247616 +1 0.2933421288888426 +1 0.6044595910412887 +1 0.6302728469340095 +1 0.9585115007613662 +1 0.9342800686704079 +1 0.3226701829081975 +1 0.7982301827889998 +1 0.22102862644325694 +1 0.9390780973389883 +1 0.5078780077620866 +1 0.7379344573081708 +1 0.8750078631067137 +1 0.4704701704107932 +1 0.44335359557308707 +1 0.5651814720676593 +1 0.8658845001112441 +1 0.897024614730928 +1 0.9712637967845552 +1 0.5651814720676593 +1 0.517987379389242 +1 0.40385540386469254 +1 0.9435470013187671 +1 0.5780506539476005 +1 0.594744923406366 +1 0.3970432858350056 +1 0.7916340858282026 +1 0.7219848389339433 +1 0.7916340858282026 +1 0.2850245335200196 +1 0.7658513560779588 +1 0.7379344573081708 +1 0.7137899092959652 +1 0.4876482027124213 +1 0.6302728469340095 +1 0.5310944974701136 +1 0.35250697207600584 \ No newline at end of file diff --git a/ML/ml_metrics/metrics.py b/ML/ml_metrics/metrics.py new file mode 100644 index 0000000..2e103b2 --- /dev/null +++ b/ML/ml_metrics/metrics.py @@ -0,0 +1,240 @@ +import numpy as np +from scipy.integrate import simpson +import matplotlib.pyplot as plt +import warnings + + +def true_positives(y_true, y_pred): + tp = 0 + for label, pred in zip(y_true, y_pred): + if pred == 1 and label == 1: + tp += 1 + return tp + + +def true_negatives(y_true, y_pred): + tn = 0 + for label, pred in zip(y_true, y_pred): + if pred == 0 and label == 0: + tn += 1 + return tn + + +def false_positives(y_true, y_pred): + fp = 0 + for label, pred in zip(y_true, y_pred): + if pred == 1 and label == 0: + fp += 1 + return fp + + +def false_negatives(y_true, y_pred): + fn = 0 + for label, pred in zip(y_true, y_pred): + if pred == 0 and label == 1: + fn += 1 + return fn + + +def binary_accuracy(y_true, y_pred): + tp = true_positives(y_true, y_pred) + tn = true_negatives(y_true, y_pred) + fp = false_positives(y_true, y_pred) + fn = false_negatives(y_true, y_pred) + return (tp + tn) / (tp + tn + fp + fn) + + +def precision(y_true, y_pred): + """ + Fraction of True Positive Elements divided by total number of positive predicted units + How I view it: Assuming we say someone has cancer: how often are we correct? + It tells us how much we can trust the model when it predicts an individual as positive. + """ + tp = true_negatives(y_true, y_pred) + fp = false_positives(y_true, y_pred) + return tp / (tp + fp) + + +def recall(y_true, y_pred): + """ + Recall meaasure the model's predictive accuracy for the positive class. + How I view it, out of all the people that has cancer: how often are + we able to detect it? + """ + tp = true_negatives(y_true, y_pred) + fn = false_negatives(y_true, y_pred) + return tp / (tp + fn) + + +def multiclass_accuracy(y_true, y_pred): + correct = 0 + total = len(y_true) + for label, pred in zip(y_true, y_pred): + correct += label == pred + return correct/total + + +def confusion_matrix(y_true, y_pred): + y_true = np.array(y_true) + y_pred = np.array(y_pred) + assert y_true.shape == y_pred.shape + unique_classes = np.unique(np.concatenate([y_true, y_pred], axis=0)).shape[0] + cm = np.zeros((unique_classes, unique_classes), dtype=np.int64) + + for label, pred in zip(y_true, y_pred): + cm[label, pred] += 1 + + return cm + + +def accuracy_cm(cm): + return np.trace(cm)/np.sum(cm) + + +def balanced_accuracy_cm(cm): + correctly_classified = np.diagonal(cm) + rows_sum = np.sum(cm, axis=1) + indices = np.nonzero(rows_sum)[0] + if rows_sum.shape[0] != indices.shape[0]: + warnings.warn("y_pred contains classes not in y_true") + accuracy_per_class = correctly_classified[indices]/(rows_sum[indices]) + return np.sum(accuracy_per_class)/accuracy_per_class.shape[0] + + +def precision_cm(cm, average="specific", class_label=1, eps=1e-12): + tp = np.diagonal(cm) + fp = np.sum(cm, axis=0) - tp + #precisions = np.diagonal(cm)/np.maximum(np.sum(cm, axis=0), 1e-12) + + if average == "none": + return tp/(tp+fp+eps) + + if average == "specific": + precisions = tp / (tp + fp + eps) + return precisions[class_label] + + if average == "micro": + # all samples equally contribute to the average, + # hence there is a distinction between highly + # and poorly populated classes + return np.sum(tp) / (np.sum(tp) + np.sum(fp) + eps) + + if average == "macro": + # all classes equally contribute to the average, + # no distinction between highly and poorly populated classes. + precisions = tp / (tp + fp + eps) + return np.sum(precisions)/precisions.shape[0] + + if average == "weighted": + pass + + +def recall_cm(cm, average="specific", class_label=1, eps=1e-12): + tp = np.diagonal(cm) + fn = np.sum(cm, axis=1) - tp + + if average == "none": + return tp / (tp + fn + eps) + + if average == "specific": + recalls = tp / (tp + fn + eps) + return recalls[class_label] + + if average == "micro": + return np.sum(tp) / (np.sum(tp) + np.sum(fn)) + + if average == "macro": + recalls = tp / (tp + fn + eps) + return np.sum(recalls)/recalls.shape[0] + + if average == "weighted": + pass + + +def f1score_cm(cm, average="specific", class_label=1): + precision = precision_cm(cm, average, class_label) + recall = recall_cm(cm, average, class_label) + return 2 * (precision*recall)/(precision+recall) + +# true positive rate <-> sensitivity <-> recall +# true negative rate <-> specificity <-> recall for neg. class +# ROC curve +# AUC from ROC +# Precision-Recall Curve +# Log Loss +# Mattheus Correlation +# Cohen Kappa score +# --> REGRESSION METRICS + + +def roc_curve(y_true, y_preds, plot_graph=True, calculate_AUC=True, threshold_step=0.01): + TPR, FPR = [], [] + + for threshold in np.arange(np.min(y_preds), np.max(y_preds), threshold_step): + predictions = (y_preds > threshold) * 1 + cm = confusion_matrix(y_true, predictions) + recalls = recall_cm(cm, average="none") + # note TPR == sensitivity == recall + tpr = recalls[1] + # note tnr == specificity (which is same as recall for the negative class) + tnr = recalls[0] + TPR.append(tpr) + FPR.append(1-tnr) + + if plot_graph: + plt.plot(FPR, TPR) + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + plt.title("ROC curve") + plt.show() + + if calculate_AUC: + print(np.abs(np.trapz(TPR, FPR))) + + +def precision_recall_curve(y_true, y_preds, plot_graph=True, calculate_AUC=True, threshold_step=0.01): + recalls, precisions = [], [] + + for threshold in np.arange(np.min(y_preds), np.max(y_preds), threshold_step): + predictions = (y_preds > threshold) * 1 + cm = confusion_matrix(y_true, predictions) + recall = recall_cm(cm, average="specific", class_label=1) + precision = precision_cm(cm, average="specific", class_label=1) + recalls.append(recall) + precisions.append(precision) + + recalls.append(0) + precisions.append(1) + + if plot_graph: + plt.plot(recalls, precisions) + plt.xlabel("Recall") + plt.ylabel("Precision") + plt.title("Precision-Recall curve") + plt.show() + + if calculate_AUC: + print(np.abs(np.trapz(precisions, recalls))) + + +y = [] +probs = [] +with open("data.txt") as f: + for line in f.readlines(): + label, pred = line.split() + label = int(label) + pred = float(pred) + y.append(label) + probs.append(pred) + +precision_recall_curve(y, probs, threshold_step=0.001) +#from sklearn.metrics import precision_recall_curve +#precisions, recalls, _ = precision_recall_curve(y, probs) +#plt.plot(recalls, precisions) +#plt.xlabel("Recall") +#plt.ylabel("Precision") +#plt.title("Precision-Recall curve") +#plt.show() +#print(np.abs(np.trapz(precisions, recalls))) + +