https://gitlab.com/mcoavoux/mtgpy-release-findings-2021.git
Tip revision: c9972219cd75049269d26632d2bb79619d661298 authored by mcoavoux on 20 May 2021, 13:04:44 UTC
up readme
up readme
Tip revision: c997221
bert.py
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizer, BertModel
from transformers import AutoTokenizer, AutoModel
class BertEncoder(nn.Module):
def __init__(self, bert_id, reinitialize):
super(BertEncoder, self).__init__()
self.tokenizer = AutoTokenizer.from_pretrained(bert_id, do_lower_case=("uncased" in bert_id))
self.bert = AutoModel.from_pretrained(bert_id)
if reinitialize:
config = self.bert.config
self.bert = BertModel(config)
print("Reinitializing Bert parameters")
# just to get the device easily
self.device = nn.Parameter(torch.zeros(1))
self.replace = {"-LRB-": "(", "-RRB-": ")", "#LRB#": "(", "#RRB#": ")"}
self.dim = self.bert.pooler.dense.in_features
def forward(self, sentence, batch):
if batch:
batch_tokens = []
batch_masks = []
batch_lengths = []
for token_list in sentence:
token_list = [self.replace[token] if token in self.replace else token for token in token_list]
for i in range(len(token_list)):
# if token_list[i] in self.replace:
# token_list[i] = self.replace[token_list[i]]
wptokens = []
for token in token_list:
wp = self.tokenizer.wordpiece_tokenizer.tokenize(token)
wptokens.extend(wp)
mask = [0 if tok[:2] == "##" else 1 for tok in wptokens]
mask = torch.tensor(mask, dtype=torch.bool, device=self.device.device)
indexed_tokens = self.tokenizer.convert_tokens_to_ids(wptokens)
tokens_tensor = torch.tensor([indexed_tokens], device = self.device.device)
batch_tokens.append(tokens_tensor.view(-1))
batch_masks.append(mask)
batch_lengths.append(len(wptokens))
padded = pad_sequence(batch_tokens, batch_first=True)
mask = padded != 0
encoded_layers = self.bert(input_ids=padded, attention_mask=mask)[0]
split_layers = encoded_layers.split([1 for _ in sentence])
assert(len(split_layers) == len(batch_masks))
filtered_layers = [layer.squeeze(0)[:l][m] for layer, l, m in zip(split_layers, batch_lengths, batch_masks)]
return filtered_layers
else:
for i in range(len(sentence)):
if sentence[i] in self.replace:
sentence[i] = self.replace[sentence[i]]
wptokens = []
for token in sentence:
wp = self.tokenizer.wordpiece_tokenizer.tokenize(token)
wptokens.extend(wp)
mask = [0 if tok[:2] == "##" else 1 for tok in wptokens]
mask = torch.tensor(mask, dtype=torch.bool, device=self.device.device)
indexed_tokens = self.tokenizer.convert_tokens_to_ids(wptokens)
tokens_tensor = torch.tensor([indexed_tokens], device = self.device.device)
encoded_layers = self.bert(input_ids=tokens_tensor)[0]
#encoded_layers,_ = self.bert(tokens_tensor, output_all_encoded_layers=False)
filtered_layers = encoded_layers.squeeze(0)[mask]
return filtered_layers
if __name__ == "__main__":
# for bert_id in ['bert-base-uncased',
# 'bert-large-uncased',
# 'bert-base-cased',
# 'bert-large-cased',
# 'bert-base-multilingual-uncased',
# 'bert-base-multilingual-cased',
# 'bert-base-chinese',
# 'bert-base-german-cased',
# 'bert-large-uncased-whole-word-masking',
# 'bert-large-cased-whole-word-masking',
# 'bert-large-uncased-whole-word-masking-finetuned-squad',
# 'bert-large-cased-whole-word-masking-finetuned-squad',
# 'bert-base-cased-finetuned-mrpc',
# 'bert-base-german-dbmdz-cased',
# 'bert-base-german-dbmdz-uncased']:
for bert_id in ['bert-base-cased']:
for reinitialize in [False, True]:
bert = BertEncoder(bert_id, reinitialize=reinitialize)
bert.eval()
#bert.cuda()
sentence = "The bill , whose backers include Chairman Dan Rostenkowski -LRB- D. , Ill. -RRB- , would prevent the Resolution Trust Corp. from raising temporary working capital by having an RTC-owned bank or thrift issue debt that would n't be counted on the federal budget .".split()
print(len(sentence))
output = bert(sentence, batch=False)
print(output[0][:10])
print(len(output))
# print(output)
# print([i.shape for i in output])
output = bert([sentence[:10], sentence], batch=True)
print(output[0][0][:10])
print(output[1][0][:10])