Raw File
CPH.py
import networkx as nx
import random
import ast
import re
import time
import joblib
import numpy as np
from copy import deepcopy
from copy import copy

import pandas as pd

from Features import Features
memodict = {}


'''
This script consists of 4 parts: 
- Help functions (mainly for handling the input)
- Input_Set CLASS: code for opening the input and running the Cherry Picking Heuristic (CPH)
- PhT CLASS: environment for a phylogenetic tree
- PhN CLASS: environment for a phylogenetic network
'''

#######################################################################
########                   HELP FUNCTIONS                      ########
#######################################################################

# Write length newick: convert ":" to "," and then evaluate as list of lists using ast.literal_eval
# Then, in each list, the node is followed by the length of the incoming arc.
# This only works as long as each branch has length and all internal nodes are labeled.
def newick_to_tree(newick, current_labels=dict()):
    # newick = newick[:-1]
    distances = False
    # presence of : indicates the use of lengths in the trees
    if ":" in newick:
        distances = True
        # taxon names may already be enclosed by " or ', otherwise, we add these now
        if "'" not in newick and '"' not in newick:
            newick = re.sub(r"([,\(])([a-zA-Z\d]+)", r"\1'\2", newick)
            newick = re.sub(r"([a-zA-Z\d]):", r"\1':", newick)
        newick = newick.replace(":", ",")
    else:
        # taxon names may already be enclosed by " or ', otherwise, we add these now
        if not "'" in newick and not '"' in newick:
            newick = re.sub(r"([,\(])([a-zA-Z\d]+)", r"\1'\2", newick)
            newick = re.sub(r"([a-zA-Z\d])([,\(\)])", r"\1'\2", newick)
    # turn the string into a pyhton nested list using [ instead of (
    newick = newick.replace("(", "[")
    newick = newick.replace(")", "]")
    nestedtree = ast.literal_eval(newick)
    # parse the nested list into a list of edges with some additional information about the leaves
    # we start with the root 2, so that we can append a root edge (1,2)
    edges, leaves, current_labels, current_node = nested_list_to_tree(nestedtree, 2, current_labels, distances=distances)
    # put all this information into a networkx DiGraph with or without distances/lengths
    tree = nx.DiGraph()
    if distances:
        edges.append((1, 2, 0))
        tree.add_weighted_edges_from(edges, weight='length')

    else:
        edges.append((1, 2, 0))
        # tree.add_edges_from(edges)
        tree.add_weighted_edges_from(edges, weight='length')
    add_node_attributes(tree, distances=distances, root=2)
    return tree, leaves, current_labels, distances


# Auxiliary function to convert list of lists to tree (graph)
# Works recursively, where we keep track of the nodes we have already used
# Leaves are nodes with negative integer as ID, and already existing taxa are coupled to node IDs by current_labels.
def nested_list_to_tree(nestedList, next_node, current_labels, distances=False):
    edges = []
    leaves = set()
    top_node = next_node
    current_node = next_node + 1
    if distances:
        # each element in the sublist has 2 properties, the subtree, and the length, which are adjacent in nestedList
        for i in range(0, len(nestedList), 2):
            t = nestedList[i]
            length = nestedList[i + 1]
            if type(t) == list:  # Not a leaf
                edges.append((top_node, current_node, length))
                extra_edges, extra_leaves, current_labels, current_node = nested_list_to_tree(t, current_node,
                                                                                             current_labels,
                                                                                             distances=distances)
            else:  # A leaf
                if str(t) not in current_labels:
                    current_labels[str(t)] = -len(current_labels)
                edges.append((top_node, current_labels[str(t)], length))
                extra_edges = []
                extra_leaves = {current_labels[str(t)]}
            edges = edges + extra_edges
            leaves = leaves.union(extra_leaves)
    else:
        # no lengths/distances, so each subtree is simply an element of nestedList
        for t in nestedList:
            length = 1
            if type(t) == list:
                edges.append((top_node, current_node, length))
                extra_edges, extra_leaves, current_labels, current_node = nested_list_to_tree(t, current_node,
                                                                                             current_labels)
            else:
                if str(t) not in current_labels:
                    current_labels[str(t)] = -len(current_labels)
                edges.append((top_node, current_labels[str(t)], length))
                extra_edges = []
                extra_leaves = {current_labels[str(t)]}
            edges = edges + extra_edges
            leaves = leaves.union(extra_leaves)
    return edges, leaves, current_labels, current_node


# per node, add the edge based and comb height of the node as an attribute.
def add_node_attributes(tree, distances=True, root=0):
    attrs = dict()
    for x in tree.nodes:
            if distances:
                try:
                    attrs[x] = {"node_length": nx.algorithms.shortest_paths.generic.shortest_path_length(tree, root, x, weight="length"),
                                "node_comb": nx.algorithms.shortest_paths.generic.shortest_path_length(tree, root, x)}
                except nx.exception.NetworkXNoPath:
                    attrs[x] = {"node_length": None, "node_comb": None}
            else:
                try:
                    node_comb = nx.algorithms.shortest_paths.generic.shortest_path_length(tree, root, x)
                    attrs[x] = {"node_length": node_comb,
                                "node_comb": node_comb}
                except nx.exception.NetworkXNoPath:
                    attrs[x] = {"node_length": None, "node_comb": None}

    nx.set_node_attributes(tree, attrs)


# Modifies a cherry-picking sequence so that it represents a network with exactly one root.
# A sequence may be such that reconstructing a network from the sequence results in multiple roots
# This function adds some pairs to the sequence so that the network has a single root.
# returns the new sequence, and also modifies the sets of trees reduced by each pair in the sequence,
# so that the new pairs are also represented (they reduce no trees)
def sequence_add_roots(seq, red_trees):
    leaves_encountered = set()
    roots = set()
    # The roots can be found by going back through the sequence and finding pairs where the second element has not been
    # encountered in the sequence yet
    for pair in reversed(seq):
        if pair[1] not in leaves_encountered:
            roots.add(pair[1])
        leaves_encountered.add(pair[0])
        leaves_encountered.add(pair[1])
    i = 0
    roots = list(roots)
    # Now add some pairs to make sure each second element is already part of some pair in the sequence read backwards,
    # except for the last pair in the sequence
    for i in range(len(roots) - 1):
        seq.append((roots[i], roots[i + 1]))
        # none of the trees are reduced by the new pairs.
        red_trees.append(set())
        i += 1
    return seq, red_trees


# return leaves from network
def get_leaves(net):
    return {u for u in net.nodes() if net.out_degree(u) == 0}

#######################################################################
########            INPUT SET CLASS with CPH method            ########
#######################################################################


# Methods for sets of phylogenetic trees
class Input_Set:
    def __init__(self, newick_strings=[], tree_set=None, instance=0, leaves=None, full_leaf_set=True):
        # The dictionary of trees
        self.trees = dict()
        # the set of leaf labels of the trees
        self.labels = dict()
        self.labels_reversed = dict()
        self.leaves = set()
        self.instance = instance

        # the current best sequence we have found for this set of trees
        self.best_seq = None
        # the list of reduced trees for each of the pairs in the best sequence
        self.best_red_trees = None

        # the best sequence for the algorithm using lengths as input as well
        self.best_seq_with_lengths = None
        # the sets of reduced trees for each pair in this sequence
        self.best_seq_with_lengths_red_trees = None
        # the height of each pair in this sequence
        self.best_seq_with_lengths_heights = None

        # true if distances are used
        self.distances = True
        # computation times
        self.CPS_Compute_Time = 0
        self.CPS_Compute_Reps = 0
        self.DurationPerTrial = []
        self.RetPerTrial = []

        if tree_set is None:
            # read the input trees in 'newick_strings'
            for n in newick_strings:
                tree = PhT()
                self.trees[len(self.trees)] = tree
                self.labels, distances_in_tree = tree.tree_from_newick(newick=n, current_labels=self.labels)
                self.distances = self.distances and distances_in_tree
            # ONLY INTERSECTION
            self.leaves = set(self.labels.keys())
        else:
            self.trees = {t: PhT(tree) for t, tree in tree_set.items()}
            self.leaves = leaves
            self.labels = {l: l for l in self.leaves}

        if full_leaf_set:
            self.unique_leaves = self.leaves
        else:
            self.unique_leaves = self.trees[0].leaves
            for t, tree in self.trees.items():
                if t == 0:
                    continue
                self.unique_leaves = tree.leaves.intersection(self.unique_leaves)
        self.num_leaves = len(self.unique_leaves)

        # make a reverse dictionary for the leaf labels, to look up the label of a given node
        for l, i in self.labels.items():
            self.labels_reversed[i] = l

    # Make a deepcopy of an instance
    def __deepcopy__(self, memodict={}):
        copy_inputs = Input_Set()
        copy_inputs.trees = deepcopy(self.trees, memodict)
        copy_inputs.labels = deepcopy(self.labels, memodict)
        copy_inputs.labels_reversed = deepcopy(self.labels_reversed, memodict)
        copy_inputs.leaves = deepcopy(self.leaves, memodict)
        return copy_inputs

    # Find new cherry-picking sequences for the trees and update the best found
    def CPSBound(self, repeats=1,
                 progress=False,
                 time_limit=None,
                 reduce_trivial=False,
                 pick_ml=False,
                 pick_ml_triv=False,
                 pick_random=False,
                 model_name=None,
                 relabel=False,
                 relabel_cher_triv=False,
                 ml_thresh=None,
                 problem_type=None):
        # Set the specific heuristic that we use, based on the user input and whether the trees have lengths
        Heuristic = self.CPHeuristic
        # Initialize the recorded best sequences and corresponding data
        best = None
        red_trees_best = []
        starting_time = time.time()
        self.DurationPerTrial = dict()
        self.RetPerTrial = dict()
        # Try as many times as required by the integer 'repeats'
        df_pred = None
        for i in np.arange(repeats):
            start_trial = time.time()
            if not (pick_ml or pick_ml_triv) and progress:
                print(f"\nInstance {self.instance} {problem_type}: Trial {i}\n")
            # RUN HEURISTIC
            new, reduced_trees, df_pred = Heuristic(progress=progress,
                                                    reduce_trivial=reduce_trivial,
                                                    pick_ml=pick_ml,
                                                    pick_ml_triv=pick_ml_triv,
                                                    pick_random=pick_random,
                                                    model_name=model_name,
                                                    relabel=relabel,
                                                    relabel_cher_triv=relabel_cher_triv,
                                                    ml_thresh=ml_thresh,
                                                    problem_type=problem_type)
            if progress:
                print(f"Instance {self.instance} {problem_type}: found sequence of length: {len(new)}")
            # COMPLETE PARTIAL SEQUENCE
            new, reduced_trees = sequence_add_roots(new, reduced_trees)
            if progress:
                print(f"Instance {self.instance} {problem_type}: length after completing sequence: {len(new)}")

            self.CPS_Compute_Reps += 1
            self.DurationPerTrial[i] = time.time() - start_trial

            # FIND RETICULATION NUMBER
            seq_length = len(new)
            self.RetPerTrial[i] = seq_length - self.num_leaves + 1
            # store best sequence
            if best is None or seq_length < best_length:
                best = new
                best_length = copy(seq_length)
                red_trees_best = reduced_trees
            if progress:
                print(f"Instance {self.instance} {problem_type}: best sequence has length: {best_length}")
            if time_limit and time.time() - starting_time > time_limit:
                break

        # storing stuff of heuristic
        self.CPS_Compute_Time += time.time() - starting_time
        # storing best network
        new_seq = best
        if not self.best_seq_with_lengths or len(new_seq) < len(self.best_seq_with_lengths):
            converted_new_seq = []
            for pair in new_seq:
                converted_new_seq += [(self.labels_reversed[pair[0]], self.labels_reversed[pair[1]])]
            self.best_seq_with_lengths = converted_new_seq
            self.best_seq_with_lengths_red_trees = red_trees_best
        seq_return = [(self.labels_reversed[x], self.labels_reversed[y]) for x, y in new_seq]
        return self.best_seq_with_lengths, seq_return, df_pred

    def CPHeuristic(self, progress=False, reduce_trivial=True, pick_ml=False,
                    pick_ml_triv=False, model_name=None, pick_random=False, relabel=False, relabel_cher_triv=False,
                    ml_thresh=None, problem_type=None):
        # Works in a copy of the input trees, copy_of_inputs, because trees have to be reduced somewhere.
        copy_of_inputs = deepcopy(self)
        CPS = []
        reduced_trees = []
        # Make dict of reducible pairs
        reducible_pairs = self.find_all_pairs()
        triv_picked = False
        if pick_ml or pick_ml_triv:
            # create initial features
            start_time_init = time.time()
            features = Features(reducible_pairs, copy_of_inputs.trees, root=2)
            df_pred = pd.DataFrame(columns=["x", "y", "no_cher_pred", "cher_pred", "ret_cher_pred", "no_ret_cher_pred", "trees_reduced"])
            if progress:
                print(
                    f"Instance {self.instance} {problem_type}: Initial features found in {np.round(time.time() - start_time_init, 3)}s")
            # open prediction model
            if model_name is None:
                model_name = "LearningCherries/RFModels/rf_cherries_N1000_maxL20_random_balanced.joblib"
            rf_model = joblib.load(model_name)
        else:
            df_pred = None
        while copy_of_inputs.trees:
            if progress and (pick_ml or pick_ml_triv):
                print(f"Instance {self.instance} {problem_type}: Sequence has length {len(CPS)}")
                print(f"Instance {self.instance} {problem_type}: {len(copy_of_inputs.trees)} trees left.\n")
            if reduce_trivial:
                chosen_cherry, triv_picked = copy_of_inputs.pick_trivial(reducible_pairs)
                if chosen_cherry is None:
                    pick_random = True
                else:
                    pick_random = False

            if pick_ml_triv:
                trivial_slice = features.data["trivial"] == 1
                if trivial_slice.any():       # find trivial cherry
                    triv_cherries = list(features.data.loc[trivial_slice].index)
                    if relabel:
                        triv_cherry_num = np.random.choice(np.arange(len(triv_cherries)))
                        chosen_cherry = triv_cherries[triv_cherry_num]
                    else:
                        chosen_cherry = copy_of_inputs.pick_cherry(triv_cherries, reducible_pairs)
                    if features.data.loc[chosen_cherry]["cherry_in_tree"] < 1:
                        triv_picked = True
                    else:
                        triv_picked = False
                    pick_ml = False
                else:
                    pick_ml = True

            if pick_ml:
                # predict if cherry
                prediction = pd.DataFrame(np.array([p[:, 1] for p in rf_model.predict_proba(
                    features.data)]).transpose(), index=features.data.index)

                max_cherry = (prediction[1] + prediction[2]).argmax()
                chosen_cherry = prediction.index[max_cherry]
                chosen_cherry_prob = prediction.loc[chosen_cherry, 1] + prediction.loc[chosen_cherry, 2]

                if ml_thresh is not None and chosen_cherry_prob < ml_thresh:
                    random_cherry_num = np.random.choice(len(reducible_pairs))
                    chosen_cherry = list(reducible_pairs)[random_cherry_num]

                df_pred.loc[len(df_pred)] = [*chosen_cherry, *prediction.loc[chosen_cherry], np.nan]

                if relabel and features.data.loc[chosen_cherry]["trivial"] == 1 and features.data.loc[chosen_cherry]["cherry_in_tree"] < 1:
                    triv_picked = True
                elif relabel_cher_triv and features.data.loc[chosen_cherry]["trivial"] == 1 and \
                        features.data.loc[chosen_cherry]["cherry_in_tree"] < 1 and \
                        prediction.loc[chosen_cherry][1] > prediction.loc[chosen_cherry][[0, 2, 3]].sum():
                    triv_picked = True
                else:
                    triv_picked = False

                if progress and (pick_ml or pick_ml_triv):
                    print(f"Instance {self.instance} {problem_type}: chosen cherry = {chosen_cherry}, "
                          f"ML prediction = {list(prediction.loc[chosen_cherry])}")

            if pick_random:
                random_cherry_num = np.random.choice(len(reducible_pairs))
                chosen_cherry = list(reducible_pairs)[random_cherry_num]
                if reduce_trivial:
                    triv_picked = False
                elif relabel and copy_of_inputs.trivial_check(chosen_cherry, reducible_pairs[chosen_cherry]):
                    triv_picked = True
                else:
                    triv_picked = False

            CPS += [chosen_cherry]

            # RELABEL
            if triv_picked and relabel:
                reducible_pairs, merged_cherries = copy_of_inputs.relabel_trivial(*chosen_cherry, reducible_pairs)
                if pick_ml or pick_ml_triv:
                    features.relabel_trivial_features(*chosen_cherry, reducible_pairs, merged_cherries, copy_of_inputs.trees)

            # UPDATE SOME FEATURES BEFORE
            if pick_ml or pick_ml_triv:
                features.update_cherry_features_before(chosen_cherry, reducible_pairs, copy_of_inputs.trees)

            # REDUCE CHOSEN CHERRY FROM FOREST
            new_reduced = copy_of_inputs.reduce_pair_in_all(chosen_cherry, reducible_pairs=reducible_pairs)
            reducible_pairs = copy_of_inputs.update_reducible_pairs(reducible_pairs, new_reduced)
            if progress and (pick_ml or pick_ml_triv):
                print(f"Instance {self.instance} {problem_type}: {len(reducible_pairs)} reducible pairs left")
            reduced_trees += [new_reduced]

            if len(copy_of_inputs.trees) == 0:
                break

            # UPDATE FEATURES AFTER REDUCTION
            if pick_ml or pick_ml_triv:
                copy_of_inputs.update_node_comb_length(*chosen_cherry, new_reduced)
                features.update_cherry_features_after(chosen_cherry, reducible_pairs, copy_of_inputs.trees, new_reduced)

        return CPS, reduced_trees, df_pred

    # select order of chosen cherry
    def pick_order(self, x, y, new_reduced, return_cherry=True):
        leaf_x_left = 0
        leaf_y_left = 0
        for t, tree in self.trees.items():
            if t in new_reduced:
                continue
            if x in tree.leaves:
                leaf_x_left += 1
            if y in tree.leaves:
                leaf_y_left += 1
        if return_cherry:
            # FAVOR X, Y OVER Y, X
            if leaf_x_left <= leaf_y_left:
                return x, y
            else:
                return y, x
        else:
            return leaf_x_left, leaf_y_left

    def pick_cherry(self, triv_cherries, reducible_pairs):
        leaf_left = dict()
        for x, y in triv_cherries:
            leaf_x_left, leaf_y_left = self.pick_order(x, y, reducible_pairs[x, y], return_cherry=False)
            leaf_left[(x, y)] = leaf_x_left
            leaf_left[(y, x)] = leaf_y_left
        best_cherry_id = np.argmin(list(leaf_left.values()))
        return list(leaf_left)[best_cherry_id]

    # when using machine learning, update the topological/combinatorial length of nodes
    def update_node_comb_length(self, x, y, reduced_trees):
        for t in reduced_trees:
            try:
                self.trees[t].nw.nodes[y]["node_comb"] -= 1
            except KeyError:
                continue

    # Finds the set of reducible pairs in all trees
    # Returns a dictionary with reducible pairs as keys, and the trees they reduce as values.
    def find_all_pairs(self):
        reducible_pairs = dict()
        for i, t in self.trees.items():
            red_pairs_t = t.find_all_reducible_pairs()
            for pair in red_pairs_t:
                if pair in reducible_pairs:
                    reducible_pairs[pair].add(i)
                else:
                    reducible_pairs[pair] = {i}
        return reducible_pairs

    # Returns the updated dictionary of reducible pairs in all trees after a reduction (with the trees they reduce as values)
    # we only need to update for the trees that got reduced: 'new_red_treed'
    def update_reducible_pairs(self, reducible_pairs, new_red_trees):
        # Remove trees to update from all pairs
        pair_del = []
        for pair, trees in reducible_pairs.items():
            trees.difference_update(new_red_trees)
            if len(trees) == 0:
                pair_del.append(pair)
        for pair in pair_del:
            del reducible_pairs[pair]
        # Add the trees to the right pairs again
        for index in new_red_trees:
            if index in self.trees:
                t = self.trees[index]
                red_pairs_t = t.find_all_reducible_pairs()
                for pair in red_pairs_t:
                    if pair in reducible_pairs:
                        reducible_pairs[pair].add(index)
                    else:
                        reducible_pairs[pair] = {index}
        return reducible_pairs

    # reduces the given pair in all trees
    # Returns the set of trees that were reduced
    # CHANGES THE SET OF TREES, ONLY PERFORM IN A COPY OF THE CLASS INSTANCE
    def reduce_pair_in_all(self, pair, reducible_pairs=None):
        if not len(reducible_pairs):
            print("no reducible pairs")
        if reducible_pairs is None:
            reducible_pairs = dict()
        reduced_trees_for_pair = []
        if pair in reducible_pairs:
            trees_to_reduce = reducible_pairs[pair]
        else:
            trees_to_reduce = deepcopy(self.trees)
        for t in trees_to_reduce:
            if t in self.trees:
                tree = self.trees[t]
                # print(t, tree.leaves)
                if tree.reduce_pair(*pair):
                    reduced_trees_for_pair += [t]
                    if (self.trees[t].root == 0 and len(tree.nw.edges()) <= 1) or \
                            (self.trees[t].root == 2 and len(tree.nw.edges()) <= 2):
                        # print(t, pair, tree.leaves)
                        del self.trees[t]
        return set(reduced_trees_for_pair)

    def pick_trivial(self, reducible_pairs):
        trivial_cherries = []
        trivial_in_all_cherries = []
        for c, trees in reducible_pairs.items():
            if len(trees) == len(self.trees):
                trivial_in_all_cherries.append(c)
                continue
            trivial_check = self.trivial_check(c, trees)
            if trivial_check:
                trivial_cherries.append(c)

        if trivial_in_all_cherries:
            chosen_cherry = trivial_in_all_cherries[np.random.choice(len(trivial_in_all_cherries))]
            triv_picked = False
        elif trivial_cherries:
            chosen_cherry = trivial_cherries[np.random.choice(len(trivial_cherries))]
            triv_picked = True
        else:
            chosen_cherry = None
            triv_picked = False

        return chosen_cherry, triv_picked

    def trivial_check(self, c, trees):
        if len(trees) == len(self.trees):
            return False
        return len([t for t, tree in self.trees.items() if (set(c).issubset(tree.leaves) and t not in trees)]) == 0

    # Returns all trivial pairs involving the leaf l
    def trivial_pair_with(self, l):
        pairs = set()
        # Go through all trees t with index i.
        for i, t in self.trees.items():
            # If the leaf occurs in t
            if l in t.leaves:
                # Compute reducible pairs of t with the leaf as first coordinate
                pairs_in_t = t.find_pairs_with_first(l)
                # If we did not have a set of candidate pairs yet, use pairs_in_t
                if not pairs:
                    pairs = pairs_in_t
                # Else, the candidate pairs must also be in t, so take intersection
                else:
                    pairs = pairs & pairs_in_t
                # If we do not have any candidate pairs after checking a tree with l as leaf, we stop.
                if not pairs:
                    break
        return pairs

    def relabel_trivial(self, x, y, reducible_pairs):
        # print(f"Cherry = {(x, y)}: RELABEL X = {x} to Y = {y}")
        merged_cherries = set()
        new_cherries = set()
        for t, tree in self.trees.items():
            if t in reducible_pairs[(x, y)]:
                continue
            if x in tree.leaves:
                # change leaf set
                tree.leaves.remove(x)
                tree.leaves.add(y)

                # relabel x to y
                tree.nw = nx.relabel_nodes(tree.nw, {x: y})

                # check if we have a new cherry now
                for p in tree.nw.predecessors(y):
                    for c in tree.nw.successors(p):
                        if c == y:
                            continue
                        if c not in tree.leaves:
                            continue
                        if (c, y) in reducible_pairs:
                            reducible_pairs[(c, y)].add(t)
                            reducible_pairs[(y, c)].add(t)
                            try:
                                del reducible_pairs[(c, x)], reducible_pairs[(x, c)]
                                merged_cherries.add((x, c))
                            except KeyError:
                                pass
                        else:
                            # add to reducible_pairs?
                            reducible_pairs[(c, y)] = {t}
                            reducible_pairs[(y, c)] = {t}
                            new_cherries.add((c, y))
                            try:
                                del reducible_pairs[(c, x)], reducible_pairs[(x, c)]
                            except KeyError:
                                pass
        return reducible_pairs, merged_cherries


#######################################################################
########              PHYLOGENETIC TREE CLASS                  ########
#######################################################################

# A class representing a phylogenetic tree
# Contains methods to reduce trees
class PhT:
    def __init__(self, tree=None):
        if tree is None:
            # the actual graph
            self.nw = nx.DiGraph()
            # the set of leaf labels of the network
            self.leaves = set()
            self.root = 2
        else:
            self.nw = tree
            self.root = 0
            self.leaves = get_leaves(self.nw)

    # Builds a tree from a newick string
    def tree_from_newick(self, newick=None, current_labels=dict()):
        self.nw, self.leaves, current_labels, distances = newick_to_tree(newick, current_labels)
        return current_labels, distances

    # Checks whether the pair (x,y) forms a cherry in the tree
    def is_cherry(self, x, y):
        if (x not in self.leaves) or (y not in self.leaves):
            return False
        px = -1
        py = -1
        for p in self.nw.predecessors(x):
            px = p
        for p in self.nw.predecessors(y):
            py = p
        return px == py

    # Returns the height of (x,y) if it is a cherry:
    #     i.e.: length(p,x)+length(p,y)/2
    # Returns false otherwise
    def height_of_cherry(self, x, y):
        if (x not in self.leaves) or (y not in self.leaves):
            return False
        px = -1
        py = -1
        for p in self.nw.predecessors(x):
            px = p
        for p in self.nw.predecessors(y):
            py = p
        if px == py:
            height = [float(self.nw[px][x]['length']), float(self.nw[py][y]['length'])]
            return height
        return False

        # suppresses a degree-2 node v and returns true if successful

    # the new arc has length length(p,v)+length(v,c)
    # returns false if v is not a degree-2 node
    def clean_node(self, v):
        if self.nw.out_degree(v) == 1 and self.nw.in_degree(v) == 1:
            pv = -1
            for p in self.nw.predecessors(v):
                pv = p
            cv = -1
            for c in self.nw.successors(v):
                cv = c
            self.nw.add_edges_from([(pv, cv, self.nw[pv][v])])
            if 'length' in self.nw[pv][v] and 'length' in self.nw[v][cv]:
                self.nw[pv][cv]['length'] = self.nw[pv][v]['length'] + self.nw[v][cv]['length']
            self.nw.remove_node(v)
            return True
        return False

    # reduces the pair (x,y) in the tree if it is present as cherry
    # i.e., removes the leaf x and its incoming arc, and then cleans up its parent node.
    # note that if px, and py have different lengths, the length of px is lost in the new network.
    # returns true if successful and false otherwise
    def reduce_pair(self, x, y):
        if x not in self.leaves or y not in self.leaves:
            return False
        px = - 1
        py = - 1
        for p in self.nw.predecessors(x):
            px = p
        for p in self.nw.predecessors(y):
            py = p
        if self.is_cherry(x, y):
            self.nw.remove_node(x)
            self.leaves.remove(x)
            self.clean_node(py)
            return True
        return False

    # Returns all reducible pairs in the tree involving x, where x is the first element
    def find_pairs_with_first(self, x):
        pairs = set()
        px = -1
        for p in self.nw.predecessors(x):
            px = p
        if self.nw.out_degree(px) > 1:
            for cpx in self.nw.successors(px):
                if cpx in self.leaves:
                    if cpx == x:
                        continue
                    pairs.add((x, cpx))
        return pairs - {x, x}

    # Returns all reducible pairs in the tree
    def find_all_reducible_pairs(self):
        red_pairs = set()
        for l in self.leaves:
            red_pairs = red_pairs.union(self.find_pairs_with_first(l))
        return red_pairs


#######################################################################
########              PHYLOGENETIC NETWORK CLASS               ########
#######################################################################


# A class for phylogenetic networks
class PhN:
    def __init__(self, net=None, seq=set(), newick=None, best_tree_from_network=None, reduced_trees=None, heights=None):
        # the actual graph
        self.nw = nx.DiGraph()
        # the set of leaf labels of the network
        self.leaves = set()
        # a dictionary giving the node for a given leaf label
        self.labels = dict()
        # the number of nodes in the graph
        self.no_nodes = 0
        self.leaf_nodes = dict()
        self.TCS = seq
        self.CPS = seq
        self.newick = newick
        self.reducible_pairs = set()
        self.reticulated_cherries = set()
        self.cherries = set()
        self.level = None
        self.no_embedded_trees = 0
        # if a cherry-picking sequence is given, build the network from this sequence
        if seq:
            total_len = len(seq)
            current_trees_embedded = set()
            # Creates a phylogenetic network from a cherry picking sequence:
            if reduced_trees:
                for i, pair in enumerate(reversed(seq)):
                    if heights:
                        self.add_pair(*pair, red_trees=reduced_trees[total_len - 1 - i],
                                      current_trees=current_trees_embedded, height=heights[total_len - 1 - i])
                        current_trees_embedded = current_trees_embedded | reduced_trees[total_len - 1 - i]
                    else:
                        self.add_pair(*pair, red_trees=reduced_trees[total_len - 1 - i],
                                      current_trees=current_trees_embedded)
                self.no_embedded_trees = len(current_trees_embedded)
            else:
                for pair in reversed(seq):
                    self.add_pair(*pair)
        # if a newick string is given, build the network from the newick string
        elif newick:
            self.newick = newick
            network, self.leaves, self.labels = self.newick_to_network(newick)
            self.nw = network
            self.no_nodes = len(list(self.nw))
            self.compute_leaf_nodes()
        # if a network 'best_tree_from_network' is given, extract the best tree from this network and use this tree
        # as the network
        elif best_tree_from_network:
            self.nw.add_edges_from(best_tree_from_network.Best_Tree())
            self.labels = best_tree_from_network.labels
            self.leaf_nodes = best_tree_from_network.leaf_nodes
            self.leaves = best_tree_from_network.leaves
            self.no_nodes = best_tree_from_network.no_nodes
            # self.Clean_Up()

        elif net:
            self.nw = net
            self.leaves = get_leaves(self.nw)
            self.labels = {l: l for l in self.leaves}

    def is_cherry(self, x, y):
        if (x not in self.leaves) or (y not in self.leaves):
            return False
        px = -1
        py = -1
        for p in self.nw.predecessors(x):
            px = p
        for p in self.nw.predecessors(y):
            py = p
        return px == py

    # Returns true if (x_label,y_label) forms a reticulate cherry in the network, false otherwise
    def is_ret_cherry(self, x_label, y_label):
        if not x_label in self.leaves or not x_label in self.leaves:
            return False
        x = self.labels[x_label]
        y = self.labels[y_label]
        px = -1
        py = -1
        for p in self.nw.predecessors(x):
            px = p
        for p in self.nw.predecessors(y):
            py = p
        return (self.nw.in_degree(px) > 1) and self.nw.out_degree(px) == 1 and (py in self.nw.predecessors(px))

    # Returns the leaf nodes of the network
    def compute_leaf_nodes(self):
        self.leaf_nodes = dict()
        for v in self.labels:
            self.leaf_nodes[self.labels[v]] = v

    def reticulations_non_binary(self):
        return [self.nw.in_degree(v)-1 for v in self.nw.nodes() if self.nw.in_degree(v) >= 2]

    # Adds a pair to the network, using the construction from a cherry-picking sequence
    # returns false if y is not yet in the network and the network is not empty
    def add_pair(self, x, y, red_trees=set(), current_trees=set(), height=[1, 1]):
        # if the network is empty, create a cherry (x,y)
        if len(self.leaves) == 0:
            self.nw.add_edge(0, 1, no_of_trees=len(red_trees), length=0)
            self.nw.add_edge(1, 2, no_of_trees=len(red_trees), length=height[0])
            self.nw.add_edge(1, 3, no_of_trees=len(red_trees), length=height[1])
            self.leaves = {x, y}
            self.labels[x] = 2
            self.labels[y] = 3
            self.leaf_nodes[2] = x
            self.leaf_nodes[3] = y
            self.no_nodes = 4
            return True
        # if y is not in the network return false, as there is no way to add the pair and get a phylogenetic network
        if y not in self.leaves:
            return False
        # add the pair to the existing network
        node_y = self.labels[y]
        parent_node_y = -1
        for p in self.nw.predecessors(node_y):
            parent_node_y = p

        # first add all edges around y
        length_incoming_y = self.nw[parent_node_y][node_y]['length']
        no_of_trees_incoming_y = self.nw[parent_node_y][node_y]['no_of_trees']
        height_goal_x = height[0]
        if height[1] < length_incoming_y:
            height_pair_y_real = height[1]
        else:
            height_pair_y_real = length_incoming_y
            height_goal_x += height[1] - height_pair_y_real

        self.nw.add_edge(node_y, self.no_nodes, no_of_trees=no_of_trees_incoming_y + len(red_trees - current_trees),
                         length=height_pair_y_real)
        self.nw[parent_node_y][node_y]['length'] = length_incoming_y - height_pair_y_real
        self.leaf_nodes.pop(self.labels[y], False)
        self.labels[y] = self.no_nodes
        self.leaf_nodes[self.no_nodes] = y

        # Now also add edges around x
        # x is not yet in the network, so make a cherry (x,y)
        if x not in self.leaves:
            self.nw.add_edge(node_y, self.no_nodes + 1, no_of_trees=len(red_trees), length=height_goal_x)
            self.leaves.add(x)
            self.labels[x] = self.no_nodes + 1
            self.leaf_nodes[self.no_nodes + 1] = x
            self.no_nodes += 2
        # x is already in the network, so create a reticulate cherry (x,y)
        else:
            node_x = self.labels[x]
            for parent in self.nw.predecessors(node_x):
                px = parent
            length_incoming_x = self.nw[px][node_x]['length']
            no_of_trees_incoming_x = self.nw[px][node_x]['no_of_trees']
            # if x is below a reticulation, and the height of the new pair is above the height of this reticulation,
            # add the new hybrid arc to the existing reticulation
            if self.nw.in_degree(px) > 1 and length_incoming_x <= height_goal_x:
                self.nw.add_edge(node_y, px, no_of_trees=len(red_trees), length=height_goal_x - length_incoming_x)
                self.nw[px][node_x]['no_of_trees'] += len(red_trees)
                self.no_nodes += 1
            # create a new reticulation vertex above x to attach the hybrid arc to
            else:
                height_pair_x = min(height_goal_x, length_incoming_x)
                self.nw.add_edge(node_y, node_x, no_of_trees=len(red_trees), length=height_goal_x - height_pair_x)
                self.nw.add_edge(node_x, self.no_nodes + 1, no_of_trees=no_of_trees_incoming_x + len(red_trees),
                                 length=height_pair_x)
                self.nw[px][node_x]['length'] = length_incoming_x - height_pair_x
                self.leaf_nodes.pop(self.labels[x], False)
                self.labels[x] = self.no_nodes + 1
                self.leaf_nodes[self.no_nodes + 1] = x
                self.no_nodes += 2
        return True

    # suppresses v if it is a degree-2 node
    def Clean_Node(self, v):
        if self.nw.out_degree(v) == 1 and self.nw.in_degree(v) == 1:
            pv = -1
            for p in self.nw.predecessors(v):
                pv = p
            cv = -1
            for c in self.nw.successors(v):
                cv = c
            self.nw.add_edges_from([(pv, cv, self.nw[v][cv])])
            if self.nw[pv][v]['length'] and self.nw[v][cv]['length']:
                self.nw[pv][cv]['length'] = self.nw[pv][v]['length'] + self.nw[v][cv]['length']
            self.nw.remove_node(v)
            return True
        return False

    # reduces the pair (x_label,y_label) if it is reducible in the network
    # returns a new set reducible pairs that involve the leaves x_label and y_label
    def reduce_pair(self, x_label, y_label):
        if x_label not in self.leaves or not y_label in self.leaves:
            return set()
        x = self.labels[x_label]
        y = self.labels[y_label]
        px = -1
        py = -1
        for p in self.nw.predecessors(x):
            px = p
        for p in self.nw.predecessors(y):
            py = p
        if self.is_cherry(x_label, y_label):
            self.reducible_pairs.difference_update({(x_label, y_label), (y_label, x_label)})
            self.nw.remove_node(x)
            self.leaves.remove(x_label)
            self.labels.pop(x_label, False)
            self.Clean_Node(py)
            # AddCherriesInvolving y
            new_pairs = {("no_leaf", "no_leaf")} | self.Find_Pairs_With_First(y_label) | self.Find_Pairs_With_Second(
                y_label)
            self.reducible_pairs = self.reducible_pairs.union(new_pairs - {("no_leaf", "no_leaf")})
            return new_pairs
        if self.is_ret_cherry(x_label, y_label):
            # print(f"{(x_label, y_label)} is a reticulated cherry")
            self.reducible_pairs.difference_update({(x_label, y_label), (y_label, x_label)})
            self.nw.remove_edge(py, px)
            self.Clean_Node(px)
            self.Clean_Node(py)
            # AddCherriesInvolving x and y
            new_pairs = {("no_leaf", "no_leaf")} | self.Find_Pairs_With_First(x_label) | self.Find_Pairs_With_Second(
                x_label) | self.Find_Pairs_With_First(y_label) | self.Find_Pairs_With_Second(y_label)
            self.reducible_pairs = self.reducible_pairs.union(new_pairs - {("no_leaf", "no_leaf")})
            return new_pairs
        return set()

    # Returns all reducible pairs in the network where x_label is the first element of the pair
    def Find_Pairs_With_First(self, x_label):
        pairs = set()
        x = self.labels[x_label]
        px = -1
        for p in self.nw.predecessors(x):
            px = p
        if self.nw.in_degree(px) > 1:
            for ppx in self.nw.predecessors(px):
                for cppx in self.nw.successors(ppx):
                    if cppx in self.leaf_nodes:
                        pairs.add((x_label, self.leaf_nodes[cppx]))
        if self.nw.out_degree(px) > 1:
            for cpx in self.nw.successors(px):
                if cpx in self.leaf_nodes:
                    pairs.add((x_label, self.leaf_nodes[cpx]))
        return pairs - {(x_label, x_label)}

    # Returns all reducible pairs in the network where x_label is the second element of the pair
    def Find_Pairs_With_Second(self, x_label):
        pairs = set()
        x = self.labels[x_label]
        px = -1
        for p in self.nw.predecessors(x):
            px = p
        if self.nw.out_degree(px) > 1:
            for cpx in self.nw.successors(px):
                if cpx in self.leaf_nodes:
                    pairs.add((self.leaf_nodes[cpx], x_label))
                if self.nw.in_degree(cpx) > 1:
                    for ccpx in self.nw.successors(cpx):
                        if ccpx in self.leaf_nodes:
                            pairs.add((self.leaf_nodes[ccpx], x_label))
        return pairs - {(x_label, x_label)}
back to top