python绘制二分类混淆矩阵,混淆矩阵简单例子

  python绘制二分类混淆矩阵,混淆矩阵简单例子

  从python的基础到现在的入门,每个人都必须有一定的python基础。今天边肖给大家带来一个关于python的高级内容,——绘制混淆矩阵。来看看吧~

  介绍:

  混淆矩阵通过指示正确/不正确标签的计数来指示表格格式的模型的准确性。

  计算/绘制混淆矩阵:

  下面是计算混淆矩阵的过程。

  您需要一个包含预期结果值的测试数据集或验证数据集。

  对测试数据集中的每一行进行预测。

  从预期结果和预测中计数:

  每一类的正确预测数。

  每个类的错误预测数,由预测的类组织。

  然后将这些数字组织成表格或矩阵,如下所示:

  预期值:矩阵的每一行都对应一个预测类。

  跨顶预测:矩阵的每一列对应一个实际的类。

  -justify:inter-ideograph">然后将正确和不正确分类的计数填入表格中。

  

Reading混淆矩阵:

  

一个类的正确预测的总数进入该类值的预期行,以及该类值的预测列。

  

以同样的方式,一个类别的不正确预测总数进入该类别值的预期行,以及该类别值的预测列。

  

对角元素表示预测标签等于真实标签的点的数量,而非对角线元素是分类器错误标记的元素。混淆矩阵的对角线值越高越好,表明许多正确的预测。

  

Python绘制混淆矩阵 :

  

importitertools

  

  importnumpyasnp

  

  importmatplotlib.pyplotasplt

  

  fromsklearnimportsvm,datasets

  

  fromsklearn.model_selectionimporttrain_test_split

  

  fromsklearn.metricsimportconfusion_matrix

  

  #importsomedatatoplaywith

  

  iris=datasets.load_iris()

  

  X=iris.data

  

  y=iris.target

  

  class_names=iris.target_names

  

  #Splitthedataintoatrainingsetandatestset

  

  X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=0)

  

  #Runclassifier,usingamodelthatistooregularized(Ctoolow)tosee

  

  #theimpactontheresults

  

  classifier=svm.SVC(kernel='linear',C=0.01)

  

  y_pred=classifier.fit(X_train,y_train).predict(X_test)

  

  defplot_confusion_matrix(cm,classes,

  

  normalize=False,

  

  title='Confusionmatrix',

  

  cmap=plt.cm.Blues):

  

  """

  

  Thisfunctionprintsandplotstheconfusionmatrix.

  

  Normalizationcanbeappliedbysetting`normalize=True`.

  

  """

  

  ifnormalize:

  

  cm=cm.astype('float')/cm.sum(axis=1)[:,np.newaxis]

  

  print("Normalizedconfusionmatrix")

  

  else:

  

  print('Confusionmatrix,withoutnormalization')

  

  print(cm)

  

  plt.imshow(cm,interpolation='nearest',cmap=cmap)

  

  plt.title(title)

  

  plt.colorbar()

  

  tick_marks=np.arange(len(classes))

  

  plt.xticks(tick_marks,classes,rotation=45)

  

  plt.yticks(tick_marks,classes)

  

  fmt='.2f'ifnormalizeelse'd'

  

  thresh=cm.max()/2.

  

  fori,jinitertools.product(range(cm.shape[0]),range(cm.shape[1])):

  

  plt.text(j,i,format(cm[i,j],fmt),

  

  horizontalalignment="center",

  

  color="white"ifcm[i,j]>threshelse"black")

  

  color="white"ifcm[i,j]>threshelse"black")

  

  plt.tight_layout()

  

  plt.ylabel('Truelabel')

  

  plt.xlabel('Predictedlabel')

  

  #Computeconfusionmatrix

  

  cnf_matrix=confusion_matrix(y_test,y_pred)

  

  np.set_printoptions(precision=2)

  

  #Plotnon-normalizedconfusionmatrix

  

  plt.figure()

  

  plot_confusion_matrix(cnf_matrix,classes=class_names,

  

  title='Confusionmatrix,withoutnormalization')

  

  #Plotnormalizedconfusionmatrix

  

  plt.figure()

  

  plot_confusion_matrix(cnf_matrix,classes=class_names,normalize=True,

  

  title='Normalizedconfusionmatrix')

  

  plt.show()

  

  

Confusionmatrix,withoutnormalization

  

  [[1300]

  

  [0106]

  

  [009]]

  

  Normalizedconfusionmatrix

  

  [[1.0.0.]

  

  [0.0.620.38]

  

  [0.0.1.]]

好了,大家可以消化学习下哦~如需了解更多python实用知识,点击进入PyThon学习网教学中心

郑重声明:本文由网友发布,不代表盛行IT的观点,版权归原作者所有,仅为传播更多信息之目的,如有侵权请联系,我们将第一时间修改或删除,多谢。

留言与评论(共有 条评论)
   
验证码: