声纹识别之说话人验证speaker verification
创始人
2024-05-10 07:51:32
0

      

目录

一、speaker verification简介

二、主流方案和模型

1、Ecapa_TDNN模型

2、WavLm

三、代码实践

1、Ecapa_TDNN方案

a、模型结构

b、loss

c、数据处理

d、模型训练和评估

e、说话人验证推理

2、WavLm预训练方案

a、模型结构和loss

b、数据处理

c、模型训练

d、推理和评估

四、demo演示

五、总结


      写在最前面,最近几个月并没有在写博客上投入时间,主要是其他事情比较多也比较忙。2022年8月以后就开始准备婚礼、看房、买房,举行婚礼和看车等等,工作上也在做项目和打一些比赛,并没有什么值得写的。由于工作需要接触到了语音领域的声纹识别,对语音识别进行了一些预研,因此在这里开一篇博客,聊一聊speaker verification学习历程。

一、speaker verification简介

       Speaker Verification——说话人验证属于声纹识别领域范畴——给定两个音频,判定它们是不是同一个人所说。这里有两种不同的类型,一种是基于文本有关的,一种是基于文本无关的。基于文本有关的——每次检验的是否是同一个人说话,需要受检者说出限定范围的文本;而基于文本无关的则不需要,可以随意说话。前者相对容易一点,后者相对困难一点。Speaker Verification核心之处在于模型能够提炼出不同人声音的特征,且要有很好的区分度。

       如上图所示,要判定Enrollment和Evaluation两个音频是不是同一个说话人,一般而言,可以把两个音频直接输入模型,训练一个分类模型,让模型来判定是不是同一个类别;也可以提前把Enrollment用训练好的模型提取出一个多维向量;等到Evaluation需要验证的时候,用模型同样提取响应特征向量,计算两个向量的向量度,根据阈值判定。在实际应用过程中,为了满足高效率,大多采用后者,提前把被检音频提取向量存储到对应的库中,然后检测音频实时抽取向量,计算向量,根据设定的阈值判定是否为同一个人。

       在实际应用之前,需要对训练好的模型和整体的Speaker Verification系统进行评价。模型端评价根据建模的任务,一般采取F1值或者ACC、Recall等来评价。而评价实际的Speaker Verification系统,则有自己的一套评价体系和指标。主要是如下的评价指标:

        FAR(False Accept Rate 错误接受率)

        FRR(False Reject Rate错误拒绝率)

        EER(Equal Error Rate 等错误率

        FRR = Nfr/Ntarget   其中Nfr是指应该通过而被拒绝测试用例的数量,Ntarget 是指所有应该通过测试用例的总数

        FAR = Nfa/Nnotarget  其中Nfa是指不应该通过也通过的测试用例的数量,Nnotarget 是指所有不应该通过测试用例的总数

        EER 是指FAR==FRR时的错误率。它说话人确认系统中常用的性能评价指标

        这个没有考虑错误接受以及错误拒绝不同的影响,因此为了把它们不同的影响也考虑起来,设计不同的权重,同时也把受检者是真是假的先验概率考虑进来,得到一个新的指标dcf。

        PT真实说话人出现的先验概率,PI假的说话人出现的先验概率;越严格的系统PI/PT的值越大。比较常见的比值是1:99、1:999。

        通过不断的调整阈值,DCF是会变化的,取最小的dcf的时候对应的阈值,会使得整个系统有最佳的表现。

二、主流方案和模型

        speaker verification发展了很多年,有许多的方案。传统的一些方案,主要是利用信号处理方式,把时序信号转换为频域信号,然后再通过一些手段进行区分。看一张计算方案的演进图(摘抄自知乎问答——声纹识别算法有哪几种):

        其中可能涉及到的声学特征有MFCC、FBank和Spectrogram等,以及对它的一些数据增强。时至2022年了,大家更加关注端到端的方案,使用神经网络自动提取声学特征。比较主流的是Ecapa_TDNN模型,它于2020年被提出,通过引入SE (squeeze-excitation)模块以及通道注意机制,该方案在国际声纹识别比赛(VoxSRC2020)中取得了第一名;同时在2022年的FFSVC说话人验证任务中,该模型也被作为baseline。另外就是预训练模型,在语音领域也有很多类似文本领域Bert的预训练模型,其中个人认为效果最好的就是WavLm模型。

1、Ecapa_TDNN模型

先看整体结构图:

        可以看到ecapa_tdnn由conv1D+BN、SE-Res2Block、ASP+BN、FC+BN以及AAM-softmax等模块构成。其中SE-Res2Block能是模型学习到音频数据中更多的全局信息,这个比之前的d-vector效果更好。

SE-Res2Block:

        SE-Res2Block主要是Res2Block模块中引入了SE-Block模块——这是一个通道注意力模块,比较经典在各种网络中都表现的比较不错。

2、WavLm

       它是微软亚洲研究院与微软 Azure 语音组使用Transformer模型架构和Denoising Masked Speech Modeling 框架直接对音频时序数据进行类似Bert的掩码预训练,使用了海量的音频数据进行了预训练,在语音任务上取得了很好的效果。

        模型网络结构如图所示,特征抽取采用CNN网络层,然后特征编码采用transformer-block层,具体的模型细节这里就不分析了,可以把它看做为一个音频领域的bert,实现细节稍有不同,具体的实现可以去看huggingface的实现——WavLm和WavLmModel等。

三、代码实践

1、Ecapa_TDNN方案

a、模型结构

        代码参考了百度的paddleSpeech中paddle版本和SpeechBrain中pytorch版本代码,并做了一些删减,同时也参考了一些个人的实现VoiceprintRecognition-Pytorch,对它们的代码进行了综合考量,得到下面的Ecapa_TDNN模型结构代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameterclass TDNNBlock(nn.Module):"""An implementation of TDNN."""def __init__(self, in_channels, out_channels, kernel_size, dilation, groups=1,padding=0):super(TDNNBlock, self).__init__()self.conv = nn.Conv1d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size, dilation=dilation,groups=groups,padding=padding)self.activation = nn.ReLU()self.bn = nn.BatchNorm1d(out_channels)def forward(self,x):x = self.conv(x)x = self.activation(x)x = self.bn(x)return xclass Res2NetBlock(torch.nn.Module):"""An implementation of Res2NetBlock w/ dilation.Example-------inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)layer = Res2NetBlock(64, 64, scale=4, dilation=3)out_tensor = layer(inp_tensor).transpose(1, 2)out_tensor.shapetorch.Size([8, 120, 64])"""def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1,padding =0):super(Res2NetBlock, self).__init__()assert in_channels % scale == 0assert out_channels % scale == 0in_channel = in_channels // scalehidden_channel = out_channels // scaleself.blocks = nn.ModuleList([TDNNBlock(in_channel,hidden_channel,kernel_size=kernel_size,dilation=dilation,padding = padding)for i in range(scale - 1)])self.scale = scaledef forward(self, x):y = []for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):if i == 0:y_i = x_ielif i == 1:y_i = self.blocks[i - 1](x_i)else:y_i = self.blocks[i - 1](x_i + y_i)y.append(y_i)y = torch.cat(y, dim=1)return yclass SEBlock(nn.Module):"""省略了mask"""def __init__(self, in_channels, se_channels, out_channels):super(SEBlock,self).__init__()self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1)self.sigmoid = nn.Sigmoid()def forward(self,x):s = x.mean(dim=2, keepdim=True)s = self.relu(self.conv1(s))s = self.sigmoid(self.conv2(s))out = s * xreturn outclass SERes2NetBlock(nn.Module):def __init__(self,in_channels,out_channels,res2net_scale=8,se_channels=128,kernel_size=1,dilation=1,groups=1,padding = 0):super(SERes2NetBlock, self).__init__()self.out_channels = out_channelsself.tdnn1 = TDNNBlock(in_channels,out_channels,kernel_size=1,dilation=1,groups=groups,)self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size,padding, dilation)self.tdnn2 = TDNNBlock(out_channels,out_channels,kernel_size=1,dilation=1,groups=groups,)self.se_block = SEBlock(out_channels, se_channels, out_channels)self.shortcut = Noneif in_channels != out_channels:self.shortcut = nn.Conv1d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,)def forward(self, x):""" Processes the input tensor x and returns an output tensor."""residual = xif self.shortcut:residual = self.shortcut(x)x = self.tdnn1(x)x = self.res2net_block(x)x = self.tdnn2(x)x = self.se_block(x)return x + residualclass AttentiveStatsPool(nn.Module):def __init__(self, in_dim, bottleneck_dim):super(AttentiveStatsPool,self).__init__()# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1)  # equals W and b in the paperself.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1)  # equals V and k in the paperdef forward(self, x):# DON'T use ReLU here! In experiments, I find ReLU hard to converge.alpha = torch.tanh(self.linear1(x))alpha = torch.softmax(self.linear2(alpha), dim=2)mean = torch.sum(alpha * x, dim=2)residuals = torch.sum(alpha * x ** 2, dim=2) - mean ** 2std = torch.sqrt(residuals.clamp(min=1e-9))return torch.cat([mean, std], dim=1)class ECAPATDNN(nn.Module):def __init__(self,input_size,lin_neurons=192,channels=[512, 512, 512, 512, 1536],kernel_sizes=[5, 3, 3, 3, 1],dilations=[1, 2, 3, 4, 1],attention_channels=128,res2net_scale=8,se_channels=128,groups=[1, 1, 1, 1, 1],paddings = [0,2,3,4,0]):super(ECAPATDNN, self).__init__()assert len(channels) == len(kernel_sizes)assert len(channels) == len(dilations)self.emb_size = lin_neuronsself.channels = channelsself.blocks = nn.ModuleList()self.blocks.append(TDNNBlock(input_size,channels[0],kernel_sizes[0],dilations[0],groups[0]))for i in range(1,len(channels) -1):self.blocks.append(SERes2NetBlock(channels[i-1],channels[i],res2net_scale, se_channels, kernel_sizes[i],dilations[i],groups[i],paddings[i]))self.mfa = TDNNBlock(channels[-1],channels[-1],kernel_sizes[-1],dilations[-1],groups[-1])self.asp = AttentiveStatsPool(channels[-1],attention_channels)self.asp_bn = nn.BatchNorm1d(channels[-1] * 2)self.fc = nn.Conv1d(in_channels=channels[-1] * 2,out_channels=lin_neurons,kernel_size=1,)def forward(self,x):xl = []for layer in self.blocks:x = layer(x)xl.append(x)# Multi-layer feature aggregationx = torch.cat(xl[1:], dim=1)x = x.datax = self.mfa(x)# Attentive Statistical Poolingx = self.asp(x)x = self.asp_bn(x)x = x.unsqueeze(2)# Final linear transformationx = self.fc(x)return xclass SpeakerIdentificationModel(nn.Module):def __init__(self,backbone,num_class=1,dropout=0.1):super(SpeakerIdentificationModel, self).__init__()self.backbone = backboneif dropout > 0:self.dropout = nn.Dropout(dropout)else:self.dropout = Noneinput_size = self.backbone.emb_size# the final layer  nn.Linear 采用不同的权重初始化self.weight = Parameter(torch.FloatTensor(num_class, input_size), requires_grad=True)nn.init.xavier_normal_(self.weight, gain=1)def forward(self,x):x = self.backbone(x)if self.dropout is not None:x = self.dropout(x)logits = F.linear(F.normalize(x.squeeze(2)),weight=F.normalize(self.weight,dim=-1))return logits

b、loss

        这部分代码摘抄自VoiceprintRecognition-Pytorch

        Additive Angular Margin Loss(加性角度间隔损失函数)结合KLDivLoss(KL散度loss)得到最后的AAMloss

import mathimport torch
import torch.nn as nn
import torch.nn.functional as Fclass AdditiveAngularMargin(nn.Module):def __init__(self, margin=0.0, scale=1.0, easy_margin=False):"""The Implementation of Additive Angular Margin (AAM) proposedin the following paper: '''Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition'''(https://arxiv.org/abs/1906.07317)Args:margin (float, optional): margin factor. Defaults to 0.0.scale (float, optional): scale factor. Defaults to 1.0.easy_margin (bool, optional): easy_margin flag. Defaults to False."""super(AdditiveAngularMargin, self).__init__()self.margin = marginself.scale = scaleself.easy_margin = easy_marginself.cos_m = math.cos(self.margin)self.sin_m = math.sin(self.margin)self.th = math.cos(math.pi - self.margin)self.mm = math.sin(math.pi - self.margin) * self.margindef forward(self, outputs, targets):cosine = outputs.float()sine = torch.sqrt(1.0 - torch.pow(cosine, 2))phi = cosine * self.cos_m - sine * self.sin_mif self.easy_margin:phi = torch.where(cosine > 0, phi, cosine)else:phi = torch.where(cosine > self.th, phi, cosine - self.mm)outputs = (targets * phi) + ((1.0 - targets) * cosine)return self.scale * outputsclass AAMLoss(nn.Module):def __init__(self, margin=0.2, scale=30, easy_margin=False):super(AAMLoss, self).__init__()self.loss_fn = AdditiveAngularMargin(margin=margin, scale=scale, easy_margin=easy_margin)self.criterion = torch.nn.KLDivLoss(reduction="sum")def forward(self, outputs, targets):targets = F.one_hot(targets, outputs.shape[1]).float()predictions = self.loss_fn(outputs, targets)predictions = F.log_softmax(predictions, dim=1)loss = self.criterion(predictions, targets) / targets.sum()return loss

c、数据处理

        这部分代码功能是对wav或者mp3数据进行语音特征处理,比如fbank(melspectrogram)、spectrogram以及梅尔倒谱系数mffcc等等

import random
import torch
from torch.utils.data import Dataset
import torchaudio
from tqdm import tqdmclass AudioDataReader(Dataset):def __init__(self, data_list_path,feature_method='melspectrogram',mode='train',sr=16000,chunk_duration=3,min_duration=0.5,label2ids = {},augmentors=None):super(AudioDataReader, self).__init__()assert data_list_path is not Nonewith open(data_list_path,'r',encoding='utf-8') as f:self.lines = f.readlines()[0:]self.feature_method = feature_methodself.mode = modeself.sr = srself.chunk_duration = chunk_durationself.min_duration = min_durationself.augmentors = augmentorsself.label2ids = label2idsself.audiofeatures = self.getaudiofeatures()def load_audio(self, audio_path,feature_method='melspectrogram',mode='train',sr=16000,chunk_duration=3,min_duration=0.5,augmentors=None):"""加载并预处理音频:param audio_path: 音频路径:param feature_method: 预处理方法melspectrogram(Fbank)梅尔频谱/MFCC梅尔倒谱系数/spectrogram声谱图:param mode: 对数据处理的方式,包括train,eval,infer:param sr: 采样率:param chunk_duration: 训练或者评估使用的音频长度:param min_duration: 最小训练或者评估的音频长度:param augmentors: 数据增强方法:return:"""wav, sample_rate = torchaudio.load(audio_path)  # 加载音频返回的是张量num_wav_samples = wav.shape[1]# 数据太短不利于训练if mode == 'train':if num_wav_samples < int(min_duration * sr):raise Exception(f'音频长度小于{min_duration}s,实际长度为:{(num_wav_samples / sr):.2f}s')# print(f'音频长度小于{min_duration}s,实际长度为:{(num_wav_samples / sr):.2f}s')# return None# 对小于训练长度的复制补充num_chunk_samples = int(chunk_duration * sr)if num_wav_samples < num_chunk_samples:times = int(num_chunk_samples / num_wav_samples) - 1shortages = []temp_num_wav_samples = num_wav_samplesshortages.append(wav)if times >= 1:for _ in range(times):shortages.append(wav)temp_num_wav_samples += num_wav_samplesshortages.append(wav[:,0:(num_chunk_samples - temp_num_wav_samples)])else:shortages.append(wav[:,0:(num_chunk_samples - num_wav_samples)])wav = torch.cat(shortages, dim=1)# 裁剪需要的数据if mode == 'train':# 随机裁剪num_wav_samples = wav.shape[1]num_chunk_samples = int(chunk_duration * sr)if num_wav_samples > num_chunk_samples + 1:start = random.randint(0, num_wav_samples - num_chunk_samples - 1)end = start + num_chunk_sampleswav = wav[:,start:end]# # 对每次都满长度的再次裁剪# if random.random() > 0.5:#     wav[:random.randint(1, sr // 4)] = 0 #加入了静音数据#     wav = wav[:-random.randint(1, sr // 4)]# 数据增强if augmentors is not None:for key, augmentor in augmentors.items():if key == 'specaug':continuewav = wav.numpy()#转换为numpy,然后做增强wav = augmentor(wav)wav = torch.from_numpy(wav)elif mode == 'eval':# 为避免显存溢出,只裁剪指定长度num_wav_samples = wav.shape[1]num_chunk_samples = int(chunk_duration * sr)if num_wav_samples > num_chunk_samples + 1:wav = wav[:,0:num_chunk_samples]if feature_method == "melspectrogram":# 梅尔频谱 Fbankfeatures = torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=400, n_mels=80, hop_length=160, win_length=400)(wav)elif feature_method == "spectrogram":# 声谱图features = torchaudio.transforms.Spectrogram( n_fft=400, win_length=400, hop_length=160)(wav)elif feature_method == "MFCC":features = torchaudio.transforms.MFCC(sample_rate=sr, n_fft=400, n_mels=80, hop_length=160, win_length=400)(wav)else:raise Exception(f'预处理方法 {feature_method} 不存在!')# 数据增强if mode == 'train' and augmentors is not None:for key, augmentor in augmentors.items():if key == 'specaug':features = augmentor(features)# 需要归一化features = torch.nn.LayerNorm(features.shape[-1])(features).squeeze(0)return featuresdef getaudiofeatures(self):res = []for line in tqdm(self.lines,desc= self.mode + ' load all audios',ncols=100):temp = []try:audio_path, label = line.replace('\n', '').split('\t')label = self.label2ids[label]features = self.load_audio(audio_path=audio_path, feature_method=self.feature_method, mode=self.mode,sr=self.sr, chunk_duration=self.chunk_duration,min_duration=self.min_duration,augmentors=self.augmentors)label = torch.as_tensor(label, dtype=torch.long)temp.append(features)temp.append(label)res.append(temp)except Exception as e:print(e+',load audio data exception')return res@propertydef input_size(self):if self.feature_method == 'melspectrogram':return 80elif self.feature_method == 'spectrogram':return 201else:raise Exception(f'预处理方法 {self.feature_method} 不存在!')def __getitem__(self, item):return self.audiofeatures[item][0], self.audiofeatures[item][1]def __len__(self):return len(self.audiofeatures)

        值得注意的是没有在__getitem__()函数中读取音频加载数据,而是直接全部加载到内存中,如果数据量过大还是要在_getitem__()函数中读取音频加载数据,减小内存消耗,当然训练速度会减慢。

d、模型训练和评估

        数据集采用公共数据集:zhvoice: Chinese voice corpus中的zhstcmds数据

"zhstcmds": {"character_W": 111.9317,"duration_H": 74.53628,"n_audio_per_speaker": 120.0,"n_character_per_sentence": 10.909522417153998,"n_minute_per_speaker": 5.230616140350877,"n_second_per_audio": 2.6153080701754385,"n_speaker": 855,"sentence_W": 10.26,"size_MB": 767.7000274658203}

        总计104963条数据,随机切分,验证集10000条,训练集94963条数据。

        训练代码如下

from models.loss import AAMLoss
from models.ecapa_tdnn import SpeakerIdentificationModel,ECAPATDNN
# from models.ecapa_tdnn import SpeakerIdetification,EcapaTdnn
from tools.log import Logger
from tools.progressbar import ProgressBar
from data_utils.reader import AudioDataReader
from data_utils.noise_perturb import NoisePerturbAugmentor
from data_utils.speed_perturb import SpeedPerturbAugmentor
from data_utils.volum_perturb import VolumePerturbAugmentor
from data_utils.spec_augment import SpecAugmentorfrom torch.utils.data import DataLoader
import torch
import os
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import argparseimport random
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import yaml
import torch.nn as nndef parse_args():parser = argparse.ArgumentParser()parser.add_argument("--train_datas_path", type=str, default='./data/train_audio_paths.txt', help="train text file")parser.add_argument("--val_datas_path", type=str, default='./data/val_audio_paths.txt', help="val text file")# parser.add_argument("--train_datas_path", type=str, default='./data/train_audio_paths_small.txt', help="train text file")# parser.add_argument("--val_datas_path", type=str, default='./data/val_audio_paths_small.txt', help="val text file")parser.add_argument("--log_file", type=str, default="./log_output/speaker_identification.log", help="log_file")parser.add_argument("--model_out", type=str, default="./output/", help="model output path")parser.add_argument("--batch_size", type=int, default=64, help="batch size")parser.add_argument("--epochs", type=int, default=30, help="epochs")parser.add_argument("--lr", type=float, default=1e-3, help="epochs")parser.add_argument("--random_seed", type=int, default=100, help="random_seed")parser.add_argument("--device", type=str, default='1', help="device")args = parser.parse_args()return argsdef training(args):os.environ['CUDA_VISIBLE_DEVICES'] = args.devicelogger = Logger(log_name='SI',log_level=10,log_file=args.log_file).loggerlogger.info(args)label2ids = {}id = 0with open(args.train_datas_path,'r',encoding='utf-8') as f:lines = f.readlines()for line in lines:line = line.strip('\n')if line.split('\t')[-1] not in label2ids:label2ids[line.split('\t')[-1]] = idid += 1with open(args.val_datas_path,'r',encoding='utf-8') as f:lines = f.readlines()for line in lines:line = line.strip('\n')if line.split('\t')[-1] not in label2ids:label2ids[line.split('\t')[-1]] = idid += 1augmentors = {}with open("augment.ymal",'r', encoding="utf-8") as fp:configs = yaml.load(fp, Loader=yaml.FullLoader)augmentors['noise'] = NoisePerturbAugmentor(**configs['noise'])augmentors['speed'] = SpeedPerturbAugmentor(**configs['speed'])augmentors['volume'] = VolumePerturbAugmentor(**configs['volume'])augmentors['specaug'] = SpecAugmentor(**configs['specaug'])augmentors = Nonetime_srt = datetime.now().strftime('%Y-%m-%d')save_path = os.path.join(args.model_out,time_srt)if not os.path.exists(save_path):os.makedirs(save_path)logger.info(save_path)device = "cuda:0" if torch.cuda.is_available() else "cpu"train_dataset = AudioDataReader(feature_method='melspectrogram',data_list_path=args.train_datas_path,mode='train', label2ids=label2ids, augmentors=augmentors)train_dataloader = DataLoader(train_dataset,shuffle=True,batch_size=args.batch_size )val_dataset = AudioDataReader(feature_method='melspectrogram', data_list_path=args.val_datas_path, mode='eval', label2ids = label2ids,augmentors=augmentors)val_dataloader = DataLoader(val_dataset, shuffle=True, batch_size=args.batch_size)num_class = len(label2ids)logger.info('num_class:%d'%num_class)ecapa_tdnn = ECAPATDNN(input_size=train_dataset.input_size)model = SpeakerIdentificationModel(backbone=ecapa_tdnn, num_class=num_class).to(device)# ecapa_tdnn = EcapaTdnn(input_size=train_dataset.input_size)# model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=num_class).to(device)# logger.info(model)loss_function = AAMLoss()optimizer = AdamW(lr=args.lr,params=model.parameters())scheduler = CosineAnnealingLR(optimizer,T_max=args.epochs)logger.info("***** Running training *****")logger.info("  Num examples = %d" % len(train_dataloader))logger.info("  Num Epochs = %d" % args.epochs)writer = SummaryWriter('./runs/' + time_srt + '/')best_acc = 0total_step = 0unimproving_count = 0for epoch in range(args.epochs):pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')model.train()total_loss = 0for step, batch in enumerate(train_dataloader):batch = [t.to(device) for t in batch]audio = batch[0]speakers = batch[1]output = model(audio)loss = loss_function(output, speakers)optimizer.zero_grad()# loss.backward(retain_graph=True)loss.backward()optimizer.step()total_step += 1writer.add_scalar('Train/Learning loss', loss.item(), total_step)total_loss += loss.item()pbar(step, {'loss': loss.item()})val_acc = evaluate(model, val_dataloader, device)if best_acc < val_acc:best_acc = val_accsave_path = os.path.join(save_path,"ecapa_tdnn.bin")torch.save(model.state_dict(),save_path)is_improving = Trueunimproving_count = 0else:is_improving = Falseunimproving_count += 1if is_improving:logger.info(f"Train epoch [{epoch+1}/{args.epochs}],batch [{step+1}],Best_acc: {best_acc},Val_acc:{val_acc}, lr:{scheduler.get_lr()[0]}, total_loss:{round(total_loss,4)}. Save model!")else:logger.info(f"Train epoch [{epoch+1}/{args.epochs}],batch [{step+1}],Best_acc: {best_acc},Val_acc:{val_acc}, lr:{scheduler.get_lr()[0]}, total_loss:{round(total_loss,4)}.")writer.add_scalar('Val/val_acc', val_acc, total_step)writer.add_scalar('Val/best_acc', best_acc, total_step)writer.add_scalar('Train/Learning rate', scheduler.get_lr()[0], total_step)scheduler.step()if unimproving_count >= 5:logger.info('unimproving %d epochs, early stop!'%unimproving_count)breakdef evaluate(model,val_dataloader,device):total = 0correct_total = 0model.eval()with torch.no_grad():pbar = ProgressBar(n_total=len(val_dataloader), desc='evaluate')for step, batch in enumerate(val_dataloader):batch = [t.to(device) for t in batch]audio = batch[0]speakers = batch[1]output = model(audio)total += speakers.shape[0]preds = torch.argmax(output,dim=-1)correct = (speakers==preds).sum().item()pbar(step, {})correct_total += correctacc = correct_total/totalmodel.train()return accdef set_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Truedef collate_fn(batch):features,labels = zip(*batch)return featuresif __name__ == '__main__':args = parse_args()set_seed(args.random_seed)training(args)

训练过程中采用的评估指标直接是分类准确率,日志如下:

 验证集分类准确率是0.9503

e、说话人验证推理

        使用上述训练好的Ecapa_TDNN模型对经过数据处理后的音频数据抽取向量特征,计算相似度,通过设定的阈值来判定是否为同一个说话人,当然这里的阈值就需要经过构建的验证数据集进行搜索得到最佳阈值。

from models.ecapa_tdnn import SpeakerIdentificationModel,ECAPATDNN
from tools.log import Logger
from tools.progressbar import ProgressBar
from data_utils.reader import AudioDataReader
from data_utils.noise_perturb import NoisePerturbAugmentor
from data_utils.speed_perturb import SpeedPerturbAugmentor
from data_utils.volum_perturb import VolumePerturbAugmentor
from data_utils.spec_augment import SpecAugmentor
from torch.utils.data import DataLoader
import torch
import os
import argparse
import numpy as np
import yaml
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
import random
random.seed(100)def parse_args():parser = argparse.ArgumentParser()parser.add_argument("--train_datas_path", type=str, default='./data/train_audio_paths.txt', help="train text file")parser.add_argument("--val_datas_path", type=str, default='./data/val_audio_paths.txt', help="val text file")parser.add_argument("--log_file", type=str, default="./log_output/speaker_identification_evaluate.log", help="log_file")parser.add_argument("--batch_size", type=int, default=64, help="batch size")parser.add_argument("--random_seed", type=int, default=100, help="random_seed")parser.add_argument("--device", type=str, default='0', help="device")args = parser.parse_args()return argsdef evaluate(args):os.environ['CUDA_VISIBLE_DEVICES'] = args.devicelogger = Logger(log_name='SI',log_level=10,log_file=args.log_file).loggerlogger.info(args)label2ids = {}id = 0with open(args.train_datas_path,'r',encoding='utf-8') as f:lines = f.readlines()for line in lines:line = line.strip('\n')if line.split('\t')[-1] not in label2ids:label2ids[line.split('\t')[-1]] = idid += 1with open(args.val_datas_path,'r',encoding='utf-8') as f:lines = f.readlines()for line in lines:line = line.strip('\n')if line.split('\t')[-1] not in label2ids:label2ids[line.split('\t')[-1]] = idid += 1augmentors = {}with open("augment.ymal",'r', encoding="utf-8") as fp:configs = yaml.load(fp, Loader=yaml.FullLoader)augmentors['noise'] = NoisePerturbAugmentor(**configs['noise'])augmentors['speed'] = SpeedPerturbAugmentor(**configs['speed'])augmentors['volume'] = VolumePerturbAugmentor(**configs['volume'])augmentors['specaug'] = SpecAugmentor(**configs['specaug'])augmentors = Nonedevice = "cuda:0" if torch.cuda.is_available() else "cpu"val_dataset = AudioDataReader(feature_method='melspectrogram', data_list_path=args.val_datas_path, mode='eval', label2ids = label2ids,augmentors=augmentors)val_dataloader = DataLoader(val_dataset, shuffle=True, batch_size=args.batch_size)num_class = 875logger.info('num_class:%d'%num_class)ecapa_tdnn = ECAPATDNN(input_size=val_dataset.input_size)model = SpeakerIdentificationModel(backbone=ecapa_tdnn, num_class=num_class).to(device)weights = torch.load('./output/2022-11-07/ecapa_tdnn.bin')model.load_state_dict(weights)model.eval()logger.info("***** Running evaluate *****")logger.info("  Num examples = %d" % len(val_dataset))pbar = ProgressBar(n_total=len(val_dataloader), desc='extract features')model.eval()labels = []features = []with torch.no_grad():for step, batch in enumerate(val_dataloader):batch = [t.to(device) for t in batch]audio = batch[0]speakers = batch[1]output = model.backbone(audio)labels.append(speakers)features.append(output.squeeze(2))pbar(step,info={'step':step})labels = torch.cat(labels)features = torch.cat(features)scores_pos = []scores_neg = []y_true_pos = []y_true_neg = []for i in tqdm(range(features.shape[0]),desc='两两计算相似度',ncols=100):query = features[i]inside = features[i:,:]temp = (labels[i] == labels[i:]).detach().long()pos_index = torch.nonzero(temp==1)neg_index = torch.nonzero(temp==0)pos_label = torch.take(temp,pos_index).squeeze(1).detach().cpu().tolist()neg_label = torch.take(temp, neg_index).squeeze(1).detach().cpu().tolist()cos = torch.cosine_similarity(query, inside, dim=-1)pos_score = torch.take(cos,pos_index).squeeze(1).detach().cpu().tolist()neg_score = torch.take(cos,neg_index).squeeze(1).detach().cpu().tolist()y_true_pos.extend(pos_label)y_true_neg.extend(neg_label)scores_pos.extend(pos_score)scores_neg.extend(neg_score)print('len(y_true_neg)',len(y_true_neg))print('len(y_true_pos)',len(y_true_pos))print('len(scores_pos)', len(scores_pos))print('len(scores_neg)', len(scores_neg))if len(y_true_pos) * 99 < len(y_true_neg):indexs = random.choices(list(range(len(y_true_neg))),k=len(y_true_pos)*99)scores = scores_posy_true = y_true_posfor index in indexs:scores.append(scores_neg[index])y_true.append(y_true_neg[index])else:scores = scores_pos + scores_negy_true = y_true_pos + y_true_negprint('len(scores)', len(scores))print('len(y_true)', len(y_true))scores = torch.tensor(scores,dtype=torch.float32)y_true = torch.tensor(y_true,dtype=torch.long)# choice_best_threshold(scores, y_true)choice_best_threshold_dcf(scores, y_true)def choice_best_threshold_dcf(scores, y_true):thresholds = []fars = []frrs = []dcfs = []precisions = []recalls = []f1s = []max_precision = 0max_recall = 0max_f1 = 0f1_threshold = 0min_dcf = 1d_threshold = 0cfr = 1cfa =1err = 0.0err_threshold = 0diff = 1for i in tqdm(range(100), desc='choice_best_threshold', ncols=100):threshold = 0.01 * ithresholds.append(threshold)y_preds = (scores > threshold).long()tp = ((y_true == 1) * (y_preds == 1)).sum().item()fp = ((y_true == 0) * (y_preds == 1)).sum().item()tn = ((y_true == 0) * (y_preds == 0)).sum().item()fn = ((y_true == 1) * (y_preds == 0)).sum().item()pos = tp + fnneg = tn + fpprecision = tp / (tp + fp+1e-13)recall = tp / (tp + fn+1e-13)f1 = 2 * precision * recall / (precision + recall + 1e-13)far = fp / (fp + tn + 1e-13)frr = fn / (tp + fn + 1e-13)dcf = cfa* far *(neg/(neg+pos)) + cfr* frr *(pos/(pos+neg))precisions.append(precision)recalls.append(recall)f1s.append(f1)fars.append(far)frrs.append(frr)dcfs.append(dcf)if max_precision < precision:max_precision = precisionif max_recall < recall:max_recall = recallif max_f1 < f1:max_f1 = f1f1_threshold = thresholdif min_dcf > dcf:min_dcf = dcfd_threshold = thresholdif abs(far-frr) < diff:err = (far+frr)/2diff = abs(far-frr)err_threshold = thresholdprint(pos + neg)print('threshold:%.4f err:%.4f'%(err_threshold, err))print("d_threshold:%.4f, min_dcf%.4f"%(d_threshold, min_dcf))print("f1_threshold:%.4f, max_f1%.4f" % (f1_threshold, max_f1))start = time.time()plt.figure(figsize=(30,30),dpi=80)plt.title('2D curve ')plt.plot(thresholds, frrs, label='frr')plt.plot(thresholds, fars, label='far')plt.plot(thresholds, dcfs, label='dcf')plt.plot(thresholds, precisions, label='pre')plt.plot(thresholds, recalls, label='recall')plt.plot(thresholds, f1s, label='f1')plt.legend(loc=0)plt.scatter(d_threshold, min_dcf, c='red', s=100)plt.text(d_threshold, min_dcf, " min_dcf(%.4f,%.4f)"%(d_threshold, min_dcf))plt.scatter(err_threshold,err,c='blue',s=100)plt.text(err_threshold,err," err(%.4f,%.4f)"%(err_threshold,err))plt.scatter(f1_threshold, max_f1, c='yellow', s=100)plt.text(f1_threshold, max_f1, " f1(%.4f,%.4f)"%(f1_threshold, max_f1))plt.xlabel('threshold')plt.ylabel('frr f dcf recall or precision')plt.xticks(thresholds[::2])plt.yticks(thresholds[::2])end = time.time()print('plot time is', end - start)plt.savefig('ecapatdnn_2d_curve_voiceprint_dcf.png')plt.show()print("finish")def choice_best_threshold(scores,y_true):best_precision_threshold = 0precision_best = 0precision_recall = 0precision_f1 = 0tp_1 = 0fp_1 = 0fn_1 = 0tn_1 = 0best_recall_threshold = 0recall_best = 0recall_precision = 0recall_f1 = 0tp_2 = 0fp_2 = 0fn_2 = 0tn_2 = 0best_f1_threshold = 0f1_best = 0f1_precision = 0f1_recall = 0tp_3 = 0fp_3 = 0fn_3 = 0tn_3 = 0fars = []#误接受率frrs = []#误拒识率far_min = 1frr_min = 1thresholds = []err = Nonetp_4 = 0fp_4 = 0fn_4 = 0tn_4 = 0diff = 1for i in tqdm( range(100),desc='choice_best_threshold',ncols=100):threshold = 0.01 * ithresholds.append(threshold)y_preds = (scores > threshold).long()tp = ((y_true == 1)*(y_preds==1)).sum().item()fp = ((y_true == 0)*(y_preds==1)).sum().item()tn = ((y_true==0)*(y_preds==0)).sum().item()fn = ((y_true==1)*(y_preds==0)).sum().item()precision = tp /(tp+fp)recall = tp/(tp+fn)f1 = 2*precision*recall/(precision+recall + 1e-13)far = fp/(fp+tn)frr = fn/(tp+fn)fars.append(far)frrs.append(frr)if precision > precision_best:precision_best = precisionbest_precision_threshold = thresholdprecision_recall = recallprecision_f1 = f1tp_1 = tpfp_1 = fpfn_1 = fntn_1 = tnif recall > recall_best:recall_best = recallbest_recall_threshold = thresholdrecall_precision = precisionrecall_f1 = f1tp_2 = tpfp_2 = fpfn_2 = fntn_2 = tnif f1 > f1_best:f1_best = f1f1_precision = precisionf1_recall = recallbest_f1_threshold = thresholdtp_3 = tpfp_3 = fpfn_3 = fntn_3 = tnif abs(far-frr) < diff:diff = abs(far-frr)err = (far+frr)/2far_min = farfrr_min = frrtp_4 = tpfp_4 = fpfn_4 = fntn_4 = tnprint(f"tp:{tp_4} fp{fp_4} tn{tn_4} fn{fn_4}")print("frr_min:%.4f,far_min:%.4f,err:%.4f"%(frr_min,far_min,err))print("precision:%.4f recall:%.4f"%(tp_4 /(tp_4+fp_4), tp_4/(tp_4+fn_4)))print('*'*50)print(f"tp:{tp_1} fp{fp_1} tn{tn_1} fn{fn_1}")print('best_precision_threshold:%.4f, precision_best:%.4f precision_recall:%.4f precision_f1:%.4f'%(best_precision_threshold,precision_best,precision_recall, precision_f1))print('*' * 50)print(f"tp:{tp_2} fp{fp_2} tn{tn_2} fn{fn_2}")print('best_recall_threshold:%.4f, recall_best:%.4f recall_precision:%.4f recall_f1:%.4f' % (best_recall_threshold, recall_best, recall_precision, recall_f1))print('*' * 50)print(f"tp:{tp_3} fp{fp_3} tn{tn_3} fn{fn_3}")print("frr:%.4f,far:%.4f"%(fn_3/(fn_3+tp_3),fp_3/(fp_3+tn_3)))print('best_f1_threshold:%.4f, f1_best:%.4f f1_precision:%.4f f1_recall:%.4f' % (best_f1_threshold, f1_best, f1_precision, f1_recall))print('*' * 50)# print(fars[0],"--",frrs[0])# print(fars[-1], "--", frrs[-1])## plt.figure(figsize=(20,20),dpi=80)# plt.title('2D curve ')# plt.plot(fars, frrs)# plt.plot(thresholds,thresholds)# plt.scatter(err,err,c='red',s=100)# plt.text(err,err,(err,err))## plt.xlabel('far')# plt.ylabel('frr')# plt.xticks(thresholds[::2])# plt.yticks(thresholds[::2])# plt.show()# plt.savefig('2d_curve_voiceprint_det.png')def set_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Truedef collate_fn(batch):features,labels = zip(*batch)return featuresif __name__ == '__main__':args = parse_args()set_seed(args.random_seed)evaluate(args)

采用far和frr以及errdct等评价指标来获取最佳threshold:

 可以看到最小dcf对应的相似度阈值是0.4500。

2、WavLm预训练方案

a、模型结构和loss

from transformers import WavLMModel, WavLMPreTrainedModel
from transformers.modeling_outputs import XVectorOutput
from transformers.pytorch_utils import torch_int_div
import torch.nn as nn
import torch
from typing import Optional, Tuple, Union_HIDDEN_STATES_START_POSITION = 2class TDNNLayer(nn.Module):def __init__(self, config, layer_id=0):super().__init__()self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]self.out_conv_dim = config.tdnn_dim[layer_id]self.kernel_size = config.tdnn_kernel[layer_id]self.dilation = config.tdnn_dilation[layer_id]self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)self.activation = nn.ReLU()def forward(self, hidden_states):hidden_states = hidden_states.unsqueeze(1)hidden_states = nn.functional.unfold(hidden_states,(self.kernel_size, self.in_conv_dim),stride=(1, self.in_conv_dim),dilation=(self.dilation, 1),)hidden_states = hidden_states.transpose(1, 2)hidden_states = self.kernel(hidden_states)hidden_states = self.activation(hidden_states)return hidden_statesclass AMSoftmaxLoss(nn.Module):def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):super(AMSoftmaxLoss, self).__init__()self.scale = scaleself.margin = marginself.num_labels = num_labelsself.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)self.loss = nn.CrossEntropyLoss()def forward(self, hidden_states, labels = None):weight = nn.functional.normalize(self.weight, dim=0)hidden_states = nn.functional.normalize(hidden_states, dim=1)cos_theta = torch.mm(hidden_states, weight)if labels is not None:psi = cos_theta - self.marginlabels = labels.flatten()onehot = nn.functional.one_hot(labels, self.num_labels)logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)loss = self.loss(logits, labels)return loss,cos_thetaelse:return cos_thetaclass WavLm(WavLMPreTrainedModel):def __init__(self,config):super(WavLm, self).__init__(config)self.wavlm = WavLMModel(config)num_layers = config.num_hidden_layers + 1  # transformer layers + input embeddingsif config.use_weighted_layer_sum:self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]self.tdnn = nn.ModuleList(tdnn_layers)self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)self.init_weights()def forward(self,input_values: Optional[torch.Tensor],attention_mask: Optional[torch.Tensor] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,labels: Optional[torch.Tensor] = None,):return_dict = return_dict if return_dict is not None else self.config.use_return_dictoutput_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_statesoutputs = self.wavlm(input_values,attention_mask=attention_mask,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)if self.config.use_weighted_layer_sum:hidden_states = outputs[_HIDDEN_STATES_START_POSITION]hidden_states = torch.stack(hidden_states, dim=1)norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)else:hidden_states = outputs[0]hidden_states = self.projector(hidden_states)for tdnn_layer in self.tdnn:hidden_states = tdnn_layer(hidden_states)# Statistic Poolingif attention_mask is None:mean_features = hidden_states.mean(dim=1)std_features = hidden_states.std(dim=1)else:feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)mean_features = []std_features = []for i, length in enumerate(tdnn_output_lengths):mean_features.append(hidden_states[i, :length].mean(dim=0))std_features.append(hidden_states[i, :length].std(dim=0))mean_features = torch.stack(mean_features)std_features = torch.stack(std_features)statistic_pooling = torch.cat([mean_features, std_features], dim=-1)output_embeddings = self.feature_extractor(statistic_pooling)logits = self.classifier(output_embeddings)loss = Noneif labels is not None:loss, cos_theta = self.objective(logits, labels)else:cos_theta = self.objective(logits, labels)logits = cos_thetaif not return_dict:output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]return ((loss,) + output) if loss is not None else outputreturn XVectorOutput(loss=loss,logits=logits,embeddings=output_embeddings,hidden_states=outputs.hidden_states,attentions=outputs.attentions,)def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):"""Computes the output length of the TDNN layers"""def _conv_out_length(input_length, kernel_size, stride):# 1D convolutional layer output length formula taken# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.htmlreturn (input_length - kernel_size) // stride + 1for kernel_size in self.config.tdnn_kernel:input_lengths = _conv_out_length(input_lengths, kernel_size, 1)return input_lengthsdef _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None):"""Computes the output length of the convolutional layers"""add_adapter = self.config.add_adapter if add_adapter is None else add_adapterdef _conv_out_length(input_length, kernel_size, stride):# 1D convolutional layer output length formula taken# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.htmlreturn torch_int_div(input_length - kernel_size, stride) + 1for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):input_lengths = _conv_out_length(input_lengths, kernel_size, stride)if add_adapter:for _ in range(self.config.num_adapter_layers):input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)return input_lengths

b、数据处理

import random
import torch
from torch.utils.data import Dataset
import torchaudio
from tqdm import tqdmclass AudioDataReader(Dataset):def __init__(self, data_list_path,mode='train',sr=16000,chunk_duration=3,min_duration=0.5,label2ids = {},augmentors=None):super(AudioDataReader, self).__init__()assert data_list_path is not Nonewith open(data_list_path,'r',encoding='utf-8') as f:self.lines = f.readlines()[0:]self.mode = modeself.sr = srself.chunk_duration = chunk_durationself.min_duration = min_durationself.augmentors = augmentorsself.label2ids = label2idsself.audiofeatures = self.getaudiofeatures()def handle_features(self,wav,sr,mode,chunk_duration,min_duration):num_wav_samples = wav.shape[1]# 数据太短不利于训练if mode == 'train':if num_wav_samples < int(min_duration * sr):raise Exception(f'音频长度小于{min_duration}s,实际长度为:{(num_wav_samples / sr):.2f}s')# print(f'音频长度小于{min_duration}s,实际长度为:{(num_wav_samples / sr):.2f}s')# return None# 对小于训练长度的复制补充num_chunk_samples = int(chunk_duration * sr)if num_wav_samples < num_chunk_samples:times = int(num_chunk_samples / num_wav_samples) - 1shortages = []temp_num_wav_samples = num_wav_samplesshortages.append(wav)if times >= 1:for _ in range(times):shortages.append(wav)temp_num_wav_samples += num_wav_samplesshortages.append(wav[:, 0:(num_chunk_samples - temp_num_wav_samples)])else:shortages.append(wav[:, 0:(num_chunk_samples - num_wav_samples)])wav = torch.cat(shortages, dim=1)# 裁剪需要的数据if mode == 'train':# 随机裁剪num_wav_samples = wav.shape[1]num_chunk_samples = int(chunk_duration * sr)if num_wav_samples > num_chunk_samples + 1:start = random.randint(0, num_wav_samples - num_chunk_samples - 1)end = start + num_chunk_sampleswav = wav[:, start:end]# # 对每次都满长度的再次裁剪# if random.random() > 0.5:#     wav[:random.randint(1, sr // 4)] = 0 #加入了静音数据#     wav = wav[:-random.randint(1, sr // 4)]elif mode == 'eval':# 为避免显存溢出,只裁剪指定长度num_wav_samples = wav.shape[1]num_chunk_samples = int(chunk_duration * sr)if num_wav_samples > num_chunk_samples + 1:wav = wav[:, 0:num_chunk_samples]return wavdef getaudiofeatures(self):res = []for line in tqdm(self.lines,desc= self.mode + ' load all audios',ncols=100):temp = []try:audio_path, label = line.replace('\n', '').split('\t')label = self.label2ids[label]wav, sample_rate = torchaudio.load(audio_path)  # 加载音频返回的是张量wav = self.handle_features(wav,sr=self.sr,mode=self.mode,chunk_duration=self.chunk_duration,min_duration=self.min_duration)features = wav[:,0:self.sr*self.chunk_duration].squeeze(0)attention_mask = torch.ones_like(features,dtype=torch.long)label = torch.as_tensor(label, dtype=torch.long)temp.append(features)temp.append(attention_mask)temp.append(label)res.append(temp)except Exception as e:print(e+',load audio data exception')return resdef __getitem__(self, item):return self.audiofeatures[item][0], self.audiofeatures[item][1], self.audiofeatures[item][2]def __len__(self):return len(self.audiofeatures)

        和Ecapa_TDNN的不同就是直接采用时域数据而不是采用语音特征分析后的频域信息,代码就是训练和验证样本的长度进行了控制,比较简单。

c、模型训练

from transformers import Wav2Vec2Config
from models.wavlm import WavLm
from tools.log import Logger
from tools.progressbar import ProgressBar
from data_utils.wavlm_reader import AudioDataReaderfrom torch.utils.data import DataLoader
import torch
import os
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import argparseimport random
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from torch.nn.utils.rnn import pad_sequencedef parse_args():parser = argparse.ArgumentParser()parser.add_argument("--train_datas_path", type=str, default='./data/train_audio_paths.txt', help="train text file")parser.add_argument("--val_datas_path", type=str, default='./data/val_audio_paths.txt', help="val text file")# parser.add_argument("--train_datas_path", type=str, default='./data/train_audio_paths_small.txt', help="train text file")# parser.add_argument("--val_datas_path", type=str, default='./data/val_audio_paths_small.txt', help="val text file")parser.add_argument("--log_file", type=str, default="./log_output/speaker_identification_wavlm.log", help="log_file")parser.add_argument("--model_out", type=str, default="./output/wavlm/", help="model output path")parser.add_argument("--batch_size", type=int, default=32, help="batch size")parser.add_argument("--epochs", type=int, default=30, help="epochs")parser.add_argument("--lr", type=float, default=1e-5, help="epochs")parser.add_argument("--random_seed", type=int, default=100, help="random_seed")parser.add_argument("--device", type=str, default='0', help="device")args = parser.parse_args()return argsdef training(args):os.environ['CUDA_VISIBLE_DEVICES'] = args.devicelogger = Logger(log_name='SI',log_level=10,log_file=args.log_file).loggerlogger.info(args)label2ids = {}config = Wav2Vec2Config.from_pretrained('./pretrained_models/torch/wavlm-base-plus-sv/')id = 0with open(args.train_datas_path,'r',encoding='utf-8') as f:lines = f.readlines()for line in lines:line = line.strip('\n')if line.split('\t')[-1] not in label2ids:label2ids[line.split('\t')[-1]] = idid += 1with open(args.val_datas_path,'r',encoding='utf-8') as f:lines = f.readlines()for line in lines:line = line.strip('\n')if line.split('\t')[-1] not in label2ids:label2ids[line.split('\t')[-1]] = idid += 1time_srt = datetime.now().strftime('%Y-%m-%d')save_path = os.path.join(args.model_out,time_srt)if not os.path.exists(save_path):os.makedirs(save_path)logger.info(save_path)device = "cuda:0" if torch.cuda.is_available() else "cpu"train_dataset = AudioDataReader(data_list_path=args.train_datas_path,mode='train', label2ids=label2ids)train_dataloader = DataLoader(train_dataset,shuffle=True,batch_size=args.batch_size, collate_fn=collate_fn)val_dataset = AudioDataReader(data_list_path=args.val_datas_path, mode='eval', label2ids = label2ids)val_dataloader = DataLoader(val_dataset, shuffle=True, batch_size=args.batch_size, collate_fn=collate_fn)num_class = len(label2ids)logger.info('num_class:%d'%num_class)config.num_labels = num_classmodel = WavLm.from_pretrained('./pretrained_models/torch/wavlm-base-plus-sv/', config=config, ignore_mismatched_sizes=True).to(device)model.eval()# ecapa_tdnn = EcapaTdnn(input_size=train_dataset.input_size)# model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=num_class).to(device)# logger.info(model)optimizer = AdamW(lr=args.lr,params=model.parameters())scheduler = CosineAnnealingLR(optimizer,T_max=args.epochs)logger.info("***** Running training *****")logger.info("  Num examples = %d" % len(train_dataloader))logger.info("  Num Epochs = %d" % args.epochs)writer = SummaryWriter('./runs/' + time_srt + '/')best_acc = 0total_step = 0unimproving_count = 0for epoch in range(args.epochs):pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')model.train()total_loss = 0for step, batch in enumerate(train_dataloader):batch = [t.to(device) for t in batch]wav = batch[0]mask = batch[1]speakers = batch[2]inputs = {"input_values": wav,"attention_mask": mask}output = model(**inputs,labels=speakers)loss = output.lossoptimizer.zero_grad()# loss.backward(retain_graph=True)loss.backward()optimizer.step()total_step += 1writer.add_scalar('Train/Learning loss', loss.item(), total_step)total_loss += loss.item()pbar(step, {'loss': loss.item()})val_acc = evaluate(model, val_dataloader, device)if best_acc < val_acc:best_acc = val_accmodel.save_pretrained(save_path)is_improving = Trueunimproving_count = 0else:is_improving = Falseunimproving_count += 1if is_improving:logger.info(f"Train epoch [{epoch+1}/{args.epochs}],batch [{step+1}],Best_acc: {best_acc},Val_acc:{val_acc}, lr:{scheduler.get_last_lr()[0]}, total_loss:{round(total_loss,4)}. Save model!")else:logger.info(f"Train epoch [{epoch+1}/{args.epochs}],batch [{step+1}],Best_acc: {best_acc},Val_acc:{val_acc}, lr:{scheduler.get_last_lr()[0]}, total_loss:{round(total_loss,4)}.")writer.add_scalar('Val/val_acc', val_acc, total_step)writer.add_scalar('Val/best_acc', best_acc, total_step)# writer.add_scalar('Train/Learning rate', scheduler.get_lr()[0], total_step)writer.add_scalar('Train/Learning rate', scheduler.get_last_lr()[0], total_step)scheduler.step()if unimproving_count >= 5:logger.info('unimproving %d epochs, early stop!'%unimproving_count)breakdef evaluate(model,val_dataloader,device):total = 0correct_total = 0model.eval()with torch.no_grad():pbar = ProgressBar(n_total=len(val_dataloader), desc='evaluate')for step, batch in enumerate(val_dataloader):batch = [t.to(device) for t in batch]wav = batch[0]mask = batch[1]speakers = batch[2]inputs = {"input_values": wav,"attention_mask": mask}output = model(**inputs)logits = output.logitstotal += speakers.shape[0]preds = torch.argmax(logits,dim=-1)correct = (speakers==preds).sum().item()pbar(step, {})correct_total += correctacc = correct_total/totalreturn accdef set_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Truedef collate_fn(batch):features, attention_mask, labels = zip(*batch)features = pad_sequence(features, batch_first=True, padding_value=0.0)attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)labels = torch.stack(labels, dim=-1)return features, attention_mask, labelsif __name__ == '__main__':args = parse_args()set_seed(args.random_seed)training(args)

结果如下:

 分类准确率:0.9684

d、推理和评估

同样采用far frr err dcf 以及f1 recall和precision等指标来评估

from transformers import WavLMForXVector
from tools.log import Logger
from tools.progressbar import ProgressBar
from data_utils.wavlm_reader import AudioDataReader
from torch.utils.data import DataLoader
import torch
import os
import argparse
import random
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.nn.utils.rnn import pad_sequence
import timedef parse_args():parser = argparse.ArgumentParser()parser.add_argument("--train_datas_path", type=str, default='./data/train_audio_paths.txt', help="train text file")parser.add_argument("--val_datas_path", type=str, default='./data/val_audio_paths.txt', help="val text file")# parser.add_argument("--train_datas_path", type=str, default='./data/train_audio_paths_small.txt', help="train text file")# parser.add_argument("--val_datas_path", type=str, default='./data/val_audio_paths_small.txt', help="val text file")parser.add_argument("--log_file", type=str, default="./log_output/speaker_identification_evaluate.log", help="log_file")parser.add_argument("--batch_size", type=int, default=64, help="batch size")parser.add_argument("--random_seed", type=int, default=100, help="random_seed")parser.add_argument("--device", type=str, default='0', help="device")args = parser.parse_args()return argsdef evaluate(args):os.environ['CUDA_VISIBLE_DEVICES'] = args.devicelogger = Logger(log_name='SI',log_level=10,log_file=args.log_file).loggerlogger.info(args)label2ids = {}id = 0with open(args.train_datas_path,'r',encoding='utf-8') as f:lines = f.readlines()for line in lines:line = line.strip('\n')if line.split('\t')[-1] not in label2ids:label2ids[line.split('\t')[-1]] = idid += 1with open(args.val_datas_path,'r',encoding='utf-8') as f:lines = f.readlines()for line in lines:line = line.strip('\n')if line.split('\t')[-1] not in label2ids:label2ids[line.split('\t')[-1]] = idid += 1device = "cuda:0" if torch.cuda.is_available() else "cpu"val_dataset = AudioDataReader( data_list_path=args.val_datas_path, mode='eval', label2ids = label2ids)val_dataloader = DataLoader(val_dataset, shuffle=True, batch_size=args.batch_size,collate_fn=collate_fn)num_class = 875logger.info('num_class:%d'%num_class)model = WavLMForXVector.from_pretrained('./output/wavlm/2022-11-11/').to(device)model.eval()logger.info("***** Running evaluate *****")logger.info("  Num examples = %d" % len(val_dataset))pbar = ProgressBar(n_total=len(val_dataloader), desc='extract features')model.eval()labels = []features = []with torch.no_grad():for step, batch in enumerate(val_dataloader):batch = [t.to(device) for t in batch]wav = batch[0]mask = batch[1]speakers = batch[2]inputs = {"input_values": wav,"attention_mask": mask}output = model(**inputs)labels.append(speakers)features.append(output.embeddings)pbar(step,info={'step':step})labels = torch.cat(labels)features = torch.cat(features)scores_pos = []scores_neg = []y_true_pos = []y_true_neg = []for i in tqdm(range(features.shape[0]), desc='两两计算相似度', ncols=100):query = features[i]inside = features[i:, :]temp = (labels[i] == labels[i:]).detach().long()pos_index = torch.nonzero(temp == 1)neg_index = torch.nonzero(temp == 0)pos_label = torch.take(temp, pos_index).squeeze(1).detach().cpu().tolist()neg_label = torch.take(temp, neg_index).squeeze(1).detach().cpu().tolist()cos = torch.cosine_similarity(query, inside, dim=-1)pos_score = torch.take(cos, pos_index).squeeze(1).detach().cpu().tolist()neg_score = torch.take(cos, neg_index).squeeze(1).detach().cpu().tolist()y_true_pos.extend(pos_label)y_true_neg.extend(neg_label)scores_pos.extend(pos_score)scores_neg.extend(neg_score)print('len(y_true_neg)', len(y_true_neg))print('len(y_true_pos)', len(y_true_pos))print('len(scores_pos)', len(scores_pos))print('len(scores_neg)', len(scores_neg))if len(y_true_pos) * 99 < len(y_true_neg):indexs = random.choices(list(range(len(y_true_neg))), k=len(y_true_pos) * 99)scores = scores_posy_true = y_true_posfor index in indexs:scores.append(scores_neg[index])y_true.append(y_true_neg[index])else:scores = scores_pos + scores_negy_true = y_true_pos + y_true_negprint('len(scores)', len(scores))print('len(y_true)', len(y_true))scores = torch.tensor(scores,dtype=torch.float32)y_true = torch.tensor(y_true,dtype=torch.long)choice_best_threshold_dcf(scores, y_true)def choice_best_threshold_dcf(scores, y_true):thresholds = []fars = []frrs = []dcfs = []precisions = []recalls = []f1s = []max_precision = 0max_recall = 0max_f1 = 0f1_threshold = 0min_dcf = 1d_threshold = 0cfr = 1cfa =1err = 0.0err_threshold = 0diff = 1for i in tqdm(range(100), desc='choice_best_threshold', ncols=100):threshold = 0.01 * ithresholds.append(threshold)y_preds = (scores > threshold).long()tp = ((y_true == 1) * (y_preds == 1)).sum().item()fp = ((y_true == 0) * (y_preds == 1)).sum().item()tn = ((y_true == 0) * (y_preds == 0)).sum().item()fn = ((y_true == 1) * (y_preds == 0)).sum().item()pos = tp + fnneg = tn + fpprecision = tp / (tp + fp+1e-13)recall = tp / (tp + fn+1e-13)f1 = 2 * precision * recall / (precision + recall + 1e-13)far = fp / (fp + tn + 1e-13)frr = fn / (tp + fn + 1e-13)dcf = cfa* far *(neg/(neg+pos)) + cfr* frr *(pos/(pos+neg))precisions.append(precision)recalls.append(recall)f1s.append(f1)fars.append(far)frrs.append(frr)dcfs.append(dcf)if max_precision < precision:max_precision = precisionif max_recall < recall:max_recall = recallif max_f1 < f1:max_f1 = f1f1_threshold = thresholdif min_dcf > dcf:min_dcf = dcfd_threshold = thresholdif abs(far-frr) < diff:err = (far+frr)/2diff = abs(far-frr)err_threshold = thresholdprint(pos + neg)print('threshold:%.4f err:%.4f'%(err_threshold, err))print("d_threshold:%.4f, min_dcf%.4f"%(d_threshold, min_dcf))print("f1_threshold:%.4f, max_f1%.4f" % (f1_threshold, max_f1))start = time.time()plt.figure(figsize=(30,30),dpi=80)plt.title('2D curve ')plt.plot(thresholds, frrs, label='frr')plt.plot(thresholds, fars, label='far')plt.plot(thresholds, dcfs, label='dcf')plt.plot(thresholds, precisions, label='pre')plt.plot(thresholds, recalls, label='recall')plt.plot(thresholds, f1s, label='f1')plt.legend(loc=0)plt.scatter(d_threshold, min_dcf, c='red', s=100)plt.text(d_threshold, min_dcf, " min_dcf(%.4f,%.4f)"%(d_threshold, min_dcf))plt.scatter(err_threshold,err,c='blue',s=100)plt.text(err_threshold,err," err(%.4f,%.4f)"%(err_threshold,err))plt.scatter(f1_threshold, max_f1, c='yellow', s=100)plt.text(f1_threshold, max_f1, " f1(%.4f,%.4f)"%(f1_threshold, max_f1))plt.xlabel('threshold')plt.ylabel('frr f dcf recall or precision')plt.xticks(thresholds[::2])plt.yticks(thresholds[::2])end = time.time()print('plot time is', end - start)plt.savefig('wavlm_2d_curve_voiceprint_dcf.png')plt.show()print("finish")def set_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Truedef collate_fn(batch):features,attention_mask,labels = zip(*batch)features = pad_sequence(features,batch_first=True,padding_value=0.0)attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)labels = torch.stack(labels,dim=-1)return features, attention_mask, labelsif __name__ == '__main__':args = parse_args()set_seed(args.random_seed)evaluate(args)

结果如下

        threshold=0.69  dcf 和f1值都处于最佳状态 而且f1=0.9765 err和dcf值都非常低,明显wavLm模型在该数据集上的效果要优于Ecapa_TDNN。

四、demo演示

       花了接近两周下班后的时间以及周末可以去学习了一下vue2.0和vue3.0,看的是b站尚硅谷的视频,做了一个speaker verification的前端demo(vue3.0)。先看看整体页面效果:

大体上说说demo的实现方案:

        1、后端直接使用python+flask非常简单。

        2、前端采用vue3.0+html+css做一些简单的页面也非常容易(不过完全不懂前端的话学习起来还是需要一点时间的)。

        3、算法端python+torch,模型使用了WavLm和Ecapa_TdNN模型。

五、总结

        关于这个声纹识别,本文章只是简单的做了一个尝试和验证一下主流的模型方案的效果。并没有考虑实际业务场景,比方说音频的背景是否有噪声、跨设备、跨距离、录音代替真人实时说话问题、以及如何优化、上线需要注意那些问题都没有讨论。这里面还有很多值得学习的地方,本人水平有限,后续再来学习。

        关于预训练模型WavLM和CNN组网模型,个人认为WavLm应该是更加主流,个人更看好WavLm,如果有相应的音频数据,继续预训练+微调应该能解决一些特定领域的问题,前提是要有大规模的数据。

参考文章:

Speaker Verification——学习笔记

说话人确认系统性能评价指标EER和minDCF

ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification

通用模型、全新框架,WavLM语音预训练模型全解

WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing

相关内容

热门资讯

美国不提安卓系统华为,迈向自主... 华为与美国:一场关于技术、市场与政策的较量在当今这个数字化的世界里,智能手机已经成为我们生活中不可或...
安卓系统怎么打开ppt,选择文... 你有没有遇到过这种情况:手里拿着安卓手机,突然需要打开一个PPT文件,却怎么也找不到方法?别急,今天...
谷歌退回到安卓系统,探索创新未... 你知道吗?最近科技圈可是炸开了锅,谷歌竟然宣布要退回到安卓系统!这可不是一个简单的决定,背后肯定有着...
安卓系统待机耗电多少,深度解析... 你有没有发现,手机电量总是不经用?尤其是安卓系统,有时候明明没怎么用,电量就“嗖”的一下子就下去了。...
小米主题安卓原生系统,安卓原生... 亲爱的手机控们,你是否曾为手机界面单调乏味而烦恼?想要给手机换换“衣服”,让它焕然一新?那就得聊聊小...
voyov1安卓系统,探索创新... 你有没有发现,最近你的手机是不是变得越来越流畅了?没错,我要说的就是那个让手机焕发青春的Vivo V...
电脑刷安卓tv系统,轻松打造智... 你有没有想过,家里的安卓电视突然变得卡顿,反应迟钝,是不是时候给它来个“大保健”了?没错,今天就要来...
安卓系统即将要收费,未来手机应... 你知道吗?最近有个大消息在科技圈里炸开了锅,那就是安卓系统可能要开始收费了!这可不是开玩笑的,这可是...
雷凌车载安卓系统,智能出行新体... 你有没有发现,现在的汽车越来越智能了?这不,我最近就体验了一把雷凌车载安卓系统的魅力。它就像一个聪明...
怎样拍照好看安卓系统,轻松拍出... 拍照好看,安卓系统也能轻松搞定!在这个看脸的时代,拍照已经成为每个人生活中不可或缺的一部分。无论是记...
安卓车机系统音频,安卓车机系统... 你有没有发现,现在越来越多的汽车都开始搭载智能车机系统了?这不,咱们就来聊聊安卓车机系统在音频方面的...
老苹果手机安卓系统,兼容与创新... 你手里那台老苹果手机,是不是已经陪你走过了不少风风雨雨?现在,它竟然还能装上安卓系统?这可不是天方夜...
安卓系统7.dns,优化网络连... 你有没有发现,你的安卓手机最近是不是有点儿“慢吞吞”的?别急,别急,让我来给你揭秘这可能与你的安卓系...
安卓手机系统怎么加速,安卓手机... 你有没有发现,你的安卓手机最近变得有点“慢吞吞”的?别急,别急,今天就来给你支几招,让你的安卓手机瞬...
小米note安卓7系统,探索性... 你有没有发现,手机更新换代的速度简直就像坐上了火箭呢?这不,小米Note这款手机,自从升级到了安卓7...
安卓和鸿蒙系统游戏,两大系统游... 你有没有发现,最近手机游戏界可是热闹非凡呢!安卓和鸿蒙系统两大巨头在游戏领域展开了一场激烈的较量。今...
安卓手机没有系统更,揭秘潜在风... 你有没有发现,现在安卓手机的品牌和型号真是五花八门,让人挑花了眼。不过,你知道吗?尽管市面上安卓手机...
充值宝带安卓系统,安卓系统下的... 你有没有发现,最近手机上的一款充值宝APP,在安卓系统上可是火得一塌糊涂呢!这不,今天就来给你好好扒...
安卓系统8.0镜像下载,轻松打... 你有没有想过,想要给你的安卓手机升级到最新的系统,却不知道从哪里下载那个神秘的安卓系统8.0镜像呢?...
安卓系统修改大全,全方位修改大... 你有没有想过,你的安卓手机其实是个大宝藏,里面藏着无数可以让你手机焕然一新的秘密?没错,今天就要来个...