pytorch 计算混淆矩阵
创始人
2025-05-28 15:17:22
0

混淆矩阵是评估模型结果的一种指标 用来判断分类模型的好坏

 预测对了 为对角线 

还可以通过矩阵的上下角发现哪些容易出错

从这个 矩阵出发 可以得到 acc != precision recall  特异度?

 

 目标检测01笔记AP mAP recall precision是什么 查全率是什么 查准率是什么 什么是准确率 什么是召回率_:)�东东要拼命的博客-CSDN博客

 acc  是对所有类别来说的

其他三个都是 对于类别来说的

下面给出源码 

import json
import osimport matplotlib.pyplot as plt
import numpy as np
import torch
from prettytable import PrettyTable
from torchvision import datasets
from torchvision.models import MobileNetV2
from torchvision.transforms import transformsclass ConfusionMatrix(object):"""注意版本问题,使用numpy来进行数值计算的"""def __init__(self, num_classes: int, labels: list):self.matrix = np.zeros((num_classes, num_classes))self.num_classes = num_classesself.labels = labelsdef update(self, preds, labels):for p, t in zip(preds, labels):self.matrix[t, p] += 1# 行代表预测标签 列表示真实标签def summary(self):# calculate accuracysum_TP = 0for i in range(self.num_classes):sum_TP += self.matrix[i, i]acc = sum_TP / np.sum(self.matrix)print("acc is", acc)# precision, recall, specificitytable = PrettyTable()table.fields_names = ["", "pre", "recall", "spec"]for i in range(self.num_classes):TP = self.matrix[i, i]FP = np.sum(self.matrix[i, :]) - TPFN = np.sum(self.matrix[:, i]) - TPTN = np.sum(self.matrix) - TP - FP - FNpre = round(TP / (TP + FP), 3)    # round 保留三位小数recall = round(TP / (TP + FN), 3)spec = round(TN / (FP + FN), 3)table.add_row([self.labels[i], pre, recall, spec])print(table)def plot(self):matrix = self.matrixprint(matrix)plt.imshow(matrix, cmap=plt.cm.Blues)  # 颜色变化从白色到蓝色# 设置 x  轴坐标 labelplt.xticks(range(self.num_classes), self.labels, rotation=45)# 将原来的 x 轴的数字替换成我们想要的信息 self.num_classes  x 轴旋转45度# 设置 y  轴坐标 labelplt.yticks(range(self.num_classes), self.labels)# 显示 color bar  可以通过颜色的密度看出数值的分布plt.colorbar()plt.xlabel("true_label")plt.ylabel("Predicted_label")plt.title("ConfusionMatrix")# 在图中标注数量 概率信息thresh = matrix.max() / 2# 设定阈值来设定数值文本的颜色 开始遍历图像的时候一般是图像的左上角for x in range(self.num_classes):for y in range(self.num_classes):# 这里矩阵的行列交换,因为遍历的方向 第y行 第x列info = int(matrix[y, x])plt.text(x, y, info,verticalalignment='center',horizontalalignment='center',color="white" if info > thresh else "black")plt.tight_layout()# 图形显示更加的紧凑plt.show()if __name__ ==' __main__':device = torch.device("cuda:0" if torch.cuda.is_available()else "cpu")print(device)# 使用验证集的预处理方式data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor()transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])data_loot = os.path.abspath(os.path.join(os.getcwd(), "../.."))# get data root pathimage_path = data_loot + "/data_set/flower_data/"# flower data set pathvalidate_dataset = datasets.ImageFolder(root=image_path +"val",transform=data_transform)batch_size = 16validate_loader = torch.utils.data.DataLoder(validate_dataset,batch_size=batch_size,shuffle=False,num_workers=2)net = MobileNetV2(num_classes=5)#加载预训练的权重model_weight_path = "./MobileNetV2.pth"net.load_state_dict(torch.load(model_weight_path, map_location=device))net.to(device)#read class_indicttry:json_file = open('./class_indicts.json', 'r')class_indict = json.load(json_file)except Exception as e:print(e)exit(-1)labels = [label for _, label in class_indict.item()]# 通过json文件读出来的labelconfusion = ConfusionMatrix(num_classes=5, labels=labels)net.eval()# 启动验证模式# 通过上下文管理器  no_grad  来停止pytorch的变量对梯度的跟踪with torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))outputs = torch.softmax(outputs, dim=1)outputs = torch.argmax(outputs, dim=1)# 获取概率最大的元素confusion.update(outputs.numpy(), val_labels.numpy())# 预测值和标签值confusion.plot()# 绘制混淆矩阵confusion.summary()# 来打印各个指标信息

是这样的 这篇算是一个学习笔记,其中的基础图都源于我的导师

 霹雳吧啦Wz的个人空间_哔哩哔哩_bilibili

欢迎无依无靠的CV同学加入 

讲的非常好 代码其实也是导师给的 

我能做的就是读懂每一行加点注释

给不想看视频的同学留点时间

相关内容

热门资讯

【MySQL】锁 锁 文章目录锁全局锁表级锁表锁元数据锁(MDL)意向锁AUTO-INC锁...
【内网安全】 隧道搭建穿透上线... 文章目录内网穿透-Ngrok-入门-上线1、服务端配置:2、客户端连接服务端ÿ...
GCN的几种模型复现笔记 引言 本篇笔记紧接上文,主要是上一篇看写了快2w字,再去接入代码感觉有点...
数据分页展示逻辑 import java.util.Arrays;import java.util.List;impo...
Redis为什么选择单线程?R... 目录专栏导读一、Redis版本迭代二、Redis4.0之前为什么一直采用单线程?三、R...
【已解决】ERROR: Cou... 正确指令: pip install pyyaml
关于测试,我发现了哪些新大陆 关于测试 平常也只是听说过一些关于测试的术语,但并没有使用过测试工具。偶然看到编程老师...
Lock 接口解读 前置知识点Synchronized synchronized 是 Java 中的关键字,...
Win7 专业版安装中文包、汉... 参考资料:http://www.metsky.com/archives/350.htm...
3 ROS1通讯编程提高(1) 3 ROS1通讯编程提高3.1 使用VS Code编译ROS13.1.1 VS Code的安装和配置...
大模型未来趋势 大模型是人工智能领域的重要发展趋势之一,未来有着广阔的应用前景和发展空间。以下是大模型未来的趋势和展...
python实战应用讲解-【n... 目录 如何在Python中计算残余的平方和 方法1:使用其Base公式 方法2:使用statsmod...
学习u-boot 需要了解的m... 一、常用函数 1. origin 函数 origin 函数的返回值就是变量来源。使用格式如下...
常用python爬虫库介绍与简... 通用 urllib -网络库(stdlib)。 requests -网络库。 grab – 网络库&...
药品批准文号查询|药融云-中国... 药品批文是国家食品药品监督管理局(NMPA)对药品的审评和批准的证明文件...
【2023-03-22】SRS... 【2023-03-22】SRS推流搭配FFmpeg实现目标检测 说明: 外侧测试使用SRS播放器测...
有限元三角形单元的等效节点力 文章目录前言一、重新复习一下有限元三角形单元的理论1、三角形单元的形函数(Nÿ...
初级算法-哈希表 主要记录算法和数据结构学习笔记,新的一年更上一层楼! 初级算法-哈希表...
进程间通信【Linux】 1. 进程间通信 1.1 什么是进程间通信 在 Linux 系统中,进程间通信...
【Docker】P3 Dock... Docker数据卷、宿主机与挂载数据卷的概念及作用挂载宿主机配置数据卷挂载操作示例一个容器挂载多个目...