Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

https://gitlab.com/mcoavoux/mtgpy-release-findings-2021.git
04 December 2021, 19:36:23 UTC
  • Code
  • Branches (1)
  • Releases (1)
  • Visits
    • Branches
    • Releases
    • HEAD
    • refs/heads/master
    • c9972219cd75049269d26632d2bb79619d661298
    • v1.0
  • a387b78
  • /
  • src
  • /
  • mtg.py
Raw File Download
Take a new snapshot of a software origin

If the archived software origin currently browsed is not synchronized with its upstream version (for instance when new commits have been issued), you can explicitly request Software Heritage to take a new snapshot of it.

Use the form below to proceed. Once a request has been submitted and accepted, it will be processed as soon as possible. You can then check its processing state by visiting this dedicated page.
swh spinner

Processing "take a new snapshot" request ...

Permalinks

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • content
  • directory
  • revision
  • snapshot
origin badgecontent badge Iframe embedding
swh:1:cnt:2cfd33b6f538ba2f841fe9b0033203f4e10c36f0
origin badgedirectory badge Iframe embedding
swh:1:dir:18c8d5ffba1b141866b03a2717f92cafe55d5c89
origin badgerevision badge
swh:1:rev:c9972219cd75049269d26632d2bb79619d661298
origin badgesnapshot badge
swh:1:snp:c3b19ab77fec904d36694903d5dade0c8b1c98fc
Citations

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • content
  • directory
  • revision
  • snapshot
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Tip revision: c9972219cd75049269d26632d2bb79619d661298 authored by mcoavoux on 20 May 2021, 13:04:44 UTC
up readme
Tip revision: c997221
mtg.py
import sys
import os
from collections import defaultdict
import copy
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pickle
import random

from Asgd import MyAsgd, AAdam
from word_encoders import WordEmbedder, Words2Tensors
from sentence_encoders import TransformerNetwork, SentenceEncoderLSTM
from nk_transformer import NKTransformer
from nk_transformer import LayerNormalization

import fasttext_embeddings

import discodop_eval
from state_gap import State
import corpus_reader
import tree as T

from features import feature_functions



class Parser(nn.Module):
    def __init__(self, args,
                 i2labels, i2tags, i2chars, i2words):
        super(Parser, self).__init__()

        self.discontinuous = not args.projective

        self.args = args
        self.i2labels = i2labels
        self.i2tags =  i2tags
        self.i2chars = i2chars
        self.i2words = i2words

        self.labels2i = {l : i for i,l in enumerate(i2labels)}
        self.tags2i = {l : i for i,l in enumerate(i2tags)}
        self.chars2i = {l : i for i,l in enumerate(i2chars)}
        self.words2i = {l : i for i,l in enumerate(i2words)}

        self.words2tensors = Words2Tensors(self.chars2i, self.words2i, pchar=args.dcu, pword=args.dwu)

        self.num_labels = len(i2labels)
        self.num_tags = len(i2tags)
        self.num_words = len(i2words)
        self.num_chars = len(i2chars)

        self.word_embedder = WordEmbedder(args, len(self.i2words), len(self.i2chars), self.words2tensors, self.words2i)

        d_input = self.word_embedder.output_dim

        if args.enc == "lstm":
            self.encoder = SentenceEncoderLSTM(args, d_input)
        elif args.enc == "transformer":
            self.encoder = TransformerNetwork(args, d_input)
        elif args.enc == "nktransformer":
            self.encoder = NKTransformer(args, d_input)
    
        dim_encoder = args.lstm_dim
        if args.enc == "no":
            dim_encoder = self.word_embedder.output_dim
        elif args.enc == "transformer" or args.enc == "nktransformer":
            dim_encoder = args.trans_dmodel

        self.feature_function, num_features = feature_functions[args.feats]

        dict_fun = {"tanh": nn.Tanh, "relu": nn.ReLU}
        activation_function = dict_fun[args.fun]


        # NK sequence
#        self.f_label = nn.Sequential(
#            nn.Linear(hparams.d_model, hparams.d_label_hidden),
#            LayerNormalization(hparams.d_label_hidden),
#            nn.ReLU(),
#            nn.Linear(hparams.d_label_hidden, label_vocab.size - 1),
#            )

        self.structure = nn.Sequential(
                    nn.Dropout(args.diff),
                    nn.Linear(dim_encoder*num_features, args.H),
                    LayerNormalization(args.H),
                    activation_function(),
                    nn.Linear(args.H, 3),
                    nn.LogSoftmax(dim=1))

        self.label = nn.Sequential(
                    nn.Dropout(args.diff),
                    nn.Linear(dim_encoder*num_features, args.H),
                    LayerNormalization(args.H),
                    activation_function(),
                    nn.Linear(args.H, self.num_labels),
                    nn.LogSoftmax(dim=1))

        self.tagger = nn.Sequential(
                    nn.Dropout(args.diff),
                    nn.Linear(dim_encoder, args.H),
                    LayerNormalization(args.H),
                    activation_function(),
                    nn.Linear(args.H, self.num_tags),
                    nn.LogSoftmax(dim=1))

#        # Structural actions
#        self.structure = nn.Sequential(
#                            nn.Dropout(args.diff),
#                            nn.Linear(dim_encoder*num_features, args.H),
#                            nn.LayerNorm(args.H),
#                            activation_function(),
##                            nn.Linear(args.H, args.H),
##                            nn.LayerNorm(args.H)
##                            activation_function(),
#                            nn.Linear(args.H, 3),
#                            nn.LogSoftmax(dim=1))
#        # Labelling actions
#        self.label = nn.Sequential(
#                        nn.Dropout(args.diff),
#                        nn.Linear(dim_encoder*num_features, args.H),
#                        nn.LayerNorm(args.H),
#                        activation_function(),
##                        nn.Linear(args.H, args.H),
##                        nn.LayerNorm(args.H)
##                        activation_function(),
#                        nn.Linear(args.H, self.num_labels),
#                        nn.LogSoftmax(dim=1))
#        # POS tags
#        self.tagger = nn.Sequential(
#                        nn.Dropout(args.dtag),
#                        nn.Linear(dim_encoder, args.H),
#                        nn.LayerNorm(args.H),
#                        activation_function(),
#                        nn.Linear(args.H, self.num_tags),
#                        nn.LogSoftmax(dim=1))

        self.loss_function = nn.NLLLoss(reduction="sum")
        self.default_values = nn.Parameter(torch.Tensor(4, dim_encoder))

        self.initialize_parameters(args)

    def initialize_parameters(self, args):
        # Xavier initialization for every layer
        # uniform initialization for embeddings

        torch.nn.init.uniform_(self.default_values, -0.01, 0.01)
#        torch.nn.init.xavier_normal_(self.label[1].weight)
#        torch.nn.init.xavier_normal_(self.label[4].weight)

#        torch.nn.init.xavier_normal_(self.structure[1].weight)
#        torch.nn.init.xavier_normal_(self.structure[4].weight)

#        torch.nn.init.xavier_normal_(self.tagger[1].weight)
#        torch.nn.init.xavier_normal_(self.tagger[4].weight)


    def train_batch(self, batch_sentences, all_embeddings, targets):
        batch_features, batch_tags = targets
        output_l1, output_l2 = all_embeddings

        batch_label_input = []
        batch_label_output = []
        batch_struct_input = []
        batch_struct_output = []

        for sentence_embeddings, features in zip(output_l2, batch_features):
            #batch_features: {"struct": (struct_input, struct_output), "label": (labels_input, labels_output)}
            l_input, l_output = features["label"]
            s_input, s_output = features["struct"]
            
            batch_label_input.append(sentence_embeddings[l_input].view(len(l_input), -1))
            batch_label_output.append(l_output)
            
            batch_struct_input.append(sentence_embeddings[s_input].view(len(s_input), -1))
            batch_struct_output.append(s_output)

        batch_label_input = torch.cat(batch_label_input, dim=0)

        batch_label_output = torch.cat(batch_label_output)

        batch_struct_input = torch.cat(batch_struct_input, dim=0)
        batch_struct_output = torch.cat(batch_struct_output)

        labels_output_tensors = self.label(batch_label_input)
        label_loss = self.loss_function(labels_output_tensors, batch_label_output)
        label_loss /= len(batch_label_input)

        struct_output_tensors = self.structure(batch_struct_input)
        struct_loss = self.loss_function(struct_output_tensors, batch_struct_output)
        struct_loss /= len(batch_struct_input)

        batch_tagging_input = torch.cat(output_l1, dim=0)
        batch_tagging_targets = torch.cat(batch_tags)

        tagging_output = self.tagger(batch_tagging_input)
        tagging_loss = torch.sum(self.loss_function(tagging_output, batch_tagging_targets))
        tagging_loss /= len(batch_tagging_targets)


        return {"tagging loss": tagging_loss, "struct loss": struct_loss, "label loss": label_loss}


    def forward(self, batch_sentences_raw, targets=None):
        # raw: no SOS / EOS
        batch_sentences = [["<SOS>"] + sent + ["<EOS>"] for sent in batch_sentences_raw]

        lengths = torch.tensor([len(s) for s in batch_sentences])
        all_embeddings = self.word_embedder(batch_sentences)

        if self.args.enc != "no":
            padded_char_based_embeddings = torch.nn.utils.rnn.pad_sequence(all_embeddings, batch_first=True)
            unpacked_l1, unpacked_l2 = self.encoder(padded_char_based_embeddings, lengths)
        else:
            unpacked_l1 = all_embeddings
            unpacked_l2 = all_embeddings

        unpacked_l2 = [torch.cat([self.default_values, l2_sent], dim=0) for l2_sent in unpacked_l2]

        if targets is not None:
            return self.train_batch(batch_sentences, (unpacked_l1, unpacked_l2), targets)

        tokens = [[T.Token(tok, i, [None]) for i, tok in enumerate(sent)] for sent in batch_sentences_raw]

        return self.parse_batch(tokens, (unpacked_l1, unpacked_l2))
        #return self.parse(tokens, (unpacked_l1, unpacked_l2))

    def parse(self, sentences, sentence_embeddings):
        # sentences: list of list of T.Tokens
        output_l1, output_l2 = sentence_embeddings

        lengths = [len(sent) for sent in sentences]
        batch_tagging_input = torch.cat(output_l1, dim=0)
        tag_scores = self.tagger(batch_tagging_input)
        tag_predictions = torch.argmax(tag_scores, dim=1).cpu().split(lengths)
        tag_predictions = [v.numpy() for v in tag_predictions]

        for sent_tags, sent in zip(tag_predictions, sentences):
            for tag, tok in zip(sent_tags, sent):
                tok.set_tag(self.i2tags[tag])

        trees = []
        # make this more parallel / batch
        for sentence, sentence_embedding in zip(sentences, output_l2):
            state = State(sentence, self.discontinuous)
            while not state.is_final():
                next_type = state.next_action_type()
                
                stack, queue, buffer, sent_len = state.get_input()
                input_features = self.feature_function(sentence_embedding.device, stack, queue, buffer, sent_len)
                input_embeddings = sentence_embedding[input_features].view(1, -1)
                if next_type == State.LABEL:
                    output = self.label(input_embeddings).squeeze()
                    state.filter_action(next_type, output)

                    prediction = torch.argmax(output)
                    if prediction == 0:
                        state.nolabel()
                    else:
                        pred_label = self.i2labels[prediction]
                        state.labelX(pred_label)
                else:
                    output = self.structure(input_embeddings).squeeze()
                    state.filter_action(next_type, output)

                    prediction = torch.argmax(output.squeeze())
                    if prediction == State.SHIFT:
                        state.shift()
                    elif prediction == State.COMBINE:
                        state.combine()
                    else:
                        assert(prediction == State.GAP)
                        state.gap()
            trees.append(state.get_tree())
        return trees

    def parse_batch(self, sentences, sentence_embeddings):
        # sentences: list of list of T.Tokens
        output_l1, output_l2 = sentence_embeddings

        lengths = [len(sent) for sent in sentences]
        batch_tagging_input = torch.cat(output_l1, dim=0)
        tag_scores = self.tagger(batch_tagging_input)
        tag_predictions = torch.argmax(tag_scores, dim=1).cpu().split(lengths)
        tag_predictions = [v.numpy() for v in tag_predictions]

        for sent_tags, sent in zip(tag_predictions, sentences):
            for tag, tok in zip(sent_tags, sent):
                tok.set_tag(self.i2tags[tag])

        
        trees = [None for _ in range(len(sentences))]
        states = [State(sentence, self.discontinuous, sent_id=i) for i, sentence in enumerate(sentences)]

        while len(states) > 0:
            next_types = [state.next_action_type() for state in states]
            input_structs = [state.get_input() for state in states]
            input_features = [self.feature_function(output_l2[0].device, stack, queue, buffer, sent_len)
                                for stack, queue, buffer, sent_len in input_structs]

            input_embeddings = [output_l2[state.sent_id][in_feats].view(-1) for state, in_feats in zip(states, input_features)]

            label_states = [states[i] for i, ntype in enumerate(next_types) if ntype == State.LABEL]
            label_inputs = [input_embeddings[i] for i, ntype in enumerate(next_types) if ntype == State.LABEL]

            if len(label_states) > 0:
                label_outputs = self.label(torch.stack(label_inputs))
                for i, state in enumerate(label_states):
                    state.filter_action(State.LABEL, label_outputs[i])
                    prediction = torch.argmax(label_outputs[i])
                    if prediction == 0: 
                        state.nolabel()
                    else:
                        pred_label = self.i2labels[prediction]
                        state.labelX(pred_label)

            struct_states = [states[i] for i, ntype in enumerate(next_types) if ntype == State.STRUCT]
            struct_inputs = [input_embeddings[i] for i, ntype in enumerate(next_types) if ntype == State.STRUCT]

            #print(torch.stack(struct_inputs).shape)
            if len(struct_states) > 0:
                struct_outputs = self.structure(torch.stack(struct_inputs))
                for i, state in enumerate(struct_states):
                    state.filter_action(State.STRUCT, struct_outputs[i])
                    prediction = torch.argmax(struct_outputs[i])
                    if prediction == State.SHIFT:
                        state.shift()
                    elif prediction == State.COMBINE:
                        state.combine()
                    else:
                        assert(prediction == State.GAP)
                        state.gap()

            states = []
            for state in label_states + struct_states:
                if state.is_final():
                    trees[state.sent_id] = state.get_tree()
                else:
                    states.append(state)
        return trees

    def predict_corpus(self, corpus, batch_size):
        # corpus: list of list of tokens
        trees = []

        # sort by length for lstm batching
        indices, corpus = zip(*sorted(zip(range(len(corpus)), 
                                                corpus), 
                                                key = lambda x: len(x[1]),
                                                reverse=True))

        with torch.no_grad():
            for i in range(0, len(corpus), batch_size):
                tree_batch = self.forward(corpus[i:i+batch_size])

                for t in tree_batch:
                    t.expand_unaries()
                    trees.append(t)

        # reorder
        _, trees = zip(*sorted(zip(indices, trees), key = lambda x:x[0]))
        return trees

def assign_pretrained_embeddings(model, ft, dim_ft, freeze):
    print("Assigning pretrained fasttext embeddings")
    embedding = model.word_embedder.word_embeddings
    device = embedding.weight.data.device
    for w, v in ft.items():
        i = model.words2i[w]
        embedding.weight.data[i, :dim_ft] = torch.tensor(v).to(device)

    if freeze:
        print("Freezing word embedding layer")
        embedding.requires_grad_(False)
    

def get_vocabulary(corpus):
    """Extract vocabulary for characters, tokens, non-terminals, POS tags"""
    words = defaultdict(int)

    chars = defaultdict(int)
#    chars["<START>"] += 2
#    chars["<STOP>"] += 2

    # These should be treated as single characters
    chars["-LRB-"] += 2
    chars["-RRB-"] += 2
    chars["#LRB#"] += 2
    chars["#RRB#"] += 2

    tag_set = defaultdict(int)

    label_set = defaultdict(int)
    for tree in corpus:
        tokens = T.get_yield(tree)
        for tok in tokens:
            for char in tok.token:
                chars[char] += 1

            tag_set[tok.get_tag()] += 1

            words[tok.token] += 1

        constituents = T.get_constituents(tree)
        for label, _ in constituents:
            label_set[label] += 1

    return words, chars, label_set, tag_set

def extract_features(device, labels2i, sentence, feature_function):

    state = State(sentence, True)

    struct_input = []
    struct_output = []

    labels_input = []
    labels_output = []

    while not state.is_final():
        next_type = state.next_action_type()
        gold_action, (stack, queue, buffer, sent_len) = state.oracle()
        input_features = feature_function(device, stack, queue, buffer, sent_len)

        if next_type == State.STRUCT:
            struct_input.append(input_features)
            struct_output.append(State.mapping[gold_action[0]])

        else:
            assert(next_type == State.LABEL)
            labels_input.append(input_features)
            labels_output.append(labels2i[gold_action[1]])


    struct_input = torch.stack(struct_input)
    struct_output = torch.tensor(struct_output, dtype=torch.long, device=device)

    labels_input = torch.stack(labels_input)
    labels_output = torch.tensor(labels_output, dtype=torch.long, device=device)

    #print(labels_output)

    return {"struct": (struct_input, struct_output), "label": (labels_input, labels_output)}


def extract_tags(device, tags2i, sentence):
    # Returns a tensor for tag ids for a single sentence
    idxes = [tags2i[tok.get_tag()] for tok in sentence]
    return torch.tensor(idxes, dtype=torch.long, device=device)

def compute_f(TPs, total_golds, total_preds):
    p, r, f = 0, 0, 0
    if total_preds > 0:
        p = TPs / total_preds
    if total_golds > 0:
        r = TPs / total_golds
    if (p, r) != (0, 0):
        f = 2*p*r / (p+r)
    return p, r, f

def Fscore_corpus(golds, preds):
    TPs = 0
    total_preds = 0
    total_golds = 0
    UTPs = 0
    for gold, pred in zip(golds, preds):
        TPs += len([c for c in gold if c in pred])
        total_golds += len(gold)
        total_preds += len(pred)

        ugold = defaultdict(int)
        for _, span in gold:
            ugold[span] += 1
        upred = defaultdict(int)
        for _, span in pred:
            upred[span] += 1
        for span in upred:
            UTPs += min(upred[span], ugold[span])

    p, r, f = compute_f(TPs, total_golds, total_preds)

    up, ur, uf = compute_f(UTPs, total_golds, total_preds)
    return p*100, r*100, f*100, up*100, ur*100, uf*100


def prepare_corpus(corpus):
    sentences = [T.get_yield(corpus[i]) for i in range(len(corpus))]
    raw_sentences = [[tok.token for tok in sentence] for sentence in sentences]
    return sentences, raw_sentences


def eval_tagging(gold, pred):
    # Returns accuracy for tag predictions
    acc = 0
    tot = 0
    assert(len(gold) == len(pred))
    for sent_g, sent_p in zip(gold, pred):
        assert(len(sent_g) == len(sent_p))
        for tok_g, tok_p in zip(sent_g, sent_p):
            if tok_g.get_tag() == tok_p.get_tag():
                acc += 1
        tot += len(sent_g)
    return acc * 100 / tot

def train_epoch(device, epoch, model, optimizer, args, logger, train_raw_sentences, features, tag_features, idxs, scheduling_dict, batch_size, check_callback):

    tag_loss = 0
    struct_loss = 0
    label_loss = 0

    model.train()
    grad_norm = 0

    random.shuffle(idxs)

    n_batches = len(idxs)//batch_size
    check_id = n_batches // args.check
    num_check = 0
    for j, i in enumerate(range(0, len(idxs), batch_size)):

        batch_sentences = [train_raw_sentences[idx] for idx in idxs[i:i+batch_size]]
        batch_features = [features[idx] for idx in idxs[i:i+batch_size]]
        batch_tags = [tag_features[idx] for idx in idxs[i:i+batch_size]]

#        print(batch_features[0])
#        print(batch_tags[0])
        batch_features = [ {k: (v[0].to(device), v[1].to(device)) for k,v in feats.items()} for feats in batch_features]
        batch_tags = [feats.to(device) for feats in batch_tags]

        batch_sentences, batch_features, batch_tags = zip(*sorted(zip(batch_sentences, batch_features, batch_tags), 
                                                           key = lambda x: len(x[0]), reverse=True))

        optimizer.zero_grad()
        losses = model(batch_sentences, targets=(batch_features, batch_tags))

        batch_total_loss = losses["tagging loss"] + losses["struct loss"] + losses["label loss"]
        batch_total_loss.backward()

        tag_loss += losses["tagging loss"].item()
        struct_loss += losses["struct loss"].item()
        label_loss += losses["label loss"].item()

        if args.G is not None:
            grad_norm += torch.nn.utils.clip_grad_norm_(model.parameters(), args.G)
        optimizer.step()

        if j % 20 == 0 and args.v > 0:
            logger.info(f"Epoch {epoch} Training batch {j}/{n_batches} sent {i} to {i+batch_size} lr:{list(optimizer.param_groups)[0]['lr']:.4e}")

        scheduling_dict["total_steps"] += 1
        if scheduling_dict["total_steps"] <= scheduling_dict["warmup_steps"]:
            set_lr(optimizer, scheduling_dict["total_steps"] * scheduling_dict["warmup_coeff"])


        if (j + 1) % check_id == 0 and num_check < args.check - 1 and epoch >= 30:
            num_check += 1
            check_callback(epoch, num_check, False)
            

    grad_norm /= n_batches

    return {"tag loss": tag_loss, "struct loss": struct_loss, "label loss": label_loss, "grad norm": grad_norm}


def set_lr(optimizer, new_lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lr


def eval_on_corpus(args, model, sentences, raw_sentences, corpus):
    pred_trees = model.predict_corpus(raw_sentences, args.eval_batchsize)

    sentence_pred_tokens = [T.get_yield(tree) for tree in pred_trees]

    tageval = eval_tagging(sentences, sentence_pred_tokens)

    outfile = f"{args.model}/tmp_{corpus}.discbracket"
    with open(outfile, "w", encoding="utf8") as fstream:
        for tree in pred_trees:
            fstream.write("{}\n".format(str(tree)))

    if corpus == "dev":
        goldfile = args.dev.replace(".ctbk", ".discbracket")
    else:
        goldfile = f"{args.model}/tmp_sample_train_gold.discbracket"

    discop, discor, discodop_f = discodop_eval.call_eval(goldfile, outfile)
    disco2p, disco2r, discodop2_f = discodop_eval.call_eval(goldfile, outfile, disconly=True)

    return {"p": discop, "r": discor, "f": discodop_f,
            "dp": disco2p, "dr": disco2r, "df": discodop2_f, 
            "tag": round(tageval, 2)}


def load_corpora(args, logger):
    logger.info("Loading corpora...")
    if args.fmt == "ctbk":
        train_corpus = corpus_reader.read_ctbk_corpus(args.train)
        dev_corpus = corpus_reader.read_ctbk_corpus(args.dev)
    elif args.fmt == "discbracket":
        train_corpus = corpus_reader.read_discbracket_corpus(args.train)
        dev_corpus = corpus_reader.read_discbracket_corpus(args.dev)
    elif args.fmt == "bracket":
        train_corpus = corpus_reader.read_bracket_corpus(args.train)
        dev_corpus = corpus_reader.read_bracket_corpus(args.dev)

    parsed_corpus = []
    if args.parsed is not None:
        parsed_corpus = corpus_reader.read_discbracket_corpus(args.parsed)
        random.shuffle(parsed_corpus)

    for tree in train_corpus + parsed_corpus:
        tree.merge_unaries()
    return train_corpus, dev_corpus, parsed_corpus

def get_voc_dicts(words, vocabulary, label_set, tag_set):
    i2chars = ["<PAD>", "<UNK>", "<START>", "<STOP>", "<SOS>", "<EOS>"] + sorted(vocabulary, key = lambda x: vocabulary[x], reverse=True)
    #chars2i = {k:i for i, k in enumerate(i2chars)}

    i2labels = ["nolabel"] + sorted(label_set)
    labels2i = {k: i for i, k in enumerate(i2labels)}

    i2tags = sorted(tag_set)
    tags2i = {k:i for i, k in enumerate(i2tags)}

    i2words = ["<PAD>", "<UNK>", "<SOS>", "<EOS>"] + sorted(words, key=lambda x: words[x], reverse=True)
    #words2i = {k:i for i, k in enumerate(i2words)}
    return i2chars, i2labels, labels2i, i2tags, tags2i, i2words

def check_dev(epoch, check_id, optimizer, model, args, dev_sentences, dev_raw_sentences, sample_train, sample_train_raw, scheduling_dict, average):
    optimizer.zero_grad()
    model.eval()

    if args.O in ["asgd", "aadam"] and average:
        optimizer.average()

    dev_results = eval_on_corpus(args, model, dev_sentences, dev_raw_sentences, "dev")
    train_results = eval_on_corpus(args, model, sample_train, sample_train_raw, "train")
    discodop_f = dev_results["f"]

    print(f"Epoch {epoch}.{check_id} lr:{optimizer.param_groups[0]['lr']:.4e}",
         "dev:",   " ".join([f"{k}={v}" for k, v in sorted(dev_results.items())]),
         "train:", " ".join([f"{k}={v}" for k, v in sorted(train_results.items())]),
#              f" dev tag:{dtag:.2f} f:{discodop_f} p:{discop} r:{discor} disco f:{discodop2_f} p:{disco2p} r:{disco2r}", 
#              f" train tag:{dtag:.2f} f:{discodop_f} p:{discop} r:{discor} disco f:{discodop2_f} p:{disco2p} r:{disco2r}", 
          f"best: {scheduling_dict['best_dev_f']} ep {scheduling_dict['best_epoch']} avg={int(average)}",
          flush=True)

    if discodop_f > scheduling_dict['best_dev_f']:
        scheduling_dict["best_epoch"] = epoch
        scheduling_dict['best_dev_f'] = discodop_f
        model.cpu()
        model.words2tensors.to(torch.device("cpu"))
        torch.save(model.state_dict(), "{}/model".format(args.model))
        with open(f"{args.model}/best_dev_score", "w", encoding="utf8") as ff:
            ff.write(str(scheduling_dict['best_dev_f']))
        model.to(device)
        model.words2tensors.to(device)

    if args.O in ["asgd", "aadam"] and average:
        optimizer.cancel_average()

    model.train()


def do_training(args, device, model, optimizer, train_corpus, dev_corpus, labels2i, tags2i, num_epochs, batch_size):

    logger.info("Constructing training examples...")
    train_sentences, train_raw_sentences = prepare_corpus(train_corpus)

    dev_sentences, dev_raw_sentences = prepare_corpus(dev_corpus)

    # just 500 training sentences to monitor learning
    sample_indexes = sorted(np.random.choice(len(train_sentences), min(500, len(train_sentences)), replace=False))
    sample_train = [train_sentences[i] for i in sample_indexes]
    sample_train_raw = [train_raw_sentences[i] for i in sample_indexes]

    sample_trainfile = f"{args.model}/tmp_sample_train_gold.discbracket"
    with open(sample_trainfile, "w", encoding="utf8") as fstream:
        for i in sample_indexes:
            tree = train_corpus[i]
            tree.expand_unaries()
            fstream.write("{}\n".format(str(tree)))
            tree.merge_unaries()


    feature_function, _ = feature_functions[args.feats]
    features = [extract_features(torch.device("cpu"), labels2i, sentence, feature_function) for sentence in train_sentences]
    tag_features = [extract_tags(torch.device("cpu"), tags2i, sentence) for sentence in train_sentences]

    idxs = list(range(len(train_sentences)))
    num_tokens = sum([len(sentence) for sentence in train_sentences])
    random.shuffle(idxs)

    logger.info("Starting training")

    warmup_steps = int(len(train_corpus) / args.B)
    print(f"warmup steps: {warmup_steps}")
    scheduling_dict = {"warmup_coeff": args.l / warmup_steps,
                       "total_steps": 0,
                       "patience": args.patience,
                       "warmup_steps": warmup_steps,
                       "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max',
                                                                               factor=0.5,
                                                                               patience=args.patience,
                                                                               verbose=True),
                       "best_epoch": 0,
                       "best_dev_f": 0}

    set_lr(optimizer, scheduling_dict["warmup_coeff"])


    check_callback = lambda epoch, check_id, avg: check_dev(epoch, check_id, optimizer, model, args, 
                                                    dev_sentences, dev_raw_sentences, 
                                                    sample_train, sample_train_raw, scheduling_dict, avg)

    # TODO: change criterion to stop training
    for epoch in range(1, num_epochs+1):

        epoch_dict = train_epoch(device, epoch, model, optimizer, args, logger, train_raw_sentences, features, tag_features, idxs, scheduling_dict, batch_size, check_callback)

        tag_loss = epoch_dict["tag loss"]
        struct_loss = epoch_dict["struct loss"]
        label_loss = epoch_dict["label loss"]
        grad_norm = epoch_dict["grad norm"]


        check_callback(epoch, args.check, False)
        if args.O in ["asgd", "aadam"] and epoch >= args.start_averaging:
            check_callback(epoch, args.check, True)

        print(f"Epoch {epoch} tag:{tag_loss:.4f} str:{struct_loss:.4f} lab:{label_loss:.4f} gn:{grad_norm:.4f}")

        if scheduling_dict["total_steps"] > scheduling_dict["warmup_steps"]:
            scheduling_dict["scheduler"].step(scheduling_dict["best_dev_f"])
            if epoch - scheduling_dict["best_epoch"] > (scheduling_dict["patience"] + 1) * 3:
                print("Finishing training: no improvement on dev.")
                break

    if args.O in ["asgd", "aadam"]:
        optimizer.average()


def main_train(args, logger, device):
    train_corpus, dev_corpus, parsed_corpus = load_corpora(args, logger)

    if args.ft is not None:
        logger.info("Loading fast text vectors")
        ft, dim_ft = fasttext_embeddings.load_vectors(args.ft)
        logger.info("Loading fast text vectors: done")
    else:
        ft = None
        dim_ft = None

    if args.load_pretrained is None:
        logger.info("Vocabulary extraction...")
        words, vocabulary, label_set, tag_set = get_vocabulary(train_corpus + parsed_corpus)

        i2chars, i2labels, labels2i, i2tags, tags2i, i2words = get_voc_dicts(words, vocabulary, label_set, tag_set)

        if ft is not None:
            voc_set = set(i2words)
            for token in ft:
                if token not in voc_set:
                    i2words.append(token)

        logger.info("Model initialization...")
        model = Parser(args, i2labels, i2tags, i2chars, i2words)
        
        if ft is not None:
            assign_pretrained_embeddings(model, ft, dim_ft, freeze=args.freeze_ft)
        
    else:
        i2chars, i2labels, i2tags, i2words, args_train = load_data(f"{args.load_pretrained}")
        labels2i = {label: i for i, label in enumerate(i2labels)}
        tags2i = {tag: i for i, tag in enumerate(i2tags)}

        model = Parser(args_train, i2labels, i2tags, i2chars, i2words)
        
        state_dict = torch.load(f"{args.load_pretrained}/model")
        model.load_state_dict(state_dict)

    with open(f"{args.model}/data", "wb") as fp:
        pickle.dump([i2chars, i2labels, i2tags, i2words, args], fp)

    model.to(device)
    model.words2tensors.to(device)


    random.shuffle(train_corpus)

    if args.S is not None:
        train_corpus = train_corpus[:args.S]
        dev_corpus = dev_corpus[:args.S]
        parsed_corpus = parsed_corpus[:args.S]

    print("Training sentences: {}".format(len(train_corpus)))
    print("Additional non-gold training sentences: {}".format(len(parsed_corpus)))
    print("Dev set sentences: {}".format(len(dev_corpus)))


#    num_parameters = 0
#    num_parameters_embeddings = 0
#    for name, p in model.named_parameters():
#        print("parameter '{}', {}".format(name, p.numel()))
#        if "embedding" in name:
#            num_parameters_embeddings += p.numel()
#        else:
#            num_parameters += p.numel()
#    print("Total number of parameters: {}".format(num_parameters + num_parameters_embeddings))
#    print("Embedding parameters: {}".format(num_parameters_embeddings))
#    print("Non embedding parameters: {}".format(num_parameters))

    parameters = [pp for pp in model.parameters() if pp.requires_grad]

    if len(parsed_corpus) > 0:
        optimizer = optim.Adam(parameters, lr=args.parsed_lr, betas=(0.9, 0.98), eps=1e-9)
        do_training(args, device, model, optimizer, parsed_corpus, dev_corpus, labels2i, tags2i, args.parsed_epochs, args.parsed_batchsize)
        return

    if args.O == "asgd":
        optimizer = MyAsgd(parameters, lr=args.l, 
                           momentum=args.m, weight_decay=0,
                           dc=args.d)
    elif args.O == "adam":
        optimizer = optim.Adam(parameters, lr=args.l, betas=(0.9, 0.98), eps=1e-9)
    elif args.O == "aadam":
        start = (len(train_corpus) // args.B) * args.start_averaging # start averaging after 20 epochs
        optimizer = AAdam(parameters, start=start,lr=args.l, betas=(0.9, 0.98), eps=1e-9)
    else:
        optimizer = optim.SGD(parameters, lr=args.l, momentum=args.m, weight_decay=0)

    do_training(args, device, model, optimizer, train_corpus, dev_corpus, labels2i, tags2i, args.i, args.B)

def read_raw_corpus(filename):
    sentences = []
    with open(filename, encoding="utf8") as f:
        for line in f:
            # For negra
            line = line.replace("(", "#LRB#").replace(")", "#RRB#")
            line = line.strip().split()
            if len(line) > 0:
                sentences.append(line)
    return sentences

def load_data(modeldir):
    with open(f"{modeldir}/data", "rb") as fp:
        return pickle.load(fp)

def main_eval(args, logger, device):

    i2chars, i2labels, i2tags, i2words, args_train = load_data(f"{args.model}")

    model = Parser(args_train, i2labels, i2tags, i2chars, i2words)
    
    state_dict = torch.load("{}/model".format(args.model))
    model.load_state_dict(state_dict)
    model.to(device)
    model.words2tensors.to(device)
    model.eval()
    logger.info("Model loaded")
    

    sentences = read_raw_corpus(args.corpus)
    #sentence_toks = [[T.Token(token, i, [None]) for i, token in enumerate(sentence)] for sentence in sentences]

    trees = model.predict_corpus(sentences, args.eval_batchsize)

    if args.output is None:
        for tree in trees:
            print(tree)
    else:
        with open(args.output, "w", encoding="utf8") as f:
            for tree in trees:
                f.write("{}\n".format(str(tree)))

        if args.gold is not None:
            p, r, f = discodop_eval.call_eval(args.gold, args.output, disconly=False)
            dp, dr, df = discodop_eval.call_eval(args.gold, args.output, disconly=True)

            print(f"precision={p}")
            print(f"recall={r}")
            print(f"fscore={f}")
            print(f"disc-precision={dp}")
            print(f"disc-recall={dr}")
            print(f"disc-fscore={df}")

            pred_tokens = [T.get_yield(tree) for tree in trees]
            gold_tokens = [T.get_yield(tree) for tree in corpus_reader.read_discbracket_corpus(args.gold)]

            tag_eval = eval_tagging(gold_tokens, pred_tokens)

            print(f"tagging={tag_eval:.2f}")


def main(args, logger, device):
    """
        Discoparset: transition based discontinuous constituency parser

    warning: CPU trainer is non-deterministic (due to multithreading approximation)
    """
    if args.mode == "train":
        try:
            main_train(args, logger, device)
        except KeyboardInterrupt:
            print("Training interrupted, exiting")
            return

    else:
        main_eval(args, logger, device)

if __name__ == "__main__":
#    import cProfile
#    # if check avoids hackery when not profiling
#    # Optional; hackery *seems* to work fine even when not profiling, it's just wasteful
#    if sys.modules['__main__'].__file__ == cProfile.__file__:
#        import main  # Imports you again (does *not* use cache or execute as __main__)
#        globals().update(vars(main))  # Replaces current contents with newly imported stuff
#        sys.modules['__main__'] = main  # Ensu


    sys.setrecursionlimit(1500)

    logging.basicConfig(stream=sys.stderr, level=logging.DEBUG, 
                format='%(asctime)s-%(relativeCreated)d-%(levelname)s:%(message)s')
    import argparse

    usage = main.__doc__

    parser = argparse.ArgumentParser(description = usage, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    subparsers = parser.add_subparsers(dest="mode", description="Execution modes", help='train: training, eval: test')
    subparsers.required = True

    train_parser = subparsers.add_parser("train", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    eval_parser = subparsers.add_parser("eval", formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # train corpora
    train_parser.add_argument("model", help="Directory (to be created) for model exportation")
    train_parser.add_argument("train", help="Training corpus")
    train_parser.add_argument("dev",   help="Dev corpus")
    train_parser.add_argument("--fmt", help="Format for train and dev corpus", choices=["ctbk", "discbracket", "bracket"], default="discbracket")

    # general options
    train_parser.add_argument("--gpu", type=int, default=None, help="Use GPU if available")
    train_parser.add_argument("-t", type=int, default=1, help="Number of threads for torch cpu")
    train_parser.add_argument("-S", type=int, default=None, help="Use only X first training examples")
    train_parser.add_argument("-v", type=int, default=1, choices=[0,1], help="Verbosity level")

    train_parser.add_argument("--projective", action="store_true", help="Projective mode, with --fmt bracket.")

    train_parser.add_argument("--enc", choices=["lstm", "transformer", "nktransformer", "no"], default="lstm", help="Type of sentence encoder")

    # training options
    train_parser.add_argument("-i", default=1000, type=int, help="Number of epochs")
    train_parser.add_argument("-l", default=0.0008, type=float, help="Learning rate")
    train_parser.add_argument("-m", default=0, type=float, help="Momentum (for sgd and asgd)")
    train_parser.add_argument("-d", default=1e-7, type=float, help="Decay constant for lr (asgd)")

    train_parser.add_argument("-I", type=float, default=0.01, help="Embedding init uniform on [-I, I]")
    train_parser.add_argument("-G", default=None, type=float, help="Max norm for gradient clipping")

    train_parser.add_argument("-B", type=int, default=1, help="Size of batch")
    train_parser.add_argument("--eval-batchsize", type=int, default=30, help="Batch size for eval")
    train_parser.add_argument("-O", default="adam", choices=["adam", "aadam", "sgd", "asgd"], help="Optimizer")
    train_parser.add_argument("--start-averaging", type=int, default=40, help="Aadam starts averaging at epoch <int> ")
    train_parser.add_argument("-s", type=int, default=10, help="Random seed")

    train_parser.add_argument("--patience", type=int, default=5, help="Patience or lr scheduler")
    train_parser.add_argument("--check", type=int, default=4, help="Checks on dev set per epoch")


    # Dropout
    train_parser.add_argument("--dcu",  default=0.2, type=float, help="Dropout characters (replace by UNK)")
    #train_parser.add_argument("--dce",  default=0.2, type=float, help="Dropout for character embedding layer")
    train_parser.add_argument("--diff", default=0.2, type=float, help="Dropout for FF output layers")
    #train_parser.add_argument("--dtag", default=0.5, type=float, help="Dropout for tagger")
    train_parser.add_argument("--dwu",  default=0.3, type=float, help="Dropout words (replace by UNK)")


    # Dimensions / architecture
    train_parser.add_argument("-H", type=int, default=250, help="Dimension of hidden layer for FF nets")
    train_parser.add_argument("-C", type=int, default=100, help="Dimension of char bilstm")
    train_parser.add_argument("-c", type=int, default=64, help="Dimension of char embeddings")
    train_parser.add_argument("-w", type=int, default=32, help="Use word embeddings with dim=w")
    train_parser.add_argument("--ft", default=None, help="use pretrained fasttext embeddings (path), automatically set -w")
    train_parser.add_argument("--freeze-ft", action="store_true", help="don't fine tune fasttext embeddings")


    # activation
    train_parser.add_argument("--fun", type=str, default="tanh", choices=["relu", "tanh"],
                        help="Activation functions for FF networks")

    # feature functions
    train_parser.add_argument("--feats", default="tacl", choices=["all", "tacl", "tacl_base", "global"], help="Feature functions")

    # bert:
    train_parser.add_argument("--bert", action="store_true", help="Use Bert as lexical model.")
    train_parser.add_argument("--init-bert", action="store_true", help="Reinitialize bert parameters")
    train_parser.add_argument("--bert-id",
                              default="bert-base-cased", type=str,
                              help="Which BERT model to use?")

    train_parser.add_argument("--no-char", action="store_true",
                              help="Deactivate char based embeddings")
    train_parser.add_argument("--no-word", action="store_true",
                              help="Deactivate word embeddings")

    # Semisupervision with additional corpus
    train_parser.add_argument("--parsed", default=None, help="Path to corpus for pretraining")
    train_parser.add_argument("--parsed-batchsize", type=int, metavar=" ", default=300, help="Batch size for pretraining") # utility: having a different default value
    train_parser.add_argument("--parsed-epochs", type=int, metavar=" ", default=1, help="Number of epochs for pretraining")
    train_parser.add_argument("--parsed-lr", type=float, metavar=" ", default=0.0001, help="Learning rate for pretraining")
    train_parser.add_argument("--load-pretrained", metavar=" ", default=None, help="Path to pretrained model")


    # Options from transformer
    SentenceEncoderLSTM.add_cmd_options(train_parser)
    TransformerNetwork.add_cmd_options(train_parser)

    # test corpus
    eval_parser.add_argument("model", help="Pytorch model")
    eval_parser.add_argument("corpus", help="Test corpus, 1 tokenized sentence per line")
    eval_parser.add_argument("output", help="Outputfile")
    eval_parser.add_argument("--gold", default=None, help="Gold corpus (discbracket). If provided, eval with discodop")
    eval_parser.add_argument("--eval-batchsize", type=int, default=250, help="Batch size for eval")

    # test options
    eval_parser.add_argument("--gpu", type=int, default=None, help="Use GPU <int> if available")
    eval_parser.add_argument("-v", type=int, default=1, choices=[0,1], help="Verbosity level")
    eval_parser.add_argument("-t", type=int, default=1, help="Number of threads for torch cpu")



    args = parser.parse_args()
    if hasattr(args, 'ft') and args.ft is not None:
        args.w = 300

    for k, v in vars(args).items():
        print(k, v)
    
    torch.set_num_threads(args.t)

    logger = logging.getLogger()
    logger.info("Mode={}".format(args.mode))

    if args.mode == "train":
        os.makedirs(args.model, exist_ok = True)

    if args.gpu is not None:
        os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu)

    use_cuda = torch.cuda.is_available()
    if use_cuda and args.gpu is not None:
        logger.info("Using gpu {}".format(args.gpu))
        device = torch.device("cuda".format(args.gpu))
    else:
        logger.info("Using cpu")
        device = torch.device("cpu")
    
    SEED = 0
    if args.mode == "train":
        SEED = args.s
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    main(args, logger, device)

Software Heritage — Copyright (C) 2015–2025, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Contact— JavaScript license information— Web API

back to top