https://github.com/mirefek/geo_logic
Tip revision: c8c9b715cae0acfb07585c25659b1333d075cc5b authored by mirefek on 25 July 2021, 19:43:55 UTC
bugfix
bugfix
Tip revision: c8c9b71
uf_dict.py
from collections import defaultdict
from stop_watch import StopWatch
"""
UnionFindDict is a dictionary-like structure for the lookup table of the logical core.
It is a dictionary of the form (label, tuple of objects) -> (tuple of objects)
which can be set using
add(label, input, output)
and retrieved using
get(label, input)
The "add" function cannot overwrite, that is, it raises an exception
if the (label, input) is already in the database. The "get" function
returns a tuple, or None if it is not in the database.
Moreover, the structure allows "gluing" of objects. The function
glue(obj1, obj2)
makes the two objects equal from the perspective of the UnionFindDict.
The "glue" function returns the list of all pairs (a,b) that were glued,
they include the initial (obj1, obj2) and other pairs glued
due to extensionality.
"""
class UnionFindDict:
def __init__(self):
self.data = dict() # the main dictionary
self.obj_to_root_d = dict() # obj -> (representative) obj
self.obj_to_children = defaultdict(set) # inverse of obj_to_root
self.obj_to_keys = defaultdict(set) # obj -> (label, input) such that obj in input or output
def obj_to_root(self, obj):
return self.obj_to_root_d.get(obj, obj)
def tup_to_root(self, tup):
return tuple(map(self.obj_to_root, tup))
def _data_add(self, label, args, vals):
#print("_data_add", label, args, vals)
key = label, args
if key in self.data:
if self.data[key] == vals: return
raise KeyError("key {} is already in the uf_dictionary".format(key))
self.data[key] = vals
for obj in args + vals:
#print(" obj_to_keys[{}] :".format(obj))
#print(" {}".format(self.obj_to_keys[obj]))
self.obj_to_keys[obj].add(key)
#print(" {}".format(self.obj_to_keys[obj]))
def _data_remove(self, label, args):
#print("_data_remove", label, args)
key = label, args
vals = self.data[key]
del self.data[key]
for obj in args + vals:
#print(" obj_to_keys[{}] :".format(obj))
#print(" {}".format(self.obj_to_keys[obj]))
self.obj_to_keys[obj].discard(key)
#print(" {}".format(self.obj_to_keys[obj]))
return vals
def add(self, label, args, vals):
#print('add', label, args, vals)
args, vals = map(self.tup_to_root, (args, vals))
self._data_add(label, args, vals)
return args, vals
def is_equal(self, n1, n2):
n1, n2 = map(self.obj_to_root, (n1, n2))
return n1 == n2
def glue(self, n1, n2):
#print('glue', n1, n2)
result = self.multi_glue((n1, n2))
return result
def multi_glue(self, *pairs):
changed = []
to_glue = list(pairs)
while to_glue:
n1, n2 = to_glue.pop()
n1, n2 = map(self.obj_to_root, (n1, n2))
if n1 == n2: continue
c1, c2 = [
len(self.obj_to_children[n]) + len(self.obj_to_keys[n])
for n in (n1, n2)
]
if c1 < c2: n1, n2 = n2, n1
changed.append((n1, n2))
self.obj_to_root_d[n2] = n1
children1 = self.obj_to_children[n1]
children2 = self.obj_to_children[n2]
for child in children2:
self.obj_to_root_d[child] = n1
children1.update(children2)
children1.add(n2)
children2.clear()
#print("{} : {}".format(n2, self.obj_to_keys[n2]))
for key in tuple(self.obj_to_keys[n2]):
label, args = key
vals = self._data_remove(label, args)
args, vals = map(self.tup_to_root, (args, vals))
ori_val = self.get(label, args)
if ori_val is not None:
assert(len(vals) == len(ori_val))
to_glue.extend(zip(vals, ori_val))
else: self._data_add(label, args, vals)
return changed
def __contains__(self, key):
return self.tup_to_root(key) in self.data
def get(self, label, args): # default = None, otherwise tuple
args = self.tup_to_root(args)
return self.data.get((label, args), None)
if __name__ == "__main__":
d = UnionFindDict()
d.add("A", (1, 0), ())
d.add("B", (1, 0), (2,))
d.add("C", (1, 0), (2,))
d.add("D", (1, 2), (3,))
d.add("E", (3,), (4,))
d.add("F", (3,), (4,))
d.add("G", (3,), (5,))
d.add("H", (3,), (5,))
d.glue(1, 4)
d.glue(2, 5)
print(d.data)