python绘制二分类混淆矩阵,混淆矩阵简单例子
从python的基础到现在的入门,每个人都必须有一定的python基础。今天边肖给大家带来一个关于python的高级内容,——绘制混淆矩阵。来看看吧~
介绍:
混淆矩阵通过指示正确/不正确标签的计数来指示表格格式的模型的准确性。
计算/绘制混淆矩阵:
下面是计算混淆矩阵的过程。
您需要一个包含预期结果值的测试数据集或验证数据集。
对测试数据集中的每一行进行预测。
从预期结果和预测中计数:
每一类的正确预测数。
每个类的错误预测数,由预测的类组织。
然后将这些数字组织成表格或矩阵,如下所示:
预期值:矩阵的每一行都对应一个预测类。
跨顶预测:矩阵的每一列对应一个实际的类。
-justify:inter-ideograph">然后将正确和不正确分类的计数填入表格中。
Reading混淆矩阵:
一个类的正确预测的总数进入该类值的预期行,以及该类值的预测列。
以同样的方式,一个类别的不正确预测总数进入该类别值的预期行,以及该类别值的预测列。
对角元素表示预测标签等于真实标签的点的数量,而非对角线元素是分类器错误标记的元素。混淆矩阵的对角线值越高越好,表明许多正确的预测。
用Python绘制混淆矩阵 :
importitertools
importnumpyasnp
importmatplotlib.pyplotasplt
fromsklearnimportsvm,datasets
fromsklearn.model_selectionimporttrain_test_split
fromsklearn.metricsimportconfusion_matrix
#importsomedatatoplaywith
iris=datasets.load_iris()
X=iris.data
y=iris.target
class_names=iris.target_names
#Splitthedataintoatrainingsetandatestset
X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=0)
#Runclassifier,usingamodelthatistooregularized(Ctoolow)tosee
#theimpactontheresults
classifier=svm.SVC(kernel='linear',C=0.01)
y_pred=classifier.fit(X_train,y_train).predict(X_test)
defplot_confusion_matrix(cm,classes,
normalize=False,
title='Confusionmatrix',
cmap=plt.cm.Blues):
"""
Thisfunctionprintsandplotstheconfusionmatrix.
Normalizationcanbeappliedbysetting`normalize=True`.
"""
ifnormalize:
cm=cm.astype('float')/cm.sum(axis=1)[:,np.newaxis]
print("Normalizedconfusionmatrix")
else:
print('Confusionmatrix,withoutnormalization')
print(cm)
plt.imshow(cm,interpolation='nearest',cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks=np.arange(len(classes))
plt.xticks(tick_marks,classes,rotation=45)
plt.yticks(tick_marks,classes)
fmt='.2f'ifnormalizeelse'd'
thresh=cm.max()/2.
fori,jinitertools.product(range(cm.shape[0]),range(cm.shape[1])):
plt.text(j,i,format(cm[i,j],fmt),
horizontalalignment="center",
color="white"ifcm[i,j]>threshelse"black")
color="white"ifcm[i,j]>threshelse"black")
plt.tight_layout()
plt.ylabel('Truelabel')
plt.xlabel('Predictedlabel')
#Computeconfusionmatrix
cnf_matrix=confusion_matrix(y_test,y_pred)
np.set_printoptions(precision=2)
#Plotnon-normalizedconfusionmatrix
plt.figure()
plot_confusion_matrix(cnf_matrix,classes=class_names,
title='Confusionmatrix,withoutnormalization')
#Plotnormalizedconfusionmatrix
plt.figure()
plot_confusion_matrix(cnf_matrix,classes=class_names,normalize=True,
title='Normalizedconfusionmatrix')
plt.show()
Confusionmatrix,withoutnormalization
[[1300]
[0106]
[009]]
Normalizedconfusionmatrix
[[1.0.0.]
[0.0.620.38]
[0.0.1.]]
好了,大家可以消化学习下哦~如需了解更多python实用知识,点击进入PyThon学习网教学中心。
郑重声明:本文由网友发布,不代表盛行IT的观点,版权归原作者所有,仅为传播更多信息之目的,如有侵权请联系,我们将第一时间修改或删除,多谢。