https://gricad-gitlab.univ-grenoble-alpes.fr/coavouxm/flaubertagger.git
Tip revision: c939ca9fac094ac3c379256ef3d3d4d14a5a4bf1 authored by m on 23 February 2024, 16:44:50 UTC
up
up
Tip revision: c939ca9
tagger.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import logging
import pickle
import sys
import random
import os
from collections import defaultdict
from sentence_encoders import HierarchicalLSTM
from bert import BertEncoder
from corpus_reader import load_corpus, DepTree
def accuracy(gold, pred, header):
correct = [0 for _ in header]
exact_match = 0
tokens = 0
sentences = 0
assert(len(gold) == len(pred))
for g_sent, p_sent in zip(gold, pred):
assert(len(g_sent) == len(p_sent))
sentences += 1
tokens += len(g_sent)
for i in range(len(g_sent)):
assert(len(g_sent[i]) == len(p_sent[i]))
assert(len(g_sent[i]) == len(header))
if g_sent[i] == p_sent[i]:
exact_match += 1
for j in range(len(correct)):
if g_sent[i][j] == p_sent[i][j]:
correct[j] += 1
output_dict = {f"{h}": f"{c / tokens * 100:.1f}" for h, c in zip(header, correct)}
output_dict["exact_match"] = round(exact_match / tokens * 100, 1)
return output_dict
class FlaubertTagger(nn.Module):
def __init__(self, output_tags, char2i, word2i, args):
super(FlaubertTagger, self).__init__()
self.args = args
if args.lstm:
self.encoder = HierarchicalLSTM(args, char2i, word2i)
else:
self.encoder = BertEncoder(args.bert_id)
self.cpos = args.cpos
self.output_tags = output_tags
self.classifiers = nn.ModuleList([
nn.Sequential(nn.Dropout(args.D),
nn.Linear(self.encoder.dim, len(i2tag)))
for i2tag in self.output_tags])
self.loss = nn.CrossEntropyLoss()
if False : #freeze_encoder:
for param in self.encoder.parameters():
param.requires_grad = False
def forward(self, sentences, labels_all=None):
contextualized_embeddings = self.encoder(sentences, batch=True)
linearized = torch.cat(contextualized_embeddings)
logits_all = [classifier(linearized) for classifier in self.classifiers]
predictions_all = [torch.argmax(logits, dim=1) for logits in logits_all]
# [tag_type][sentence_id]
tags_idx = [[tens.cpu().numpy() for tens in predictions.split([len(sent) for sent in sentences])]
for predictions in predictions_all]
# [sentence_id][tag_type]
tags_idx = list(zip(*tags_idx))
output_dict = {"predictions": tags_idx}
if labels_all is not None:
tokens = 0
correct = np.array([0 for _ in labels_all[0]])
for gold_tags, pred_tags in zip(labels_all, tags_idx):
for i in range(len(correct)):
correct[i] += np.sum(gold_tags[i].cpu().numpy() == pred_tags[i])
tokens += len(gold_tags[i])
labels_all = list(zip(*labels_all))
losses = [self.loss(logits, torch.cat(labels)) for logits, labels in zip(logits_all, labels_all)]
loss = sum(losses)
output_dict["all_losses"] = [l.item() for l in losses]
output_dict["loss"] = loss
output_dict["acc"] = {'tokens': tokens, "correct": correct}
return output_dict
def predict_sentences(self, logger, corpus, batch_size, i2tags):
self.eval()
result = []
k = 0
for i in range(0, len(corpus), batch_size):
k += 1
if k % 100 == 0:
logger.info(f"Tagging batch {k} over {len(corpus)//batch_size}: sentences {i} to {i+batch_size} over {len(corpus)}")
sentences = corpus[i:i+batch_size]
output_dict = self(sentences)
result.extend(output_dict["predictions"])
output = []
for sentence_tags in result:
sentence = []
for tags in zip(*sentence_tags):
tags = [i2t[t] for i2t,t in zip(i2tags, tags)]
sentence.append(tags)
output.append(sentence)
return output
def eval_on_corpus(self, trees, batch_size, i2tags, header):
examples = [tree.get_training_example(self.cpos) for tree in trees]
sentences = [ex["tokens"] for ex in examples]
tags_seqs = [ex["all_tags"] for ex in examples]
predictions = self.predict_sentences(sentences, batch_size, i2tags)
eval_dict = accuracy(tags_seqs, predictions, header)
return eval_dict
def ids_to(self, device):
if self.args.lstm:
self.encoder.embedder.word2tensors.to(device)
def prepare_batch(args, trees, tags2i):
training_examples = [tree.get_training_example(args.cpos) for tree in trees]
sentences = [ex["tokens"] for ex in training_examples]
tags_seqs_sentences = [ex["all_tags"] for ex in training_examples]
tags_seqs_sentences = [list(zip(*ex)) for ex in tags_seqs_sentences]
tags_idx = [[[tags2i[i][t] for t in tag_seq] for i, tag_seq in enumerate(tags_seq_sentence)] for tags_seq_sentence in tags_seqs_sentences]
tags_idx = [[torch.tensor(tag_seq, device=device) for tag_seq in tags_sentence] for tags_sentence in tags_idx]
return sentences, tags_idx
def sample_batches(aux_data, batch_size, sample_size):
if len(aux_data) == 0:
return []
# prepare size_sample batches of aux_data)
sample = list(np.random.choice(aux_data, sample_size, replace=False))
return [sample[i:i+batch_size] for i in range(0, len(sample), batch_size) ]
def main_train(args, logger, device):
corpus = load_corpus(args)
train, dev, test = corpus["corpus"]
aux_data = corpus["aux_data"]
i2tags = corpus["i2tags"]
tags2i = corpus["tags2i"]
stats = corpus["stats"]
header = corpus["header"]
char2i = corpus["char2i"]
word2i = corpus["word2i"]
with open(f"{args.model}/data", "wb") as fp:
pickle.dump([char2i, word2i, i2tags, args], fp)
for tagtype, tagset in zip(header, stats):
print(tagtype, len(tagset))
for k, v in sorted(tagset.items(), key = lambda x: x[1], reverse=True):
print(k, v)
print()
flautag = FlaubertTagger(i2tags, char2i, word2i, args)
flautag.to(device)
flautag.ids_to(device)
optimizer = optim.Adam(flautag.parameters(), lr=args.l)
batch_size = args.B
best_dev = 0
best_epoch = 0
for epoch in range(args.i):
epoch_loss = 0
flautag.train()
batches_train = [train[i:i+batch_size] for i in range(0, len(train), batch_size) ]
batches_aux = sample_batches(aux_data, batch_size=5*batch_size, sample_size=2*len(train))
all_batches = batches_train + batches_aux
random.shuffle(all_batches)
logger.info(f"Epoch {epoch}, train batches: {len(batches_train)}, aux batches: {len(batches_aux)}")
for i, trees in enumerate(all_batches):
optimizer.zero_grad()
sentences, tags_idx = prepare_batch(args, trees, tags2i)
output_dict = flautag(sentences, tags_idx)
loss = output_dict["loss"]
losses = output_dict["all_losses"]
np_loss = loss.item()
epoch_loss += np_loss
loss.backward()
optimizer.step()
losses_str = " ".join([f"{l:.3f}" for l in losses])
acc_str = output_dict["acc"]["correct"] / output_dict["acc"]["tokens"] * 100
acc_str = " ".join([f"{x:.2f}" for x in acc_str])
logger.info(f"Batch {i} sentences {i*batch_size} to {(i+1)*batch_size-1} of {len(batches_train)} loss {np_loss:.4f}: {losses_str} acc {acc_str}")
eval_dict = flautag.eval_on_corpus(dev, args.eval_batchsize, i2tags, header)
eval_dict_str = " ".join([f"{k[:10]}={v}" for k,v in sorted(eval_dict.items(), key = lambda x: x!="POS")])
if float(eval_dict["POS"]) > best_dev:
best_dev = float(eval_dict["POS"])
best_epoch = epoch
flautag.cpu()
flautag.ids_to(torch.device("cpu"))
torch.save(flautag.state_dict(), "{}/model".format(args.model))
with open(f"{args.model}/best_dev_score", "w", encoding="utf8") as ff:
ff.write(f"{best_dev}\n{best_epoch}")
flautag.to(device)
flautag.ids_to(device)
print(f"Epoch {epoch} loss: {epoch_loss:.4f} dev: {eval_dict_str} best={best_dev} e{best_epoch}", flush = True)
def load_data(modeldir):
with open(f"{modeldir}/data", "rb") as fp:
return pickle.load(fp)
def predictions_to_conll(sentences, predictions, use_cpos):
trees = []
for sentence, prediction in zip(sentences, predictions):
#tokens, cpos=None, fpos=None, features=None,
cpos = None
fpos = None
tags = [p[0] for p in prediction]
if use_cpos:
cpos = tags
else:
fpos = tags
features = []
for _, *rest in prediction:
f = [r for r in rest if r != "_"]
if len(f) > 0:
features.append("|".join(f))
else:
features.append("_")
assert(len(sentence) == len(features))
tree = DepTree(tokens=sentence, cpos=cpos, fpos=fpos, features=features)
trees.append(tree)
return trees
def main_eval(args, logger, device):
char2i, word2i, i2tags, args_train = load_data(f"{args.model}")
flautag = FlaubertTagger(i2tags, char2i, word2i, args_train)
state_dict = torch.load("{}/model".format(args.model))
flautag.load_state_dict(state_dict)
flautag.to(device)
flautag.ids_to(device)
flautag.eval()
logger.info("Model loaded")
with open(args.corpus, encoding="utf8") as f:
sentences = []
for line in f:
sent = line.strip().replace("\ufeff", "").replace("\u200e", "").replace("\u200b", "").replace("\x7f ","")
if sent:
sent = sent.split()
sentences.append(sent)
predictions = flautag.predict_sentences(logger, sentences, args.eval_batchsize, i2tags)
conll_trees = predictions_to_conll(sentences, predictions, args_train.cpos)
with open(f"{args.output}", "w", encoding="utf8") as of:
for tree in conll_trees:
of.write(f"{str(tree)}\n\n")
def main(args, logger, device):
if args.mode == "train":
try:
main_train(args, logger, device)
except KeyboardInterrupt:
print("Training interrupted, exiting")
return
else:
main_eval(args, logger, device)
if __name__ == "__main__":
import argparse
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG,
format='%(asctime)s-%(relativeCreated)d-%(levelname)s:%(message)s')
usage = main.__doc__
parser = argparse.ArgumentParser(description = usage, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
subparsers = parser.add_subparsers(dest="mode", description="Execution modes", help='train: training, eval: test')
subparsers.required = True
train_parser = subparsers.add_parser("train", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
eval_parser = subparsers.add_parser("eval", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# train corpora
train_parser.add_argument("model", help="Directory (to be created) for model exportation")
#train_parser.add_argument("train", help="Training corpus")
#train_parser.add_argument("dev", help="Dev corpus")
#train_parser.add_argument("--fmt", help="Format for train and dev corpus", choices=["ctbk", "discbracket", "bracket"], default="discbracket")
#train_parser.add_argument("--fmt", help="Format for train and dev corpus", choices=["ctbk", "discbracket", "bracket"], default="discbracket")
train_parser.add_argument("--corpus", default="ftb_spmrl", choices=["ftb_spmrl", "fqb", "ftb", "gsd", "partut", "pud", "sequoia", "spoken"],
help="Dataset. Has train/dev/test: ftb, gsd, partut, sequoia, spoken. Has morphology: ftb, gsd, partut, sequoia.")
train_parser.add_argument("--aux-data", default=None, help="Path to artificial extra dataset")
train_parser.add_argument("--cpos", action="store_true", help="Use coarse pos instead of fpos")
# general options
train_parser.add_argument("--cuda", type=int, default=None, help="Use GPU if available")
#train_parser.add_argument("-t", type=int, default=1, help="Number of threads for torch cpu")
train_parser.add_argument("-S", type=int, default=None, help="Use only X first training examples")
train_parser.add_argument("--bert-id", default="flaubert/flaubert_small_cased", help="Flaubert identifier")
train_parser.add_argument("--lstm", action="store_true", help="Use hierarchical LSTM instead of Flaubert")
# training options
train_parser.add_argument("-i", default=12, type=int, help="Number of epochs")
train_parser.add_argument("-l", default=0.00001, type=float, help="Learning rate")
train_parser.add_argument("-B", type=int, default=20, help="Size of batch")
train_parser.add_argument("-D", type=float, default=0.5, help="Dropout")
train_parser.add_argument("--eval-batchsize", type=int, default=30, help="Batch size for eval")
#train_parser.add_argument("--start-averaging", type=int, default=40, help="Aadam starts averaging at epoch <int> ")
train_parser.add_argument("-s", type=int, default=10, help="Random seed")
HierarchicalLSTM.add_cmd_options(train_parser)
# test corpus
eval_parser.add_argument("model", help="Pytorch model")
eval_parser.add_argument("corpus", help="Test corpus, 1 tokenized sentence per line")
eval_parser.add_argument("output", help="Outputfile")
# eval_parser.add_argument("--gold", default=None, help="Gold corpus (disc)bracket. If provided, eval with discodop")
eval_parser.add_argument("--eval-batchsize", type=int, default=32, help="Batch size for eval")
# test options
eval_parser.add_argument("--cuda", action="store_true", help="Use GPU if available")
# eval_parser.add_argument("-v", type=int, default=1, choices=[0,1], help="Verbosity level")
# eval_parser.add_argument("-t", type=int, default=1, help="Number of threads for torch cpu")
args = parser.parse_args()
for k, v in vars(args).items():
print(k, v)
#torch.set_num_threads(args.t)
logger = logging.getLogger()
logger.info("Mode={}".format(args.mode))
if args.mode == "train":
os.makedirs(args.model, exist_ok = True)
use_cuda = torch.cuda.is_available()
if use_cuda and args.cuda:
logger.info("Using gpu")
device = torch.device("cuda")
else:
logger.info("Using cpu")
device = torch.device("cpu")
logger.info(f"{str(device)}")
SEED = 0
if args.mode == "train":
SEED = args.s
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
main(args, logger, device)