https://gitlab.com/mcoavoux/mtgpy-release-findings-2021.git
Tip revision: c9972219cd75049269d26632d2bb79619d661298 authored by mcoavoux on 20 May 2021, 13:04:44 UTC
up readme
up readme
Tip revision: c997221
mtg.py
import sys
import os
from collections import defaultdict
import copy
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pickle
import random
from Asgd import MyAsgd, AAdam
from word_encoders import WordEmbedder, Words2Tensors
from sentence_encoders import TransformerNetwork, SentenceEncoderLSTM
from nk_transformer import NKTransformer
from nk_transformer import LayerNormalization
import fasttext_embeddings
import discodop_eval
from state_gap import State
import corpus_reader
import tree as T
from features import feature_functions
class Parser(nn.Module):
def __init__(self, args,
i2labels, i2tags, i2chars, i2words):
super(Parser, self).__init__()
self.discontinuous = not args.projective
self.args = args
self.i2labels = i2labels
self.i2tags = i2tags
self.i2chars = i2chars
self.i2words = i2words
self.labels2i = {l : i for i,l in enumerate(i2labels)}
self.tags2i = {l : i for i,l in enumerate(i2tags)}
self.chars2i = {l : i for i,l in enumerate(i2chars)}
self.words2i = {l : i for i,l in enumerate(i2words)}
self.words2tensors = Words2Tensors(self.chars2i, self.words2i, pchar=args.dcu, pword=args.dwu)
self.num_labels = len(i2labels)
self.num_tags = len(i2tags)
self.num_words = len(i2words)
self.num_chars = len(i2chars)
self.word_embedder = WordEmbedder(args, len(self.i2words), len(self.i2chars), self.words2tensors, self.words2i)
d_input = self.word_embedder.output_dim
if args.enc == "lstm":
self.encoder = SentenceEncoderLSTM(args, d_input)
elif args.enc == "transformer":
self.encoder = TransformerNetwork(args, d_input)
elif args.enc == "nktransformer":
self.encoder = NKTransformer(args, d_input)
dim_encoder = args.lstm_dim
if args.enc == "no":
dim_encoder = self.word_embedder.output_dim
elif args.enc == "transformer" or args.enc == "nktransformer":
dim_encoder = args.trans_dmodel
self.feature_function, num_features = feature_functions[args.feats]
dict_fun = {"tanh": nn.Tanh, "relu": nn.ReLU}
activation_function = dict_fun[args.fun]
# NK sequence
# self.f_label = nn.Sequential(
# nn.Linear(hparams.d_model, hparams.d_label_hidden),
# LayerNormalization(hparams.d_label_hidden),
# nn.ReLU(),
# nn.Linear(hparams.d_label_hidden, label_vocab.size - 1),
# )
self.structure = nn.Sequential(
nn.Dropout(args.diff),
nn.Linear(dim_encoder*num_features, args.H),
LayerNormalization(args.H),
activation_function(),
nn.Linear(args.H, 3),
nn.LogSoftmax(dim=1))
self.label = nn.Sequential(
nn.Dropout(args.diff),
nn.Linear(dim_encoder*num_features, args.H),
LayerNormalization(args.H),
activation_function(),
nn.Linear(args.H, self.num_labels),
nn.LogSoftmax(dim=1))
self.tagger = nn.Sequential(
nn.Dropout(args.diff),
nn.Linear(dim_encoder, args.H),
LayerNormalization(args.H),
activation_function(),
nn.Linear(args.H, self.num_tags),
nn.LogSoftmax(dim=1))
# # Structural actions
# self.structure = nn.Sequential(
# nn.Dropout(args.diff),
# nn.Linear(dim_encoder*num_features, args.H),
# nn.LayerNorm(args.H),
# activation_function(),
## nn.Linear(args.H, args.H),
## nn.LayerNorm(args.H)
## activation_function(),
# nn.Linear(args.H, 3),
# nn.LogSoftmax(dim=1))
# # Labelling actions
# self.label = nn.Sequential(
# nn.Dropout(args.diff),
# nn.Linear(dim_encoder*num_features, args.H),
# nn.LayerNorm(args.H),
# activation_function(),
## nn.Linear(args.H, args.H),
## nn.LayerNorm(args.H)
## activation_function(),
# nn.Linear(args.H, self.num_labels),
# nn.LogSoftmax(dim=1))
# # POS tags
# self.tagger = nn.Sequential(
# nn.Dropout(args.dtag),
# nn.Linear(dim_encoder, args.H),
# nn.LayerNorm(args.H),
# activation_function(),
# nn.Linear(args.H, self.num_tags),
# nn.LogSoftmax(dim=1))
self.loss_function = nn.NLLLoss(reduction="sum")
self.default_values = nn.Parameter(torch.Tensor(4, dim_encoder))
self.initialize_parameters(args)
def initialize_parameters(self, args):
# Xavier initialization for every layer
# uniform initialization for embeddings
torch.nn.init.uniform_(self.default_values, -0.01, 0.01)
# torch.nn.init.xavier_normal_(self.label[1].weight)
# torch.nn.init.xavier_normal_(self.label[4].weight)
# torch.nn.init.xavier_normal_(self.structure[1].weight)
# torch.nn.init.xavier_normal_(self.structure[4].weight)
# torch.nn.init.xavier_normal_(self.tagger[1].weight)
# torch.nn.init.xavier_normal_(self.tagger[4].weight)
def train_batch(self, batch_sentences, all_embeddings, targets):
batch_features, batch_tags = targets
output_l1, output_l2 = all_embeddings
batch_label_input = []
batch_label_output = []
batch_struct_input = []
batch_struct_output = []
for sentence_embeddings, features in zip(output_l2, batch_features):
#batch_features: {"struct": (struct_input, struct_output), "label": (labels_input, labels_output)}
l_input, l_output = features["label"]
s_input, s_output = features["struct"]
batch_label_input.append(sentence_embeddings[l_input].view(len(l_input), -1))
batch_label_output.append(l_output)
batch_struct_input.append(sentence_embeddings[s_input].view(len(s_input), -1))
batch_struct_output.append(s_output)
batch_label_input = torch.cat(batch_label_input, dim=0)
batch_label_output = torch.cat(batch_label_output)
batch_struct_input = torch.cat(batch_struct_input, dim=0)
batch_struct_output = torch.cat(batch_struct_output)
labels_output_tensors = self.label(batch_label_input)
label_loss = self.loss_function(labels_output_tensors, batch_label_output)
label_loss /= len(batch_label_input)
struct_output_tensors = self.structure(batch_struct_input)
struct_loss = self.loss_function(struct_output_tensors, batch_struct_output)
struct_loss /= len(batch_struct_input)
batch_tagging_input = torch.cat(output_l1, dim=0)
batch_tagging_targets = torch.cat(batch_tags)
tagging_output = self.tagger(batch_tagging_input)
tagging_loss = torch.sum(self.loss_function(tagging_output, batch_tagging_targets))
tagging_loss /= len(batch_tagging_targets)
return {"tagging loss": tagging_loss, "struct loss": struct_loss, "label loss": label_loss}
def forward(self, batch_sentences_raw, targets=None):
# raw: no SOS / EOS
batch_sentences = [["<SOS>"] + sent + ["<EOS>"] for sent in batch_sentences_raw]
lengths = torch.tensor([len(s) for s in batch_sentences])
all_embeddings = self.word_embedder(batch_sentences)
if self.args.enc != "no":
padded_char_based_embeddings = torch.nn.utils.rnn.pad_sequence(all_embeddings, batch_first=True)
unpacked_l1, unpacked_l2 = self.encoder(padded_char_based_embeddings, lengths)
else:
unpacked_l1 = all_embeddings
unpacked_l2 = all_embeddings
unpacked_l2 = [torch.cat([self.default_values, l2_sent], dim=0) for l2_sent in unpacked_l2]
if targets is not None:
return self.train_batch(batch_sentences, (unpacked_l1, unpacked_l2), targets)
tokens = [[T.Token(tok, i, [None]) for i, tok in enumerate(sent)] for sent in batch_sentences_raw]
return self.parse_batch(tokens, (unpacked_l1, unpacked_l2))
#return self.parse(tokens, (unpacked_l1, unpacked_l2))
def parse(self, sentences, sentence_embeddings):
# sentences: list of list of T.Tokens
output_l1, output_l2 = sentence_embeddings
lengths = [len(sent) for sent in sentences]
batch_tagging_input = torch.cat(output_l1, dim=0)
tag_scores = self.tagger(batch_tagging_input)
tag_predictions = torch.argmax(tag_scores, dim=1).cpu().split(lengths)
tag_predictions = [v.numpy() for v in tag_predictions]
for sent_tags, sent in zip(tag_predictions, sentences):
for tag, tok in zip(sent_tags, sent):
tok.set_tag(self.i2tags[tag])
trees = []
# make this more parallel / batch
for sentence, sentence_embedding in zip(sentences, output_l2):
state = State(sentence, self.discontinuous)
while not state.is_final():
next_type = state.next_action_type()
stack, queue, buffer, sent_len = state.get_input()
input_features = self.feature_function(sentence_embedding.device, stack, queue, buffer, sent_len)
input_embeddings = sentence_embedding[input_features].view(1, -1)
if next_type == State.LABEL:
output = self.label(input_embeddings).squeeze()
state.filter_action(next_type, output)
prediction = torch.argmax(output)
if prediction == 0:
state.nolabel()
else:
pred_label = self.i2labels[prediction]
state.labelX(pred_label)
else:
output = self.structure(input_embeddings).squeeze()
state.filter_action(next_type, output)
prediction = torch.argmax(output.squeeze())
if prediction == State.SHIFT:
state.shift()
elif prediction == State.COMBINE:
state.combine()
else:
assert(prediction == State.GAP)
state.gap()
trees.append(state.get_tree())
return trees
def parse_batch(self, sentences, sentence_embeddings):
# sentences: list of list of T.Tokens
output_l1, output_l2 = sentence_embeddings
lengths = [len(sent) for sent in sentences]
batch_tagging_input = torch.cat(output_l1, dim=0)
tag_scores = self.tagger(batch_tagging_input)
tag_predictions = torch.argmax(tag_scores, dim=1).cpu().split(lengths)
tag_predictions = [v.numpy() for v in tag_predictions]
for sent_tags, sent in zip(tag_predictions, sentences):
for tag, tok in zip(sent_tags, sent):
tok.set_tag(self.i2tags[tag])
trees = [None for _ in range(len(sentences))]
states = [State(sentence, self.discontinuous, sent_id=i) for i, sentence in enumerate(sentences)]
while len(states) > 0:
next_types = [state.next_action_type() for state in states]
input_structs = [state.get_input() for state in states]
input_features = [self.feature_function(output_l2[0].device, stack, queue, buffer, sent_len)
for stack, queue, buffer, sent_len in input_structs]
input_embeddings = [output_l2[state.sent_id][in_feats].view(-1) for state, in_feats in zip(states, input_features)]
label_states = [states[i] for i, ntype in enumerate(next_types) if ntype == State.LABEL]
label_inputs = [input_embeddings[i] for i, ntype in enumerate(next_types) if ntype == State.LABEL]
if len(label_states) > 0:
label_outputs = self.label(torch.stack(label_inputs))
for i, state in enumerate(label_states):
state.filter_action(State.LABEL, label_outputs[i])
prediction = torch.argmax(label_outputs[i])
if prediction == 0:
state.nolabel()
else:
pred_label = self.i2labels[prediction]
state.labelX(pred_label)
struct_states = [states[i] for i, ntype in enumerate(next_types) if ntype == State.STRUCT]
struct_inputs = [input_embeddings[i] for i, ntype in enumerate(next_types) if ntype == State.STRUCT]
#print(torch.stack(struct_inputs).shape)
if len(struct_states) > 0:
struct_outputs = self.structure(torch.stack(struct_inputs))
for i, state in enumerate(struct_states):
state.filter_action(State.STRUCT, struct_outputs[i])
prediction = torch.argmax(struct_outputs[i])
if prediction == State.SHIFT:
state.shift()
elif prediction == State.COMBINE:
state.combine()
else:
assert(prediction == State.GAP)
state.gap()
states = []
for state in label_states + struct_states:
if state.is_final():
trees[state.sent_id] = state.get_tree()
else:
states.append(state)
return trees
def predict_corpus(self, corpus, batch_size):
# corpus: list of list of tokens
trees = []
# sort by length for lstm batching
indices, corpus = zip(*sorted(zip(range(len(corpus)),
corpus),
key = lambda x: len(x[1]),
reverse=True))
with torch.no_grad():
for i in range(0, len(corpus), batch_size):
tree_batch = self.forward(corpus[i:i+batch_size])
for t in tree_batch:
t.expand_unaries()
trees.append(t)
# reorder
_, trees = zip(*sorted(zip(indices, trees), key = lambda x:x[0]))
return trees
def assign_pretrained_embeddings(model, ft, dim_ft, freeze):
print("Assigning pretrained fasttext embeddings")
embedding = model.word_embedder.word_embeddings
device = embedding.weight.data.device
for w, v in ft.items():
i = model.words2i[w]
embedding.weight.data[i, :dim_ft] = torch.tensor(v).to(device)
if freeze:
print("Freezing word embedding layer")
embedding.requires_grad_(False)
def get_vocabulary(corpus):
"""Extract vocabulary for characters, tokens, non-terminals, POS tags"""
words = defaultdict(int)
chars = defaultdict(int)
# chars["<START>"] += 2
# chars["<STOP>"] += 2
# These should be treated as single characters
chars["-LRB-"] += 2
chars["-RRB-"] += 2
chars["#LRB#"] += 2
chars["#RRB#"] += 2
tag_set = defaultdict(int)
label_set = defaultdict(int)
for tree in corpus:
tokens = T.get_yield(tree)
for tok in tokens:
for char in tok.token:
chars[char] += 1
tag_set[tok.get_tag()] += 1
words[tok.token] += 1
constituents = T.get_constituents(tree)
for label, _ in constituents:
label_set[label] += 1
return words, chars, label_set, tag_set
def extract_features(device, labels2i, sentence, feature_function):
state = State(sentence, True)
struct_input = []
struct_output = []
labels_input = []
labels_output = []
while not state.is_final():
next_type = state.next_action_type()
gold_action, (stack, queue, buffer, sent_len) = state.oracle()
input_features = feature_function(device, stack, queue, buffer, sent_len)
if next_type == State.STRUCT:
struct_input.append(input_features)
struct_output.append(State.mapping[gold_action[0]])
else:
assert(next_type == State.LABEL)
labels_input.append(input_features)
labels_output.append(labels2i[gold_action[1]])
struct_input = torch.stack(struct_input)
struct_output = torch.tensor(struct_output, dtype=torch.long, device=device)
labels_input = torch.stack(labels_input)
labels_output = torch.tensor(labels_output, dtype=torch.long, device=device)
#print(labels_output)
return {"struct": (struct_input, struct_output), "label": (labels_input, labels_output)}
def extract_tags(device, tags2i, sentence):
# Returns a tensor for tag ids for a single sentence
idxes = [tags2i[tok.get_tag()] for tok in sentence]
return torch.tensor(idxes, dtype=torch.long, device=device)
def compute_f(TPs, total_golds, total_preds):
p, r, f = 0, 0, 0
if total_preds > 0:
p = TPs / total_preds
if total_golds > 0:
r = TPs / total_golds
if (p, r) != (0, 0):
f = 2*p*r / (p+r)
return p, r, f
def Fscore_corpus(golds, preds):
TPs = 0
total_preds = 0
total_golds = 0
UTPs = 0
for gold, pred in zip(golds, preds):
TPs += len([c for c in gold if c in pred])
total_golds += len(gold)
total_preds += len(pred)
ugold = defaultdict(int)
for _, span in gold:
ugold[span] += 1
upred = defaultdict(int)
for _, span in pred:
upred[span] += 1
for span in upred:
UTPs += min(upred[span], ugold[span])
p, r, f = compute_f(TPs, total_golds, total_preds)
up, ur, uf = compute_f(UTPs, total_golds, total_preds)
return p*100, r*100, f*100, up*100, ur*100, uf*100
def prepare_corpus(corpus):
sentences = [T.get_yield(corpus[i]) for i in range(len(corpus))]
raw_sentences = [[tok.token for tok in sentence] for sentence in sentences]
return sentences, raw_sentences
def eval_tagging(gold, pred):
# Returns accuracy for tag predictions
acc = 0
tot = 0
assert(len(gold) == len(pred))
for sent_g, sent_p in zip(gold, pred):
assert(len(sent_g) == len(sent_p))
for tok_g, tok_p in zip(sent_g, sent_p):
if tok_g.get_tag() == tok_p.get_tag():
acc += 1
tot += len(sent_g)
return acc * 100 / tot
def train_epoch(device, epoch, model, optimizer, args, logger, train_raw_sentences, features, tag_features, idxs, scheduling_dict, batch_size, check_callback):
tag_loss = 0
struct_loss = 0
label_loss = 0
model.train()
grad_norm = 0
random.shuffle(idxs)
n_batches = len(idxs)//batch_size
check_id = n_batches // args.check
num_check = 0
for j, i in enumerate(range(0, len(idxs), batch_size)):
batch_sentences = [train_raw_sentences[idx] for idx in idxs[i:i+batch_size]]
batch_features = [features[idx] for idx in idxs[i:i+batch_size]]
batch_tags = [tag_features[idx] for idx in idxs[i:i+batch_size]]
# print(batch_features[0])
# print(batch_tags[0])
batch_features = [ {k: (v[0].to(device), v[1].to(device)) for k,v in feats.items()} for feats in batch_features]
batch_tags = [feats.to(device) for feats in batch_tags]
batch_sentences, batch_features, batch_tags = zip(*sorted(zip(batch_sentences, batch_features, batch_tags),
key = lambda x: len(x[0]), reverse=True))
optimizer.zero_grad()
losses = model(batch_sentences, targets=(batch_features, batch_tags))
batch_total_loss = losses["tagging loss"] + losses["struct loss"] + losses["label loss"]
batch_total_loss.backward()
tag_loss += losses["tagging loss"].item()
struct_loss += losses["struct loss"].item()
label_loss += losses["label loss"].item()
if args.G is not None:
grad_norm += torch.nn.utils.clip_grad_norm_(model.parameters(), args.G)
optimizer.step()
if j % 20 == 0 and args.v > 0:
logger.info(f"Epoch {epoch} Training batch {j}/{n_batches} sent {i} to {i+batch_size} lr:{list(optimizer.param_groups)[0]['lr']:.4e}")
scheduling_dict["total_steps"] += 1
if scheduling_dict["total_steps"] <= scheduling_dict["warmup_steps"]:
set_lr(optimizer, scheduling_dict["total_steps"] * scheduling_dict["warmup_coeff"])
if (j + 1) % check_id == 0 and num_check < args.check - 1 and epoch >= 30:
num_check += 1
check_callback(epoch, num_check, False)
grad_norm /= n_batches
return {"tag loss": tag_loss, "struct loss": struct_loss, "label loss": label_loss, "grad norm": grad_norm}
def set_lr(optimizer, new_lr):
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
def eval_on_corpus(args, model, sentences, raw_sentences, corpus):
pred_trees = model.predict_corpus(raw_sentences, args.eval_batchsize)
sentence_pred_tokens = [T.get_yield(tree) for tree in pred_trees]
tageval = eval_tagging(sentences, sentence_pred_tokens)
outfile = f"{args.model}/tmp_{corpus}.discbracket"
with open(outfile, "w", encoding="utf8") as fstream:
for tree in pred_trees:
fstream.write("{}\n".format(str(tree)))
if corpus == "dev":
goldfile = args.dev.replace(".ctbk", ".discbracket")
else:
goldfile = f"{args.model}/tmp_sample_train_gold.discbracket"
discop, discor, discodop_f = discodop_eval.call_eval(goldfile, outfile)
disco2p, disco2r, discodop2_f = discodop_eval.call_eval(goldfile, outfile, disconly=True)
return {"p": discop, "r": discor, "f": discodop_f,
"dp": disco2p, "dr": disco2r, "df": discodop2_f,
"tag": round(tageval, 2)}
def load_corpora(args, logger):
logger.info("Loading corpora...")
if args.fmt == "ctbk":
train_corpus = corpus_reader.read_ctbk_corpus(args.train)
dev_corpus = corpus_reader.read_ctbk_corpus(args.dev)
elif args.fmt == "discbracket":
train_corpus = corpus_reader.read_discbracket_corpus(args.train)
dev_corpus = corpus_reader.read_discbracket_corpus(args.dev)
elif args.fmt == "bracket":
train_corpus = corpus_reader.read_bracket_corpus(args.train)
dev_corpus = corpus_reader.read_bracket_corpus(args.dev)
parsed_corpus = []
if args.parsed is not None:
parsed_corpus = corpus_reader.read_discbracket_corpus(args.parsed)
random.shuffle(parsed_corpus)
for tree in train_corpus + parsed_corpus:
tree.merge_unaries()
return train_corpus, dev_corpus, parsed_corpus
def get_voc_dicts(words, vocabulary, label_set, tag_set):
i2chars = ["<PAD>", "<UNK>", "<START>", "<STOP>", "<SOS>", "<EOS>"] + sorted(vocabulary, key = lambda x: vocabulary[x], reverse=True)
#chars2i = {k:i for i, k in enumerate(i2chars)}
i2labels = ["nolabel"] + sorted(label_set)
labels2i = {k: i for i, k in enumerate(i2labels)}
i2tags = sorted(tag_set)
tags2i = {k:i for i, k in enumerate(i2tags)}
i2words = ["<PAD>", "<UNK>", "<SOS>", "<EOS>"] + sorted(words, key=lambda x: words[x], reverse=True)
#words2i = {k:i for i, k in enumerate(i2words)}
return i2chars, i2labels, labels2i, i2tags, tags2i, i2words
def check_dev(epoch, check_id, optimizer, model, args, dev_sentences, dev_raw_sentences, sample_train, sample_train_raw, scheduling_dict, average):
optimizer.zero_grad()
model.eval()
if args.O in ["asgd", "aadam"] and average:
optimizer.average()
dev_results = eval_on_corpus(args, model, dev_sentences, dev_raw_sentences, "dev")
train_results = eval_on_corpus(args, model, sample_train, sample_train_raw, "train")
discodop_f = dev_results["f"]
print(f"Epoch {epoch}.{check_id} lr:{optimizer.param_groups[0]['lr']:.4e}",
"dev:", " ".join([f"{k}={v}" for k, v in sorted(dev_results.items())]),
"train:", " ".join([f"{k}={v}" for k, v in sorted(train_results.items())]),
# f" dev tag:{dtag:.2f} f:{discodop_f} p:{discop} r:{discor} disco f:{discodop2_f} p:{disco2p} r:{disco2r}",
# f" train tag:{dtag:.2f} f:{discodop_f} p:{discop} r:{discor} disco f:{discodop2_f} p:{disco2p} r:{disco2r}",
f"best: {scheduling_dict['best_dev_f']} ep {scheduling_dict['best_epoch']} avg={int(average)}",
flush=True)
if discodop_f > scheduling_dict['best_dev_f']:
scheduling_dict["best_epoch"] = epoch
scheduling_dict['best_dev_f'] = discodop_f
model.cpu()
model.words2tensors.to(torch.device("cpu"))
torch.save(model.state_dict(), "{}/model".format(args.model))
with open(f"{args.model}/best_dev_score", "w", encoding="utf8") as ff:
ff.write(str(scheduling_dict['best_dev_f']))
model.to(device)
model.words2tensors.to(device)
if args.O in ["asgd", "aadam"] and average:
optimizer.cancel_average()
model.train()
def do_training(args, device, model, optimizer, train_corpus, dev_corpus, labels2i, tags2i, num_epochs, batch_size):
logger.info("Constructing training examples...")
train_sentences, train_raw_sentences = prepare_corpus(train_corpus)
dev_sentences, dev_raw_sentences = prepare_corpus(dev_corpus)
# just 500 training sentences to monitor learning
sample_indexes = sorted(np.random.choice(len(train_sentences), min(500, len(train_sentences)), replace=False))
sample_train = [train_sentences[i] for i in sample_indexes]
sample_train_raw = [train_raw_sentences[i] for i in sample_indexes]
sample_trainfile = f"{args.model}/tmp_sample_train_gold.discbracket"
with open(sample_trainfile, "w", encoding="utf8") as fstream:
for i in sample_indexes:
tree = train_corpus[i]
tree.expand_unaries()
fstream.write("{}\n".format(str(tree)))
tree.merge_unaries()
feature_function, _ = feature_functions[args.feats]
features = [extract_features(torch.device("cpu"), labels2i, sentence, feature_function) for sentence in train_sentences]
tag_features = [extract_tags(torch.device("cpu"), tags2i, sentence) for sentence in train_sentences]
idxs = list(range(len(train_sentences)))
num_tokens = sum([len(sentence) for sentence in train_sentences])
random.shuffle(idxs)
logger.info("Starting training")
warmup_steps = int(len(train_corpus) / args.B)
print(f"warmup steps: {warmup_steps}")
scheduling_dict = {"warmup_coeff": args.l / warmup_steps,
"total_steps": 0,
"patience": args.patience,
"warmup_steps": warmup_steps,
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max',
factor=0.5,
patience=args.patience,
verbose=True),
"best_epoch": 0,
"best_dev_f": 0}
set_lr(optimizer, scheduling_dict["warmup_coeff"])
check_callback = lambda epoch, check_id, avg: check_dev(epoch, check_id, optimizer, model, args,
dev_sentences, dev_raw_sentences,
sample_train, sample_train_raw, scheduling_dict, avg)
# TODO: change criterion to stop training
for epoch in range(1, num_epochs+1):
epoch_dict = train_epoch(device, epoch, model, optimizer, args, logger, train_raw_sentences, features, tag_features, idxs, scheduling_dict, batch_size, check_callback)
tag_loss = epoch_dict["tag loss"]
struct_loss = epoch_dict["struct loss"]
label_loss = epoch_dict["label loss"]
grad_norm = epoch_dict["grad norm"]
check_callback(epoch, args.check, False)
if args.O in ["asgd", "aadam"] and epoch >= args.start_averaging:
check_callback(epoch, args.check, True)
print(f"Epoch {epoch} tag:{tag_loss:.4f} str:{struct_loss:.4f} lab:{label_loss:.4f} gn:{grad_norm:.4f}")
if scheduling_dict["total_steps"] > scheduling_dict["warmup_steps"]:
scheduling_dict["scheduler"].step(scheduling_dict["best_dev_f"])
if epoch - scheduling_dict["best_epoch"] > (scheduling_dict["patience"] + 1) * 3:
print("Finishing training: no improvement on dev.")
break
if args.O in ["asgd", "aadam"]:
optimizer.average()
def main_train(args, logger, device):
train_corpus, dev_corpus, parsed_corpus = load_corpora(args, logger)
if args.ft is not None:
logger.info("Loading fast text vectors")
ft, dim_ft = fasttext_embeddings.load_vectors(args.ft)
logger.info("Loading fast text vectors: done")
else:
ft = None
dim_ft = None
if args.load_pretrained is None:
logger.info("Vocabulary extraction...")
words, vocabulary, label_set, tag_set = get_vocabulary(train_corpus + parsed_corpus)
i2chars, i2labels, labels2i, i2tags, tags2i, i2words = get_voc_dicts(words, vocabulary, label_set, tag_set)
if ft is not None:
voc_set = set(i2words)
for token in ft:
if token not in voc_set:
i2words.append(token)
logger.info("Model initialization...")
model = Parser(args, i2labels, i2tags, i2chars, i2words)
if ft is not None:
assign_pretrained_embeddings(model, ft, dim_ft, freeze=args.freeze_ft)
else:
i2chars, i2labels, i2tags, i2words, args_train = load_data(f"{args.load_pretrained}")
labels2i = {label: i for i, label in enumerate(i2labels)}
tags2i = {tag: i for i, tag in enumerate(i2tags)}
model = Parser(args_train, i2labels, i2tags, i2chars, i2words)
state_dict = torch.load(f"{args.load_pretrained}/model")
model.load_state_dict(state_dict)
with open(f"{args.model}/data", "wb") as fp:
pickle.dump([i2chars, i2labels, i2tags, i2words, args], fp)
model.to(device)
model.words2tensors.to(device)
random.shuffle(train_corpus)
if args.S is not None:
train_corpus = train_corpus[:args.S]
dev_corpus = dev_corpus[:args.S]
parsed_corpus = parsed_corpus[:args.S]
print("Training sentences: {}".format(len(train_corpus)))
print("Additional non-gold training sentences: {}".format(len(parsed_corpus)))
print("Dev set sentences: {}".format(len(dev_corpus)))
# num_parameters = 0
# num_parameters_embeddings = 0
# for name, p in model.named_parameters():
# print("parameter '{}', {}".format(name, p.numel()))
# if "embedding" in name:
# num_parameters_embeddings += p.numel()
# else:
# num_parameters += p.numel()
# print("Total number of parameters: {}".format(num_parameters + num_parameters_embeddings))
# print("Embedding parameters: {}".format(num_parameters_embeddings))
# print("Non embedding parameters: {}".format(num_parameters))
parameters = [pp for pp in model.parameters() if pp.requires_grad]
if len(parsed_corpus) > 0:
optimizer = optim.Adam(parameters, lr=args.parsed_lr, betas=(0.9, 0.98), eps=1e-9)
do_training(args, device, model, optimizer, parsed_corpus, dev_corpus, labels2i, tags2i, args.parsed_epochs, args.parsed_batchsize)
return
if args.O == "asgd":
optimizer = MyAsgd(parameters, lr=args.l,
momentum=args.m, weight_decay=0,
dc=args.d)
elif args.O == "adam":
optimizer = optim.Adam(parameters, lr=args.l, betas=(0.9, 0.98), eps=1e-9)
elif args.O == "aadam":
start = (len(train_corpus) // args.B) * args.start_averaging # start averaging after 20 epochs
optimizer = AAdam(parameters, start=start,lr=args.l, betas=(0.9, 0.98), eps=1e-9)
else:
optimizer = optim.SGD(parameters, lr=args.l, momentum=args.m, weight_decay=0)
do_training(args, device, model, optimizer, train_corpus, dev_corpus, labels2i, tags2i, args.i, args.B)
def read_raw_corpus(filename):
sentences = []
with open(filename, encoding="utf8") as f:
for line in f:
# For negra
line = line.replace("(", "#LRB#").replace(")", "#RRB#")
line = line.strip().split()
if len(line) > 0:
sentences.append(line)
return sentences
def load_data(modeldir):
with open(f"{modeldir}/data", "rb") as fp:
return pickle.load(fp)
def main_eval(args, logger, device):
i2chars, i2labels, i2tags, i2words, args_train = load_data(f"{args.model}")
model = Parser(args_train, i2labels, i2tags, i2chars, i2words)
state_dict = torch.load("{}/model".format(args.model))
model.load_state_dict(state_dict)
model.to(device)
model.words2tensors.to(device)
model.eval()
logger.info("Model loaded")
sentences = read_raw_corpus(args.corpus)
#sentence_toks = [[T.Token(token, i, [None]) for i, token in enumerate(sentence)] for sentence in sentences]
trees = model.predict_corpus(sentences, args.eval_batchsize)
if args.output is None:
for tree in trees:
print(tree)
else:
with open(args.output, "w", encoding="utf8") as f:
for tree in trees:
f.write("{}\n".format(str(tree)))
if args.gold is not None:
p, r, f = discodop_eval.call_eval(args.gold, args.output, disconly=False)
dp, dr, df = discodop_eval.call_eval(args.gold, args.output, disconly=True)
print(f"precision={p}")
print(f"recall={r}")
print(f"fscore={f}")
print(f"disc-precision={dp}")
print(f"disc-recall={dr}")
print(f"disc-fscore={df}")
pred_tokens = [T.get_yield(tree) for tree in trees]
gold_tokens = [T.get_yield(tree) for tree in corpus_reader.read_discbracket_corpus(args.gold)]
tag_eval = eval_tagging(gold_tokens, pred_tokens)
print(f"tagging={tag_eval:.2f}")
def main(args, logger, device):
"""
Discoparset: transition based discontinuous constituency parser
warning: CPU trainer is non-deterministic (due to multithreading approximation)
"""
if args.mode == "train":
try:
main_train(args, logger, device)
except KeyboardInterrupt:
print("Training interrupted, exiting")
return
else:
main_eval(args, logger, device)
if __name__ == "__main__":
# import cProfile
# # if check avoids hackery when not profiling
# # Optional; hackery *seems* to work fine even when not profiling, it's just wasteful
# if sys.modules['__main__'].__file__ == cProfile.__file__:
# import main # Imports you again (does *not* use cache or execute as __main__)
# globals().update(vars(main)) # Replaces current contents with newly imported stuff
# sys.modules['__main__'] = main # Ensu
sys.setrecursionlimit(1500)
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG,
format='%(asctime)s-%(relativeCreated)d-%(levelname)s:%(message)s')
import argparse
usage = main.__doc__
parser = argparse.ArgumentParser(description = usage, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
subparsers = parser.add_subparsers(dest="mode", description="Execution modes", help='train: training, eval: test')
subparsers.required = True
train_parser = subparsers.add_parser("train", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
eval_parser = subparsers.add_parser("eval", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# train corpora
train_parser.add_argument("model", help="Directory (to be created) for model exportation")
train_parser.add_argument("train", help="Training corpus")
train_parser.add_argument("dev", help="Dev corpus")
train_parser.add_argument("--fmt", help="Format for train and dev corpus", choices=["ctbk", "discbracket", "bracket"], default="discbracket")
# general options
train_parser.add_argument("--gpu", type=int, default=None, help="Use GPU if available")
train_parser.add_argument("-t", type=int, default=1, help="Number of threads for torch cpu")
train_parser.add_argument("-S", type=int, default=None, help="Use only X first training examples")
train_parser.add_argument("-v", type=int, default=1, choices=[0,1], help="Verbosity level")
train_parser.add_argument("--projective", action="store_true", help="Projective mode, with --fmt bracket.")
train_parser.add_argument("--enc", choices=["lstm", "transformer", "nktransformer", "no"], default="lstm", help="Type of sentence encoder")
# training options
train_parser.add_argument("-i", default=1000, type=int, help="Number of epochs")
train_parser.add_argument("-l", default=0.0008, type=float, help="Learning rate")
train_parser.add_argument("-m", default=0, type=float, help="Momentum (for sgd and asgd)")
train_parser.add_argument("-d", default=1e-7, type=float, help="Decay constant for lr (asgd)")
train_parser.add_argument("-I", type=float, default=0.01, help="Embedding init uniform on [-I, I]")
train_parser.add_argument("-G", default=None, type=float, help="Max norm for gradient clipping")
train_parser.add_argument("-B", type=int, default=1, help="Size of batch")
train_parser.add_argument("--eval-batchsize", type=int, default=30, help="Batch size for eval")
train_parser.add_argument("-O", default="adam", choices=["adam", "aadam", "sgd", "asgd"], help="Optimizer")
train_parser.add_argument("--start-averaging", type=int, default=40, help="Aadam starts averaging at epoch <int> ")
train_parser.add_argument("-s", type=int, default=10, help="Random seed")
train_parser.add_argument("--patience", type=int, default=5, help="Patience or lr scheduler")
train_parser.add_argument("--check", type=int, default=4, help="Checks on dev set per epoch")
# Dropout
train_parser.add_argument("--dcu", default=0.2, type=float, help="Dropout characters (replace by UNK)")
#train_parser.add_argument("--dce", default=0.2, type=float, help="Dropout for character embedding layer")
train_parser.add_argument("--diff", default=0.2, type=float, help="Dropout for FF output layers")
#train_parser.add_argument("--dtag", default=0.5, type=float, help="Dropout for tagger")
train_parser.add_argument("--dwu", default=0.3, type=float, help="Dropout words (replace by UNK)")
# Dimensions / architecture
train_parser.add_argument("-H", type=int, default=250, help="Dimension of hidden layer for FF nets")
train_parser.add_argument("-C", type=int, default=100, help="Dimension of char bilstm")
train_parser.add_argument("-c", type=int, default=64, help="Dimension of char embeddings")
train_parser.add_argument("-w", type=int, default=32, help="Use word embeddings with dim=w")
train_parser.add_argument("--ft", default=None, help="use pretrained fasttext embeddings (path), automatically set -w")
train_parser.add_argument("--freeze-ft", action="store_true", help="don't fine tune fasttext embeddings")
# activation
train_parser.add_argument("--fun", type=str, default="tanh", choices=["relu", "tanh"],
help="Activation functions for FF networks")
# feature functions
train_parser.add_argument("--feats", default="tacl", choices=["all", "tacl", "tacl_base", "global"], help="Feature functions")
# bert:
train_parser.add_argument("--bert", action="store_true", help="Use Bert as lexical model.")
train_parser.add_argument("--init-bert", action="store_true", help="Reinitialize bert parameters")
train_parser.add_argument("--bert-id",
default="bert-base-cased", type=str,
help="Which BERT model to use?")
train_parser.add_argument("--no-char", action="store_true",
help="Deactivate char based embeddings")
train_parser.add_argument("--no-word", action="store_true",
help="Deactivate word embeddings")
# Semisupervision with additional corpus
train_parser.add_argument("--parsed", default=None, help="Path to corpus for pretraining")
train_parser.add_argument("--parsed-batchsize", type=int, metavar=" ", default=300, help="Batch size for pretraining") # utility: having a different default value
train_parser.add_argument("--parsed-epochs", type=int, metavar=" ", default=1, help="Number of epochs for pretraining")
train_parser.add_argument("--parsed-lr", type=float, metavar=" ", default=0.0001, help="Learning rate for pretraining")
train_parser.add_argument("--load-pretrained", metavar=" ", default=None, help="Path to pretrained model")
# Options from transformer
SentenceEncoderLSTM.add_cmd_options(train_parser)
TransformerNetwork.add_cmd_options(train_parser)
# test corpus
eval_parser.add_argument("model", help="Pytorch model")
eval_parser.add_argument("corpus", help="Test corpus, 1 tokenized sentence per line")
eval_parser.add_argument("output", help="Outputfile")
eval_parser.add_argument("--gold", default=None, help="Gold corpus (discbracket). If provided, eval with discodop")
eval_parser.add_argument("--eval-batchsize", type=int, default=250, help="Batch size for eval")
# test options
eval_parser.add_argument("--gpu", type=int, default=None, help="Use GPU <int> if available")
eval_parser.add_argument("-v", type=int, default=1, choices=[0,1], help="Verbosity level")
eval_parser.add_argument("-t", type=int, default=1, help="Number of threads for torch cpu")
args = parser.parse_args()
if hasattr(args, 'ft') and args.ft is not None:
args.w = 300
for k, v in vars(args).items():
print(k, v)
torch.set_num_threads(args.t)
logger = logging.getLogger()
logger.info("Mode={}".format(args.mode))
if args.mode == "train":
os.makedirs(args.model, exist_ok = True)
if args.gpu is not None:
os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu)
use_cuda = torch.cuda.is_available()
if use_cuda and args.gpu is not None:
logger.info("Using gpu {}".format(args.gpu))
device = torch.device("cuda".format(args.gpu))
else:
logger.info("Using cpu")
device = torch.device("cpu")
SEED = 0
if args.mode == "train":
SEED = args.s
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
main(args, logger, device)