1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#(c) 2013-2014 by Authors
#This file is a part of Ragout program.
#Released under the BSD license (see LICENSE file)

"""
This module provides some phylogeny-related functions
(includeing one for solving half-breakpoint state parsimony problem)
"""

import math
from collections import defaultdict
from itertools import chain
import logging

from ragout.parsers.phylogeny_parser import (parse_tree, PhyloException)
from ragout.phylogeny.inferer import TreeInferer
logger = logging.getLogger()

class Phylogeny:
    """
    Represents phylogenetic tree and scores it with
    given half-breakpoint states
    """
    def __init__(self, tree):
        self.tree = tree
        self.tree_string = str(tree)
        self._scale_branches()

    @classmethod
    def from_newick(phylo, newick_str):
        return phylo(parse_tree(newick_str))

    @classmethod
    def from_permutations(phylo, perm_container):
        ti = TreeInferer(perm_container)
        return phylo(ti.build())

    def _scale_branches(self):
        """
        Fits mu coefficient according to branch lengths
        to avoid underflows/overflows
        """
        def tree_length(node):
            if node.terminal:
                return []

            lengths = []
            for node, _bootstrap, length in node.get_edges():
                assert length is not None
                lengths.append(length)
                lengths.extend(tree_length(node))

            return lengths

        lengths = tree_length(self.tree)
        assert len(lengths)
        self.mu = float(1) / _median(lengths)
        logger.debug("Branch lengths: {0}, mu = {1}".format(lengths, self.mu))

    def estimate_tree(self, leaf_states):
        """
        Scores the tree with weighted parsimony procedure
        """
        all_states = set(leaf_states.values())

        #score of a tree branch
        def branch_score(parent, child, branch):
            if parent == child or child is None:
                return 0.0
            else:
                #prevent underflow
                length = max(branch, 0.0000001)
                #adding one to counter possibly small exp value 
                return 1.0 + math.exp(-self.mu * length)

        #recursive
        def rec_helper(root):
            if root.terminal:
                leaf_score = (lambda s: 0.0 if s == leaf_states[root.identifier]
                                            else float("inf"))
                return {s : leaf_score(s) for s in all_states}

            nodes_scores = {}
            for node, _bootstrap, _length  in root.get_edges():
                nodes_scores[node] = rec_helper(node)

            root_scores = defaultdict(float)
            for root_state in all_states:
                for node, _bootstrap, branch_length in root.edges:
                    min_score = float("inf")
                    for child_state in all_states:
                        score = (nodes_scores[node][child_state] +
                                branch_score(root_state, child_state,
                                             branch_length))
                        min_score = min(min_score, score)
                    root_scores[root_state] += min_score

            return root_scores

        return min(rec_helper(self.tree).values())

    def terminals_dfs_order(self):
        """
        Returns terminal nodes' names in dfs order
        """
        def get_labels(root):
            if root.terminal:
                return [root.identifier]

            edges = sorted(root.get_edges(), key=lambda e: e[2], reverse=True)
            edges = sorted(edges, key=lambda e: e[0].terminal)
            return list(chain(*map(lambda e: get_labels(e[0]), edges)))

        return get_labels(self.tree)


def _median(values):
    sorted_values = sorted(values)
    return sorted_values[(len(values) - 1) / 2]