import re import copy import math import datetime as dt import json def decimalDate(date,fmt="%Y-%m-%d",variable=False,dateSplitter='-'): """ Converts calendar dates in specified format to decimal date. """ if variable==True: ## if date is variable - extract what is available dateL=len(date.split(dateSplitter)) if dateL==2: fmt=dateSplitter.join(fmt.split(dateSplitter)[:-1]) elif dateL==1: fmt=dateSplitter.join(fmt.split(dateSplitter)[:-2]) adatetime=dt.datetime.strptime(date,fmt) ## convert to datetime object year = adatetime.year ## get year boy = dt.datetime(year, 1, 1) ## get beginning of the year eoy = dt.datetime(year + 1, 1, 1) ## get beginning of next year return year + ((adatetime - boy).total_seconds() / ((eoy - boy).total_seconds())) ## return fractional year def convertDate(x,start,end): """ Converts calendar dates between given formats """ return dt.datetime.strftime(dt.datetime.strptime(x,start),end) class clade: ## clade class def __init__(self,givenName): self.branchType='leaf' ## clade class poses as a leaf self.subtree=None ## subtree will contain all the branches that were collapsed self.leaves=None self.length=0.0 self.height=None self.absoluteTime=None self.parent=None self.traits={} self.index=None self.name=givenName ## the pretend tip name for the clade self.numName=givenName self.x=None self.y=None self.lastHeight=None ## refers to the height of the highest tip in the collapsed clade self.lastAbsoluteTime=None ## refers to the absolute time of the highest tip in the collapsed clade self.width=1 class node: ## node class def __init__(self): self.branchType='node' self.length=0.0 ## branch length, recovered from string self.height=None ## height, set by traversing the tree, which adds up branch lengths along the way self.absoluteTime=None ## branch end point in absolute time, once calibrations are done self.parent=None ## reference to parent node of the node self.children=[] ## a list of descendent branches of this node self.traits={} ## dictionary that will contain annotations from the tree string, e.g. {'posterior':1.0} self.index=None ## index of the character designating this object in the tree string, it's a unique identifier for every object in the tree self.childHeight=None ## the youngest descendant tip of this node self.x=None ## X and Y coordinates of this node, once drawTree() is called self.y=None ## contains references to all tips of this node self.leaves=set() ## is a set of tips that are descended from it class leaf: ## leaf class def __init__(self): self.branchType='leaf' self.name=None ## name of tip after translation, since BEAST trees will generally have numbers for taxa but will provide a map at the beginning of the file self.numName=None ## the original name of the taxon, would be an integer if coming from BEAST, otherwise can be actual name self.index=None ## index of the character that defines this object, will be a unique ID for each object in the tree self.length=None ## branch length self.absoluteTime=None ## position of tip in absolute time self.height=None ## height of tip self.parent=None ## parent self.traits={} ## trait dictionary self.x=None ## position of tip on x axis if the tip were to be plotted self.y=None ## position of tip on y axis if the tip were to be plotted class tree: ## tree class def __init__(self): self.cur_node=node() ## current node is a new instance of a node class self.cur_node.index='Root' ## first object in the tree is the root to which the rest gets attached self.cur_node.length=0.0 ## startind node branch length is 0 self.cur_node.height=0.0 ## starting node height is 0 self.root=None #self.cur_node ## root of the tree is current node self.Objects=[] ## tree objects have a flat list of all branches in them self.tipMap=None self.treeHeight=0 ## tree height is the distance between the root and the most recent tip self.ySpan=0.0 def add_node(self,i): """ Attaches a new node to current node. """ new_node=node() ## new node instance new_node.index=i ## new node's index is the position along the tree string if self.root is None: self.root=new_node new_node.parent=self.cur_node ## new node's parent is current node self.cur_node.children.append(new_node) ## new node is a child of current node self.cur_node=new_node ## current node is now new node self.Objects.append(self.cur_node) ## add new node to list of objects in the tree def add_leaf(self,i,name): """ Attach a new leaf (tip) to current node. """ new_leaf=leaf() ## new instance of leaf object new_leaf.index=i ## index is position along tree string if self.root is None: self.root=new_leaf new_leaf.parent=self.cur_node ## leaf's parent is current node self.cur_node.children.append(new_leaf) ## assign leaf to parent's children new_leaf.numName=name ## numName is the name tip has inside tree string, BEAST trees usually have numbers for tip names self.cur_node=new_leaf ## current node is now new leaf self.Objects.append(self.cur_node) ## add leaf to all objects in the tree def subtree(self,k=None,traverse_condition=None): """ Generate a subtree (as a baltic tree object) from a traversal. k is the starting branch for traversal (default: root). traverse_condition is a function that determines whether a child branch should be visited (default: always true). Returns a new baltic tree instance. Note - custom traversal functions can result in multitype trees. If this is undesired call singleType() on the resulting subtree afterwards. """ subtree=copy.deepcopy(self.traverse_tree(k,include_condition=lambda k:True,traverse_condition=traverse_condition)) if subtree is None or len([k for k in subtree if k.branchType=='leaf'])==0: return None else: local_tree=tree() ## create a new tree object where the subtree will be local_tree.Objects=subtree ## assign branches to new tree object local_tree.root=subtree[0] ## connect tree object's root with subtree subtree_set=set(subtree) ## turn branches into set for quicker look up later if traverse_condition is not None: ## didn't use default traverse condition, might need to deal with hanging nodes and prune children for nd in local_tree.getInternal(): ## iterate over nodes nd.children=list(filter(lambda k:k in subtree_set,nd.children)) ## only keep children seen in traversal local_tree.fixHangingNodes() return local_tree def singleType(self): """ Removes any branches with a single child (multitype nodes). """ multiTypeNodes=[k for k in self.Objects if k.branchType=='node' and len(k.children)==1] while len(multiTypeNodes)>0: multiTypeNodes=[k for k in self.Objects if k.branchType=='node' and len(k.children)==1] for k in sorted(multiTypeNodes,key=lambda x:-x.height): child=k.children[0] ## fetch child grandparent=k.parent ## fetch grandparent child.parent=grandparent ## child's parent is now grandparent grandparent.children.append(child) ## add child to grandparent's children grandparent.children.remove(k) ## remove old parent from grandparent's children grandparent.children=list(set(grandparent.children)) child.length+=k.length ## adjust child length multiTypeNodes.remove(k) ## remove old parent from multitype nodes self.Objects.remove(k) ## remove old parent from all objects self.sortBranches() def setAbsoluteTime(self,date): """ place all objects in absolute time by providing the date of the most recent tip """ for i in self.Objects: ## iterate over all objects i.absoluteTime=date-self.treeHeight+i.height ## heights are in units of time from the root def treeStats(self): """ provide information about the tree """ self.traverse_tree() ## traverse the tree obs=self.Objects ## convenient list of all objects in the tree print('\nTree height: %.6f\nTree length: %.6f'%(self.treeHeight,sum([x.length for x in obs]))) ## report the height and length of tree nodes=self.getInternal() ## get all nodes strictlyBifurcating=False ## assume tree is not strictly bifurcating multiType=False singleton=False N_children=[len(x.children) for x in nodes] if len(N_children)==0: singleton=True else: minChildren,maxChildren=min(N_children),max(N_children) ## get the largest number of descendant branches of any node if maxChildren==2 and minChildren==2: ## if every node has at most two children branches strictlyBifurcating=True ## it's strictly bifurcating if minChildren==1: multiType=True hasTraits=False ## assume tree has no annotations maxAnnotations=max([len(x.traits) for x in obs]) ## check the largest number of annotations any branch has if maxAnnotations>0: ## if it's more than 0 hasTraits=True ## there are annotations if strictlyBifurcating: print('strictly bifurcating tree') ## report if multiType: print('multitype tree') ## report if singleton: print('singleton tree') if hasTraits: print('annotations present') ## report print('\nNumbers of objects in tree: %d (%d nodes and %d leaves)\n'%(len(obs),len(nodes),len(obs)-len(nodes))) ## report numbers of different objects in the tree def traverse_tree(self,cur_node=None,include_condition=lambda k:k.branchType=='leaf',traverse_condition=lambda k:True,collect=None,verbose=False): if cur_node==None: ## if no starting point defined - start from root for k in self.Objects: ## reset various parameters if k.branchType=='node': k.leaves=set() k.childHeight=None k.height=None if verbose==True: print('Initiated traversal from root') cur_node=self.root#.children[-1] if collect==None: ## initiate collect list if not initiated collect=[] if cur_node.parent and cur_node.height==None: ## cur_node has a parent - set height if it doesn't already cur_node.height=cur_node.length+cur_node.parent.height elif cur_node.height==None: ## cur_node does not have a parent (root), if height not set before it's zero cur_node.height=0.0 if verbose==True: print('at %s (%s)'%(cur_node.index,cur_node.branchType)) if include_condition(cur_node): ## test if interested in cur_node collect.append(cur_node) ## add to collect list for reporting later if cur_node.branchType=='leaf': ## cur_node is a tip cur_node.parent.leaves.add(cur_node.numName) ## add to parent's list of tips elif cur_node.branchType=='node': ## cur_node is node for child in filter(traverse_condition,cur_node.children): ## only traverse through children we're interested if verbose==True: print('visiting child %s'%(child.index)) self.traverse_tree(cur_node=child,include_condition=include_condition,traverse_condition=traverse_condition,verbose=verbose,collect=collect) ## recurse through children if verbose==True: print('child %s done'%(child.index)) assert len(cur_node.children)>0, 'Tried traversing through hanging node without children. Index: %s'%(cur_node.index) cur_node.childHeight=max([child.childHeight if child.branchType=='node' else child.height for child in cur_node.children]) if cur_node.parent: cur_node.parent.leaves=cur_node.parent.leaves.union(cur_node.leaves) ## pass tips seen during traversal to parent self.treeHeight=cur_node.childHeight ## it's the highest child of the starting node return collect def renameTips(self,d=None): """ Give each tip its correct label using a dictionary. """ if d==None and self.tipMap!=None: d=self.tipMap for k in self.getExternal(): ## iterate through leaf objects in tree k.name=d[k.numName] ## change its name def sortBranches(self,descending=True): """ Sort descendants of each node. """ if descending==True: modifier=-1 ## define the modifier for sorting function later elif descending==False: modifier=1 for k in self.getInternal(): ## iterate over nodes ## split node's offspring into nodes and leaves, sort each list individually nodes=sorted([x for x in k.children if x.branchType=='node'],key=lambda q:(-len(q.leaves)*modifier,q.length*modifier)) leaves=sorted([x for x in k.children if x.branchType=='leaf'],key=lambda q:q.length*modifier) if modifier==1: ## if sorting one way - nodes come first, leaves later k.children=nodes+leaves elif modifier==-1: ## otherwise sort the other way k.children=leaves+nodes self.drawTree() ## update x and y positions of each branch, since y positions will have changed because of sorting def drawTree(self,order=None,verbose=False): """ Find x and y coordinates of each branch. """ if order==None: order=self.traverse_tree() ## order is a list of tips recovered from a tree traversal to make sure they're plotted in the correct order along the vertical tree dimension if verbose==True: print('Drawing tree in pre-order') else: if verbose==True: print('Drawing tree with provided order') name_order=[x.numName for x in order] skips=[1 if isinstance(x,leaf) else x.width+1 for x in order] for k in self.Objects: ## reset coordinates for all objects k.x=None k.y=None storePlotted=0 drawn={} ## drawn keeps track of what's been drawn while len(drawn)!=len(self.Objects): # keep drawing the tree until everything is drawn if verbose==True: print('Drawing iteration %d'%(len(drawn))) for k in filter(lambda w:w.index not in drawn,self.Objects): ## iterate through objects that have not been drawn if k.branchType=='leaf': ## if leaf - get position of leaf, draw branch connecting tip to parent node if verbose==True: print('Setting leaf %s y coordinate to'%(k.index)) x=k.height ## x position is height y_idx=name_order.index(k.numName) ## y position of leaf is given by the order in which tips were visited during the traversal y=sum(skips[y_idx:]) ## sum across skips to find y position if verbose==True: print('%s'%(y)) if isinstance(k,clade) and skips[y_idx]>1: ## if dealing with collapsed clade - adjust y position to be in the middle of the skip y-=skips[y_idx]/2.0 if verbose==True: print('adjusting clade y position to %s'%(y)) k.x=x ## set x and y coordinates k.y=y drawn[k.index]=None ## remember that this objects has been drawn if hasattr(k.parent,'yRange')==False: ## if parent doesn't have a maximum extent of its children's y coordinates setattr(k.parent,'yRange',[k.y,k.y]) ## assign it if k.branchType=='node': ## if parent is non-root node and y positions of all its children are known if len([q.y for q in k.children if q.y!=None])==len(k.children): if verbose==True: print('Setting node %s coordinates'%(k.index)) x=k.height ## x position is height children_y_coords=[q.y for q in k.children if q.y!=None] ## get all existing y coordinates of the node y=sum(children_y_coords)/float(len(children_y_coords)) ## internal branch is in the middle of the vertical bar k.x=x k.y=y drawn[k.index]=None ## remember that this objects has been drawn minYrange=min([min(child.yRange) if child.branchType=='node' else child.y for child in k.children]) ## get lowest y coordinate across children maxYrange=max([max(child.yRange) if child.branchType=='node' else child.y for child in k.children]) ## get highest y coordinate across children setattr(k,'yRange',[minYrange,maxYrange]) ## assign the maximum extent of children's y coordinates assert len(drawn)>storePlotted,'Got stuck trying to find y positions of objects' storePlotted=len(drawn) self.ySpan=sum(skips) def drawUnrooted(self,n=None,total=None): """ Calculate x and y coordinates in an unrooted arrangement. Code translated from https://github.com/nextstrain/auspice/commit/fc50bbf5e1d09908be2209450c6c3264f298e98c, written by Richard Neher. """ if n==None: total=sum([1 if isinstance(x,leaf) else x.width+1 for x in self.getExternal()]) n=self.root#.children[0] for k in self.Objects: k.traits['tau']=0.0 k.x=0.0 k.y=0.0 if n.branchType=='leaf': w=2*math.pi*1.0/float(total) else: w=2*math.pi*len(n.leaves)/float(total) if n.parent.x==None: n.parent.x=0.0 n.parent.y=0.0 n.x = n.parent.x + n.length * math.cos(n.traits['tau'] + w*0.5) n.y = n.parent.y + n.length * math.sin(n.traits['tau'] + w*0.5) eta=n.traits['tau'] if n.branchType=='node': for ch in n.children: if ch.branchType=='leaf': w=2*math.pi*1.0/float(total) else: w=2*math.pi*len(ch.leaves)/float(total) ch.traits['tau'] = eta eta += w self.drawUnrooted(ch,total) def commonAncestor(self,descendants,numName=False,strict=False): types=[desc.__class__ for desc in descendants] assert len(set(types))==1,'More than one type of data detected in descendants list' if numName==False: assert sum([1 if k in [w.name for w in self.getExternal()] else 0 for k in descendants])==len(descendants),'Not all specified descendants are in tree: %s'%(descendants) else: assert sum([1 if k in [w.numName for w in self.getExternal()] else 0 for k in descendants])==len(descendants),'Not all specified descendants are in tree: %s'%(descendants) dtype=list(set(types))[0] allAncestors=sorted([k for k in self.Objects if (k.branchType=='node' or isinstance(k,clade)) and len(k.leaves)>=len(descendants)],key=lambda x:x.height) if numName==False: ancestor=[k for k in allAncestors if sum([[self.tipMap[w] for w in k.leaves].count(l) for l in descendants])==len(descendants)][-1] else: ancestor=[k for k in allAncestors if sum([[w for w in k.leaves].count(l) for l in descendants])==len(descendants)][-1] if strict==False: return ancestor elif strict==True and len(ancestor.leaves)==len(descendants): return ancestor elif strict==True and len(ancestor.leaves)>len(descendants): return None def collapseSubtree(self,cl,givenName,verbose=False,widthFunction=lambda k:len(k.leaves)): """ Collapse an entire subtree into a clade object. """ assert cl.branchType=='node','Cannot collapse non-node class' collapsedClade=clade(givenName) collapsedClade.index=cl.index collapsedClade.leaves=cl.leaves collapsedClade.length=cl.length collapsedClade.height=cl.height collapsedClade.parent=cl.parent collapsedClade.absoluteTime=cl.absoluteTime collapsedClade.traits=cl.traits collapsedClade.width=widthFunction(cl) if verbose==True: print('Replacing node %s (parent %s) with a clade class'%(cl.index,cl.parent.index)) parent=cl.parent remove_from_tree=self.traverse_tree(cl,include_condition=lambda k: True) collapsedClade.subtree=remove_from_tree assert len(remove_from_tree)0: clades=[k for k in self.Objects if isinstance(k,clade)] for cl in clades: parent=cl.parent subtree=cl.subtree parent.children.remove(cl) parent.children.append(subtree[0]) self.Objects+=subtree self.Objects.remove(cl) if self.tipMap!=None: self.tipMap.pop(cl.name,None) self.traverse_tree() def collapseBranches(self,collapseIf=lambda x:x.traits['posterior']<=0.5,designated_nodes=[],verbose=False): """ Collapse all branches that satisfy a function collapseIf (default is an anonymous function that returns true if posterior probability is <=0.5). Alternatively, a list of nodes can be supplied to the script. Returns a deep copied version of the tree. """ newTree=copy.deepcopy(self) ## work on a copy of the tree if len(designated_nodes)==0: ## no nodes were designated for deletion - relying on anonymous function to collapse nodes nodes_to_delete=list(filter(lambda n: n.branchType=='node' and collapseIf(n)==True and n!=newTree.root, newTree.Objects)) ## fetch a list of all nodes who are not the root and who satisfy the condition else: assert [w.branchType for w in designated_nodes].count('node')==len(designated_nodes),'Non-node class detected in list of nodes designated for deletion' assert len([w for w in designated_nodes if w!=newTree.root])==0,'Root node was designated for deletion' nodes_to_delete=list(filter(lambda w: w.index in [q.index for q in designated_nodes], newTree.Objects)) ## need to look up nodes designated for deletion by their indices, since the tree has been copied and nodes will have new memory addresses if verbose==True: print('%s nodes set for collapsing: %s'%(len(nodes_to_delete),[w.index for w in nodes_to_delete])) # assert len(nodes_to_delete)0: ## as long as there are branches to be collapsed - keep reducing the tree if verbose==True: print('Continuing collapse cycle, %s nodes left'%(len(nodes_to_delete))) for k in sorted(nodes_to_delete,key=lambda x:-x.height): ## start with branches near the tips zero_node=k.children ## fetch the node's children k.parent.children+=zero_node ## add them to the zero node's parent old_parent=k ## node to be deleted is the old parent new_parent=k.parent ## once node is deleted, the parent to all their children will be the parent of the deleted node if new_parent==None: new_parent=self.root if verbose==True: print('Removing node %s, attaching children %s to node %s'%(old_parent.index,[w.index for w in k.children],new_parent.index)) for w in newTree.Objects: ## assign the parent of deleted node as the parent to any children of deleted node if w.parent==old_parent: w.parent=new_parent w.length+=old_parent.length if verbose==True: print('Fixing branch length for node %s'%(w.index)) k.parent.children.remove(k) ## remove traces of deleted node - it doesn't exist as a child, doesn't exist in the tree and doesn't exist in the nodes list newTree.Objects.remove(k) nodes_to_delete.remove(k) ## in fact, the node never existed if len(designated_nodes)==0: nodes_to_delete==list(filter(lambda n: n.branchType=='node' and collapseIf(n)==True and n!=newTree.root, newTree.Objects)) else: assert [w.branchType for w in designated_nodes].count('node')==len(designated_nodes),'Non-node class detected in list of nodes designated for deletion' assert len([w for w in designated_nodes if w!=newTree.root])==0,'Root node was designated for deletion' nodes_to_delete=[w for w in newTree.Objects if w.index in [q.index for q in designated_nodes]] if verbose==True: print('Removing references to node %s'%(k.index)) newTree.sortBranches() ## sort the tree to traverse, draw and sort tree to adjust y coordinates return newTree ## return collapsed tree def toString(self,cur_node=None,traits=None,numName=False,verbose=False,nexus=False,string_fragment=None,traverse_condition=None,json=False): """ Output the topology of the tree with branch lengths and comments to stringself. cur_node: starting point (default: None, starts at root) traits: list of keys that will be used to output entries in traits dict of each branch (default: all traits) numName: boolean, whether encoded (True) or decoded (default: False) tip names will be output verbose: boolean, debug nexus: boolean, whether to output newick (default: False) or nexus (True) formatted tree string_fragment: list of characters that comprise the tree string """ if cur_node==None: cur_node=self.root#.children[-1] if traits==None: ## if None traits=set(sum([list(k.traits.keys()) for k in self.Objects],[])) ## fetch all trait keys if string_fragment==None: string_fragment=[] if nexus==True: assert json==False,'Nexus format not a valid option for JSON output' if verbose==True: print('Exporting to Nexus format') string_fragment.append('#NEXUS\nBegin trees;\ntree TREE1 = [&R] ') if traverse_condition==None: traverse_condition=lambda k: True comment=[] ## will hold comment if len(traits)>0: ## non-empty list of traits to output for tr in traits: ## iterate through keys if tr in cur_node.traits: ## if key is available if verbose==True: print('trait %s available for %s (%s) type: %s'%(tr,cur_node.index,cur_node.branchType,type(cur_node.traits[tr]))) if isinstance(cur_node.traits[tr],str): ## string value comment.append('%s="%s"'%(tr,cur_node.traits[tr])) if verbose==True: print('adding string comment %s'%(comment[-1])) elif isinstance(cur_node.traits[tr],float) or isinstance(cur_node.traits[tr],int): ## float or integer comment.append('%s=%s'%(tr,cur_node.traits[tr])) if verbose==True: print('adding numeric comment %s'%(comment[-1])) elif isinstance(cur_node.traits[tr],list): ## lists rangeComment=[] for val in cur_node.traits[tr]: if isinstance(val,str): ## string rangeComment.append('"%s"'%(val)) elif isinstance(val,float) or isinstance(val,int): ## float or integer rangeComment.append('%s'%(val)) comment.append('%s={%s}'%(tr,','.join(rangeComment))) if verbose==True: print('adding range comment %s'%(comment[-1])) elif verbose==True: print('trait %s unavailable for %s (%s)'%(tr,cur_node.index,cur_node.branchType)) if cur_node.branchType=='node': if verbose==True: print('node: %s'%(cur_node.index)) string_fragment.append('(') traverseChildren=list(filter(traverse_condition,cur_node.children)) assert len(traverseChildren)>0,'Node %s does not have traversable children'%(cur_node.index) for c,child in enumerate(traverseChildren): ## iterate through children of node if they satisfy traverse condition if verbose==True: print('moving to child %s of node %s'%(child.index,cur_node.index)) self.toString(cur_node=child,traits=traits,numName=numName,verbose=verbose,nexus=nexus,string_fragment=string_fragment,traverse_condition=traverse_condition) if (c+1)0: if verbose==True: print('adding comment to %s'%(cur_node.index)) comment=','.join(comment) comment='[&'+comment+']' string_fragment.append('%s'%(comment)) ## end of node, add annotations if verbose==True: print('adding branch length to %s'%(cur_node.index)) string_fragment.append(':%8f'%(cur_node.length)) ## end of node, add branch length if cur_node==self.root:#.children[-1]: string_fragment.append(';') if nexus==True: string_fragment.append('\nEnd;') if verbose==True: print('finished') return ''.join(string_fragment) def allTMRCAs(self,numName=True): if numName==False: assert len(self.tipMap)>0,'Tree does not have a translation dict for tip names' tip_names=[self.tipMap[k.numName] for k in self.Objects if isinstance(k,leaf)] else: tip_names=[k.numName for k in self.Objects if isinstance(k,leaf)] tmrcaMatrix={x:{y:None if x!=y else 0.0 for y in tip_names} for x in tip_names} ## pairwise matrix of tips for k in self.getInternal(): if numName==True: all_children=list(k.leaves) ## fetch all descendant tips of node else: all_children=[self.tipMap[lv] for lv in k.leaves] for x in range(0,len(all_children)-1): ## for all pairwise comparisons of tips for y in range(x+1,len(all_children)): tipA=all_children[x] tipB=all_children[y] if tmrcaMatrix[tipA][tipB]==None or tmrcaMatrix[tipA][tipB]<=k.absoluteTime: ## if node's time is more recent than previous entry - set new TMRCA value for pair of tips tmrcaMatrix[tipA][tipB]=k.absoluteTime tmrcaMatrix[tipB][tipA]=k.absoluteTime return tmrcaMatrix def reduceTree(self,keep,verbose=False): """ Reduce the tree to just those tracking a small number of tips. Returns a new baltic tree object. """ assert len(keep)>0,"No tips given to reduce the tree to." assert len([k for k in keep if k.branchType!='leaf'])==0, "Embedding contains %d non-leaf branches."%(len([k for k in keep if k.branchType!='leaf'])) if verbose==True: print("Preparing branch hash for keeping %d branches"%(len(keep))) branch_hash={k.index:k for k in keep} embedding=[] if verbose==True: print("Deep copying tree") reduced_tree=copy.deepcopy(self) ## new tree object for k in reduced_tree.Objects: ## deep copy branches from current tree if k.index in branch_hash: ## if branch is designated as one to keep cur_b=k if verbose==True: print("Traversing to root from %s"%(cur_b.index)) while cur_b!=reduced_tree.root: ## descend to root if verbose==True: print("at %s root: %s"%(cur_b.index,cur_b==reduced_tree.root)) embedding.append(cur_b) ## keep track of the path to root cur_b=cur_b.parent embedding.append(reduced_tree.root) ## add root to embedding if verbose==True: print("Finished extracting embedding") embedding=set(embedding) ## prune down to only unique branches reduced_tree.Objects=sorted(list(embedding),key=lambda x:x.height) ## assign branches that are kept to new tree's Objects if verbose==True: print("Pruning untraversed lineages") for k in reduced_tree.getInternal(): ## iterate through reduced tree k.children = [c for c in k.children if c in embedding] ## only keep children that are present in lineage traceback reduced_tree.root.children=[c for c in reduced_tree.root.children if c in embedding] ## do the same for root reduced_tree.fixHangingNodes() if verbose==True: print("Last traversal and branch sorting") reduced_tree.traverse_tree() ## traverse reduced_tree.sortBranches() ## sort return reduced_tree ## return new tree def countLineages(self,t,condition=lambda x:True): return len([k for k in self.Objects if k.parent.absoluteTime0: for h in sorted(hangingNodes,key=lambda x:-x.height): h.parent.children.remove(h) ## remove old parent from grandparent's children hangingNodes.remove(h) ## remove old parent from multitype nodes self.Objects.remove(h) ## remove old parent from all objects hangingNodes=list(filter(hangingCondition,self.Objects)) ## regenerate list def addText(self,ax,target=lambda k:k.branchType=='leaf',position=lambda k:(k.x*1.01,k.y),text=lambda k:k.numName,zorder_function=lambda k: 101,**kwargs): for k in filter(target,self.Objects): x,y=position(k) z=zorder_function(k) ax.text(x,y,text(k),zorder=z,**kwargs) return ax def plotPoints(self,ax,x_attr=lambda k:k.height,y_attr=lambda k:k.y,target=lambda k:k.branchType=='leaf',size_function=lambda k:40,colour_function=lambda k:'k',zorder_function=lambda k: 100,**kwargs): for k in filter(target,self.Objects): y=y_attr(k) ## get y coordinates x=x_attr(k) ## x coordinate c=colour_function(k) size=size_function(k) z=zorder_function(k) ax.scatter(x,y,s=size,facecolor=c,edgecolor='none',zorder=z,**kwargs) ## put a circle at each tip return ax def plotTree(self,ax,type='rectangular',target=lambda k: True,x_attr=lambda k:k.height,y_attr=lambda k:k.y,branchWidth=lambda k:2,colour_function=lambda f:'k',zorder_function=lambda k: 98,**kwargs): assert type in ['rectangular','unrooted'],'Unrecognised drawing type "%s"'%(type) for k in filter(target,self.Objects): ## iterate over branches in the tree y=y_attr(k) ## get y coordinates x=x_attr(k) ## x coordinate xp=x_attr(k.parent) ## get parent's x if xp==None: xp=x c=colour_function(k) b=branchWidth(k) z=zorder_function(k) if type=='rectangular': if k.branchType=='node': ## if node... yl=y_attr(k.children[0]) ## get y coordinates of first and last child yr=y_attr(k.children[-1]) ax.plot([x,x],[yl,yr],color=c,lw=b,zorder=z,**kwargs) ## plot vertical bar connecting node to both its offspring ax.plot([x,xp],[y,y],color=c,lw=b,zorder=z,**kwargs) ## plot horizontal branch to parent elif type=='unrooted': yp=y_attr(k.parent) ax.plot([x,xp],[y,yp],color=c,lw=b,zorder=z,**kwargs) return ax def make_tree(data,ll=None,verbose=False): """ data is a tree string, ll (LL) is an instance of a tree object """ if isinstance(data,str)==False: ## tree string is not an instance of string (could be unicode) - convert data=str(data) if ll==None: ## calling without providing a tree object - create one ll=tree() i=0 ## is an adjustable index along the tree string, it is incremented to advance through the string stored_i=None ## store the i at the end of the loop, to make sure we haven't gotten stuck somewhere in an infinite loop while i < len(data): ## while there's characters left in the tree string - loop away if stored_i == i and verbose==True: print('%d >%s<'%(i,data[i])) assert (stored_i != i),'\nTree string unparseable\nStopped at >>%s<<\nstring region looks like this: %s'%(data[i],data[i:i+5000]) ## make sure that you've actually parsed something last time, if not - there's something unexpected in the tree string stored_i=i ## store i for later if data[i] == '(': ## look for new nodes if verbose==True: print('%d adding node'%(i)) ll.add_node(i) ## add node to current node in tree ll i+=1 ## advance in tree string by one character cerberus=re.match('(\(|,)([0-9]+)(\[|\:)',data[i-1:i+100]) ## look for tips in BEAST format (integers). if cerberus is not None: if verbose==True: print('%d adding leaf (BEAST) %s'%(i,cerberus.group(2))) ll.add_leaf(i,cerberus.group(2)) ## add tip i+=len(cerberus.group(2)) ## advance in tree string by however many characters the tip is encoded cerberus=re.match('(\(|,)(\'|\")*([A-Za-z\_\-\|\.0-9\?\/ ]+)(\'|\"|)(\[)*',data[i-1:i+200]) ## look for tips with unencoded names - if the tips have some unusual format you'll have to modify this if cerberus is not None: if verbose==True: print('%d adding leaf (non-BEAST) %s'%(i,cerberus.group(3))) ll.add_leaf(i,cerberus.group(3).strip('"').strip("'")) ## add tip i+=len(cerberus.group(3))+cerberus.group().count("'")+cerberus.group().count('"') ## advance in tree string by however many characters the tip is encoded cerberus=re.match('\)([0-9]+)\[',data[i-1:i+100]) ## look for multitype tree singletons. if cerberus is not None: if verbose==True: print('%d adding multitype node %s'%(i,cerberus.group(1))) i+=len(cerberus.group(1)) cerberus=re.match('(\:)*\[(&[A-Za-z\_\-{}\,0-9\.\%=\"\'\+!# :\/\(\)\&]+)\]',data[i:])## look for MCC comments if cerberus is not None: if verbose==True: print('%d comment: %s'%(i,cerberus.group(2))) comment=cerberus.group(2) numerics=re.findall('[,&][A-Za-z\_\.0-9]+=[0-9\-Ee\.]+',comment) ## find all entries that have values as floats strings=re.findall('[,&][A-Za-z\_\.0-9]+=["|\']*[A-Za-z\_0-9\.\+ :\/\(\)\&]+["|\']*',comment) ## strings treelist=re.findall('[,&][A-Za-z\_\.0-9]+={[A-Za-z\_,{}0-9\. :\/\(\)\&]+}',comment) ## complete history logged robust counting (MCMC trees) sets=re.findall('[,&][A-Za-z\_\.0-9\%]+={[A-Za-z\.\-0-9eE,\"\_ :\/\(\)\&]+}',comment) ## sets and ranges figtree=re.findall('\![A-Za-z]+=[A-Za-z0-9# :\/\(\)\&]+',comment) for vals in strings: tr,val=vals.split('=') tr=tr[1:] if '+' in val: val=val.split('+')[0] ## DO NOT ALLOW EQUIPROBABLE DOUBLE ANNOTATIONS (which are in format "A+B") - just get the first one ll.cur_node.traits[tr]=val.strip('"') for vals in numerics: ## assign all parsed annotations to traits of current branch tr,val=vals.split('=') ## split each value by =, left side is name, right side is value tr=tr[1:] ll.cur_node.traits[tr]=float(val) for val in treelist: tr,val=val.split('=') tr=tr[1:] microcerberus=re.findall('{([0-9]+,[0-9\.\-e]+,[A-Z]+,[A-Z]+)}',val) ll.cur_node.traits[tr]=[] for val in microcerberus: codon,timing,start,end=val.split(',') ll.cur_node.traits[tr].append((int(codon),float(timing),start,end)) for vals in sets: tr,val=vals.split('=') tr=tr[1:] if 'set' in tr: ll.cur_node.traits[tr]=[] for v in val[1:-1].split(','): if 'set.prob' in tr: ll.cur_node.traits[tr].append(float(v)) else: ll.cur_node.traits[tr].append(v.strip('"')) else: try: ll.cur_node.traits[tr]=list(map(float,val[1:-1].split(','))) except: print('some other trait: %s'%(vals)) if len(figtree)>0: print('FigTree comment found, ignoring') i+=len(cerberus.group()) ## advance in tree string by however many characters it took to encode labels cerberus=re.match('([A-Za-z\_\-0-9\.]+)(\:|\;)',data[i:])## look for old school node labels if cerberus is not None: if verbose==True: print('old school comment found: %s'%(cerberus.group(1))) ll.cur_node.traits['label']=cerberus.group(1) i+=len(cerberus.group(1)) microcerberus=re.match('(\:)*([0-9\.\-Ee]+)',data[i:i+100]) ## look for branch lengths without comments if microcerberus is not None: if verbose==True: print('adding branch length (%d) %.6f'%(i,float(microcerberus.group(2)))) ll.cur_node.length=float(microcerberus.group(2)) ## set branch length of current node i+=len(microcerberus.group()) ## advance in tree string by however many characters it took to encode branch length if data[i] == ',' or data[i] == ')': ## look for bifurcations or clade ends i+=1 ## advance in tree string ll.cur_node=ll.cur_node.parent if data[i] == ';': ## look for string end return ll break ## end loop def make_treeJSON(JSONnode,json_translation,ll=None,verbose=False): if 'children' in JSONnode: ## only nodes have children new_node=node() else: new_node=leaf() new_node.numName=JSONnode[json_translation['name']] ## set leaf numName new_node.name=JSONnode[json_translation['name']] ## set leaf name to be the same if ll is None: ll=tree() ll.root=new_node if 'attr' in JSONnode: attr = JSONnode.pop('attr') JSONnode.update(attr) new_node.parent=ll.cur_node ## set parent-child relationships ll.cur_node.children.append(new_node) new_node.index=JSONnode[json_translation['name']] ## indexing is based on name new_node.traits={n:JSONnode[n] for n in list(JSONnode.keys()) if n!='children'} ## set traits to non-children attributes ll.Objects.append(new_node) ll.cur_node=new_node if 'children' in JSONnode: for child in JSONnode['children']: make_treeJSON(child,json_translation,ll) ll.cur_node=ll.cur_node.parent return ll def loadNewick(tree_path,tip_regex='\|([0-9]+\-[0-9]+\-[0-9]+)',date_fmt='%Y-%m-%d',variableDate=True,absoluteTime=True,verbose=False): ll=None if isinstance(tree_path,str): handle=open(tree_path,'r') else: handle=tree_path for line in handle: l=line.strip('\n') if '(' in l: treeString_start=l.index('(') ll=make_tree(l[treeString_start:],verbose=verbose) ## send tree string to make_tree function if verbose==True: print('Identified tree string') assert ll,'Regular expression failed to find tree string' ll.traverse_tree(verbose=verbose) ## traverse tree ll.sortBranches() ## traverses tree, sorts branches, draws tree if absoluteTime==True: tipDates=[] for k in ll.getExternal(): n=k.numName k.name=k.numName cerberus=re.search(tip_regex,n) if cerberus is not None: tipDates.append(decimalDate(cerberus.group(1),fmt=date_fmt,variable=variableDate)) #highestTip=max(tipDates) #ll.setAbsoluteTime(highestTip) return ll def loadNexus(tree_path,tip_regex='\|([0-9]+\-[0-9]+\-[0-9]+)',date_fmt='%Y-%m-%d',treestring_regex='tree [A-Za-z\_]+([0-9]+)',variableDate=True,absoluteTime=True,verbose=False): tipFlag=False tips={} tipNum=0 ll=None if isinstance(tree_path,str): handle=open(tree_path,'r') else: handle=tree_path for line in handle: l=line.strip('\n') cerberus=re.search('dimensions ntax=([0-9]+);',l.lower()) if cerberus is not None: tipNum=int(cerberus.group(1)) if verbose==True: print('File should contain %d taxa'%(tipNum)) cerberus=re.search(treestring_regex,l) if cerberus is not None: treeString_start=l.index('(') ll=make_tree(l[treeString_start:]) ## send tree string to make_tree function if verbose==True: print('Identified tree string') if tipFlag==True: cerberus=re.search('([0-9]+) ([A-Za-z\-\_\/\.\'0-9 \|?]+)',l) if cerberus is not None: tips[cerberus.group(1)]=cerberus.group(2).strip('"').strip("'") if verbose==True: print('Identified tip translation %s: %s'%(cerberus.group(1),tips[cerberus.group(1)])) elif ';' not in l: print('tip not captured by regex:',l.replace('\t','')) if 'translate' in l.lower(): tipFlag=True if ';' in l: tipFlag=False assert ll,'Regular expression failed to find tree string' ll.traverse_tree() ## traverse tree ll.sortBranches() ## traverses tree, sorts branches, draws tree if len(tips)>0: ll.renameTips(tips) ## renames tips from numbers to actual names ll.tipMap=tips if absoluteTime==True: tipDates=[] for k in ll.getExternal(): if len(tips)>0: n=k.name else: n=k.numName cerberus=re.search(tip_regex,n) if cerberus is not None: tipDates.append(decimalDate(cerberus.group(1),fmt=date_fmt,variable=variableDate)) highestTip=max(tipDates) ll.setAbsoluteTime(highestTip) return ll def loadJSON(tree_path,json_translation={'name':'strain','absoluteTime':'num_date'},json_meta=None,verbose=False,sort=True,stats=True): """ Load a nextstrain JSON by providing either the path to JSON or a file handle. json_translation is a dictionary that translates JSON attributes to baltic branch attributes (e.g. 'absoluteTime' is called 'num_date' in nextstrain JSONs). Note that to avoid conflicts in setting node heights you can either define the absolute time of each node or branch lengths (e.g. if you want a substitution tree). """ assert 'name' in json_translation and ('absoluteTime' in json_translation or 'length' in json_translation),'JSON translation dictionary missing entries: %s'%(', '.join([entry for entry in ['name','height','absoluteTime','length'] if (entry in json_translation)==False])) if verbose==True: print('Reading JSON') if isinstance(tree_path,str): with open(tree_path) as json_data: d = json.load(json_data) ll=make_treeJSON(d,json_translation,verbose=verbose) else: ll=make_treeJSON(json.load(tree_path),json_translation,verbose=verbose) assert ('absoluteTime' in json_translation and 'length' not in json_translation) or ('absoluteTime' not in json_translation and 'length' in json_translation),'Cannot use both absolute time and branch length, include only one in json_translation dictionary.' for attr in json_translation: ## iterate through attributes in json_translation for k in ll.Objects: ## for every branch setattr(k,attr,k.traits[json_translation[attr]]) ## set attribute value for branch if 'absoluteTime' in json_translation: ## if using absoluteTime need to set branch lengths for traversals for k in ll.Objects: if json_translation['absoluteTime'] in k.parent.traits: k.length=k.traits[json_translation['absoluteTime']]-k.parent.traits[json_translation['absoluteTime']] else: k.length=0.0 if verbose==True: print('Traversing and drawing tree') ll.traverse_tree(verbose=verbose) ll.drawTree() if stats==True: ll.treeStats() ## initial traversal, checks for stats if sort==True: ll.sortBranches() ## traverses tree, sorts branches, draws tree if json_meta: if isinstance(json_meta,str): metadata=json.load(open(json_meta['file'],'r')) else: metadata=json.load(json_meta['file']) cmap=dict(metadata['color_options'][json_meta['traitName']]['color_map']) setattr(ll,'cmap',cmap) return ll if __name__ == '__main__': import sys ll=make_tree(sys.argv[1],ll) ll.traverse_tree() sys.stdout.write('%s\n'%(ll.treeHeight))