1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
| \begin{listing}[htbp] \begin{minted}{Python} class SentenceBERT(nn.Module): """ Siamese 结构,两个句子分别输入BERT, [CLS] sen_a [SEP], [CLS] sen_b [SEP] """
def __init__(self, config): super(SentenceBERT, self).__init__() self.config = config self.bert = BertModel.from_pretrained(config.model_name) self.dropout = nn.Dropout(config.dropout_rate) self.bert_config = self.bert.config self.fc = nn.Linear(self.bert_config.hidden_size * 3, config.num_classes)
def forward(self, sen_a_input_ids, sen_a_token_type_ids, sen_a_attention_mask, sen_b_input_ids, sen_b_token_type_ids, sen_b_attention_mask, inference=False): sen_a_bert_outputs = self.bert( input_ids=sen_a_input_ids, token_type_ids=sen_a_token_type_ids, attention_mask=sen_a_attention_mask ) sen_b_bert_outputs = self.bert( input_ids=sen_b_input_ids, token_type_ids=sen_b_token_type_ids, attention_mask=sen_b_attention_mask ) sen_a_bert_output, sen_b_bert_output = sen_a_bert_outputs[0], sen_b_bert_outputs[0] # (batch_size, seq_len, hidden_size)
sen_a_len, sen_b_len = (sen_a_attention_mask != 0).sum(dim=1, keepdim=True), (sen_b_attention_mask != 0).sum(dim=1, keepdim=True)
sen_a_pooling, sen_b_pooling = sen_a_bert_output.sum(dim=1) / sen_a_len, sen_b_bert_output.sum(dim=1) / sen_b_len # (batch_size, hidden_size)
if inference: similarity = F.cosine_similarity(sen_a_pooling, sen_b_pooling, dim=1) return similarity
hidden = torch.cat([sen_a_pooling, sen_b_pooling, torch.abs(sen_a_pooling - sen_b_pooling)], dim=1)
return self.fc(hidden) \end{minted} \caption{model.py} \label{model} \end{listing}
|