Bert的pooler_output是什么?
admin
2024-02-05 14:12:40
0

BERT的两个输出

在学习bert的时候,我们知道bert是输出每个token的embeding。但在使用hugging face的bert模型时,发现除了last_hidden_state还多了一个pooler_output输出。

例如:

from transformers import AutoTokenizer, AutoModeltokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")inputs = tokenizer("I'm caixunkun. I like singing, dancing, rap and basketball.", return_tensors="pt")
outputs = model(**inputs)print("last_hidden_state shape:", outputs.last_hidden_state.size())
print("pooler_output shape:", outputs.pooler_output.size())
last_hidden_state shape: torch.Size([1, 20, 768])
pooler_output shape: torch.Size([1, 768])

许多人可能以为pooler_output[CLS]token的embedding,但使用last_hidden_state shape[:, 0]比较后,发现又不是,然后就很奇怪。

Bert的Pooler_output

先说一下结论: pooler_output可以理解成该句子语义的特征向量表示

那它是怎么来的?和[CLS]token的embedding区别在哪?

我们将Bert模型打印一下,会发现最后还有一个BertPooler层,pooler_output就是从这来的。如下所示:

BertModel((embedding): BertEmbeddings(....)(encoder): BertEncoder(... # 12层TransformerEncoder)(pooler): BertPooler((dense): Linear(in_features=768, out_features=768, bias=True)(activation): Tanh())
)

其中encoder就是将BERT的所有token经过12个TransformerEncoder进行embedding。pooler就是将[CLS]这个token再过一下全连接层+Tanh激活函数,作为该句子的特征向量

我们可以从Bert源码中验证以上结论。在transformers.models.bert.modeling_bert.BertModel.forward方法中这么一行代码:

# sequence_output就是last_hidden_state
# self.pooler就是上面的BertPooler
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

我们再来看看transformers.models.bert.modeling_bert.BertPooler的源码:

class BertPooler(nn.Module):def __init__(self, config):super().__init__()self.dense = nn.Linear(config.hidden_size, config.hidden_size)self.activation = nn.Tanh()def forward(self, hidden_states):# hidden_states的第一个维度是batch_size。所以用[:, 0]取所有句子的[CLS]的embeddingfirst_token_tensor = hidden_states[:, 0]pooled_output = self.dense(first_token_tensor)pooled_output = self.activation(pooled_output)return pooled_output

从上面的源码可以看出,pooler_output 就是[CLS]embedding又经历了一次全连接层的输出。我们可以通过以下代码进行验证:

print("pooler:", model.pooler)
my_pooler_output = model.pooler(outputs.last_hidden_state)
print(my_pooler_output[0, :5])
print(outputs.pooler_output[0, :5])
pooler: BertPooler((dense): Linear(in_features=768, out_features=768, bias=True)(activation): Tanh()
)
tensor([-0.8129, -0.6216, -0.9810,  0.8090,  0.9032], grad_fn=)
tensor([-0.8129, -0.6216, -0.9810,  0.8090,  0.9032], grad_fn=)

Bert的Pooler_output的由来

我们知道,BERT的训练包含两个任务:MLM和NSP任务(Next Sentence Prediction)。 对这两个任务不熟悉的朋友可以参考:BERT源码实现与解读(Pytorch) 和 【论文阅读】BERT 两篇文章。

其中MLM就是挖空,然后让bert预测这个空是什么。做该任务是使用token embedding进行预测。

而Next Sentence Prediction就是预测bert接受的两句话是否为一对。例如:窗前明月光,疑是地上霜 为 True,窗前明月光,李白打开窗为False。

所以,NSP任务需要句子的语义信息来预测,但是我们看下源码是怎么做的。transformers.models.bert.modeling_bert.BertForNextSentencePrediction的部分源码如下:

class BertForNextSentencePrediction(BertPreTrainedModel):def __init__(self, config):super().__init__(config)self.bert = BertModel(config)self.cls = BertOnlyNSPHead(config)	# 这个就是一个 nn.Linear(config.hidden_size, 2)...def forward(...):...outputs = self.bert(...)pooled_output = outputs[1] # 取pooler_outputseq_relationship_scores = self.cls(pooled_output)	# 使用pooler_ouput送给后续的全连接层进行预测...

从上面的源码可以看出,在NSP任务训练时,并不是直接使用[CLS]token的embedding作为句子特征传给后续分类头的,而是使用的是pooler_output。个人原因可能是因为直接使用[CLS]的embedding效果不够好。

但在MLM任务时,是直接使用的是last_hidden_state,有兴趣可以看一下transformers.models.bert.modeling_bert.BertForMaskedLM的源码。

相关内容

热门资讯

【MySQL】锁 锁 文章目录锁全局锁表级锁表锁元数据锁(MDL)意向锁AUTO-INC锁...
【内网安全】 隧道搭建穿透上线... 文章目录内网穿透-Ngrok-入门-上线1、服务端配置:2、客户端连接服务端ÿ...
GCN的几种模型复现笔记 引言 本篇笔记紧接上文,主要是上一篇看写了快2w字,再去接入代码感觉有点...
数据分页展示逻辑 import java.util.Arrays;import java.util.List;impo...
Redis为什么选择单线程?R... 目录专栏导读一、Redis版本迭代二、Redis4.0之前为什么一直采用单线程?三、R...
【已解决】ERROR: Cou... 正确指令: pip install pyyaml
关于测试,我发现了哪些新大陆 关于测试 平常也只是听说过一些关于测试的术语,但并没有使用过测试工具。偶然看到编程老师...
Lock 接口解读 前置知识点Synchronized synchronized 是 Java 中的关键字,...
Win7 专业版安装中文包、汉... 参考资料:http://www.metsky.com/archives/350.htm...
3 ROS1通讯编程提高(1) 3 ROS1通讯编程提高3.1 使用VS Code编译ROS13.1.1 VS Code的安装和配置...
大模型未来趋势 大模型是人工智能领域的重要发展趋势之一,未来有着广阔的应用前景和发展空间。以下是大模型未来的趋势和展...
python实战应用讲解-【n... 目录 如何在Python中计算残余的平方和 方法1:使用其Base公式 方法2:使用statsmod...
学习u-boot 需要了解的m... 一、常用函数 1. origin 函数 origin 函数的返回值就是变量来源。使用格式如下...
常用python爬虫库介绍与简... 通用 urllib -网络库(stdlib)。 requests -网络库。 grab – 网络库&...
药品批准文号查询|药融云-中国... 药品批文是国家食品药品监督管理局(NMPA)对药品的审评和批准的证明文件...
【2023-03-22】SRS... 【2023-03-22】SRS推流搭配FFmpeg实现目标检测 说明: 外侧测试使用SRS播放器测...
有限元三角形单元的等效节点力 文章目录前言一、重新复习一下有限元三角形单元的理论1、三角形单元的形函数(Nÿ...
初级算法-哈希表 主要记录算法和数据结构学习笔记,新的一年更上一层楼! 初级算法-哈希表...
进程间通信【Linux】 1. 进程间通信 1.1 什么是进程间通信 在 Linux 系统中,进程间通信...
【Docker】P3 Dock... Docker数据卷、宿主机与挂载数据卷的概念及作用挂载宿主机配置数据卷挂载操作示例一个容器挂载多个目...