1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
|
\begin{listing}[htbp]
\begin{minted}{Python}
class SentenceBERT(nn.Module):
"""
Siamese 结构,两个句子分别输入BERT, [CLS] sen_a [SEP], [CLS] sen_b [SEP]
"""
def __init__(self, config):
super(SentenceBERT, self).__init__()
self.config = config
self.bert = BertModel.from_pretrained(config.model_name)
self.dropout = nn.Dropout(config.dropout_rate)
self.bert_config = self.bert.config
self.fc = nn.Linear(self.bert_config.hidden_size * 3, config.num_classes)
def forward(self,
sen_a_input_ids,
sen_a_token_type_ids,
sen_a_attention_mask,
sen_b_input_ids,
sen_b_token_type_ids,
sen_b_attention_mask,
inference=False):
sen_a_bert_outputs = self.bert(
input_ids=sen_a_input_ids,
token_type_ids=sen_a_token_type_ids,
attention_mask=sen_a_attention_mask
)
sen_b_bert_outputs = self.bert(
input_ids=sen_b_input_ids,
token_type_ids=sen_b_token_type_ids,
attention_mask=sen_b_attention_mask
)
sen_a_bert_output, sen_b_bert_output = sen_a_bert_outputs[0], sen_b_bert_outputs[0]
# (batch_size, seq_len, hidden_size)
sen_a_len, sen_b_len = (sen_a_attention_mask != 0).sum(dim=1, keepdim=True), (sen_b_attention_mask != 0).sum(dim=1, keepdim=True)
sen_a_pooling, sen_b_pooling = sen_a_bert_output.sum(dim=1) / sen_a_len, sen_b_bert_output.sum(dim=1) / sen_b_len
# (batch_size, hidden_size)
if inference:
similarity = F.cosine_similarity(sen_a_pooling, sen_b_pooling, dim=1)
return similarity
hidden = torch.cat([sen_a_pooling, sen_b_pooling, torch.abs(sen_a_pooling - sen_b_pooling)], dim=1)
return self.fc(hidden)
\end{minted}
\caption{model.py}
\label{model}
\end{listing}
|