>综合>资讯>

albert和bert 瘦身成功,的ALBERT能取代BERT吗

时间:2024-03-22 11:56:36/人气:239 ℃

十三 发自 凹非寺

量子位 报道 | 公众号 QbitAI

参数比BERT少了80%,性能却提高了。

这就是谷歌去年提出的“瘦身成功版BERT”模型——ALBERT

这个模型一经发布,就受到了高度关注,二者的对比也成为了热门话题。

而最近,网友Naman Bansal就提出了一个疑问:

是否应该用ALBERT来代替BERT?

能否替代,比比便知。

BERT与ALBERT

BERT模型是大家比较所熟知的。

2018年由谷歌提出,训练的语料库规模非常庞大,包含33亿个词语。

模型的创新点集中在了预训练过程,采用Masked LM和Next Sentence Prediction两种方法,分别捕捉词语和句子级别的表示。

BERT的出现,彻底改变了预训练产生词向量和下游具体NLP任务的关系。

时隔1年后,谷歌又提出ALBERT,也被称作“lite-BERT”,骨干网络和BERT相似,采用的依旧是 Transformer 编码器,激活函数也是GELU。

其最大的成功,就在于参数量比BERT少了80%,同时还取得了更好的结果。

与BERT相比的改进,主要包括嵌入向量参数化的因式分解、跨层参数共享、句间连贯性损失采用SOP,以及移除了dropout。

下图便是BERT和ALBERT,在SQuAD和RACE数据集上的性能测试比较结果。

可以看出,ALBERT性能取得了较好的结果。

如何实现自定义语料库(预训练)ALBERT?

为了进一步了解ALBERT,接下来,将在自定义语料库中实现ALBERT。

所采用的数据集是“用餐点评数据集”,目标就是通过ALBERT模型来识别菜肴的名称

第一步:下载数据集并准备文件

1#Downladingallfilesanddata 2 3!wgethttps://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/dish_name_train.csv 4!wgethttps://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/dish_name_val.csv 5!wgethttps://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/restaurant_review.txt 6!wgethttps://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/restaurant_review_nopunct.txt 7!wgethttps://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/models_toy/albert_config.json 8!wgethttps://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/model_checkpoint/finetune_checkpoint 9!wgethttps://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/model_checkpoint/pretrain_checkpoint1011#CreatingfilesandsettingupALBERT1213!pipinstallsentencepiece14!gitclonehttps://github.com/google-research/ALBERT15!python./ALBERT/create_pretraining_data.py--input_file"restaurant_review.txt"--output_file"restaurant_review_train"--vocab_file"vocab.txt"--max_seq_length=6416!pipinstalltransformers17!pipinstalltfrecord

第二步:使用transformer并定义层

1#DefiningLayersforALBERT 2 3fromtransformers.modeling_albertimportAlbertModel,AlbertPreTrainedModel 4fromtransformers.configuration_albertimportAlbertConfig 5importtorch.nnasnn 6classAlbertSequenceOrderHead(nn.Module): 7def__init__(self,config): 8super().__init__() 9self.dense=nn.Linear(config.hidden_size,2)10self.bias=nn.Parameter(torch.zeros(2))1112defforward(self,hidden_states):13hidden_states=self.dense(hidden_states)14prediction_scores=hidden_states self.bias1516returnprediction_scores1718fromtorch.nnimportCrossEntropyLoss19fromtransformers.modeling_bertimportACT2FN20classAlbertForPretrain(AlbertPreTrainedModel):2122def__init__(self,config):23super().__init__(config)2425self.albert=AlbertModel(config)2627#ForMaskedLM28#Theoriginalhuggingfaceimplementation,creatednewoutputweightsviadenselayer29#HowevertheoriginalAlbert30self.predictions_dense=nn.Linear(config.hidden_size,config.embedding_size)31self.predictions_activation=ACT2FN[config.hidden_act]32self.predictions_LayerNorm=nn.LayerNorm(config.embedding_size)33self.predictions_bias=nn.Parameter(torch.zeros(config.vocab_size))34self.predictions_decoder=nn.Linear(config.embedding_size,config.vocab_size)3536self.predictions_decoder.weight=self.albert.embeddings.word_embeddings.weight3738#Forsequenceorderprediction39self.seq_relationship=AlbertSequenceOrderHead(config)404142defforward(43self,44input_ids=None,45attention_mask=None,46token_type_ids=None,47position_ids=None,48head_mask=None,49inputs_embeds=None,50masked_lm_labels=None,51seq_relationship_labels=None,52):5354outputs=self.albert(55input_ids,56attention_mask=attention_mask,57token_type_ids=token_type_ids,58position_ids=position_ids,59head_mask=head_mask,60inputs_embeds=inputs_embeds,61)6263loss_fct=CrossEntropyLoss()6465sequence_output=outputs[0]6667sequence_output=self.predictions_dense(sequence_output)68sequence_output=self.predictions_activation(sequence_output)69sequence_output=self.predictions_LayerNorm(sequence_output)70prediction_scores=self.predictions_decoder(sequence_output)717273ifmasked_lm_labelsisnotNone:74masked_lm_loss=loss_fct(prediction_scores.view(-1,self.config.vocab_size)75,masked_lm_labels.view(-1))7677pooled_output=outputs[1]78seq_relationship_scores=self.seq_relationship(pooled_output)79ifseq_relationship_labelsisnotNone:80seq_relationship_loss=loss_fct(seq_relationship_scores.view(-1,2),seq_relationship_labels.view(-1))8182loss=masked_lm_loss seq_relationship_loss8384returnloss

第三步:使用LAMB优化器并微调ALBERT

1#UsingLAMBoptimizer 2#LAMB-"https://github.com/cybertronai/pytorch-lamb" 3 4importtorch 5fromtorch.optimimportOptimizer 6classLamb(Optimizer): 7r"""ImplementsLambalgorithm. 8Ithasbeenproposedin`LargeBatchOptimizationforDeepLearning:TrainingBERTin76minutes`_. 9Arguments: 10params(iterable):iterableofparameterstooptimizeordictsdefining 11parametergroups 12lr(float,optional):learningrate(default:1e-3) 13betas(Tuple[float,float],optional):coefficientsusedforcomputing 14runningaveragesofgradientanditssquare(default:(0.9,0.999)) 15eps(float,optional):termaddedtothedenominatortoimprove 16numericalstability(default:1e-8) 17weight_decay(float,optional):weightdecay(L2penalty)(default:0) 18adam(bool,optional):alwaysusetrustratio=1,whichturnsthisinto 19Adam.Usefulforcomparisonpurposes. 20.._LargeBatchOptimizationforDeepLearning:TrainingBERTin76minutes: 21https://arxiv.org/abs/1904.00962 22""" 23 24def__init__(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-6, 25weight_decay=0,adam=False): 26ifnot0.0<=lr: 27raiseValueError("Invalidlearningrate:{}".format(lr)) 28ifnot0.0<=eps: 29raiseValueError("Invalidepsilonvalue:{}".format(eps)) 30ifnot0.0<=betas[0]<1.0: 31raiseValueError("Invalidbetaparameteratindex0:{}".format(betas[0])) 32ifnot0.0<=betas[1]<1.0: 33raiseValueError("Invalidbetaparameteratindex1:{}".format(betas[1])) 34defaults=dict(lr=lr,betas=betas,eps=eps, 35weight_decay=weight_decay) 36self.adam=adam 37super(Lamb,self).__init__(params,defaults) 38 39defstep(self,closure=None): 40"""Performsasingleoptimizationstep. 41Arguments: 42closure(callable,optional):Aclosurethatreevaluatesthemodel 43andreturnstheloss. 44""" 45loss=None 46ifclosureisnotNone: 47loss=closure() 48 49forgroupinself.param_groups: 50forpingroup['params']: 51ifp.gradisNone: 52continue 53grad=p.grad.data 54ifgrad.is_sparse: 55raiseRuntimeError('Lambdoesnotsupportsparsegradients,considerSparseAdaminstad.') 56 57state=self.state[p] 58 59#Stateinitialization 60iflen(state)==0: 61state['step']=0 62#Exponentialmovingaverageofgradientvalues 63state['exp_avg']=torch.zeros_like(p.data) 64#Exponentialmovingaverageofsquaredgradientvalues 65state['exp_avg_sq']=torch.zeros_like(p.data) 66 67exp_avg,exp_avg_sq=state['exp_avg'],state['exp_avg_sq'] 68beta1,beta2=group['betas'] 69 70state['step'] =1 71 72#Decaythefirstandsecondmomentrunningaveragecoefficient 73#m_t 74exp_avg.mul_(beta1).add_(1-beta1,grad) 75#v_t 76exp_avg_sq.mul_(beta2).addcmul_(1-beta2,grad,grad) 77 78#Paperv3doesnotusedebiasing. 79#bias_correction1=1-beta1**state['step'] 80#bias_correction2=1-beta2**state['step'] 81#Applybiastolrtoavoidbroadcast. 82step_size=group['lr']#*math.sqrt(bias_correction2)/bias_correction1 83 84weight_norm=p.data.pow(2).sum().sqrt().clamp(0,10) 85 86adam_step=exp_avg/exp_avg_sq.sqrt().add(group['eps']) 87ifgroup['weight_decay']!=0: 88adam_step.add_(group['weight_decay'],p.data) 89 90adam_norm=adam_step.pow(2).sum().sqrt() 91ifweight_norm==0oradam_norm==0: 92trust_ratio=1 93else: 94trust_ratio=weight_norm/adam_norm 95state['weight_norm']=weight_norm 96state['adam_norm']=adam_norm 97state['trust_ratio']=trust_ratio 98ifself.adam: 99trust_ratio=1100101p.data.add_(-step_size*trust_ratio,adam_step)102103returnloss104105importtime106importtorch.nnasnn107importtorch108fromtfrecord.torch.datasetimportTFRecordDataset109importnumpyasnp110importos111112LEARNING_RATE=0.001113EPOCH=40114BATCH_SIZE=2115MAX_GRAD_NORM=1.0116117print(f"---Resume/Starttraining---")118feat_map={"input_ids":"int",119"input_mask":"int",120"segment_ids":"int",121"next_sentence_labels":"int",122"masked_lm_positions":"int",123"masked_lm_ids":"int"}124pretrain_file='restaurant_review_train'125126#Createalbertpretrainmodel127config=AlbertConfig.from_json_file("albert_config.json")128albert_pretrain=AlbertForPretrain(config)129#Createoptimizer130optimizer=Lamb([{"params":[pforn,pinlist(albert_pretrain.named_parameters())]}],lr=LEARNING_RATE)131albert_pretrain.train()132dataset=TFRecordDataset(pretrain_file,index_path=None,description=feat_map)133loader=torch.utils.data.DataLoader(dataset,batch_size=BATCH_SIZE)134135tmp_loss=0136start_time=time.time()137138ifos.path.isfile('pretrain_checkpoint'):139print(f"---Loadfromcheckpoint---")140checkpoint=torch.load("pretrain_checkpoint")141albert_pretrain.load_state_dict(checkpoint['model_state_dict'])142optimizer.load_state_dict(checkpoint['optimizer_state_dict'])143epoch=checkpoint['epoch']144loss=checkpoint['loss']145losses=checkpoint['losses']146147else:148epoch=-1149losses=[]150foreinrange(epoch 1,EPOCH):151forbatchinloader:152b_input_ids=batch['input_ids'].long()153b_token_type_ids=batch['segment_ids'].long()154b_seq_relationship_labels=batch['next_sentence_labels'].long()155156#Convertthedataformatfromloadeddecodedformatintoformat157#loadedformatiscreatedbygoogle'sAlbertcreate_pretrain.pyscript158#requiredbyhuggingfacespytorchimplementationofalbert159mask_rows=np.nonzero(batch['masked_lm_positions'].numpy())[0]160mask_cols=batch['masked_lm_positions'].numpy()[batch['masked_lm_positions'].numpy()!=0]161b_attention_mask=np.zeros((BATCH_SIZE,64),dtype=np.int64)162b_attention_mask[mask_rows,mask_cols]=1163b_masked_lm_labels=np.zeros((BATCH_SIZE,64),dtype=np.int64)-100164b_masked_lm_labels[mask_rows,mask_cols]=batch['masked_lm_ids'].numpy()[batch['masked_lm_positions'].numpy()!=0]165b_attention_mask=torch.tensor(b_attention_mask).long()166b_masked_lm_labels=torch.tensor(b_masked_lm_labels).long()167168169loss=albert_pretrain(input_ids=b_input_ids170,attention_mask=b_attention_mask171,token_type_ids=b_token_type_ids172,masked_lm_labels=b_masked_lm_labels173,seq_relationship_labels=b_seq_relationship_labels)174175#clearsoldgradients176optimizer.zero_grad()177#backwardpass178loss.backward()179#gradientclipping180torch.nn.utils.clip_grad_norm_(parameters=albert_pretrain.parameters(),max_norm=MAX_GRAD_NORM)181#updateparameters182optimizer.step()183184tmp_loss =loss.detach().item()185186#printmetricsandsavetocheckpointeveryepoch187print(f"Epoch:{e}")188print(f"Trainloss:{(tmp_loss/20)}")189print(f"TrainTime:{(time.time()-start_time)/60}mins")190losses.append(tmp_loss/20)191192tmp_loss=0193start_time=time.time()194195torch.save({'model_state_dict':albert_pretrain.state_dict(),'optimizer_state_dict':optimizer.state_dict(),196'epoch':e,'loss':loss,'losses':losses}197,'pretrain_checkpoint')198frommatplotlibimportpyplotasplot199plot.plot(losses)200201#FinetuningALBERT202203#Atthetimeofwriting,Huggingfacedidntprovidetheclassobjectfor204#AlbertForTokenClassification,hencewriteyourowndefinationbelow205fromtransformers.modeling_albertimportAlbertModel,AlbertPreTrainedModel206fromtransformers.configuration_albertimportAlbertConfig207fromtransformers.tokenization_bertimportBertTokenizer208importtorch.nnasnn209fromtorch.nnimportCrossEntropyLoss210classAlbertForTokenClassification(AlbertPreTrainedModel):211212def__init__(self,albert,config):213super().__init__(config)214self.num_labels=config.num_labels215216self.albert=albert217self.dropout=nn.Dropout(config.hidden_dropout_prob)218self.classifier=nn.Linear(config.hidden_size,config.num_labels)219220defforward(221self,222input_ids=None,223attention_mask=None,224token_type_ids=None,225position_ids=None,226head_mask=None,227inputs_embeds=None,228labels=None,229):230231outputs=self.albert(232input_ids,233attention_mask=attention_mask,234token_type_ids=token_type_ids,235position_ids=position_ids,236head_mask=head_mask,237inputs_embeds=inputs_embeds,238)239240sequence_output=outputs[0]241242sequence_output=self.dropout(sequence_output)243logits=self.classifier(sequence_output)244245returnlogits246247importnumpyasnp248deflabel_sent(name_tokens,sent_tokens):249label=[]250i=0251iflen(name_tokens)>len(sent_tokens):252label=np.zeros(len(sent_tokens))253else:254whilei<len(sent_tokens):255found_match=False256ifname_tokens[0]==sent_tokens[i]:257found_match=True258forjinrange(len(name_tokens)-1):259if((i j 1)>=len(sent_tokens)):260returnlabel261ifname_tokens[j 1]!=sent_tokens[i j 1]:262found_match=False263iffound_match:264label.extend(list(np.ones(len(name_tokens)).astype(int)))265i=i len(name_tokens)266else:267label.extend([0])268i=i 1269else:270label.extend([0])271i=i 1272returnlabel273274importpandasaspd275importglob276importos277278tokenizer=BertTokenizer(vocab_file="vocab.txt")279280df_data_train=pd.read_csv("dish_name_train.csv")281df_data_train['name_tokens']=df_data_train['dish_name'].apply(tokenizer.tokenize)282df_data_train['review_tokens']=df_data_train.review.apply(tokenizer.tokenize)283df_data_train['review_label']=df_data_train.apply(lambdarow:label_sent(row['name_tokens'],row['review_tokens']),axis=1)284285df_data_val=pd.read_csv("dish_name_val.csv")286df_data_val=df_data_val.dropna().reset_index()287df_data_val['name_tokens']=df_data_val['dish_name'].apply(tokenizer.tokenize)288df_data_val['review_tokens']=df_data_val.review.apply(tokenizer.tokenize)289df_data_val['review_label']=df_data_val.apply(lambdarow:label_sent(row['name_tokens'],row['review_tokens']),axis=1)290291MAX_LEN=64292BATCH_SIZE=1293fromkeras.preprocessing.sequenceimportpad_sequences294importtorch295fromtorch.utils.dataimportTensorDataset,DataLoader,RandomSampler,SequentialSampler296297tr_inputs=pad_sequences([tokenizer.convert_tokens_to_ids(txt)fortxtindf_data_train['review_tokens']],maxlen=MAX_LEN,dtype="long",truncating="post",padding="post")298tr_tags=pad_sequences(df_data_train['review_label'],maxlen=MAX_LEN,padding="post",dtype="long",truncating="post")299#createthemasktoignorethepaddedelementsinthesequences.300tr_masks=[[float(i>0)foriinii]foriiintr_inputs]301tr_inputs=torch.tensor(tr_inputs)302tr_tags=torch.tensor(tr_tags)303tr_masks=torch.tensor(tr_masks)304train_data=TensorDataset(tr_inputs,tr_masks,tr_tags)305train_sampler=RandomSampler(train_data)306train_dataloader=DataLoader(train_data,sampler=train_sampler,batch_size=BATCH_SIZE)307308309val_inputs=pad_sequences([tokenizer.convert_tokens_to_ids(txt)fortxtindf_data_val['review_tokens']],maxlen=MAX_LEN,dtype="long",truncating="post",padding="post")310val_tags=pad_sequences(df_data_val['review_label'],maxlen=MAX_LEN,padding="post",dtype="long",truncating="post")311#createthemasktoignorethepaddedelementsinthesequences.312val_masks=[[float(i>0)foriinii]foriiinval_inputs]313val_inputs=torch.tensor(val_inputs)314val_tags=torch.tensor(val_tags)315val_masks=torch.tensor(val_masks)316val_data=TensorDataset(val_inputs,val_masks,val_tags)317val_sampler=RandomSampler(val_data)318val_dataloader=DataLoader(val_data,sampler=val_sampler,batch_size=BATCH_SIZE)319320model_tokenclassification=AlbertForTokenClassification(albert_pretrain.albert,config)321fromtorch.optimimportAdam322LEARNING_RATE=0.0000003323FULL_FINETUNING=True324ifFULL_FINETUNING:325param_optimizer=list(model_tokenclassification.named_parameters())326no_decay=['bias','gamma','beta']327optimizer_grouped_parameters=[328{'params':[pforn,pinparam_optimizerifnotany(ndinnforndinno_decay)],329'weight_decay_rate':0.01},330{'params':[pforn,pinparam_optimizerifany(ndinnforndinno_decay)],331'weight_decay_rate':0.0}332]333else:334param_optimizer=list(model_tokenclassification.classifier.named_parameters())335optimizer_grouped_parameters=[{"params":[pforn,pinparam_optimizer]}]336optimizer=Adam(optimizer_grouped_parameters,lr=LEARNING_RATE)

第四步:为自定义语料库训练模型

1#Trainingthemodel 2 3#fromtorch.utils.tensorboardimportSummaryWriter 4importtime 5importos.path 6importtorch.nnasnn 7importtorch 8EPOCH=800 9MAX_GRAD_NORM=1.0 10 11start_time=time.time() 12tr_loss,tr_acc,nb_tr_steps=0,0,0 13eval_loss,eval_acc,nb_eval_steps=0,0,0 14 15ifos.path.isfile('finetune_checkpoint'): 16print(f"---Loadfromcheckpoint---") 17checkpoint=torch.load("finetune_checkpoint") 18model_tokenclassification.load_state_dict(checkpoint['model_state_dict']) 19optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 20epoch=checkpoint['epoch'] 21train_losses=checkpoint['train_losses'] 22train_accs=checkpoint['train_accs'] 23eval_losses=checkpoint['eval_losses'] 24eval_accs=checkpoint['eval_accs'] 25 26else: 27epoch=-1 28train_losses,train_accs,eval_losses,eval_accs=[],[],[],[] 29 30print(f"---Resume/Starttraining---") 31foreinrange(epoch 1,EPOCH): 32 33#TRAINloop 34model_tokenclassification.train() 35 36forbatchintrain_dataloader: 37#addbatchtogpu 38batch=tuple(tfortinbatch) 39b_input_ids,b_input_mask,b_labels=batch 40#forwardpass 41b_outputs=model_tokenclassification(b_input_ids,token_type_ids=None,attention_mask=b_input_mask,labels=b_labels) 42 43ce_loss_fct=CrossEntropyLoss() 44#Onlykeepactivepartsoftheloss 45b_active_loss=b_input_mask.view(-1)==1 46b_active_logits=b_outputs.view(-1,config.num_labels)[b_active_loss] 47b_active_labels=b_labels.view(-1)[b_active_loss] 48 49loss=ce_loss_fct(b_active_logits,b_active_labels) 50acc=torch.mean((torch.max(b_active_logits.detach(),1)[1]==b_active_labels.detach()).float()) 51 52model_tokenclassification.zero_grad() 53#backwardpass 54loss.backward() 55#tracktrainloss 56tr_loss =loss.item() 57tr_acc =acc 58nb_tr_steps =1 59#gradientclipping 60torch.nn.utils.clip_grad_norm_(parameters=model_tokenclassification.parameters(),max_norm=MAX_GRAD_NORM) 61#updateparameters 62optimizer.step() 63 64 65#VALIDATIONonvalidationset 66model_tokenclassification.eval() 67forbatchinval_dataloader: 68batch=tuple(tfortinbatch) 69b_input_ids,b_input_mask,b_labels=batch 70 71withtorch.no_grad(): 72 73b_outputs=model_tokenclassification(b_input_ids,token_type_ids=None, 74attention_mask=b_input_mask,labels=b_labels) 75 76loss_fct=CrossEntropyLoss() 77#Onlykeepactivepartsoftheloss 78b_active_loss=b_input_mask.view(-1)==1 79b_active_logits=b_outputs.view(-1,config.num_labels)[b_active_loss] 80b_active_labels=b_labels.view(-1)[b_active_loss] 81loss=loss_fct(b_active_logits,b_active_labels) 82acc=np.mean(np.argmax(b_active_logits.detach().cpu().numpy(),axis=1).flatten()==b_active_labels.detach().cpu().numpy().flatten()) 83 84eval_loss =loss.mean().item() 85eval_acc =acc 86nb_eval_steps =1 87 88ife%10==0: 89 90print(f"Epoch:{e}") 91print(f"Trainloss:{(tr_loss/nb_tr_steps)}") 92print(f"Trainacc:{(tr_acc/nb_tr_steps)}") 93print(f"TrainTime:{(time.time()-start_time)/60}mins") 94 95print(f"Validationloss:{eval_loss/nb_eval_steps}") 96print(f"ValidationAccuracy:{(eval_acc/nb_eval_steps)}") 97 98train_losses.append(tr_loss/nb_tr_steps) 99train_accs.append(tr_acc/nb_tr_steps)100eval_losses.append(eval_loss/nb_eval_steps)101eval_accs.append(eval_acc/nb_eval_steps)102103104tr_loss,tr_acc,nb_tr_steps=0,0,0105eval_loss,eval_acc,nb_eval_steps=0,0,0106start_time=time.time()107108torch.save({'model_state_dict':model_tokenclassification.state_dict(),'optimizer_state_dict':optimizer.state_dict(),109'epoch':e,'train_losses':train_losses,'train_accs':train_accs,'eval_losses':eval_losses,'eval_accs':eval_accs}110,'finetune_checkpoint')111112plot.plot(train_losses)113plot.plot(train_accs)114plot.plot(eval_losses)115plot.plot(eval_accs)116plot.legend(labels=['train_loss','train_accuracy','validation_loss','validation_accuracy'])

第五步:预测

1#Prediction 2 3defpredict(texts): 4tokenized_texts=[tokenizer.tokenize(txt)fortxtintexts] 5input_ids=pad_sequences([tokenizer.convert_tokens_to_ids(txt)fortxtintokenized_texts], 6maxlen=MAX_LEN,dtype="long",truncating="post",padding="post") 7attention_mask=[[float(i>0)foriinii]foriiininput_ids] 8 9input_ids=torch.tensor(input_ids)10attention_mask=torch.tensor(attention_mask)1112dataset=TensorDataset(input_ids,attention_mask)13datasampler=SequentialSampler(dataset)14dataloader=DataLoader(dataset,sampler=datasampler,batch_size=BATCH_SIZE)1516predicted_labels=[]1718forbatchindataloader:19batch=tuple(tfortinbatch)20b_input_ids,b_input_mask=batch2122withtorch.no_grad():23logits=model_tokenclassification(b_input_ids,token_type_ids=None,24attention_mask=b_input_mask)2526predicted_labels.append(np.multiply(np.argmax(logits.detach().cpu().numpy(),axis=2),b_input_mask.detach().cpu().numpy()))27#np.concatenate(predicted_labels),toflattenlistofarraysofbatch_size*max_lenintolistofarraysofmax_len28returnnp.concatenate(predicted_labels).astype(int),tokenized_texts2930defget_dish_candidate_names(predicted_label,tokenized_text):31name_lists=[]32iflen(np.where(predicted_label>0)[0])>0:33name_idx_combined=np.where(predicted_label>0)[0]34name_idxs=np.split(name_idx_combined,np.where(np.diff(name_idx_combined)!=1)[0] 1)35name_lists.append(["".join(np.take(tokenized_text,name_idx))forname_idxinname_idxs])36#Ifthereduplicatenamesinthename_lists37name_lists=np.unique(name_lists)38returnname_lists39else:40returnNone4142texts=df_data_val.review.values43predicted_labels,_=predict(texts)44df_data_val['predicted_review_label']=list(predicted_labels)45df_data_val['predicted_name']=df_data_val.apply(lambdarow:get_dish_candidate_names(row.predicted_review_label,row.review_tokens)46,axis=1)4748texts=df_data_train.review.values49predicted_labels,_=predict(texts)50df_data_train['predicted_review_label']=list(predicted_labels)51df_data_train['predicted_name']=df_data_train.apply(lambdarow:get_dish_candidate_names(row.predicted_review_label,row.review_tokens)52,axis=1)5354(df_data_val)

实验结果

可以看到,模型成功地从用餐评论中,提取出了菜名。

模型比拼

从上面的实战应用中可以看到,ALBERT虽然很lite,结果也可以说相当不错。

那么,参数少、结果好,是否就可以替代BERT呢?

我们可以仔细看下二者实验性能的比较,这里的Speedup是指训练时间。

因为数据数据少了,分布式训练时吞吐上去了,所以ALBERT训练更快。但推理时间还是需要和BERT一样的transformer计算。

所以可以总结为:

此外,Naman Bansal认为,由于ALBERT的结构,实现ALBERT的计算代价比BERT要高一些。

所以,还是“鱼和熊掌不可兼得”的关系,要想让ALBERT完全超越、替代BERT,还需要做更进一步的研究和改良。

传送门

博客地址:

https://medium.com/@namanbansal9909/should-we-shift-from-bert-to-albert-e6fbb7779d3e

— 完 —

量子位 QbitAI · 头条号签约

关注我们,第一时间获知前沿科技动态

首页/电脑版/地图
© 2024 CwBaiKe.Com All Rights Reserved.