https://gricad-gitlab.univ-grenoble-alpes.fr/coavouxm/flaubertagger.git
Tip revision: c939ca9fac094ac3c379256ef3d3d4d14a5a4bf1 authored by m on 23 February 2024, 16:44:50 UTC
up
up
Tip revision: c939ca9
word_encoders.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
SPECIAL_CHARS = {"-LRB-", "-RRB-", "#RRB#", "#LRB#"} # warning: duplicate from corpus reader
class Words2Tensors():
"""Stores each token as a long tensor"""
def __init__(self, char2i, word2i, pchar, pword):
""" Expects word2i ordered by decreasing freqs"""
#super(Words2Tensors, self).__init__()
self.w2tensor = {}
self.cunk = char2i["<UNK>"]
self.char2i = char2i
self.pchar = pchar
self.pword = pword
self.wunk = word2i["<UNK>"]
self.freq_cap = int(0.25 * len(word2i)) # only replace 75% least frequent tokens
longest = max([len(w) for w in word2i]) + 8
self.cmaskr = torch.rand(longest)
# bool: boolean selection of indices (torch.long -> index selection)
self.cmaski = torch.tensor(list(range(longest)), dtype=torch.bool)
longest_sentence=200000
self.wmaskr = torch.rand(longest_sentence)
# bool: boolean selection of indices (torch.long -> index selection)
self.wmaski = torch.tensor(list(range(longest_sentence)), dtype=torch.bool)
self.words2i = word2i
self.initialize()
def initialize(self):
for tok in self.words2i:
self.add(tok, init=True)
def add(self, tok, init):
char2i = self.char2i
start, stop = [char2i["<START>"]], [char2i["<STOP>"]]
if tok in SPECIAL_CHARS:
chars_idx = [start[0], char2i[tok], stop[0]]
elif tok in {"<SOS>", "<EOS>"}:
chars_idx = [char2i[tok]]
else:
chars_idx = start + [char2i[c] if c in char2i else char2i["<UNK>"] for c in tok] + stop
if init:
self.w2tensor[tok] = (
torch.tensor(chars_idx, dtype=torch.long),
torch.tensor(chars_idx, dtype=torch.long)
)
else:
device = self.w2tensor["<UNK>"][0].device
self.w2tensor[tok] = (
torch.tensor(chars_idx, dtype=torch.long, device=device),
torch.tensor(chars_idx, dtype=torch.long, device=device)
)
def get(self, words, training):
for w in words :
if w not in self.w2tensor:
self.add(w, init=False)
if not training or self.pchar is None:
return [self.w2tensor[w][0] for w in words]
tensors = [self.w2tensor[w] for w in words]
for c1, c2 in tensors:
c2.copy_(c1)
n = len(c2)
self.cmaskr[:n].uniform_(0, 1)
self.cmaski[:n].copy_(self.cmaskr[:n] > (1-self.pchar))
# do not replace <START> and <STOP> symbols
self.cmaski[0] = 0
self.cmaski[-1] = 0
c2[self.cmaski[:n]] = self.cunk
return [t for _, t in tensors]
def get_word_idxes(self, words, training):
res = []
for word in words:
if word in self.words2i:
res.append(self.words2i[word])
else:
res.append(self.wunk)
tensor = torch.tensor(res)
n = len(tensor)
if training and self.pword > 0:
rarest_token = tensor > self.freq_cap
self.wmaskr[:n].uniform_(0, 1)
self.wmaski[:n].copy_(self.wmaskr[:n] > (1-self.pword))
tensor[self.wmaski[:n] * rarest_token] = self.wunk
return tensor
def to(self, device):
self.cmaskr = self.cmaskr.to(device)
self.cmaski = self.cmaski.to(device)
self.w2tensor = {k: (v0.to(device), v1.to(device)) for k, (v0, v1) in self.w2tensor.items()}
class CharacterLstmLayer(nn.Module):
def __init__(self, emb_dim, voc_size, out_dim, words2tensors=None, dropout=0.2, embed_init=0.1):
"""
Args:
emb_dim: dimension of input embeddings
voc_size: size of vocabulary (0 = padding)
out_dim: dimension of bi-lstm output (each direction is out_dim // 2)
"""
super(CharacterLstmLayer, self).__init__()
self.words2tensors = words2tensors
self.emb_dim = emb_dim
self.out_dim = out_dim
self.voc_size = voc_size
self.embeddings = nn.Embedding(voc_size, emb_dim, padding_idx=0)
self.lstm = nn.LSTM(emb_dim, out_dim // 2, num_layers=1, bidirectional=True, batch_first=True)
if dropout > 0:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
self.initialize_parameters(embed_init)
def initialize_parameters(self, embed_init):
torch.nn.init.uniform_(self.embeddings.weight.data, -embed_init, embed_init)
self.embeddings.weight.data[0].fill_(0)
def forward(self, input):
"""
Args:
input: list of torch.long tensors OR
list of tokens (str) if self.words2tensors is not None
Returns:
res: tensor of size (batch, out_dim)
"""
if self.words2tensors is not None:
input = self.words2tensors.get(input, training=self.training)
# Pytorch rnn batches need to be sorted by decreasing lengths:
order, sorted_by_length = zip(* sorted(enumerate(input), key = lambda x: len(x[1]), reverse=True))
lengths = [len(i) for i in sorted_by_length]
padded_char_seqs = torch.nn.utils.rnn.pad_sequence(sorted_by_length, batch_first=True)
padded_char_seqs_embeddings = self.embeddings(padded_char_seqs)
if self.dropout is not None:
padded_char_seqs_embeddings = self.dropout(padded_char_seqs_embeddings)
packed_padded_char_seqs = torch.nn.utils.rnn.pack_padded_sequence(
padded_char_seqs_embeddings, lengths, batch_first=True)
_, (hn_xdir_bat_xdim, _) = self.lstm(packed_padded_char_seqs)
# hn is (num dir, batch, outdim)
lstm_output = torch.cat([hn_xdir_bat_xdim[0], hn_xdir_bat_xdim[1]], 1)
# reorder idxes
rev, _ = zip(*sorted(enumerate(order), key = lambda x : x[1]))
res = torch.embedding(lstm_output, torch.tensor(rev, dtype=torch.long).to(lstm_output.device))
return res
class WordEmbedder(nn.Module):
def __init__(self, args, num_words, num_chars, word2tensors):
"""
args.w word embedding
args.c char embedding
args.C char based embeddings
args.no_char bool
args.no_word bool
"""
super(WordEmbedder, self).__init__()
self.no_char = args.no_char
self.no_word = args.no_word
self.output_dim = 0
if not self.no_char:
self.output_dim += args.C
else:
args.c = 2
args.C = 2
if not self.no_word:
self.output_dim += args.w
else:
args.w = 2
self.word_embeddings = nn.Embedding(num_words, args.w, padding_idx=0)
self.char_encoder = CharacterLstmLayer(
emb_dim=args.c,
voc_size=num_chars,
out_dim=args.C,
embed_init=args.I,
words2tensors=word2tensors,
dropout=0)# Already drop character with pchar
self.word2tensors = word2tensors
self.initialize_parameters(args)
def initialize_parameters(self, args):
torch.nn.init.uniform_(self.word_embeddings.weight.data, -args.I, args.I)
self.word_embeddings.weight.data[0].fill_(0)
def forward(self, sentences):
embeddings = []
lengths = [len(s) for s in sentences]
all_tokens = [tok for s in sentences for tok in s]
if not self.no_char:
# char-based embeddings
char_based_embeddings = self.char_encoder(all_tokens)
char_based_embeddings = char_based_embeddings.split(lengths)
embeddings.append(char_based_embeddings)
if not self.no_word:
# Word embeddings
sentences_idxs = [self.word2tensors.get_word_idxes(sent, training=self.training).to(self.word_embeddings.weight.device) for sent in sentences]
word_embeddings = [self.word_embeddings(s) for s in sentences_idxs]
embeddings.append(word_embeddings)
if len(embeddings) == 1:
return embeddings[0], lengths
result = []
for sentence_embeddings in zip(*embeddings):
result.append(torch.cat(sentence_embeddings, dim=1))
return result, lengths
if __name__ == "__main__":
from collections import defaultdict
import argparse
sentences = ["Influential members of the House Ways and Means Committee introduced legislation that would restrict how the new savings-and-loan bailout agency can raise capital , creating another potential obstacle to the government 's sale of sick thrifts .", "The bill , whose backers include Chairman Dan Rostenkowski -LRB- D. , Ill. -RRB- , would prevent the Resolution Trust Corp. from raising temporary working capital by having an RTC-owned bank or thrift issue debt that would n't be counted on the federal budget ."]
sentences = [sent.split() for sent in sentences]
voc = []
chars = []
for sentence in sentences:
for token in sentence:
voc.append(token)
for char in token:
chars.append(char)
i2word = ["<PAD>", "<UNK>"] + sorted(set(voc))
i2char = ["<PAD>", "<UNK>", "<START>", "<STOP>", "-LRB-", "-RRB-"] + sorted(set(chars))
char2i = {c:i for i,c in enumerate(i2char)}
word2i = {w:i for i,w in enumerate(i2word)}
words2tensors = Words2Tensors(char2i, word2i, 0.2, 0.2)
# print(words2tensors.w2tensor["members"])
# char_lstm = CharacterLstmLayer(30, len(i2char), 20, words2tensors=words2tensors)
# embed_sentence = char_lstm(sentences[0])
# print(embed_sentence.shape)
args = argparse.Namespace()
args.c = 10
args.C = 12
args.I = 0.01
args.w = 15
args.no_char = False
args.no_word = False
embedder = WordEmbedder(args, len(i2word), len(i2char), words2tensors)
output = embedder(sentences)