https://gitlab.com/mcoavoux/mtgpy-release-findings-2021.git
Tip revision: c9972219cd75049269d26632d2bb79619d661298 authored by mcoavoux on 20 May 2021, 13:04:44 UTC
up readme
up readme
Tip revision: c997221
state_gap.py
import tree as T
def get_item_sets(l):
struct_sets = []
for item in l:
item_set = set()
for s in item:
item_set |= s.get_span()
struct_sets.append(item_set)
return struct_sets
class State:
"""Class for parsing configurations and transition applications"""
STRUCT, LABEL = 0, 1
SHIFT, COMBINE, GAP = 0, 1, 2
mapping = {"shift": SHIFT, "combine": COMBINE, "gap": GAP}
def __init__(self, tokens, discontinuous, sent_id=None):
self.sent_id = sent_id
self.discontinuous = discontinuous
self.stack = []
self.queue = []
self.j = 0
self.i = 0
self.buffer = tokens
def print(self):
print("Stack", " ".join([f"{s}" for s in get_item_sets(self.stack)]))
print("Queue", " ".join([f"{s}" for s in get_item_sets(self.queue)]))
print("i={}".format(self.i))
print("j={}".format(self.j))
print("buffer size = {}".format(len(self.buffer)))
def next_action_type(self):
# Returns STRUCT or LABEL depending on type of the next action
return self.j % 2
def shift(self):
# Apply shift
self.stack += self.queue
self.queue = [[self.buffer[self.i]]]
self.i += 1
self.j += 1
def combine(self):
# Apply combine
new_item = self.stack[-1] + self.queue[-1]
self.stack.pop()
self.queue.pop()
self.stack += self.queue
self.queue = [new_item]
self.j += 1
def gap(self):
self.queue = [self.stack.pop()] + self.queue
def labelX(self, X):
# Apply label-X
# X: str
assert(len(self.queue) == 1)
self.queue = [[T.Tree(X, self.queue[0])]]
self.j += 1
def nolabel(self):
# Apply no-Label
self.j += 1
def is_prefinal(self):
# Returns True if configuration is final
# or would be final after a labelling action
if len(self.stack) > 0: return False
if len(self.queue) != 1: return False
if self.i != len(self.buffer): return False
return True
def is_final(self):
# Returns True if the configuration is final
# (and there is a single full tree in the focus)
return self.is_prefinal() and self.next_action_type() == State.STRUCT
def can_shift(self):
# Returns True if shift is possible in current configuration
if len(self.queue) > 1:
return False
if self.next_action_type() == State.LABEL:
return False
return self.i != len(self.buffer)
def can_combine(self):
# Returns True if shift is possible in current configuration
return len(self.stack) > 0 and self.next_action_type() == self.STRUCT
def can_gap(self):
return self.discontinuous and len(self.stack) > 1
def get_tree(self):
# Assumes that configuration is final and returns the predicted tree
assert(self.is_final())
return self.queue[0][0]
def filter_action(self, action_type, output):
if action_type == State.STRUCT:
if not self.can_shift():
output[State.SHIFT] = -10**10
if not self.can_combine():
output[State.COMBINE] = -10**10
if not self.can_gap():
output[State.GAP] = - 10**10
else:
if self.is_prefinal():
output[0] = - 10**10
def oracle(self):
# Assumes that the configuration is built upon a gold tree
# Returns a training example: next action + input from which it should be predicted
# Side effect: apply the gold action
if self.next_action_type() == State.LABEL:
input_res = self.get_labelling_step_input()
self.j += 1
gold_idxs = self.queue[0][0].parent.get_span()
current_idxs = set()
for s in self.queue[0]:
current_idxs |= s.get_span()
if current_idxs == gold_idxs:
self.queue = [[self.queue[0][0].parent]]
return ("label", self.queue[0][0].label), input_res
return ("nolabel", "nolabel"), input_res
else:
input_res = self.get_structural_step_input()
if self.stack == []:
self.shift()
return ("shift", None), input_res
p = self.queue[-1][0]
if self.stack[-1][0].parent == p.parent:
self.combine()
return ("combine", None), input_res
i = len(self.stack) - 2
while i >= 0:
if self.stack[i][0].parent == p.parent:
self.gap()
return ("gap", None), input_res
i -= 1
self.shift()
return ("shift", None), input_res
def get_input(self):
stack_sets = []
for item in self.stack[-2:]:
stack_set = set()
for s in item:
stack_set |= s.get_span()
stack_sets.append(stack_set)
queue_sets = []
for item in self.queue[-2:]:
queue_set = set()
for s in item:
queue_set |= s.get_span()
queue_sets.append(queue_set)
return stack_sets, queue_sets, self.i, len(self.buffer) # EOS -> for next token to shift
# buffer_i = None
# if self.can_shift():
# buffer_i = self.i
# return stack_sets, queue_sets, buffer_i
def get_structural_step_input(self):
return self.get_input()
def get_labelling_step_input(self):
return self.get_input()
"""
def get_structural_step_input(self):
# Returns representation of the current configuration
# mem_sets: list of sets of int
# focus_set: set of int
# buf_set: set of int (singleton)
mem_sets = []
for l in self.memory:
current_set = set()
for s in l:
current_set |= s.get_span()
mem_sets.append(current_set)
focus_set = self.get_labelling_step_input()
buf_set = None
if self.can_shift():
buf_set = self.buffer[self.i].get_span()
return mem_sets, focus_set, buf_set
def get_labelling_step_input(self):
# Returns set of indices dominated by s_f
if self.focus is None:
return None
focus_set = set()
for s in self.focus:
focus_set |= s.get_span()
return focus_set
def dyn_oracle(self, gold_constituents):
if self.next_action_type() == State.LABEL:
input_res = self.get_labelling_step_input()
input_tuple = tuple(sorted(input_res))
if input_tuple in gold_constituents:
return ("label", gold_constituents[input_tuple])
else:
return ("nolabel", "nolabel")
else:
# Here are the tricky cases
# Look for the smallest reachable constituent
# and return an action that constructs it
potential_reachable = [set(k) for k, v in gold_constituents.items() if max(k) >= self.i -1]
potential_reachable.sort(key=lambda x: (max(x), len(x)))
reachable = None
memory_sets = []
for s in self.memory:
s_0 = s[0].get_span()
for s_i in s[1:]:
s_0 |= s_i.get_span()
memory_sets.append(s_0)
focus_set = set()
for s in self.focus:
focus_set |= s.get_span()
for c in potential_reachable:
keep = True
for s in memory_sets + [focus_set]:
# s is a subset of s_g or s and s_g are disjoint
if all([i in c for i in s]) or not any([i in c for i in s]):
continue
else:
keep = False
break
if keep:
reachable = c
break
if reachable is None:
return ("shift", None)
#for i, s in reversed(list(enumerate(memory_sets))):
for i, s in sorted(list(enumerate(memory_sets)), key = lambda x: max(x[1]), reverse=True):
union = s | focus_set
if all([i in reachable for i in union]):
return ("combine", i)
return ("shift", None)
"""
if __name__ == "__main__":
import corpus_reader
corpus = corpus_reader.read_discbracket_corpus("../sample_data/train_sample.discbracket")
for tree in corpus:
tree.merge_unaries()
sentences = [T.get_yield(corpus[i]) for i in range(len(corpus))]
trees = []
for sentence in sentences:
state = State(sentence)
while not state.is_final():
#state.print()
action, _ = state.oracle()
trees.append(state.get_tree())
for a,b in zip(corpus, trees):
s1 = str(a)
s2 = str(b)
print(s1 == s2)