多种预训练任务解决NLP处理SMILES的多种弊端,代码:Knowledge-based-BERT,原文:Knowledge-based BERT: a method to extract molecular features like computational chemists,代码解析继续downstream_task。模型框架如下:
for task in args['task_name_list']:args['task_name'] = taskargs['data_path'] = '../data/task_data/' + args['task_name'] + '.npy'all_times_train_result = []all_times_val_result = []all_times_test_result = []result_pd = pd.DataFrame()result_pd['index'] = ['roc_auc', 'accuracy', 'sensitivity', 'specificity', 'f1-score', 'precision', 'recall','error rate', 'mcc']for time_id in range(args['times']):set_random_seed(2020+time_id)train_set, val_set, test_set, task_number = build_data.load_data_for_random_splited(data_path=args['data_path'], shuffle=True)print("Molecule graph is loaded!")
def load_data_for_random_splited(data_path='example.npy', shuffle=True):data = np.load(data_path, allow_pickle=True)smiles_list = data[0]tokens_idx_list = data[1]labels_list = data[2]mask_list = data[3]group_list = data[4]if shuffle:random.shuffle(group_list)print(group_list)train_set = []val_set = []test_set = []task_number = len(labels_list[1])for i, group in enumerate(group_list):molecule = [smiles_list[i], tokens_idx_list[i], labels_list[i], mask_list[i]]if group == 'training':train_set.append(molecule)elif group == 'val':val_set.append(molecule)else:test_set.append(molecule)print('Training set: {}, Validation set: {}, Test set: {}, task number: {}'.format(len(train_set), len(val_set), len(test_set), task_number))return train_set, val_set, test_set, task_number
train_loader = DataLoader(dataset=train_set,batch_size=args['batch_size'],shuffle=True,collate_fn=collate_data)val_loader = DataLoader(dataset=val_set,batch_size=args['batch_size'],collate_fn=collate_data)test_loader = DataLoader(dataset=test_set,batch_size=args['batch_size'],collate_fn=collate_data)pos_weight_task = pos_weight(train_set)one_time_train_result = []one_time_val_result = []one_time_test_result = []print('***************************************************************************************************')print('{}, {}/{} time'.format(args['task_name'], time_id+1, args['times']))print('***************************************************************************************************')loss_criterion = torch.nn.BCEWithLogitsLoss(reduction='none', pos_weight=pos_weight_task.to(args['device']))model = K_BERT_WCL(d_model=args['d_model'], n_layers=args['n_layers'], vocab_size=args['vocab_size'],maxlen=args['maxlen'], d_k=args['d_k'], d_v=args['d_v'], n_heads=args['n_heads'], d_ff=args['d_ff'],global_label_dim=args['global_labels_dim'], atom_label_dim=args['atom_labels_dim'])stopper = EarlyStopping(patience=args['patience'], pretrained_model=args['pretrain_model'],pretrain_layer=args['pretrain_layer'],task_name=args['task_name']+'_downstream_k_bert_wcl', mode=args['mode'])model.to(args['device'])stopper.load_pretrained_model(model)optimizer = Adam(model.parameters(), lr=args['lr'])
def pos_weight(train_set):smiles, tokens_idx, labels, mask = map(list, zip(*train_set))task_pos_weight_list = []for j in range(len(labels[1])):num_pos = 0num_impos = 0for i in labels:if i[j] == 1:num_pos = num_pos + 1if i[j] == 0:num_impos = num_impos + 1task_pos_weight = num_impos / (num_pos+0.00000001)task_pos_weight_list.append(task_pos_weight)return torch.tensor(task_pos_weight_list)
def load_pretrained_model(self, model):if self.pretrain_layer == 1:pretrained_parameters = ['embedding.tok_embed.weight', 'embedding.pos_embed.weight', 'embedding.norm.weight', 'embedding.norm.bias', 'layers.0.enc_self_attn.linear.weight', 'layers.0.enc_self_attn.linear.bias', 'layers.0.enc_self_attn.layernorm.weight', 'layers.0.enc_self_attn.layernorm.bias', 'layers.0.enc_self_attn.W_Q.weight', 'layers.0.enc_self_attn.W_Q.bias', 'layers.0.enc_self_attn.W_K.weight', 'layers.0.enc_self_attn.W_K.bias', 'layers.0.enc_self_attn.W_V.weight', 'layers.0.enc_self_attn.W_V.bias', 'layers.0.pos_ffn.fc.0.weight', 'layers.0.pos_ffn.fc.2.weight', 'layers.0.pos_ffn.layernorm.weight', 'layers.0.pos_ffn.layernorm.bias']elif self.pretrain_layer == 2:pretrained_parameters = ['embedding.tok_embed.weight', 'embedding.pos_embed.weight', 'embedding.norm.weight', 'embedding.norm.bias', 'layers.0.enc_self_attn.linear.weight', 'layers.0.enc_self_attn.linear.bias', 'layers.0.enc_self_attn.layernorm.weight', 'layers.0.enc_self_attn.layernorm.bias', 'layers.0.enc_self_attn.W_Q.weight', 'layers.0.enc_self_attn.W_Q.bias', 'layers.0.enc_self_attn.W_K.weight', 'layers.0.enc_self_attn.W_K.bias', 'layers.0.enc_self_attn.W_V.weight', 'layers.0.enc_self_attn.W_V.bias', 'layers.0.pos_ffn.fc.0.weight', 'layers.0.pos_ffn.fc.2.weight', 'layers.0.pos_ffn.layernorm.weight', 'layers.0.pos_ffn.layernorm.bias', 'layers.1.enc_self_attn.linear.weight', 'layers.1.enc_self_attn.linear.bias', 'layers.1.enc_self_attn.layernorm.weight', 'layers.1.enc_self_attn.layernorm.bias', 'layers.1.enc_self_attn.W_Q.weight', 'layers.1.enc_self_attn.W_Q.bias', 'layers.1.enc_self_attn.W_K.weight', 'layers.1.enc_self_attn.W_K.bias', 'layers.1.enc_self_attn.W_V.weight', 'layers.1.enc_self_attn.W_V.bias', 'layers.1.pos_ffn.fc.0.weight', 'layers.1.pos_ffn.fc.2.weight', 'layers.1.pos_ffn.layernorm.weight', 'layers.1.pos_ffn.layernorm.bias']elif self.pretrain_layer == 3:...elif self.pretrain_layer == 'all_12layer':pretrained_parameters = ['embedding.tok_embed.weight', 'embedding.pos_embed.weight','embedding.norm.weight', 'embedding.norm.bias','layers.0.enc_self_attn.linear.weight', 'layers.0.enc_self_attn.linear.bias','layers.0.enc_self_attn.layernorm.weight', 'layers.0.enc_self_attn.layernorm.bias','layers.0.enc_self_attn.W_Q.weight', 'layers.0.enc_self_attn.W_Q.bias','layers.0.enc_self_attn.W_K.weight', 'layers.0.enc_self_attn.W_K.bias','layers.0.enc_self_attn.W_V.weight', 'layers.0.enc_self_attn.W_V.bias','layers.0.pos_ffn.fc.0.weight', 'layers.0.pos_ffn.fc.2.weight','layers.0.pos_ffn.layernorm.weight', 'layers.0.pos_ffn.layernorm.bias','layers.1.enc_self_attn.linear.weight', 'layers.1.enc_self_attn.linear.bias','layers.1.enc_self_attn.layernorm.weight', 'layers.1.enc_self_attn.layernorm.bias','layers.1.enc_self_attn.W_Q.weight', 'layers.1.enc_self_attn.W_Q.bias','layers.1.enc_self_attn.W_K.weight', 'layers.1.enc_self_attn.W_K.bias','layers.1.enc_self_attn.W_V.weight', 'layers.1.enc_self_attn.W_V.bias','layers.1.pos_ffn.fc.0.weight', 'layers.1.pos_ffn.fc.2.weight','layers.1.pos_ffn.layernorm.weight', 'layers.1.pos_ffn.layernorm.bias','layers.2.enc_self_attn.linear.weight', 'layers.2.enc_self_attn.linear.bias','layers.2.enc_self_attn.layernorm.weight', 'layers.2.enc_self_attn.layernorm.bias','layers.2.enc_self_attn.W_Q.weight', 'layers.2.enc_self_attn.W_Q.bias','layers.2.enc_self_attn.W_K.weight', 'layers.2.enc_self_attn.W_K.bias','layers.2.enc_self_attn.W_V.weight', 'layers.2.enc_self_attn.W_V.bias','layers.2.pos_ffn.fc.0.weight', 'layers.2.pos_ffn.fc.2.weight','layers.2.pos_ffn.layernorm.weight', 'layers.2.pos_ffn.layernorm.bias','layers.3.enc_self_attn.linear.weight', 'layers.3.enc_self_attn.linear.bias','layers.3.enc_self_attn.layernorm.weight', 'layers.3.enc_self_attn.layernorm.bias','layers.3.enc_self_attn.W_Q.weight', 'layers.3.enc_self_attn.W_Q.bias','layers.3.enc_self_attn.W_K.weight', 'layers.3.enc_self_attn.W_K.bias','layers.3.enc_self_attn.W_V.weight', 'layers.3.enc_self_attn.W_V.bias','layers.3.pos_ffn.fc.0.weight', 'layers.3.pos_ffn.fc.2.weight','layers.3.pos_ffn.layernorm.weight', 'layers.3.pos_ffn.layernorm.bias','layers.4.enc_self_attn.linear.weight', 'layers.4.enc_self_attn.linear.bias','layers.4.enc_self_attn.layernorm.weight', 'layers.4.enc_self_attn.layernorm.bias','layers.4.enc_self_attn.W_Q.weight', 'layers.4.enc_self_attn.W_Q.bias','layers.4.enc_self_attn.W_K.weight', 'layers.4.enc_self_attn.W_K.bias','layers.4.enc_self_attn.W_V.weight', 'layers.4.enc_self_attn.W_V.bias','layers.4.pos_ffn.fc.0.weight', 'layers.4.pos_ffn.fc.2.weight','layers.4.pos_ffn.layernorm.weight', 'layers.4.pos_ffn.layernorm.bias','layers.5.enc_self_attn.linear.weight', 'layers.5.enc_self_attn.linear.bias','layers.5.enc_self_attn.layernorm.weight', 'layers.5.enc_self_attn.layernorm.bias','layers.5.enc_self_attn.W_Q.weight', 'layers.5.enc_self_attn.W_Q.bias','layers.5.enc_self_attn.W_K.weight', 'layers.5.enc_self_attn.W_K.bias','layers.5.enc_self_attn.W_V.weight', 'layers.5.enc_self_attn.W_V.bias','layers.5.pos_ffn.fc.0.weight', 'layers.5.pos_ffn.fc.2.weight','layers.5.pos_ffn.layernorm.weight', 'layers.5.pos_ffn.layernorm.bias','layers.6.enc_self_attn.linear.weight', 'layers.6.enc_self_attn.linear.bias','layers.6.enc_self_attn.layernorm.weight', 'layers.6.enc_self_attn.layernorm.bias','layers.6.enc_self_attn.W_Q.weight', 'layers.6.enc_self_attn.W_Q.bias','layers.6.enc_self_attn.W_K.weight', 'layers.6.enc_self_attn.W_K.bias','layers.6.enc_self_attn.W_V.weight', 'layers.6.enc_self_attn.W_V.bias','layers.6.pos_ffn.fc.0.weight', 'layers.6.pos_ffn.fc.2.weight','layers.6.pos_ffn.layernorm.weight', 'layers.6.pos_ffn.layernorm.bias','layers.7.enc_self_attn.linear.weight', 'layers.7.enc_self_attn.linear.bias','layers.7.enc_self_attn.layernorm.weight', 'layers.7.enc_self_attn.layernorm.bias','layers.7.enc_self_attn.W_Q.weight', 'layers.7.enc_self_attn.W_Q.bias','layers.7.enc_self_attn.W_K.weight', 'layers.7.enc_self_attn.W_K.bias','layers.7.enc_self_attn.W_V.weight', 'layers.7.enc_self_attn.W_V.bias','layers.7.pos_ffn.fc.0.weight', 'layers.7.pos_ffn.fc.2.weight','layers.7.pos_ffn.layernorm.weight', 'layers.7.pos_ffn.layernorm.bias','layers.8.enc_self_attn.linear.weight', 'layers.8.enc_self_attn.linear.bias','layers.8.enc_self_attn.layernorm.weight', 'layers.8.enc_self_attn.layernorm.bias','layers.8.enc_self_attn.W_Q.weight', 'layers.8.enc_self_attn.W_Q.bias','layers.8.enc_self_attn.W_K.weight', 'layers.8.enc_self_attn.W_K.bias','layers.8.enc_self_attn.W_V.weight', 'layers.8.enc_self_attn.W_V.bias','layers.8.pos_ffn.fc.0.weight', 'layers.8.pos_ffn.fc.2.weight','layers.8.pos_ffn.layernorm.weight', 'layers.8.pos_ffn.layernorm.bias','layers.9.enc_self_attn.linear.weight', 'layers.9.enc_self_attn.linear.bias','layers.9.enc_self_attn.layernorm.weight', 'layers.9.enc_self_attn.layernorm.bias','layers.9.enc_self_attn.W_Q.weight', 'layers.9.enc_self_attn.W_Q.bias','layers.9.enc_self_attn.W_K.weight', 'layers.9.enc_self_attn.W_K.bias','layers.9.enc_self_attn.W_V.weight', 'layers.9.enc_self_attn.W_V.bias','layers.9.pos_ffn.fc.0.weight', 'layers.9.pos_ffn.fc.2.weight','layers.9.pos_ffn.layernorm.weight', 'layers.9.pos_ffn.layernorm.bias','layers.10.enc_self_attn.linear.weight', 'layers.10.enc_self_attn.linear.bias','layers.10.enc_self_attn.layernorm.weight','layers.10.enc_self_attn.layernorm.bias', 'layers.10.enc_self_attn.W_Q.weight','layers.10.enc_self_attn.W_Q.bias', 'layers.10.enc_self_attn.W_K.weight','layers.10.enc_self_attn.W_K.bias', 'layers.10.enc_self_attn.W_V.weight','layers.10.enc_self_attn.W_V.bias', 'layers.10.pos_ffn.fc.0.weight','layers.10.pos_ffn.fc.2.weight', 'layers.10.pos_ffn.layernorm.weight','layers.10.pos_ffn.layernorm.bias''fc.1.weight', 'fc.1.bias', 'fc.3.weight', 'fc.3.bias', 'classifier_global.weight','classifier_global.bias', 'classifier_atom.weight', 'classifier_atom.bias']pretrained_model = torch.load(self.pretrained_model, map_location=torch.device('cpu'))# pretrained_model = torch.load(self.pretrained_model)model_dict = model.state_dict()pretrained_dict = {k: v for k, v in pretrained_model['model_state_dict'].items() if k in pretrained_parameters}model_dict.update(pretrained_dict)model.load_state_dict(pretrained_dict, strict=False)
for epoch in range(args['num_epochs']):train_score = run_a_train_global_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)# Validation and early stop_ = run_an_eval_global_epoch(args, model, train_loader)[0]val_score = run_an_eval_global_epoch(args, model, val_loader)[0]test_score = run_an_eval_global_epoch(args, model, test_loader)[0]if epoch < 5:early_stop = stopper.step(0, model)else:early_stop = stopper.step(val_score, model)print('epoch {:d}/{:d}, {}, lr: {:.6f}, train: {:.4f}, valid: {:.4f}, best valid {:.4f}, ''test: {:.4f}'.format(epoch + 1, args['num_epochs'], args['metric_name'], optimizer.param_groups[0]['lr'], train_score, val_score,stopper.best_score, test_score))if early_stop:break
stopper.load_checkpoint(model)
def run_an_eval_global_epoch(args, model, data_loader):model.eval()eval_meter = Meter()with torch.no_grad():for batch_id, batch_data in enumerate(data_loader):smiles, token_idx, global_labels, mask = batch_datatoken_idx = token_idx.long().to(args['device'])mask = mask.float().to(args['device'])global_labels = global_labels.float().to(args['device'])logits_global = model(token_idx)eval_meter.update(logits_global, global_labels, mask=mask)del token_idx, global_labelstorch.cuda.empty_cache()y_pred, y_true = eval_meter.compute_metric('return_pred_true')y_true_list = y_true.squeeze(dim=1).tolist()y_pred_list = torch.sigmoid(y_pred).squeeze(dim=1).tolist()# save predictiony_pred_label = [1 if x >= 0.5 else 0 for x in y_pred_list]auc = metrics.roc_auc_score(y_true_list, y_pred_list)accuracy = metrics.accuracy_score(y_true_list, y_pred_label)se, sp = sesp_score(y_true_list, y_pred_label)pre, rec, f1, sup = metrics.precision_recall_fscore_support(y_true_list, y_pred_label)mcc = metrics.matthews_corrcoef(y_true_list, y_pred_label)f1 = f1[1]rec = rec[1]pre = pre[1]err = 1 - accuracyresult = [auc, accuracy, se, sp, f1, pre, rec, err, mcc]return result
def step(self, score, model):if self.best_score is None:self.best_score = scoreself.save_checkpoint(model)elif self._check(score, self.best_score):self.best_score = scoreself.save_checkpoint(model)self.counter = 0else:self.counter += 1print('EarlyStopping counter: {} out of {}'.format(self.counter, self.patience))if self.counter >= self.patience:self.early_stop = Truereturn self.early_stop
上一篇:学习C++基本数值类型
下一篇:springboot 定时任务