【学习打卡07】 可解释机器学习笔记之Shape+Lime代码实战
创始人
2024-04-30 22:44:01
0

可解释机器学习笔记之Shape+Lime代码实战

文章目录

  • 可解释机器学习笔记之Shape+Lime代码实战
    • 基于Shapley值的可解释性分析
      • 使用Pytorch对MNIST分类可解释性分析
      • 使用shap的Deep Explainer进行可视化
      • 使用Pytorch对预训练ImageNet图像分类可解释性分析
        • 指定单个预测类别
        • 指定多个预测类别
        • 前k个预测类别
    • LIME代码实战
      • 对葡萄酒数据集二分类并进行LIME可解释性分析
        • 使用随机森林模型对葡萄酒数据集二分类
        • LIME可解释性分析
      • 对ImageNet预训练图像分类并进行LIME可解释性分析
        • LIME可解释性分析
    • 总结

首先非常感谢同济子豪兄拍摄的可解释机器学习公开课,并且免费分享,这门课程,包含人工智能可解释性、显著性分析领域的导论、算法综述、经典论文精读、代码实战、前沿讲座。由B站知名人工智能科普UP主“同济子豪兄”主讲。 课程主页: https://github.com/TommyZihao/zihao_course/blob/main/XAI 一起打开AI的黑盒子,洞悉AI的脑回路和注意力,解释它、了解它、改进它,进而信赖它。知其然,也知其所以然。这里给出链接,倡导大家一起学习, 别忘了给子豪兄点个关注哦。

学习GitHub 内容链接:
https://github.com/TommyZihao/zihao_course/tree/main/XAI

B站视频合集链接:
https://space.bilibili.com/1900783/channel/collectiondetail?sid=713364

基于Shapley值的可解释性分析

SHAP 属于模型事后解释的方法,它的核心思想是计算特征对模型输出的边际贡献,再从全局和局部两个层面对“黑盒模型”进行解释。SHAP构建一个加性的解释模型,所有的特征都视为“贡献者”。对于每个预测样本,模型都产生一个预测值,SHAP value就是该样本中每个特征所分配到的数值。基本思想:计算一个特征加入到模型时的边际贡献,然后考虑到该特征在所有的特征序列的情况下不同的边际贡献,取均值,即某该特征的SHAPbaseline value
SHAP(SHapley Additive exPlanation)是Python开发的一个"模型解释"包,可以解释任何机器学习模型的输出。

import torch
import torchvision
from torchvision import datasets, transforms, models
from torch import nn, optim
from torch.nn import functional as F
import osimport numpy as np
import json
from PIL import Image
# 使用torch-gpu
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

shap代码实战

import shap

使用Pytorch对MNIST分类可解释性分析

用Pytorch构建简单的卷积神经网络,在MNIST手写数字数据集上,使用shap的Deep Explainer进行可解释性分析,并可视化每一张图像的每一个像素,对模型预测为每一个类别的影响。

# 构建卷积神经网络
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv_layers = nn.Sequential(nn.Conv2d(1, 10, kernel_size=5),nn.MaxPool2d(2),nn.ReLU(),nn.Conv2d(10, 20, kernel_size=5),nn.Dropout(),nn.MaxPool2d(2),nn.ReLU(),)self.fc_layers = nn.Sequential(nn.Linear(320, 50),nn.ReLU(),nn.Dropout(),nn.Linear(50, 10),nn.Softmax(dim=1))def forward(self, x):x = self.conv_layers(x)x = x.view(-1, 320)x = self.fc_layers(x)return x
# 初始化模型
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# 加载MNIST数据集
train_dataset = datasets.MNIST('mnist_data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor()]))test_dataset = datasets.MNIST('mnist_data', train=False, download=True,transform=transforms.Compose([transforms.ToTensor()]))
# 设置dataloader
batch_size = 256
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size, shuffle=True)
def train(model, device, train_loader, optimizer, epoch):# 训练一个 epochmodel.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output.log(), target).to(device)loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test(model, device, test_loader):# 测试一个 epochmodel.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.nll_loss(output.log(), target).item() # sum up batch losspred = output.max(1, keepdim=True)[1] # get the index of the max log-probabilitycorrect += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
num_epochs = 5for epoch in range(1, num_epochs + 1):train(model, device, train_loader, optimizer, epoch)test(model, device, test_loader)
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.297472
Train Epoch: 1 [25600/60000 (43%)]	Loss: 2.202407
Train Epoch: 1 [51200/60000 (85%)]	Loss: 1.399053Test set: Average loss: 0.0050, Accuracy: 7855/10000 (79%)Train Epoch: 2 [0/60000 (0%)]	Loss: 1.234514
Train Epoch: 2 [25600/60000 (43%)]	Loss: 0.933571
Train Epoch: 2 [51200/60000 (85%)]	Loss: 0.774069Test set: Average loss: 0.0025, Accuracy: 8880/10000 (89%)Train Epoch: 3 [0/60000 (0%)]	Loss: 0.748982
Train Epoch: 3 [25600/60000 (43%)]	Loss: 0.621569
Train Epoch: 3 [51200/60000 (85%)]	Loss: 0.535523Test set: Average loss: 0.0017, Accuracy: 9151/10000 (92%)Train Epoch: 4 [0/60000 (0%)]	Loss: 0.569322
Train Epoch: 4 [25600/60000 (43%)]	Loss: 0.596375
Train Epoch: 4 [51200/60000 (85%)]	Loss: 0.552551Test set: Average loss: 0.0014, Accuracy: 9330/10000 (93%)Train Epoch: 5 [0/60000 (0%)]	Loss: 0.447947
Train Epoch: 5 [25600/60000 (43%)]	Loss: 0.550949
Train Epoch: 5 [51200/60000 (85%)]	Loss: 0.531695Test set: Average loss: 0.0012, Accuracy: 9410/10000 (94%)

使用shap的Deep Explainer进行可视化

images, labels = next(iter(test_loader))
# 背景图像样本
background = images[:250]
background.shape
torch.Size([250, 1, 28, 28])
# 测试图像样本
test_images = images[250:254]
test_images.shape
torch.Size([4, 1, 28, 28])
# 初始化Deep Explainer
background = background.to(device)e = shap.DeepExplainer(model, background)
# 计算每个类别、每张测试图像、每个像素,对应的 shap 值
shap_values = e.shap_values(test_images)
Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
# shap 值
shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]# 测试图像
test_numpy = np.swapaxes(np.swapaxes(test_images.numpy(), 1, -1), 1, 2)
shap.image_plot(shap_numpy, -test_numpy)

png

红色代表 shap 正值:对模型预测为该类别有正向作用

蓝色代表 shap 负值:对模型预测为该类别有负向作用

使用Pytorch对预训练ImageNet图像分类可解释性分析

# 载入ImageNet预训练图像分类模型
model = torchvision.models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT, progress=False).eval().to(device)
with open('./data/imagenet_class_index.json') as file:class_names = [v[1] for v in json.load(file).values()]
# 测试图片
img_path = 'test_img/cat_dog.jpg'img_pil = Image.open(img_path)
X = torch.Tensor(np.array(img_pil)).unsqueeze(0)
# 预处理
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]def nhwc_to_nchw(x: torch.Tensor) -> torch.Tensor:if x.dim() == 4:x = x if x.shape[1] == 3 else x.permute(0, 3, 1, 2)elif x.dim() == 3:x = x if x.shape[0] == 3 else x.permute(2, 0, 1)return xdef nchw_to_nhwc(x: torch.Tensor) -> torch.Tensor:if x.dim() == 4:x = x if x.shape[3] == 3 else x.permute(0, 2, 3, 1)elif x.dim() == 3:x = x if x.shape[2] == 3 else x.permute(1, 2, 0)return x transform= [transforms.Lambda(nhwc_to_nchw),transforms.Resize(224),transforms.Lambda(lambda x: x*(1/255)),transforms.Normalize(mean=mean, std=std),transforms.Lambda(nchw_to_nhwc),
]inv_transform= [transforms.Lambda(nhwc_to_nchw),transforms.Normalize(mean = (-1 * np.array(mean) / np.array(std)).tolist(),std = (1 / np.array(std)).tolist()),transforms.Lambda(nchw_to_nhwc),
]transform = torchvision.transforms.Compose(transform)
inv_transform = torchvision.transforms.Compose(inv_transform)
# 构建模型预测函数
def predict(img: np.ndarray) -> torch.Tensor:img = nhwc_to_nchw(torch.Tensor(img)).to(device)output = model(img)return outputdef predict(img):img = nhwc_to_nchw(torch.Tensor(img)).to(device)output = model(img)return output
Xtr = transform(X)
out = predict(Xtr[0:1])
classes = torch.argmax(out, axis=1).detach().cpu().numpy()
print(f'Classes: {classes}: {np.array(class_names)[classes]}')
Classes: [239]: ['Bernese_mountain_dog']
# 构造输入图像
input_img = Xtr[0].unsqueeze(0)
batch_size = 50n_evals = 5000 # 迭代次数越大,显著性分析粒度越精细,计算消耗时间越长# 定义 mask,遮盖输入图像上的局部区域
masker_blur = shap.maskers.Image("blur(64, 64)", Xtr[0].shape)# 创建可解释分析算法
explainer = shap.Explainer(predict, masker_blur, output_names=class_names)

指定单个预测类别

# 281:虎斑猫 tabby
shap_values = explainer(input_img, max_evals=n_evals, batch_size=batch_size, outputs=[281])
  0%|          | 0/4998 [00:00
# 整理张量维度
shap_values.data = inv_transform(shap_values.data).cpu().numpy()[0] # 原图
shap_values.values = [val for val in np.moveaxis(shap_values.values[0],-1, 0)] # shap值热力图
# 可视化
shap.image_plot(shap_values=shap_values.values,pixel_values=shap_values.data,labels=shap_values.output_names)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

指定多个预测类别

# 232 边牧犬 border collie
# 281:虎斑猫 tabby
# 852 网球 tennis ball
# 288 豹子 leopard
shap_values = explainer(input_img, max_evals=n_evals, batch_size=batch_size, outputs=[232, 281, 852, 288])# 整理张量维度
shap_values.data = inv_transform(shap_values.data).cpu().numpy()[0] # 原图
shap_values.values = [val for val in np.moveaxis(shap_values.values[0],-1, 0)] # shap值热力图# 可视化
shap.image_plot(shap_values=shap_values.values,pixel_values=shap_values.data,labels=shap_values.output_names)
  0%|          | 0/4998 [00:00

png

前k个预测类别

topk = 5shap_values = explainer(input_img, max_evals=n_evals, batch_size=batch_size, outputs=shap.Explanation.argsort.flip[:topk])# 整理张量维度
shap_values.data = inv_transform(shap_values.data).cpu().numpy()[0] # 原图
shap_values.values = [val for val in np.moveaxis(shap_values.values[0],-1, 0)] # 各个类别的shap值热力图# 可视化
shap.image_plot(shap_values=shap_values.values,pixel_values=shap_values.data,labels=shap_values.output_names)
  0%|          | 0/4998 [00:00

png

LIME代码实战

对葡萄酒数据集二分类并进行LIME可解释性分析

使用随机森林模型对葡萄酒数据集二分类

import numpy as np
import pandas as pdimport lime
from lime import lime_tabular
# 加载数据集
df = pd.read_csv('./data/wine.csv')
df.head()
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
07.00.270.3620.70.04545.0170.01.00103.000.458.8bad
16.30.300.341.60.04914.0132.00.99403.300.499.5bad
28.10.280.406.90.05030.097.00.99513.260.4410.1bad
37.20.230.328.50.05847.0186.00.99563.190.409.9bad
47.20.230.328.50.05847.0186.00.99563.190.409.9bad
from sklearn.model_selection import train_test_splitX = df.drop('quality', axis=1)
y = df['quality']# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
from sklearn.ensemble import RandomForestClassifier# 使用随机森林模型训练
model = RandomForestClassifier(random_state=42)
model.fit(X_train, y_train)
RandomForestClassifier(random_state=42)
score = model.score(X_test, y_test)
score
0.8887755102040816

LIME可解释性分析

# 初始化LIME可解释性分析算法
explainer = lime_tabular.LimeTabularExplainer(training_data=np.array(X_train), # 训练集特征,必须是 numpy 的 Arrayfeature_names=X_train.columns, # 特征列名class_names=['bad', 'good'], # 预测类别名称mode='classification' # 分类模式
)
# 从测试集中选取一个样本,输入训练好的模型中预测,查看预测结果
idx = 3data_test = np.array(X_test.iloc[idx]).reshape(1, -1)
prediction = model.predict(data_test)[0]
y_true = np.array(y_test)[idx]
print('测试集中的 {} 号样本, 模型预测为 {}, 真实类别为 {}'.format(idx, prediction, y_true))
测试集中的 3 号样本, 模型预测为 bad, 真实类别为 bad
# 可解释性分析
exp = explainer.explain_instance(data_row=X_test.iloc[idx], predict_fn=model.predict_proba
)
exp.show_in_notebook(show_table=True)

对ImageNet预训练图像分类并进行LIME可解释性分析

img_path = './test_img/cat_dog.jpg'img_pil = Image.open(img_path)
img_pil

[外链图片转存中…(img-jYx6rMWl-1671980645822)]

# 加载模型
model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT).eval().to(device)
# 载入ImageNet-1000类别
idx2label, cls2label, cls2idx = [], {}, {}
with open('./data/imagenet_class_index.json', 'r') as read_file:class_idx = json.load(read_file)idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]cls2label = {class_idx[str(k)][0]: class_idx[str(k)][1] for k in range(len(class_idx))}cls2idx = {class_idx[str(k)][0]: k for k in range(len(class_idx))}    
# 预处理
trans_norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])trans_A = transforms.Compose([transforms.Resize((256, 256)),transforms.CenterCrop(224),transforms.ToTensor(),trans_norm])trans_B = transforms.Compose([transforms.ToTensor(),trans_norm])trans_C = transforms.Compose([transforms.Resize((256, 256)),transforms.CenterCrop(224)
])
# 进行图像分类
input_tensor = trans_A(img_pil).unsqueeze(0).to(device)
pred_logits = model(input_tensor)
pred_softmax = F.softmax(pred_logits, dim=1)
top_n = pred_softmax.topk(5)
# 定义分类预测函数
def batch_predict(images):batch = torch.stack(tuple(trans_B(i) for i in images), dim=0)batch = batch.to(device)logits = model(batch)probs = F.softmax(logits, dim=1)return probs.detach().cpu().numpy()
test_pred = batch_predict([trans_C(img_pil)])
test_pred.squeeze().argmax()
231

LIME可解释性分析

from lime import lime_image
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(np.array(trans_C(img_pil)), batch_predict, # 分类预测函数top_labels=5, hide_color=0, num_samples=8000) # LIME生成的邻域图像个数
  0%|          | 0/8000 [00:00
explanation.top_labels[0]
231
from skimage.segmentation import mark_boundaries
import matplotlib.pyplot as plt
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=20, hide_rest=False)
img_boundry = mark_boundaries(temp/255.0, mask)
plt.imshow(img_boundry)
plt.show()

png

temp, mask = explanation.get_image_and_mask(281, positive_only=False, num_features=20, hide_rest=False)
img_boundry = mark_boundaries(temp/255.0, mask)
plt.imshow(img_boundry)
plt.show()

png

绿色表示该区域对当前类别影响为正,红色表示该区域对当前类别影响为负

总结

在这次任务中,主要学习到了Shap和Lime工具包的使用,在图像分类的基础上去解释他,知其然还要知其所以然。使用CAM和Captum工具包,可以减少我们很多很多的代码量,并且能快速使用,快速应用在自己的任务中、

在经过一个多星期的学习,也是需要这种代码实战告诉我们,这些应用是全面且方方面面的,这样就不会空读理论,这样可以让我们有机会将理论和实践结合起来,希望后续能够将XAI和Lime运用到我的领域中,学习到更多的知识。

相关内容

热门资讯

怎么破解安卓车载系统,破解之道... 如何破解安卓车载系统:一场技术冒险之旅在当今数字化时代,汽车已经不仅仅是一种交通工具,它更是一个集成...
安卓系统桌面制作软件,打造个性... 你有没有想过,你的安卓手机桌面是不是也能变得像杂志封面一样炫酷呢?没错,今天就要来聊聊这个话题——安...
安卓官服什么系统最好,探寻最佳... 你有没有想过,你的安卓官服到底该用哪个系统呢?这可是个让人头疼的问题,毕竟每个系统都有它的特色和优缺...
安卓系统怎么安定位,步骤详解与... 你有没有想过,为什么你的手机总是能精准地告诉你附近有什么好吃的、好玩的地方呢?这都要归功于安卓系统的...
华为参与开发安卓系统,共筑智能... 你知道吗?最近有个大新闻,那就是华为竟然参与了安卓系统的开发!是不是觉得有点不可思议?别急,让我带你...
安卓新系统好还是旧系统,安卓新... 你有没有发现,每次安卓系统更新,朋友圈里就炸开了锅?有人欢呼雀跃,有人愁眉苦脸。那么,安卓新系统真的...
安卓系统主要界面元素,探索主要... 你有没有发现,每次打开安卓手机,那熟悉的界面总是让人眼前一亮?今天,就让我带你一起探索安卓系统那些让...
安卓平板7.0系统好吗,智能生... 你有没有想过,拥有一台运行着最新安卓7.0系统的平板电脑,会是怎样的体验呢?想象手指轻轻滑过屏幕,流...
安卓手机换联想系统,深度体验联... 你有没有想过,你的安卓手机换上联想系统后,会发生哪些奇妙的变化呢?想象原本熟悉的界面突然焕然一新,是...
刷安卓系统的工具,轻松实现系统... 你有没有想过,你的安卓手机是不是也能像电脑一样,装上各种有趣的系统呢?没错,今天就要来聊聊这个神奇的...
机械革命安卓系统设置,个性化定... 机械革命安卓系统设置全解析在当今这个数字化时代,智能手机已经成为我们生活中不可或缺的一部分。它不仅仅...
安卓监管系统有哪些,技术手段与... 你知道吗?随着智能手机的普及,安卓系统已经成为了全球最受欢迎的操作系统之一。但是,你知道吗?为了让这...
安卓系统更新知乎,畅享智能生活... 你有没有发现,你的安卓手机最近是不是总在提醒你更新系统呢?别急,别急,今天就来给你好好聊聊这个话题。...
安卓手机系统铃声目录,个性化音... 你有没有发现,每次拿起安卓手机,那熟悉的铃声总是能瞬间唤醒你的注意力?今天,就让我带你一起探索一下安...
安卓系统修改开机画面,安卓系统... 亲爱的手机控们,你是否厌倦了每次开机时看到的那张千篇一律的开机画面?想要给你的安卓手机来点新鲜感?那...
安卓系统隐私密码,守护个人隐私... 你有没有想过,你的安卓手机里藏着多少秘密?那些聊天记录、照片、支付信息,全都在那里静静地躺着,等着被...
8848是安卓什么系统,搭载安... 你有没有想过,你的手机里那个高大上的8848手机,它到底是用的是什么操作系统呢?别急,今天就来给你揭...
安卓刷windowsxp系统下... 你有没有想过,让你的安卓手机瞬间变身成一台Windows XP电脑呢?没错,就是那个经典的操作系统!...
插画安卓系统推荐哪个,插画风格... 你有没有想过,手机里的插画风格也能成为个性展示的一部分呢?想象你的手机界面就像是一幅精美的画作,是不...
安卓系统怎么升级cpu,解锁性... 亲爱的安卓用户们,你是否也和我一样,对手机性能的提升充满了期待?想要让你的安卓手机跑得更快,升级CP...