{ "cells": [ { "cell_type": "code", "execution_count": 9, "id": "f868e2bc-5741-48c5-ba69-33f8ccc58f4c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "#演示目的:利用鸢尾花数据集画出P-R曲线\n", "\n", "\n" ] } ], "source": [ "#coding=utf-8\n", "\n", "\"\"\"\n", "\n", "#演示目的:利用鸢尾花数据集画出P-R曲线\n", "\n", "\"\"\"\n", "\n", "print(__doc__)\n", "\n", "%matplotlib inline\n", "\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "import numpy as np\n", "\n", "from sklearn import svm, datasets\n", "\n", "from sklearn.metrics import precision_recall_curve\n", "\n", "from sklearn.metrics import average_precision_score\n", "\n", "from sklearn.preprocessing import label_binarize\n", "\n", "from sklearn.multiclass import OneVsRestClassifier\n", "\n", "\n", "\n", "#from sklearn.cross_validation import train_test_split #适用于anaconda 3.6及以前版本\n", "\n", "from sklearn.model_selection import train_test_split#适用于anaconda 3.7\n", "\n", "\n", "\n", "#以iris数据为例,画出P-R曲线\n", "\n", "iris = datasets.load_iris()\n", "\n", "X = iris.data\n", "\n", "y = iris.target\n", "\n", "\n", "\n", "# 标签二值化,将三个类转为001, 010, 100的格式.因为这是个多类分类问题,后面将要采用\n", "\n", "#OneVsRestClassifier策略转为二类分类问题\n", "\n", "y = label_binarize(y, classes=[0, 1, 2])\n", "\n", "n_classes = y.shape[1]\n", "\n", "#print (y)\n", "\n", "\n", "\n", "# 增加了800维的噪声特征\n", "\n", "random_state = np.random.RandomState(0)\n", "\n", "n_samples, n_features = X.shape\n", "\n", "\n", "\n", "X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]\n", "\n", "\n", "\n", "# Split into training and test\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5, random_state=random_state) #随机数,填0或不填,每次都会不一样\n", "\n", "\n", "\n", "# Run classifier probability : boolean, optional (default=False)Whether to enable probability estimates. This must be enabled prior to calling fit, and will slow down that method.\n", "\n", "classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True, random_state=random_state))\n", "\n", "y_score = classifier.fit(X_train, y_train).decision_function(X_test)\n", "\n", "\n", "\n", "# Compute Precision-Recall and plot curve \n", "\n", "#下面的下划线是返回的阈值。作为一个名称:此时“_”作为临时性的名称使用。\n", "\n", "#表示分配了一个特定的名称,但是并不会在后面再次用到该名称。\n", "\n", "precision = dict()\n", "\n", "recall = dict()\n", "\n", "average_precision = dict()\n", "\n", "for i in range(n_classes):\n", "\n", " precision[i], recall[i], _ = precision_recall_curve(y_test[:, i], y_score[:, i]) #The last precision and recall values are 1. and 0. respectively and do not have a corresponding threshold. This ensures that the graph starts on the x axis.\n", "\n", " average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i])#切片,第i个类的分类结果性能\n", "\n", "\n", "\n", "# Compute micro-average curve and area. ravel()将多维数组降为一维\n", "\n", "precision[\"micro\"], recall[\"micro\"], _ = precision_recall_curve(y_test.ravel(), y_score.ravel())\n", "\n", "average_precision[\"micro\"] = average_precision_score(y_test, y_score, average=\"micro\") #This score corresponds to the area under the precision-recall curve.\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "7be86f04-83fa-4190-9f1a-2524ae38e682", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot Precision-Recall curve for each class\n", "\n", "plt.clf()#clf 函数用于清除当前图像窗口\n", "\n", "plt.plot(recall[\"micro\"], precision[\"micro\"],\n", "\n", " label='micro-average Precision-recall curve (area = {0:0.2f})'.format(average_precision[\"micro\"]))\n", "\n", "for i in range(n_classes):\n", "\n", " plt.plot(recall[i], precision[i],\n", "\n", " label='Precision-recall curve of class {0} (area = {1:0.2f})'.format(i, average_precision[i]))\n", "\n", "\n", "\n", "plt.xlim([0.0, 1.0])\n", "\n", "plt.ylim([0.0, 1.05]) #xlim、ylim:分别设置X、Y轴的显示范围。\n", "\n", "plt.xlabel('Recall', fontsize=16)\n", "\n", "plt.ylabel('Precision',fontsize=16)\n", "\n", "plt.title('Extension of Precision-Recall curve to multi-class',fontsize=16)\n", "\n", "plt.legend(loc=\"lower right\")#legend 是用于设置图例的函数\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "fbdcbe72-182a-453f-92ec-dabc6b456573", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "XGBoost-Sklearn", "language": "python", "name": "xgboost-sklearn" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.2" } }, "nbformat": 4, "nbformat_minor": 5 }