https://gitlab.com/mcoavoux/mtgpy-release-findings-2021.git
Tip revision: c9972219cd75049269d26632d2bb79619d661298 authored by mcoavoux on 20 May 2021, 13:04:44 UTC
up readme
up readme
Tip revision: c997221
extract_discontinuous_trees.py
from nltk import Tree as nTree
import glob
class Tree:
"""Tree class with different types of annotations"""
def __init__(self, t) :
"""
t -- nltk tree
"""
if type(t) == str :
i,label = t.split("=",1)
self.label = label
self.children = []
self.left_index = int(i)
self.span = {self.left_index}
else :
self.label = t.label()
self.children = [Tree(c) for c in t]
self.children.sort(key = lambda x : x.left_index)
self.left_index = min((c.left_index for c in self.children))
self.span = set()
for c in self.children :
self.span |= c.span
def is_leaf(self):
return self.children == []
def is_preterminal(self) :
return len(self.children) == 1 and self.children[0].is_leaf()
def get_frontier(self, lst) :
"""
Update recursively lst to contain a list of all (unordered)
terminals in the tree
"""
if self.is_leaf() :
lst.append(self.label)
for c in self.children :
c.get_frontier(lst)
def get_list_of_terminals(self, lst) :
if self.is_leaf() :
lst.append(self.label)
for c in self.children :
c.get_list_of_terminals(lst)
def is_discontinuous(self):
if self.is_leaf():
return False
indexes = sorted(self.span)
if indexes != [i for i in range(min(indexes), max(indexes) + 1)]:
return True
return any([c.is_discontinuous() for c in self.children])
def __str__(self):
if self.is_leaf():
return f"{self.left_index}={self.label}"
else:
return f"({self.label} {' '.join([str(c) for c in self.children])})"
def get_tokens(tree):
l = []
tree.get_frontier(l)
return l
if __name__ == "__main__":
import argparse
usage = """Extract discontinuous trees from a corpus"""
parser = argparse.ArgumentParser(description = usage, formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("folder", type = str, help="Path to folder containing .discbracket files")
parser.add_argument("outfile", type = str, help="Print discontinuous tree in file")
parser.add_argument("--max-length", default=100, type=int, help="Max length of kept sentences")
parser.add_argument("--min-length", default=6, type=int, help="Min length of kept sentences")
args = parser.parse_args()
filenames = glob.glob(f"{args.folder}/*/*.discbracket")
printed_tree = 0
with open(f"{args.outfile}", "w", encoding="utf8") as outstream:
for i, filename in enumerate(filenames):
if i+1 % 100 == 0:
print(f"Processing {i+1} out of {len(filenames)} files")
ctrees = [nTree.fromstring(line.strip()) for line in open(filename, encoding="utf8")]
ftrees = []
for t in ctrees:
try:
tree = Tree(t)
ftrees.append(tree)
except:
print("Error while reading tree, ignore")
for tree in ftrees:
tokens = get_tokens(tree)
if not tree.is_discontinuous():
continue
if len(tokens) < args.min_length or len(tokens) > args.max_length:
continue
outstream.write(f"{str(tree)}\n")
printed_tree += 1
if printed_tree % 100 == 0:
print(f"Printed {printed_tree} trees so far")