Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

https://gricad-gitlab.univ-grenoble-alpes.fr/coavouxm/flaubertagger.git
09 April 2024, 05:01:41 UTC
  • Code
  • Branches (1)
  • Releases (0)
  • Visits
    • Branches
    • Releases
    • HEAD
    • refs/heads/master
    • c939ca9fac094ac3c379256ef3d3d4d14a5a4bf1
    No releases to show
  • 388777f
  • /
  • src
  • /
  • tagger.py
Raw File Download
Take a new snapshot of a software origin

If the archived software origin currently browsed is not synchronized with its upstream version (for instance when new commits have been issued), you can explicitly request Software Heritage to take a new snapshot of it.

Use the form below to proceed. Once a request has been submitted and accepted, it will be processed as soon as possible. You can then check its processing state by visiting this dedicated page.
swh spinner

Processing "take a new snapshot" request ...

Permalinks

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • content
  • directory
  • revision
  • snapshot
origin badgecontent badge Iframe embedding
swh:1:cnt:8d751da2e60b252a7ddf9a8c64052806690048a7
origin badgedirectory badge Iframe embedding
swh:1:dir:c69bc28a718a9d9be4f56505734cbb90b23d2ead
origin badgerevision badge
swh:1:rev:c939ca9fac094ac3c379256ef3d3d4d14a5a4bf1
origin badgesnapshot badge
swh:1:snp:82db0a85833d76805a36c23f8377236abee4ebea
Citations

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • content
  • directory
  • revision
  • snapshot
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Tip revision: c939ca9fac094ac3c379256ef3d3d4d14a5a4bf1 authored by m on 23 February 2024, 16:44:50 UTC
up
Tip revision: c939ca9
tagger.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import logging
import pickle
import sys
import random
import os

from collections import defaultdict
from sentence_encoders import HierarchicalLSTM
from bert import BertEncoder
from corpus_reader import load_corpus, DepTree



def accuracy(gold, pred, header):
    correct = [0 for _ in header]
    exact_match = 0
    tokens = 0
    sentences = 0
    assert(len(gold) == len(pred))
    for g_sent, p_sent in zip(gold, pred):
        assert(len(g_sent) == len(p_sent))
        sentences += 1
        tokens += len(g_sent)

        for i in range(len(g_sent)):
            assert(len(g_sent[i]) == len(p_sent[i]))
            assert(len(g_sent[i]) == len(header))
            if g_sent[i] == p_sent[i]:
                exact_match += 1

            for j in range(len(correct)):
                if g_sent[i][j] == p_sent[i][j]:
                    correct[j] += 1

    output_dict = {f"{h}": f"{c / tokens * 100:.1f}" for h, c in zip(header, correct)}
    output_dict["exact_match"] = round(exact_match / tokens * 100, 1)
    return output_dict


class FlaubertTagger(nn.Module):
    def __init__(self, output_tags, char2i, word2i, args):
        super(FlaubertTagger, self).__init__()
        self.args = args
        if args.lstm:
            self.encoder = HierarchicalLSTM(args, char2i, word2i)
        else:
            self.encoder = BertEncoder(args.bert_id)

        self.cpos = args.cpos
        self.output_tags = output_tags

        self.classifiers = nn.ModuleList([
                        nn.Sequential(nn.Dropout(args.D), 
                                      nn.Linear(self.encoder.dim, len(i2tag)))
                        for i2tag in self.output_tags])

        self.loss = nn.CrossEntropyLoss()

        if False : #freeze_encoder:
          for param in self.encoder.parameters():
              param.requires_grad = False


    def forward(self, sentences, labels_all=None):
        contextualized_embeddings = self.encoder(sentences, batch=True)
        linearized = torch.cat(contextualized_embeddings)
        logits_all = [classifier(linearized) for classifier in self.classifiers]

        predictions_all = [torch.argmax(logits, dim=1) for logits in logits_all]
        # [tag_type][sentence_id]
        tags_idx = [[tens.cpu().numpy() for tens in predictions.split([len(sent) for sent in sentences])] 
                        for predictions in predictions_all]

        # [sentence_id][tag_type]
        tags_idx = list(zip(*tags_idx))
        output_dict = {"predictions": tags_idx}
        
        if labels_all is not None:
            tokens = 0
            correct = np.array([0 for _ in labels_all[0]])
            for gold_tags, pred_tags in zip(labels_all, tags_idx):
                for i in range(len(correct)):
                    correct[i] += np.sum(gold_tags[i].cpu().numpy() == pred_tags[i])
                tokens += len(gold_tags[i])
            
            labels_all = list(zip(*labels_all))
            losses = [self.loss(logits, torch.cat(labels)) for logits, labels in zip(logits_all, labels_all)]
            loss = sum(losses)

            output_dict["all_losses"] = [l.item() for l in losses]
            output_dict["loss"] = loss
            output_dict["acc"] = {'tokens': tokens, "correct": correct}

        return output_dict


    def predict_sentences(self, logger, corpus, batch_size, i2tags):
        self.eval()
        result = []
        k = 0
        for i in range(0, len(corpus), batch_size):
            k += 1
            if k % 100 == 0:
                logger.info(f"Tagging batch {k} over {len(corpus)//batch_size}: sentences {i} to {i+batch_size} over {len(corpus)}")
            sentences = corpus[i:i+batch_size]
            output_dict = self(sentences)
            result.extend(output_dict["predictions"])

        output = []
        for sentence_tags in result:
            sentence = []
            for tags in zip(*sentence_tags):
                tags = [i2t[t] for i2t,t in zip(i2tags, tags)]

                sentence.append(tags)
            output.append(sentence)
        return output

    def eval_on_corpus(self, trees, batch_size, i2tags, header):

        examples = [tree.get_training_example(self.cpos) for tree in trees]
        sentences = [ex["tokens"] for ex in examples]
        tags_seqs = [ex["all_tags"] for ex in examples]
        
        predictions = self.predict_sentences(sentences, batch_size, i2tags)

        eval_dict = accuracy(tags_seqs, predictions, header)
        return eval_dict

    def ids_to(self, device):
        if self.args.lstm:
            self.encoder.embedder.word2tensors.to(device)


def prepare_batch(args, trees, tags2i):
    
    training_examples = [tree.get_training_example(args.cpos) for tree in trees]
    sentences = [ex["tokens"] for ex in training_examples]

    tags_seqs_sentences = [ex["all_tags"] for ex in training_examples]
    tags_seqs_sentences = [list(zip(*ex)) for ex in tags_seqs_sentences]

    tags_idx = [[[tags2i[i][t] for t in tag_seq] for i, tag_seq in enumerate(tags_seq_sentence)] for tags_seq_sentence in tags_seqs_sentences]
    tags_idx = [[torch.tensor(tag_seq, device=device) for tag_seq in tags_sentence] for tags_sentence in tags_idx]
    return sentences, tags_idx

def sample_batches(aux_data, batch_size, sample_size):
    if len(aux_data) == 0:
        return []
    # prepare size_sample batches of aux_data)
    sample = list(np.random.choice(aux_data, sample_size, replace=False))
    return [sample[i:i+batch_size] for i in range(0, len(sample), batch_size) ]

def main_train(args, logger, device):

    corpus = load_corpus(args)

    train, dev, test = corpus["corpus"]
    aux_data = corpus["aux_data"]
    i2tags = corpus["i2tags"]
    tags2i = corpus["tags2i"]
    stats  = corpus["stats"]
    header = corpus["header"]

    char2i = corpus["char2i"]
    word2i = corpus["word2i"]

    with open(f"{args.model}/data", "wb") as fp:
        pickle.dump([char2i, word2i, i2tags, args], fp)


    for tagtype, tagset in zip(header, stats):
        print(tagtype, len(tagset))
        for k, v in sorted(tagset.items(), key = lambda x: x[1], reverse=True):
            print(k, v)
        print()

    flautag = FlaubertTagger(i2tags, char2i, word2i, args)
    flautag.to(device)
    flautag.ids_to(device)
    
    optimizer = optim.Adam(flautag.parameters(), lr=args.l)
    
    batch_size = args.B
    
    best_dev = 0
    best_epoch = 0
    for epoch in range(args.i):
        epoch_loss = 0
        flautag.train()
        
        batches_train = [train[i:i+batch_size] for i in range(0, len(train), batch_size) ]
        batches_aux = sample_batches(aux_data, batch_size=5*batch_size, sample_size=2*len(train))
        all_batches = batches_train + batches_aux
        random.shuffle(all_batches)
        logger.info(f"Epoch {epoch}, train batches: {len(batches_train)}, aux batches: {len(batches_aux)}")
        for i, trees in enumerate(all_batches):
            optimizer.zero_grad()
            
            sentences, tags_idx = prepare_batch(args, trees, tags2i)
            output_dict = flautag(sentences, tags_idx)

            loss = output_dict["loss"]
            losses = output_dict["all_losses"]
            np_loss = loss.item()
            epoch_loss += np_loss
            loss.backward()
            optimizer.step()
            losses_str = " ".join([f"{l:.3f}" for l in losses])
            acc_str = output_dict["acc"]["correct"] / output_dict["acc"]["tokens"] * 100
            acc_str = " ".join([f"{x:.2f}" for x in acc_str])
            logger.info(f"Batch {i} sentences {i*batch_size} to {(i+1)*batch_size-1} of {len(batches_train)} loss {np_loss:.4f}: {losses_str} acc {acc_str}")

        eval_dict = flautag.eval_on_corpus(dev, args.eval_batchsize, i2tags, header)
        eval_dict_str = " ".join([f"{k[:10]}={v}" for k,v in sorted(eval_dict.items(), key = lambda x: x!="POS")])

        if float(eval_dict["POS"]) > best_dev:
            best_dev = float(eval_dict["POS"])
            best_epoch = epoch

            flautag.cpu()
            flautag.ids_to(torch.device("cpu"))
            torch.save(flautag.state_dict(), "{}/model".format(args.model))
            with open(f"{args.model}/best_dev_score", "w", encoding="utf8") as ff:
                ff.write(f"{best_dev}\n{best_epoch}") 

            flautag.to(device)
            flautag.ids_to(device)

        print(f"Epoch {epoch} loss: {epoch_loss:.4f} dev: {eval_dict_str} best={best_dev} e{best_epoch}", flush = True)

def load_data(modeldir):
    with open(f"{modeldir}/data", "rb") as fp:
        return pickle.load(fp)

def predictions_to_conll(sentences, predictions, use_cpos):

    trees = []
    for sentence, prediction in zip(sentences, predictions):
        
        #tokens, cpos=None, fpos=None, features=None, 
        cpos = None
        fpos = None
        tags = [p[0] for p in prediction]
        if use_cpos:
            cpos = tags
        else:
            fpos = tags
        features = []
        for _, *rest in prediction:
            f = [r for r in rest if r != "_"]
            if len(f) > 0:
                features.append("|".join(f))
            else:
                features.append("_")
        assert(len(sentence) == len(features))
        tree = DepTree(tokens=sentence, cpos=cpos, fpos=fpos, features=features)
        trees.append(tree)
    return trees

def main_eval(args, logger, device):

    char2i, word2i, i2tags, args_train = load_data(f"{args.model}")
    
    flautag = FlaubertTagger(i2tags, char2i, word2i, args_train)

    state_dict = torch.load("{}/model".format(args.model))
    flautag.load_state_dict(state_dict)
    flautag.to(device)
    flautag.ids_to(device) 
    flautag.eval()
    logger.info("Model loaded")
    
    with open(args.corpus, encoding="utf8") as f:
        sentences = []
        for line in f:
            sent = line.strip().replace("\ufeff", "").replace("\u200e", "").replace("\u200b", "").replace("\x7f ","")
            if sent:
                sent = sent.split()
                sentences.append(sent)
            

    predictions = flautag.predict_sentences(logger, sentences, args.eval_batchsize, i2tags)

    conll_trees = predictions_to_conll(sentences, predictions, args_train.cpos)
    
    with open(f"{args.output}", "w", encoding="utf8") as of:
        for tree in conll_trees:
            of.write(f"{str(tree)}\n\n")


def main(args, logger, device):

    if args.mode == "train":
        try:
            main_train(args, logger, device)
        except KeyboardInterrupt:
            print("Training interrupted, exiting")
            return
    else:
        main_eval(args, logger, device)


if __name__ == "__main__":
    import argparse
    logging.basicConfig(stream=sys.stderr, level=logging.DEBUG, 
            format='%(asctime)s-%(relativeCreated)d-%(levelname)s:%(message)s')


    usage = main.__doc__

    parser = argparse.ArgumentParser(description = usage, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    subparsers = parser.add_subparsers(dest="mode", description="Execution modes", help='train: training, eval: test')
    subparsers.required = True

    train_parser = subparsers.add_parser("train", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    eval_parser = subparsers.add_parser("eval", formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # train corpora
    train_parser.add_argument("model", help="Directory (to be created) for model exportation")
    #train_parser.add_argument("train", help="Training corpus")
    #train_parser.add_argument("dev",   help="Dev corpus")
    #train_parser.add_argument("--fmt", help="Format for train and dev corpus", choices=["ctbk", "discbracket", "bracket"], default="discbracket")
    #train_parser.add_argument("--fmt", help="Format for train and dev corpus", choices=["ctbk", "discbracket", "bracket"], default="discbracket")

    train_parser.add_argument("--corpus", default="ftb_spmrl", choices=["ftb_spmrl", "fqb", "ftb", "gsd", "partut", "pud", "sequoia", "spoken"], 
                help="Dataset. Has train/dev/test: ftb, gsd, partut, sequoia, spoken. Has morphology: ftb, gsd, partut, sequoia.")
    train_parser.add_argument("--aux-data", default=None, help="Path to artificial extra dataset")
    train_parser.add_argument("--cpos", action="store_true", help="Use coarse pos instead of fpos")

    # general options
    train_parser.add_argument("--cuda", type=int, default=None, help="Use GPU if available")
    #train_parser.add_argument("-t", type=int, default=1, help="Number of threads for torch cpu")
    train_parser.add_argument("-S", type=int, default=None, help="Use only X first training examples")
    
    train_parser.add_argument("--bert-id", default="flaubert/flaubert_small_cased", help="Flaubert identifier")
    train_parser.add_argument("--lstm", action="store_true", help="Use hierarchical LSTM instead of Flaubert")

    # training options
    train_parser.add_argument("-i", default=12, type=int, help="Number of epochs")
    train_parser.add_argument("-l", default=0.00001, type=float, help="Learning rate")
    train_parser.add_argument("-B", type=int, default=20, help="Size of batch")
    train_parser.add_argument("-D", type=float, default=0.5, help="Dropout")

    train_parser.add_argument("--eval-batchsize", type=int, default=30, help="Batch size for eval")
    
    #train_parser.add_argument("--start-averaging", type=int, default=40, help="Aadam starts averaging at epoch <int> ")
    train_parser.add_argument("-s", type=int, default=10, help="Random seed")

    HierarchicalLSTM.add_cmd_options(train_parser)

    # test corpus
    eval_parser.add_argument("model", help="Pytorch model")
    eval_parser.add_argument("corpus", help="Test corpus, 1 tokenized sentence per line")
    eval_parser.add_argument("output", help="Outputfile")
#    eval_parser.add_argument("--gold", default=None, help="Gold corpus (disc)bracket. If provided, eval with discodop")
    eval_parser.add_argument("--eval-batchsize", type=int, default=32, help="Batch size for eval")

    # test options
    eval_parser.add_argument("--cuda", action="store_true", help="Use GPU if available")
#    eval_parser.add_argument("-v", type=int, default=1, choices=[0,1], help="Verbosity level")
#    eval_parser.add_argument("-t", type=int, default=1, help="Number of threads for torch cpu")

    args = parser.parse_args()
    for k, v in vars(args).items():
        print(k, v)
    
    #torch.set_num_threads(args.t)
    logger = logging.getLogger()
    logger.info("Mode={}".format(args.mode))

    if args.mode == "train":
        os.makedirs(args.model, exist_ok = True)

    use_cuda = torch.cuda.is_available()
    if use_cuda and args.cuda:
        logger.info("Using gpu")
        device = torch.device("cuda")
    else:
        logger.info("Using cpu")
        device = torch.device("cpu")
    logger.info(f"{str(device)}")
    
    SEED = 0
    if args.mode == "train":
        SEED = args.s
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    main(args, logger, device)



Software Heritage — Copyright (C) 2015–2025, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Contact— JavaScript license information— Web API

back to top