https://gricad-gitlab.univ-grenoble-alpes.fr/coavouxm/flaubertagger.git
Tip revision: c939ca9fac094ac3c379256ef3d3d4d14a5a4bf1 authored by m on 23 February 2024, 16:44:50 UTC
up
up
Tip revision: c939ca9
bert.py
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizer, BertModel
from transformers import AutoTokenizer, AutoModel
class BertEncoder(nn.Module):
"""Just a wrapper for pretrained language models"""
def __init__(self, bert_id):
super(BertEncoder, self).__init__()
#print("load tokenizer")
self.tokenizer = AutoTokenizer.from_pretrained(bert_id, do_lower_case=("uncased" in bert_id))
#print("load model")
self.bert = AutoModel.from_pretrained(bert_id)
#print("done")
self.dim = self.bert.dim
def get_wp_tokens(self, sentence):
wptokens = [self.tokenizer.bos_token]
mask = [0]
for token in sentence:
wp = self.tokenizer.tokenize(token)
wptokens.extend(wp)
for i in range(len(wp)-1):
mask.append(0)
mask.append(1)
wptokens.append(self.tokenizer.sep_token)
mask.append(0)
return wptokens, torch.tensor(mask, dtype=torch.bool, device=self.bert.device)
"""# now done by get_wp_tokens
def get_mask(self, wptokens):
mask = [1 if tok[-4:] == "</w>" else 0 for tok in wptokens]
mask = torch.tensor(mask, dtype=torch.bool, device=self.bert.device)
return mask
"""
def forward(self, sentence, batch):
if batch:
original_lengths = [len(sent) for sent in sentence]
batch_tokens = []
batch_masks = []
batch_lengths = []
for token_list in sentence:
wptokens, mask = self.get_wp_tokens(token_list)
#print(wptokens)
#mask = self.get_mask(wptokens)
indexed_tokens = self.tokenizer.convert_tokens_to_ids(wptokens)
tokens_tensor = torch.tensor([indexed_tokens], device=self.bert.device)
batch_tokens.append(tokens_tensor.view(-1))
batch_masks.append(mask)
batch_lengths.append(len(wptokens))
padded = pad_sequence(batch_tokens, batch_first=True)
pad_mask = padded != 0
encoded_layers = self.bert(input_ids=padded, attention_mask=pad_mask)[0]
split_layers = encoded_layers.split([1 for _ in sentence])
assert(len(split_layers) == len(batch_masks))
errors = [(l, m.shape, sent) for l, m, sent in zip(batch_lengths, batch_masks, sentence) if m.shape[0] != l]
if len(errors) > 0:
print(errors)
filtered_layers = [layer.squeeze(0)[:l][m] for layer, l, m in zip(split_layers, batch_lengths, batch_masks)]
"""
if original_lengths != [len(l) for l in filtered_layers]:
print(original_lengths)
print([len(l) for l in filtered_layers])
print(sentence[0])
print()
exit()
"""
return filtered_layers
else:
wptokens, mask = self.get_wp_tokens(sentence)
#mask = self.get_mask(wptokens)
indexed_tokens = self.tokenizer.convert_tokens_to_ids(wptokens)
tokens_tensor = torch.tensor([indexed_tokens], device = self.bert.device)
encoded_layers = self.bert(input_ids=tokens_tensor)[0]
filtered_layers = encoded_layers.squeeze(0)[mask]
return filtered_layers
if __name__ == "__main__":
# for bert_id in ['bert-base-uncased',
# 'bert-base-cased',
# 'bert-base-multilingual-uncased',
# 'bert-base-multilingual-cased',
# 'bert-base-german-cased',
# 'bert-base-german-dbmdz-cased',
# 'bert-base-german-dbmdz-uncased']:
bert_ids = ['camembert/camembert-base-wikipedia-4gb',
'camembert/camembert-base-wikipedia-4gb',
'camembert/camembert-base',
'flaubert/flaubert_small_cased']
for bert_id in bert_ids:
print("loading")
bert = BertEncoder(bert_id)
print("eval")
bert.eval()
#bert.cuda()
sentence = "Le chat mange une pomme de pin sur l' anti-brouillard .".split()
print("length sentence", len(sentence))
print("computing")
output = bert(sentence, batch=False)
print(output[0][:10])
print("length output", len(output))
print("Batch")
print("lengths", 10, len(sentence), 4)
output = bert([sentence[:10], sentence, sentence[:4]], batch=True)
print(output[0].shape)
print(output[1].shape)
print(output[2].shape)