https://github.com/blab/mumps-wa-phylodynamics
Tip revision: b8358a0d49d70670dbab9eeffa9972c277b3021b authored by Louise Moncla on 17 March 2021, 19:25:20 UTC
adding in callouts for Wisconsin H and A genomes
adding in callouts for Wisconsin H and A genomes
Tip revision: b8358a0
baltic-iqtree.py
import re
import copy
import math
import datetime as dt
import json
def decimalDate(date,fmt="%Y-%m-%d",variable=False,dateSplitter='-'):
""" Converts calendar dates in specified format to decimal date. """
if variable==True: ## if date is variable - extract what is available
dateL=len(date.split(dateSplitter))
if dateL==2:
fmt=dateSplitter.join(fmt.split(dateSplitter)[:-1])
elif dateL==1:
fmt=dateSplitter.join(fmt.split(dateSplitter)[:-2])
adatetime=dt.datetime.strptime(date,fmt) ## convert to datetime object
year = adatetime.year ## get year
boy = dt.datetime(year, 1, 1) ## get beginning of the year
eoy = dt.datetime(year + 1, 1, 1) ## get beginning of next year
return year + ((adatetime - boy).total_seconds() / ((eoy - boy).total_seconds())) ## return fractional year
def convertDate(x,start,end):
""" Converts calendar dates between given formats """
return dt.datetime.strftime(dt.datetime.strptime(x,start),end)
class clade: ## clade class
def __init__(self,givenName):
self.branchType='leaf' ## clade class poses as a leaf
self.subtree=None ## subtree will contain all the branches that were collapsed
self.leaves=None
self.length=0.0
self.height=None
self.absoluteTime=None
self.parent=None
self.traits={}
self.index=None
self.name=givenName ## the pretend tip name for the clade
self.numName=givenName
self.x=None
self.y=None
self.lastHeight=None ## refers to the height of the highest tip in the collapsed clade
self.lastAbsoluteTime=None ## refers to the absolute time of the highest tip in the collapsed clade
self.width=1
class node: ## node class
def __init__(self):
self.branchType='node'
self.length=0.0 ## branch length, recovered from string
self.height=None ## height, set by traversing the tree, which adds up branch lengths along the way
self.absoluteTime=None ## branch end point in absolute time, once calibrations are done
self.parent=None ## reference to parent node of the node
self.children=[] ## a list of descendent branches of this node
self.traits={} ## dictionary that will contain annotations from the tree string, e.g. {'posterior':1.0}
self.index=None ## index of the character designating this object in the tree string, it's a unique identifier for every object in the tree
self.childHeight=None ## the youngest descendant tip of this node
self.x=None ## X and Y coordinates of this node, once drawTree() is called
self.y=None
## contains references to all tips of this node
self.leaves=set() ## is a set of tips that are descended from it
class leaf: ## leaf class
def __init__(self):
self.branchType='leaf'
self.name=None ## name of tip after translation, since BEAST trees will generally have numbers for taxa but will provide a map at the beginning of the file
self.numName=None ## the original name of the taxon, would be an integer if coming from BEAST, otherwise can be actual name
self.index=None ## index of the character that defines this object, will be a unique ID for each object in the tree
self.length=None ## branch length
self.absoluteTime=None ## position of tip in absolute time
self.height=None ## height of tip
self.parent=None ## parent
self.traits={} ## trait dictionary
self.x=None ## position of tip on x axis if the tip were to be plotted
self.y=None ## position of tip on y axis if the tip were to be plotted
class tree: ## tree class
def __init__(self):
self.cur_node=node() ## current node is a new instance of a node class
self.cur_node.index='Root' ## first object in the tree is the root to which the rest gets attached
self.cur_node.length=0.0 ## startind node branch length is 0
self.cur_node.height=0.0 ## starting node height is 0
self.root=None #self.cur_node ## root of the tree is current node
self.Objects=[] ## tree objects have a flat list of all branches in them
self.tipMap=None
self.treeHeight=0 ## tree height is the distance between the root and the most recent tip
self.ySpan=0.0
def add_node(self,i):
""" Attaches a new node to current node. """
new_node=node() ## new node instance
new_node.index=i ## new node's index is the position along the tree string
if self.root is None:
self.root=new_node
new_node.parent=self.cur_node ## new node's parent is current node
self.cur_node.children.append(new_node) ## new node is a child of current node
self.cur_node=new_node ## current node is now new node
self.Objects.append(self.cur_node) ## add new node to list of objects in the tree
def add_leaf(self,i,name):
""" Attach a new leaf (tip) to current node. """
new_leaf=leaf() ## new instance of leaf object
new_leaf.index=i ## index is position along tree string
if self.root is None:
self.root=new_leaf
new_leaf.parent=self.cur_node ## leaf's parent is current node
self.cur_node.children.append(new_leaf) ## assign leaf to parent's children
new_leaf.numName=name ## numName is the name tip has inside tree string, BEAST trees usually have numbers for tip names
self.cur_node=new_leaf ## current node is now new leaf
self.Objects.append(self.cur_node) ## add leaf to all objects in the tree
def subtree(self,k=None,traverse_condition=None):
""" Generate a subtree (as a baltic tree object) from a traversal.
k is the starting branch for traversal (default: root).
traverse_condition is a function that determines whether a child branch should be visited (default: always true).
Returns a new baltic tree instance.
Note - custom traversal functions can result in multitype trees.
If this is undesired call singleType() on the resulting subtree afterwards. """
subtree=copy.deepcopy(self.traverse_tree(k,include_condition=lambda k:True,traverse_condition=traverse_condition))
if subtree is None or len([k for k in subtree if k.branchType=='leaf'])==0:
return None
else:
local_tree=tree() ## create a new tree object where the subtree will be
local_tree.Objects=subtree ## assign branches to new tree object
local_tree.root=subtree[0] ## connect tree object's root with subtree
subtree_set=set(subtree) ## turn branches into set for quicker look up later
if traverse_condition is not None: ## didn't use default traverse condition, might need to deal with hanging nodes and prune children
for nd in local_tree.getInternal(): ## iterate over nodes
nd.children=list(filter(lambda k:k in subtree_set,nd.children)) ## only keep children seen in traversal
local_tree.fixHangingNodes()
return local_tree
def singleType(self):
""" Removes any branches with a single child (multitype nodes). """
multiTypeNodes=[k for k in self.Objects if k.branchType=='node' and len(k.children)==1]
while len(multiTypeNodes)>0:
multiTypeNodes=[k for k in self.Objects if k.branchType=='node' and len(k.children)==1]
for k in sorted(multiTypeNodes,key=lambda x:-x.height):
child=k.children[0] ## fetch child
grandparent=k.parent ## fetch grandparent
child.parent=grandparent ## child's parent is now grandparent
grandparent.children.append(child) ## add child to grandparent's children
grandparent.children.remove(k) ## remove old parent from grandparent's children
grandparent.children=list(set(grandparent.children))
child.length+=k.length ## adjust child length
multiTypeNodes.remove(k) ## remove old parent from multitype nodes
self.Objects.remove(k) ## remove old parent from all objects
self.sortBranches()
def setAbsoluteTime(self,date):
""" place all objects in absolute time by providing the date of the most recent tip """
for i in self.Objects: ## iterate over all objects
i.absoluteTime=date-self.treeHeight+i.height ## heights are in units of time from the root
def treeStats(self):
""" provide information about the tree """
self.traverse_tree() ## traverse the tree
obs=self.Objects ## convenient list of all objects in the tree
print('\nTree height: %.6f\nTree length: %.6f'%(self.treeHeight,sum([x.length for x in obs]))) ## report the height and length of tree
nodes=self.getInternal() ## get all nodes
strictlyBifurcating=False ## assume tree is not strictly bifurcating
multiType=False
singleton=False
N_children=[len(x.children) for x in nodes]
if len(N_children)==0:
singleton=True
else:
minChildren,maxChildren=min(N_children),max(N_children) ## get the largest number of descendant branches of any node
if maxChildren==2 and minChildren==2: ## if every node has at most two children branches
strictlyBifurcating=True ## it's strictly bifurcating
if minChildren==1:
multiType=True
hasTraits=False ## assume tree has no annotations
maxAnnotations=max([len(x.traits) for x in obs]) ## check the largest number of annotations any branch has
if maxAnnotations>0: ## if it's more than 0
hasTraits=True ## there are annotations
if strictlyBifurcating:
print('strictly bifurcating tree') ## report
if multiType:
print('multitype tree') ## report
if singleton:
print('singleton tree')
if hasTraits:
print('annotations present') ## report
print('\nNumbers of objects in tree: %d (%d nodes and %d leaves)\n'%(len(obs),len(nodes),len(obs)-len(nodes))) ## report numbers of different objects in the tree
def traverse_tree(self,cur_node=None,include_condition=lambda k:k.branchType=='leaf',traverse_condition=lambda k:True,collect=None,verbose=False):
if cur_node==None: ## if no starting point defined - start from root
for k in self.Objects: ## reset various parameters
if k.branchType=='node':
k.leaves=set()
k.childHeight=None
k.height=None
if verbose==True:
print('Initiated traversal from root')
cur_node=self.root#.children[-1]
if collect==None: ## initiate collect list if not initiated
collect=[]
if cur_node.parent and cur_node.height==None: ## cur_node has a parent - set height if it doesn't already
cur_node.height=cur_node.length+cur_node.parent.height
elif cur_node.height==None: ## cur_node does not have a parent (root), if height not set before it's zero
cur_node.height=0.0
if verbose==True:
print('at %s (%s)'%(cur_node.index,cur_node.branchType))
if include_condition(cur_node): ## test if interested in cur_node
collect.append(cur_node) ## add to collect list for reporting later
if cur_node.branchType=='leaf': ## cur_node is a tip
cur_node.parent.leaves.add(cur_node.numName) ## add to parent's list of tips
elif cur_node.branchType=='node': ## cur_node is node
for child in filter(traverse_condition,cur_node.children): ## only traverse through children we're interested
if verbose==True:
print('visiting child %s'%(child.index))
self.traverse_tree(cur_node=child,include_condition=include_condition,traverse_condition=traverse_condition,verbose=verbose,collect=collect) ## recurse through children
if verbose==True:
print('child %s done'%(child.index))
assert len(cur_node.children)>0, 'Tried traversing through hanging node without children. Index: %s'%(cur_node.index)
cur_node.childHeight=max([child.childHeight if child.branchType=='node' else child.height for child in cur_node.children])
if cur_node.parent:
cur_node.parent.leaves=cur_node.parent.leaves.union(cur_node.leaves) ## pass tips seen during traversal to parent
self.treeHeight=cur_node.childHeight ## it's the highest child of the starting node
return collect
def renameTips(self,d=None):
""" Give each tip its correct label using a dictionary. """
if d==None and self.tipMap!=None:
d=self.tipMap
for k in self.getExternal(): ## iterate through leaf objects in tree
k.name=d[k.numName] ## change its name
def sortBranches(self,descending=True):
""" Sort descendants of each node. """
if descending==True:
modifier=-1 ## define the modifier for sorting function later
elif descending==False:
modifier=1
for k in self.getInternal(): ## iterate over nodes
## split node's offspring into nodes and leaves, sort each list individually
nodes=sorted([x for x in k.children if x.branchType=='node'],key=lambda q:(-len(q.leaves)*modifier,q.length*modifier))
leaves=sorted([x for x in k.children if x.branchType=='leaf'],key=lambda q:q.length*modifier)
if modifier==1: ## if sorting one way - nodes come first, leaves later
k.children=nodes+leaves
elif modifier==-1: ## otherwise sort the other way
k.children=leaves+nodes
self.drawTree() ## update x and y positions of each branch, since y positions will have changed because of sorting
def drawTree(self,order=None,verbose=False):
""" Find x and y coordinates of each branch. """
if order==None:
order=self.traverse_tree() ## order is a list of tips recovered from a tree traversal to make sure they're plotted in the correct order along the vertical tree dimension
if verbose==True:
print('Drawing tree in pre-order')
else:
if verbose==True:
print('Drawing tree with provided order')
name_order=[x.numName for x in order]
skips=[1 if isinstance(x,leaf) else x.width+1 for x in order]
for k in self.Objects: ## reset coordinates for all objects
k.x=None
k.y=None
storePlotted=0
drawn={} ## drawn keeps track of what's been drawn
while len(drawn)!=len(self.Objects): # keep drawing the tree until everything is drawn
if verbose==True:
print('Drawing iteration %d'%(len(drawn)))
for k in filter(lambda w:w.index not in drawn,self.Objects): ## iterate through objects that have not been drawn
if k.branchType=='leaf': ## if leaf - get position of leaf, draw branch connecting tip to parent node
if verbose==True:
print('Setting leaf %s y coordinate to'%(k.index))
x=k.height ## x position is height
y_idx=name_order.index(k.numName) ## y position of leaf is given by the order in which tips were visited during the traversal
y=sum(skips[y_idx:]) ## sum across skips to find y position
if verbose==True:
print('%s'%(y))
if isinstance(k,clade) and skips[y_idx]>1: ## if dealing with collapsed clade - adjust y position to be in the middle of the skip
y-=skips[y_idx]/2.0
if verbose==True:
print('adjusting clade y position to %s'%(y))
k.x=x ## set x and y coordinates
k.y=y
drawn[k.index]=None ## remember that this objects has been drawn
if hasattr(k.parent,'yRange')==False: ## if parent doesn't have a maximum extent of its children's y coordinates
setattr(k.parent,'yRange',[k.y,k.y]) ## assign it
if k.branchType=='node': ## if parent is non-root node and y positions of all its children are known
if len([q.y for q in k.children if q.y!=None])==len(k.children):
if verbose==True:
print('Setting node %s coordinates'%(k.index))
x=k.height ## x position is height
children_y_coords=[q.y for q in k.children if q.y!=None] ## get all existing y coordinates of the node
y=sum(children_y_coords)/float(len(children_y_coords)) ## internal branch is in the middle of the vertical bar
k.x=x
k.y=y
drawn[k.index]=None ## remember that this objects has been drawn
minYrange=min([min(child.yRange) if child.branchType=='node' else child.y for child in k.children]) ## get lowest y coordinate across children
maxYrange=max([max(child.yRange) if child.branchType=='node' else child.y for child in k.children]) ## get highest y coordinate across children
setattr(k,'yRange',[minYrange,maxYrange]) ## assign the maximum extent of children's y coordinates
assert len(drawn)>storePlotted,'Got stuck trying to find y positions of objects'
storePlotted=len(drawn)
self.ySpan=sum(skips)
def drawUnrooted(self,n=None,total=None):
"""
Calculate x and y coordinates in an unrooted arrangement.
Code translated from https://github.com/nextstrain/auspice/commit/fc50bbf5e1d09908be2209450c6c3264f298e98c, written by Richard Neher.
"""
if n==None:
total=sum([1 if isinstance(x,leaf) else x.width+1 for x in self.getExternal()])
n=self.root#.children[0]
for k in self.Objects:
k.traits['tau']=0.0
k.x=0.0
k.y=0.0
if n.branchType=='leaf':
w=2*math.pi*1.0/float(total)
else:
w=2*math.pi*len(n.leaves)/float(total)
if n.parent.x==None:
n.parent.x=0.0
n.parent.y=0.0
n.x = n.parent.x + n.length * math.cos(n.traits['tau'] + w*0.5)
n.y = n.parent.y + n.length * math.sin(n.traits['tau'] + w*0.5)
eta=n.traits['tau']
if n.branchType=='node':
for ch in n.children:
if ch.branchType=='leaf':
w=2*math.pi*1.0/float(total)
else:
w=2*math.pi*len(ch.leaves)/float(total)
ch.traits['tau'] = eta
eta += w
self.drawUnrooted(ch,total)
def commonAncestor(self,descendants,numName=False,strict=False):
types=[desc.__class__ for desc in descendants]
assert len(set(types))==1,'More than one type of data detected in descendants list'
if numName==False:
assert sum([1 if k in [w.name for w in self.getExternal()] else 0 for k in descendants])==len(descendants),'Not all specified descendants are in tree: %s'%(descendants)
else:
assert sum([1 if k in [w.numName for w in self.getExternal()] else 0 for k in descendants])==len(descendants),'Not all specified descendants are in tree: %s'%(descendants)
dtype=list(set(types))[0]
allAncestors=sorted([k for k in self.Objects if (k.branchType=='node' or isinstance(k,clade)) and len(k.leaves)>=len(descendants)],key=lambda x:x.height)
if numName==False:
ancestor=[k for k in allAncestors if sum([[self.tipMap[w] for w in k.leaves].count(l) for l in descendants])==len(descendants)][-1]
else:
ancestor=[k for k in allAncestors if sum([[w for w in k.leaves].count(l) for l in descendants])==len(descendants)][-1]
if strict==False:
return ancestor
elif strict==True and len(ancestor.leaves)==len(descendants):
return ancestor
elif strict==True and len(ancestor.leaves)>len(descendants):
return None
def collapseSubtree(self,cl,givenName,verbose=False,widthFunction=lambda k:len(k.leaves)):
""" Collapse an entire subtree into a clade object. """
assert cl.branchType=='node','Cannot collapse non-node class'
collapsedClade=clade(givenName)
collapsedClade.index=cl.index
collapsedClade.leaves=cl.leaves
collapsedClade.length=cl.length
collapsedClade.height=cl.height
collapsedClade.parent=cl.parent
collapsedClade.absoluteTime=cl.absoluteTime
collapsedClade.traits=cl.traits
collapsedClade.width=widthFunction(cl)
if verbose==True:
print('Replacing node %s (parent %s) with a clade class'%(cl.index,cl.parent.index))
parent=cl.parent
remove_from_tree=self.traverse_tree(cl,include_condition=lambda k: True)
collapsedClade.subtree=remove_from_tree
assert len(remove_from_tree)<len(self.Objects),'Attempted collapse of entire tree'
collapsedClade.lastHeight=max([x.height for x in remove_from_tree])
if [x.absoluteTime for x in remove_from_tree].count(None)!=len(remove_from_tree):
collapsedClade.lastAbsoluteTime=max([x.absoluteTime for x in remove_from_tree])
for k in remove_from_tree:
self.Objects.remove(k)
parent.children.remove(cl)
parent.children.append(collapsedClade)
self.Objects.append(collapsedClade)
collapsedClade.parent=parent
if self.tipMap!=None:
self.tipMap[givenName]=givenName
self.traverse_tree()
self.sortBranches()
def uncollapseSubtree(self):
""" Uncollapse all collapsed subtrees. """
while len([k for k in self.Objects if isinstance(k,clade)])>0:
clades=[k for k in self.Objects if isinstance(k,clade)]
for cl in clades:
parent=cl.parent
subtree=cl.subtree
parent.children.remove(cl)
parent.children.append(subtree[0])
self.Objects+=subtree
self.Objects.remove(cl)
if self.tipMap!=None:
self.tipMap.pop(cl.name,None)
self.traverse_tree()
def collapseBranches(self,collapseIf=lambda x:x.traits['posterior']<=0.5,designated_nodes=[],verbose=False):
""" Collapse all branches that satisfy a function collapseIf (default is an anonymous function that returns true if posterior probability is <=0.5).
Alternatively, a list of nodes can be supplied to the script.
Returns a deep copied version of the tree.
"""
newTree=copy.deepcopy(self) ## work on a copy of the tree
if len(designated_nodes)==0: ## no nodes were designated for deletion - relying on anonymous function to collapse nodes
nodes_to_delete=list(filter(lambda n: n.branchType=='node' and collapseIf(n)==True and n!=newTree.root, newTree.Objects)) ## fetch a list of all nodes who are not the root and who satisfy the condition
else:
assert [w.branchType for w in designated_nodes].count('node')==len(designated_nodes),'Non-node class detected in list of nodes designated for deletion'
assert len([w for w in designated_nodes if w!=newTree.root])==0,'Root node was designated for deletion'
nodes_to_delete=list(filter(lambda w: w.index in [q.index for q in designated_nodes], newTree.Objects)) ## need to look up nodes designated for deletion by their indices, since the tree has been copied and nodes will have new memory addresses
if verbose==True:
print('%s nodes set for collapsing: %s'%(len(nodes_to_delete),[w.index for w in nodes_to_delete]))
# assert len(nodes_to_delete)<len(newTree.getInternal())-1,'Chosen cutoff would remove all branches'
while len(nodes_to_delete)>0: ## as long as there are branches to be collapsed - keep reducing the tree
if verbose==True:
print('Continuing collapse cycle, %s nodes left'%(len(nodes_to_delete)))
for k in sorted(nodes_to_delete,key=lambda x:-x.height): ## start with branches near the tips
zero_node=k.children ## fetch the node's children
k.parent.children+=zero_node ## add them to the zero node's parent
old_parent=k ## node to be deleted is the old parent
new_parent=k.parent ## once node is deleted, the parent to all their children will be the parent of the deleted node
if new_parent==None:
new_parent=self.root
if verbose==True:
print('Removing node %s, attaching children %s to node %s'%(old_parent.index,[w.index for w in k.children],new_parent.index))
for w in newTree.Objects: ## assign the parent of deleted node as the parent to any children of deleted node
if w.parent==old_parent:
w.parent=new_parent
w.length+=old_parent.length
if verbose==True:
print('Fixing branch length for node %s'%(w.index))
k.parent.children.remove(k) ## remove traces of deleted node - it doesn't exist as a child, doesn't exist in the tree and doesn't exist in the nodes list
newTree.Objects.remove(k)
nodes_to_delete.remove(k) ## in fact, the node never existed
if len(designated_nodes)==0:
nodes_to_delete==list(filter(lambda n: n.branchType=='node' and collapseIf(n)==True and n!=newTree.root, newTree.Objects))
else:
assert [w.branchType for w in designated_nodes].count('node')==len(designated_nodes),'Non-node class detected in list of nodes designated for deletion'
assert len([w for w in designated_nodes if w!=newTree.root])==0,'Root node was designated for deletion'
nodes_to_delete=[w for w in newTree.Objects if w.index in [q.index for q in designated_nodes]]
if verbose==True:
print('Removing references to node %s'%(k.index))
newTree.sortBranches() ## sort the tree to traverse, draw and sort tree to adjust y coordinates
return newTree ## return collapsed tree
def toString(self,cur_node=None,traits=None,numName=False,verbose=False,nexus=False,string_fragment=None,traverse_condition=None,json=False):
""" Output the topology of the tree with branch lengths and comments to stringself.
cur_node: starting point (default: None, starts at root)
traits: list of keys that will be used to output entries in traits dict of each branch (default: all traits)
numName: boolean, whether encoded (True) or decoded (default: False) tip names will be output
verbose: boolean, debug
nexus: boolean, whether to output newick (default: False) or nexus (True) formatted tree
string_fragment: list of characters that comprise the tree string
"""
if cur_node==None:
cur_node=self.root#.children[-1]
if traits==None: ## if None
traits=set(sum([list(k.traits.keys()) for k in self.Objects],[])) ## fetch all trait keys
if string_fragment==None:
string_fragment=[]
if nexus==True:
assert json==False,'Nexus format not a valid option for JSON output'
if verbose==True:
print('Exporting to Nexus format')
string_fragment.append('#NEXUS\nBegin trees;\ntree TREE1 = [&R] ')
if traverse_condition==None:
traverse_condition=lambda k: True
comment=[] ## will hold comment
if len(traits)>0: ## non-empty list of traits to output
for tr in traits: ## iterate through keys
if tr in cur_node.traits: ## if key is available
if verbose==True:
print('trait %s available for %s (%s) type: %s'%(tr,cur_node.index,cur_node.branchType,type(cur_node.traits[tr])))
if isinstance(cur_node.traits[tr],str): ## string value
comment.append('%s="%s"'%(tr,cur_node.traits[tr]))
if verbose==True:
print('adding string comment %s'%(comment[-1]))
elif isinstance(cur_node.traits[tr],float) or isinstance(cur_node.traits[tr],int): ## float or integer
comment.append('%s=%s'%(tr,cur_node.traits[tr]))
if verbose==True:
print('adding numeric comment %s'%(comment[-1]))
elif isinstance(cur_node.traits[tr],list): ## lists
rangeComment=[]
for val in cur_node.traits[tr]:
if isinstance(val,str): ## string
rangeComment.append('"%s"'%(val))
elif isinstance(val,float) or isinstance(val,int): ## float or integer
rangeComment.append('%s'%(val))
comment.append('%s={%s}'%(tr,','.join(rangeComment)))
if verbose==True:
print('adding range comment %s'%(comment[-1]))
elif verbose==True:
print('trait %s unavailable for %s (%s)'%(tr,cur_node.index,cur_node.branchType))
if cur_node.branchType=='node':
if verbose==True:
print('node: %s'%(cur_node.index))
string_fragment.append('(')
traverseChildren=list(filter(traverse_condition,cur_node.children))
assert len(traverseChildren)>0,'Node %s does not have traversable children'%(cur_node.index)
for c,child in enumerate(traverseChildren): ## iterate through children of node if they satisfy traverse condition
if verbose==True:
print('moving to child %s of node %s'%(child.index,cur_node.index))
self.toString(cur_node=child,traits=traits,numName=numName,verbose=verbose,nexus=nexus,string_fragment=string_fragment,traverse_condition=traverse_condition)
if (c+1)<len(traverseChildren): ## not done with children, add comma for next iteration
string_fragment.append(',')
string_fragment.append(')') ## last child, node terminates
elif cur_node.branchType=='leaf':
if numName==False: ## if real names wanted
assert cur_node.name!=None,'Tip does not have converted name' ## assert they have been converted
treeName=cur_node.name ## designate real name
elif numName==True: ## if number names wanted
treeName=cur_node.numName ## designated numName
if verbose==True:
print('leaf: %s (%s)'%(cur_node.index,treeName))
string_fragment.append("'%s'"%(treeName))
if len(comment)>0:
if verbose==True:
print('adding comment to %s'%(cur_node.index))
comment=','.join(comment)
comment='[&'+comment+']'
string_fragment.append('%s'%(comment)) ## end of node, add annotations
if verbose==True:
print('adding branch length to %s'%(cur_node.index))
string_fragment.append(':%8f'%(cur_node.length)) ## end of node, add branch length
if cur_node==self.root:#.children[-1]:
string_fragment.append(';')
if nexus==True:
string_fragment.append('\nEnd;')
if verbose==True:
print('finished')
return ''.join(string_fragment)
def allTMRCAs(self,numName=True):
if numName==False:
assert len(self.tipMap)>0,'Tree does not have a translation dict for tip names'
tip_names=[self.tipMap[k.numName] for k in self.Objects if isinstance(k,leaf)]
else:
tip_names=[k.numName for k in self.Objects if isinstance(k,leaf)]
tmrcaMatrix={x:{y:None if x!=y else 0.0 for y in tip_names} for x in tip_names} ## pairwise matrix of tips
for k in self.getInternal():
if numName==True:
all_children=list(k.leaves) ## fetch all descendant tips of node
else:
all_children=[self.tipMap[lv] for lv in k.leaves]
for x in range(0,len(all_children)-1): ## for all pairwise comparisons of tips
for y in range(x+1,len(all_children)):
tipA=all_children[x]
tipB=all_children[y]
if tmrcaMatrix[tipA][tipB]==None or tmrcaMatrix[tipA][tipB]<=k.absoluteTime: ## if node's time is more recent than previous entry - set new TMRCA value for pair of tips
tmrcaMatrix[tipA][tipB]=k.absoluteTime
tmrcaMatrix[tipB][tipA]=k.absoluteTime
return tmrcaMatrix
def reduceTree(self,keep,verbose=False):
"""
Reduce the tree to just those tracking a small number of tips.
Returns a new baltic tree object.
"""
assert len(keep)>0,"No tips given to reduce the tree to."
assert len([k for k in keep if k.branchType!='leaf'])==0, "Embedding contains %d non-leaf branches."%(len([k for k in keep if k.branchType!='leaf']))
if verbose==True:
print("Preparing branch hash for keeping %d branches"%(len(keep)))
branch_hash={k.index:k for k in keep}
embedding=[]
if verbose==True:
print("Deep copying tree")
reduced_tree=copy.deepcopy(self) ## new tree object
for k in reduced_tree.Objects: ## deep copy branches from current tree
if k.index in branch_hash: ## if branch is designated as one to keep
cur_b=k
if verbose==True:
print("Traversing to root from %s"%(cur_b.index))
while cur_b!=reduced_tree.root: ## descend to root
if verbose==True:
print("at %s root: %s"%(cur_b.index,cur_b==reduced_tree.root))
embedding.append(cur_b) ## keep track of the path to root
cur_b=cur_b.parent
embedding.append(reduced_tree.root) ## add root to embedding
if verbose==True:
print("Finished extracting embedding")
embedding=set(embedding) ## prune down to only unique branches
reduced_tree.Objects=sorted(list(embedding),key=lambda x:x.height) ## assign branches that are kept to new tree's Objects
if verbose==True:
print("Pruning untraversed lineages")
for k in reduced_tree.getInternal(): ## iterate through reduced tree
k.children = [c for c in k.children if c in embedding] ## only keep children that are present in lineage traceback
reduced_tree.root.children=[c for c in reduced_tree.root.children if c in embedding] ## do the same for root
reduced_tree.fixHangingNodes()
if verbose==True:
print("Last traversal and branch sorting")
reduced_tree.traverse_tree() ## traverse
reduced_tree.sortBranches() ## sort
return reduced_tree ## return new tree
def countLineages(self,t,condition=lambda x:True):
return len([k for k in self.Objects if k.parent.absoluteTime<t<=k.absoluteTime and condition(k)])
def getExternal(self):
return list(filter(lambda k:k.branchType=='leaf',self.Objects))
def getInternal(self):
return list(filter(lambda k:k.branchType=='node',self.Objects))
def getBranches(self,attrs=lambda x:True):
select=list(filter(attrs,self.Objects))
if len(select)==0:
raise Exception('No branches satisfying function were found amongst branches')
elif len(select)==1:
return select[-1]
else:
return select
def fixHangingNodes(self):
"""
Remove internal nodes without any children.
"""
hangingCondition=lambda k:k.branchType=='node' and len(k.children)==0
hangingNodes=list(filter(hangingCondition,self.Objects)) ## check for nodes without any children (hanging nodes)
while len(hangingNodes)>0:
for h in sorted(hangingNodes,key=lambda x:-x.height):
h.parent.children.remove(h) ## remove old parent from grandparent's children
hangingNodes.remove(h) ## remove old parent from multitype nodes
self.Objects.remove(h) ## remove old parent from all objects
hangingNodes=list(filter(hangingCondition,self.Objects)) ## regenerate list
def addText(self,ax,target=lambda k:k.branchType=='leaf',position=lambda k:(k.x*1.01,k.y),text=lambda k:k.numName,zorder_function=lambda k: 101,**kwargs):
for k in filter(target,self.Objects):
x,y=position(k)
z=zorder_function(k)
ax.text(x,y,text(k),zorder=z,**kwargs)
return ax
def plotPoints(self,ax,x_attr=lambda k:k.height,y_attr=lambda k:k.y,target=lambda k:k.branchType=='leaf',size_function=lambda k:40,colour_function=lambda k:'k',zorder_function=lambda k: 100,**kwargs):
for k in filter(target,self.Objects):
y=y_attr(k) ## get y coordinates
x=x_attr(k) ## x coordinate
c=colour_function(k)
size=size_function(k)
z=zorder_function(k)
ax.scatter(x,y,s=size,facecolor=c,edgecolor='none',zorder=z,**kwargs) ## put a circle at each tip
return ax
def plotTree(self,ax,type='rectangular',target=lambda k: True,x_attr=lambda k:k.height,y_attr=lambda k:k.y,branchWidth=lambda k:2,colour_function=lambda f:'k',zorder_function=lambda k: 98,**kwargs):
assert type in ['rectangular','unrooted'],'Unrecognised drawing type "%s"'%(type)
for k in filter(target,self.Objects): ## iterate over branches in the tree
y=y_attr(k) ## get y coordinates
x=x_attr(k) ## x coordinate
xp=x_attr(k.parent) ## get parent's x
if xp==None:
xp=x
c=colour_function(k)
b=branchWidth(k)
z=zorder_function(k)
if type=='rectangular':
if k.branchType=='node': ## if node...
yl=y_attr(k.children[0]) ## get y coordinates of first and last child
yr=y_attr(k.children[-1])
ax.plot([x,x],[yl,yr],color=c,lw=b,zorder=z,**kwargs) ## plot vertical bar connecting node to both its offspring
ax.plot([x,xp],[y,y],color=c,lw=b,zorder=z,**kwargs) ## plot horizontal branch to parent
elif type=='unrooted':
yp=y_attr(k.parent)
ax.plot([x,xp],[y,yp],color=c,lw=b,zorder=z,**kwargs)
return ax
def make_tree(data,ll=None,verbose=False):
"""
data is a tree string, ll (LL) is an instance of a tree object
"""
if isinstance(data,str)==False: ## tree string is not an instance of string (could be unicode) - convert
data=str(data)
if ll==None: ## calling without providing a tree object - create one
ll=tree()
i=0 ## is an adjustable index along the tree string, it is incremented to advance through the string
stored_i=None ## store the i at the end of the loop, to make sure we haven't gotten stuck somewhere in an infinite loop
while i < len(data): ## while there's characters left in the tree string - loop away
if stored_i == i and verbose==True:
print('%d >%s<'%(i,data[i]))
assert (stored_i != i),'\nTree string unparseable\nStopped at >>%s<<\nstring region looks like this: %s'%(data[i],data[i:i+5000]) ## make sure that you've actually parsed something last time, if not - there's something unexpected in the tree string
stored_i=i ## store i for later
if data[i] == '(': ## look for new nodes
if verbose==True:
print('%d adding node'%(i))
ll.add_node(i) ## add node to current node in tree ll
i+=1 ## advance in tree string by one character
cerberus=re.match('(\(|,)([0-9]+)(\[|\:)',data[i-1:i+100]) ## look for tips in BEAST format (integers).
if cerberus is not None:
if verbose==True:
print('%d adding leaf (BEAST) %s'%(i,cerberus.group(2)))
ll.add_leaf(i,cerberus.group(2)) ## add tip
i+=len(cerberus.group(2)) ## advance in tree string by however many characters the tip is encoded
cerberus=re.match('(\(|,)(\'|\")*([A-Za-z\_\-\|\.0-9\?\/ ]+)(\'|\"|)(\[)*',data[i-1:i+200]) ## look for tips with unencoded names - if the tips have some unusual format you'll have to modify this
if cerberus is not None:
if verbose==True:
print('%d adding leaf (non-BEAST) %s'%(i,cerberus.group(3)))
ll.add_leaf(i,cerberus.group(3).strip('"').strip("'")) ## add tip
i+=len(cerberus.group(3))+cerberus.group().count("'")+cerberus.group().count('"') ## advance in tree string by however many characters the tip is encoded
cerberus=re.match('\)([0-9]+)\[',data[i-1:i+100]) ## look for multitype tree singletons.
if cerberus is not None:
if verbose==True:
print('%d adding multitype node %s'%(i,cerberus.group(1)))
i+=len(cerberus.group(1))
cerberus=re.match('(\:)*\[(&[A-Za-z\_\-{}\,0-9\.\%=\"\'\+!# :\/\(\)\&]+)\]',data[i:])## look for MCC comments
if cerberus is not None:
if verbose==True:
print('%d comment: %s'%(i,cerberus.group(2)))
comment=cerberus.group(2)
numerics=re.findall('[,&][A-Za-z\_\.0-9]+=[0-9\-Ee\.]+',comment) ## find all entries that have values as floats
strings=re.findall('[,&][A-Za-z\_\.0-9]+=["|\']*[A-Za-z\_0-9\.\+ :\/\(\)\&]+["|\']*',comment) ## strings
treelist=re.findall('[,&][A-Za-z\_\.0-9]+={[A-Za-z\_,{}0-9\. :\/\(\)\&]+}',comment) ## complete history logged robust counting (MCMC trees)
sets=re.findall('[,&][A-Za-z\_\.0-9\%]+={[A-Za-z\.\-0-9eE,\"\_ :\/\(\)\&]+}',comment) ## sets and ranges
figtree=re.findall('\![A-Za-z]+=[A-Za-z0-9# :\/\(\)\&]+',comment)
for vals in strings:
tr,val=vals.split('=')
tr=tr[1:]
if '+' in val:
val=val.split('+')[0] ## DO NOT ALLOW EQUIPROBABLE DOUBLE ANNOTATIONS (which are in format "A+B") - just get the first one
ll.cur_node.traits[tr]=val.strip('"')
for vals in numerics: ## assign all parsed annotations to traits of current branch
tr,val=vals.split('=') ## split each value by =, left side is name, right side is value
tr=tr[1:]
ll.cur_node.traits[tr]=float(val)
for val in treelist:
tr,val=val.split('=')
tr=tr[1:]
microcerberus=re.findall('{([0-9]+,[0-9\.\-e]+,[A-Z]+,[A-Z]+)}',val)
ll.cur_node.traits[tr]=[]
for val in microcerberus:
codon,timing,start,end=val.split(',')
ll.cur_node.traits[tr].append((int(codon),float(timing),start,end))
for vals in sets:
tr,val=vals.split('=')
tr=tr[1:]
if 'set' in tr:
ll.cur_node.traits[tr]=[]
for v in val[1:-1].split(','):
if 'set.prob' in tr:
ll.cur_node.traits[tr].append(float(v))
else:
ll.cur_node.traits[tr].append(v.strip('"'))
else:
try:
ll.cur_node.traits[tr]=list(map(float,val[1:-1].split(',')))
except:
print('some other trait: %s'%(vals))
if len(figtree)>0:
print('FigTree comment found, ignoring')
i+=len(cerberus.group()) ## advance in tree string by however many characters it took to encode labels
cerberus=re.match('([A-Za-z\_\-0-9\.]+)(\:|\;)',data[i:])## look for old school node labels
if cerberus is not None:
if verbose==True:
print('old school comment found: %s'%(cerberus.group(1)))
ll.cur_node.traits['label']=cerberus.group(1)
i+=len(cerberus.group(1))
microcerberus=re.match('(\:)*([0-9\.\-Ee]+)',data[i:i+100]) ## look for branch lengths without comments
if microcerberus is not None:
if verbose==True:
print('adding branch length (%d) %.6f'%(i,float(microcerberus.group(2))))
ll.cur_node.length=float(microcerberus.group(2)) ## set branch length of current node
i+=len(microcerberus.group()) ## advance in tree string by however many characters it took to encode branch length
if data[i] == ',' or data[i] == ')': ## look for bifurcations or clade ends
i+=1 ## advance in tree string
ll.cur_node=ll.cur_node.parent
if data[i] == ';': ## look for string end
return ll
break ## end loop
def make_treeJSON(JSONnode,json_translation,ll=None,verbose=False):
if 'children' in JSONnode: ## only nodes have children
new_node=node()
else:
new_node=leaf()
new_node.numName=JSONnode[json_translation['name']] ## set leaf numName
new_node.name=JSONnode[json_translation['name']] ## set leaf name to be the same
if ll is None:
ll=tree()
ll.root=new_node
if 'attr' in JSONnode:
attr = JSONnode.pop('attr')
JSONnode.update(attr)
new_node.parent=ll.cur_node ## set parent-child relationships
ll.cur_node.children.append(new_node)
new_node.index=JSONnode[json_translation['name']] ## indexing is based on name
new_node.traits={n:JSONnode[n] for n in list(JSONnode.keys()) if n!='children'} ## set traits to non-children attributes
ll.Objects.append(new_node)
ll.cur_node=new_node
if 'children' in JSONnode:
for child in JSONnode['children']:
make_treeJSON(child,json_translation,ll)
ll.cur_node=ll.cur_node.parent
return ll
def loadNewick(tree_path,tip_regex='\|([0-9]+\-[0-9]+\-[0-9]+)',date_fmt='%Y-%m-%d',variableDate=True,absoluteTime=True,verbose=False):
ll=None
if isinstance(tree_path,str):
handle=open(tree_path,'r')
else:
handle=tree_path
for line in handle:
l=line.strip('\n')
if '(' in l:
treeString_start=l.index('(')
ll=make_tree(l[treeString_start:],verbose=verbose) ## send tree string to make_tree function
if verbose==True:
print('Identified tree string')
assert ll,'Regular expression failed to find tree string'
ll.traverse_tree(verbose=verbose) ## traverse tree
ll.sortBranches() ## traverses tree, sorts branches, draws tree
if absoluteTime==True:
tipDates=[]
for k in ll.getExternal():
n=k.numName
k.name=k.numName
cerberus=re.search(tip_regex,n)
if cerberus is not None:
tipDates.append(decimalDate(cerberus.group(1),fmt=date_fmt,variable=variableDate))
#highestTip=max(tipDates)
#ll.setAbsoluteTime(highestTip)
return ll
def loadNexus(tree_path,tip_regex='\|([0-9]+\-[0-9]+\-[0-9]+)',date_fmt='%Y-%m-%d',treestring_regex='tree [A-Za-z\_]+([0-9]+)',variableDate=True,absoluteTime=True,verbose=False):
tipFlag=False
tips={}
tipNum=0
ll=None
if isinstance(tree_path,str):
handle=open(tree_path,'r')
else:
handle=tree_path
for line in handle:
l=line.strip('\n')
cerberus=re.search('dimensions ntax=([0-9]+);',l.lower())
if cerberus is not None:
tipNum=int(cerberus.group(1))
if verbose==True:
print('File should contain %d taxa'%(tipNum))
cerberus=re.search(treestring_regex,l)
if cerberus is not None:
treeString_start=l.index('(')
ll=make_tree(l[treeString_start:]) ## send tree string to make_tree function
if verbose==True:
print('Identified tree string')
if tipFlag==True:
cerberus=re.search('([0-9]+) ([A-Za-z\-\_\/\.\'0-9 \|?]+)',l)
if cerberus is not None:
tips[cerberus.group(1)]=cerberus.group(2).strip('"').strip("'")
if verbose==True:
print('Identified tip translation %s: %s'%(cerberus.group(1),tips[cerberus.group(1)]))
elif ';' not in l:
print('tip not captured by regex:',l.replace('\t',''))
if 'translate' in l.lower():
tipFlag=True
if ';' in l:
tipFlag=False
assert ll,'Regular expression failed to find tree string'
ll.traverse_tree() ## traverse tree
ll.sortBranches() ## traverses tree, sorts branches, draws tree
if len(tips)>0:
ll.renameTips(tips) ## renames tips from numbers to actual names
ll.tipMap=tips
if absoluteTime==True:
tipDates=[]
for k in ll.getExternal():
if len(tips)>0:
n=k.name
else:
n=k.numName
cerberus=re.search(tip_regex,n)
if cerberus is not None:
tipDates.append(decimalDate(cerberus.group(1),fmt=date_fmt,variable=variableDate))
highestTip=max(tipDates)
ll.setAbsoluteTime(highestTip)
return ll
def loadJSON(tree_path,json_translation={'name':'strain','absoluteTime':'num_date'},json_meta=None,verbose=False,sort=True,stats=True):
"""
Load a nextstrain JSON by providing either the path to JSON or a file handle.
json_translation is a dictionary that translates JSON attributes to baltic branch attributes (e.g. 'absoluteTime' is called 'num_date' in nextstrain JSONs).
Note that to avoid conflicts in setting node heights you can either define the absolute time of each node or branch lengths (e.g. if you want a substitution tree).
"""
assert 'name' in json_translation and ('absoluteTime' in json_translation or 'length' in json_translation),'JSON translation dictionary missing entries: %s'%(', '.join([entry for entry in ['name','height','absoluteTime','length'] if (entry in json_translation)==False]))
if verbose==True:
print('Reading JSON')
if isinstance(tree_path,str):
with open(tree_path) as json_data:
d = json.load(json_data)
ll=make_treeJSON(d,json_translation,verbose=verbose)
else:
ll=make_treeJSON(json.load(tree_path),json_translation,verbose=verbose)
assert ('absoluteTime' in json_translation and 'length' not in json_translation) or ('absoluteTime' not in json_translation and 'length' in json_translation),'Cannot use both absolute time and branch length, include only one in json_translation dictionary.'
for attr in json_translation: ## iterate through attributes in json_translation
for k in ll.Objects: ## for every branch
setattr(k,attr,k.traits[json_translation[attr]]) ## set attribute value for branch
if 'absoluteTime' in json_translation: ## if using absoluteTime need to set branch lengths for traversals
for k in ll.Objects:
if json_translation['absoluteTime'] in k.parent.traits:
k.length=k.traits[json_translation['absoluteTime']]-k.parent.traits[json_translation['absoluteTime']]
else:
k.length=0.0
if verbose==True:
print('Traversing and drawing tree')
ll.traverse_tree(verbose=verbose)
ll.drawTree()
if stats==True:
ll.treeStats() ## initial traversal, checks for stats
if sort==True:
ll.sortBranches() ## traverses tree, sorts branches, draws tree
if json_meta:
if isinstance(json_meta,str):
metadata=json.load(open(json_meta['file'],'r'))
else:
metadata=json.load(json_meta['file'])
cmap=dict(metadata['color_options'][json_meta['traitName']]['color_map'])
setattr(ll,'cmap',cmap)
return ll
if __name__ == '__main__':
import sys
ll=make_tree(sys.argv[1],ll)
ll.traverse_tree()
sys.stdout.write('%s\n'%(ll.treeHeight))