import dendropy from dendropy.model.parsimony import fitch_down_pass from dendropy.model.parsimony import fitch_up_pass import pandas import csv import numpy import sys import re from Bio import Phylo from Bio import SeqIO from statistics import mean from statistics import stdev from scipy.stats import wilcoxon from scipy.stats import mannwhitneyu import bisect import operator import argparse import time import os ############################################################################### ## Load up the trees ########################################################## ############################################################################### def get_options(): purpose = '''This is a python script to output each branch's mutation rate. Usage: node_balance_and_mutation_get.py ''' parser = argparse.ArgumentParser(description=purpose, prog='node_mutation.py') parser.add_argument('--gubbins_res', required=True, help='Directory where all cluster dirs of gubbins res are stored"', type=str) parser.add_argument('--hit_locs_csv', required=True, help='Hit csv file from hit allocator', type=str) parser.add_argument('--out_name', required=True, help='Prefix to append to out out_files', type=str) args = parser.parse_args() return args def search_function(value, df, col_name): current_col = df[col_name] indexio = current_col.loc[current_col == value].index if len(indexio) > 1: indexio = indexio[0] if indexio.size > 0: row_indy = indexio return row_indy else: return "Not_here" def find_nth(haystack, needle, n): start = haystack.find(needle) while start >= 0 and n > 1: start = haystack.find(needle, start + len(needle)) n -= 1 return start def node_reconstruct(tree_loc, hit_csv): ## Function to reconstruct the insertion node of the individual cluster for a specific gps cluster tree = dendropy.Tree.get(path=tree_loc, schema="newick", preserve_underscores=True) ## hit df ## cluster_csv = hit_csv.copy() cluster_csv = cluster_csv.reset_index(drop=True) # print(tree.taxon_namespace) ## Now we'll create the mapping of the tree isolates to the cluster nums, with 0 being no hit. cluster_num_col = [1] tree_names = ["start"] for taxon in tree.taxon_namespace: current_name = taxon.label current_name_string = str(current_name) indy = search_function(current_name_string, cluster_csv, 'id') if isinstance(indy, str): multi_hit_name = current_name_string + "_1" indy_2 = search_function(multi_hit_name, cluster_csv, 'id') if isinstance(indy_2, str): cluster_num_col.append(0) else: cluster_val = cluster_csv.iloc[indy_2, cluster_csv.columns.get_loc("insert_name")] cluster_num_col.append(int(cluster_val)) else: cluster_val = cluster_csv.iloc[indy.values[0], cluster_csv.columns.get_loc("insert_name")] cluster_num_col.append(int(cluster_val)) tree_names.append(current_name_string) node_changer = tree.find_node_with_taxon_label(label=current_name) node_changer.taxon.label = dendropy.Taxon(label=current_name_string) cluster_num_col = cluster_num_col[1:] tree_names = tree_names[1:] ## Heres the dictionary with our insertion numbers and the tree tip ids mega_characters = dict(zip(tree_names, cluster_num_col)) with open("./fasta_data.tsv", 'w') as outfile: csv_writer = csv.writer(outfile, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL) for k, v in mega_characters.items(): csv_writer.writerow([k] + [v]) SeqIO.convert("./fasta_data.tsv", "tab", "./fasta_data.fasta", "fasta") taxa = dendropy.TaxonNamespace() data_mega = dendropy.StandardCharacterMatrix.get_from_path("./fasta_data.fasta", "fasta", taxon_namespace=taxa) taxon_state_sets_map = data_mega.taxon_state_sets_map(gaps_as_missing=True) tree = dendropy.Tree.get_from_path(tree_loc, schema="newick", preserve_underscores=True, taxon_namespace=taxa) os.remove("./fasta_data.tsv") os.remove("./fasta_data.fasta") score = fitch_down_pass(postorder_nodes=tree.postorder_node_iter(), taxon_state_sets_map=taxon_state_sets_map) fitch_up_pass(tree.postorder_node_iter()) ########################################################################### ## Ok so now the tree has been reconstructed with the fitch up pass down ## ## pass algorithm, we'll return the node labelled tree ################# ########################################################################### return tree def branch_mutations(tree, embl_csv, embl_reccy): ## Function to work out all the mutations and their positions along a branch. ## The embl csv gives you total mutations, both in and outside recombinations, so we first ## calculate the total, then to get just the base substitutions we check if any recombinations ## along the branches and remove those snps inside these events from the total to get the clonal or (outside) ## mutations. total_isolates = [] for taxon in tree.leaf_node_iter(): leaf_label = taxon.taxon # leaf_label = re.sub("\"","", leaf_label) total_isolates.append(leaf_label) start_id = [] isolates_ids = [] start_node = [] finish_node = [] tag = [] AT = [] AC = [] AG = [] TA = [] TC = [] TG = [] CA = [] CT = [] CG = [] GA = [] GT = [] GC = [] AT_outside = [] AC_outside = [] AG_outside = [] TA_outside = [] TC_outside = [] TG_outside = [] CA_outside = [] CT_outside = [] CG_outside = [] GA_outside = [] GT_outside = [] GC_outside = [] cluster_num = [] node_chain_len = [] end_nodes = [] particular_end_node = [] node_already_tested = [] edge_lengths = [] for isolate in range(len(total_isolates)): ## This loops through all the isolates to get all the branches ## current_row = total_isolates[isolate] isolate_test = str(current_row) isolate_test = re.sub("\'", "", isolate_test) nodes_in_insertion = [] edge_length = [] ## First part of the loop establishes the chain of nodes leading to the insertion ## node for the element. isolate_id = str(current_row) current_row = re.sub("\'", "", isolate_id) nodes_in_insertion.append(isolate_test) tree_node = tree.find_node_with_taxon_label(label=current_row) # print(tree_node) edge_length.append(tree_node.edge.length) parent_node = tree_node.parent_node while parent_node != None: new_parent = parent_node.parent_node old_parent = parent_node parent_node = new_parent nodes_in_insertion.append(old_parent.label) edge_length.append(old_parent.edge.length) if parent_node == None: end_node = old_parent.label if end_node not in end_nodes: end_nodes.append(end_node) for nodeys in range(len(nodes_in_insertion) - 1): current_end_id = nodes_in_insertion[nodeys] if current_end_id not in node_already_tested: source_node = nodes_in_insertion[nodeys + 1] target_node = current_end_id start_node.append(source_node) finish_node.append(target_node) tag.append("No") subset_embl_csv = embl_csv[embl_csv['end_node'] == current_end_id] A_T = 0 A_C = 0 A_G = 0 T_A = 0 T_C = 0 T_G = 0 C_A = 0 C_T = 0 C_G = 0 G_A = 0 G_T = 0 G_C = 0 A_T_outside = 0 A_C_outside = 0 A_G_outside = 0 T_A_outside = 0 T_C_outside = 0 T_G_outside = 0 C_A_outside = 0 C_T_outside = 0 C_G_outside = 0 G_A_outside = 0 G_T_outside = 0 G_C_outside = 0 if subset_embl_csv.empty: A_T += 0 A_C += 0 A_G += 0 T_A += 0 T_C += 0 T_G += 0 C_A += 0 C_T += 0 C_G += 0 G_A += 0 G_T += 0 G_C += 0 A_T_outside += 0 A_C_outside += 0 A_G_outside += 0 T_A_outside += 0 T_C_outside += 0 T_G_outside += 0 C_A_outside += 0 C_T_outside += 0 C_G_outside += 0 G_A_outside += 0 G_T_outside += 0 G_C_outside += 0 else: A_to_T = len(subset_embl_csv[(subset_embl_csv['start_base'] == "A") & ( subset_embl_csv['end_base'] == "T")].index) A_to_G = len(subset_embl_csv[(subset_embl_csv['start_base'] == "A") & ( subset_embl_csv['end_base'] == "G")].index) A_to_C = len(subset_embl_csv[(subset_embl_csv['start_base'] == "A") & ( subset_embl_csv['end_base'] == "C")].index) T_to_A = len(subset_embl_csv[(subset_embl_csv['start_base'] == "T") & ( subset_embl_csv['end_base'] == "A")].index) T_to_G = len(subset_embl_csv[(subset_embl_csv['start_base'] == "T") & ( subset_embl_csv['end_base'] == "G")].index) T_to_C = len(subset_embl_csv[(subset_embl_csv['start_base'] == "T") & ( subset_embl_csv['end_base'] == "C")].index) G_to_A = len(subset_embl_csv[(subset_embl_csv['start_base'] == "G") & ( subset_embl_csv['end_base'] == "A")].index) G_to_T = len(subset_embl_csv[(subset_embl_csv['start_base'] == "G") & ( subset_embl_csv['end_base'] == "T")].index) G_to_C = len(subset_embl_csv[(subset_embl_csv['start_base'] == "G") & ( subset_embl_csv['end_base'] == "C")].index) C_to_A = len(subset_embl_csv[(subset_embl_csv['start_base'] == "C") & ( subset_embl_csv['end_base'] == "A")].index) C_to_T = len(subset_embl_csv[(subset_embl_csv['start_base'] == "C") & ( subset_embl_csv['end_base'] == "T")].index) C_to_G = len(subset_embl_csv[(subset_embl_csv['start_base'] == "C") & ( subset_embl_csv['end_base'] == "G")].index) A_T += (A_to_T) A_C += (A_to_C) A_G += (A_to_G) T_A += (T_to_A) T_C += (T_to_C) T_G += (T_to_G) C_A += (C_to_A) C_T += (C_to_T) C_G += (C_to_G) G_A += (G_to_A) G_T += (G_to_T) G_C += (G_to_C) subset_reccy_csv = embl_reccy[embl_reccy['end_node'] == current_end_id] if subset_reccy_csv.empty: A_T_outside += A_to_T A_C_outside += A_to_C A_G_outside += A_to_G T_A_outside += T_to_A T_C_outside += T_to_C T_G_outside += T_to_G C_A_outside += C_to_A C_T_outside += C_to_T C_G_outside += C_to_G G_A_outside += G_to_A G_T_outside += G_to_T G_C_outside += G_to_C else: for reccy_row in range(len(subset_reccy_csv.index)): current_start_reccy = subset_reccy_csv.iloc[reccy_row, 2] current_end_reccy = subset_reccy_csv.iloc[reccy_row, 3] for mut_row in range(len(subset_embl_csv.index)): base_pos = subset_embl_csv.iloc[mut_row, 4] start_base = subset_embl_csv.iloc[mut_row, 2] end_base = subset_embl_csv.iloc[mut_row, 3] if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "A" and end_base == "T": A_to_T -= 1 if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "A" and end_base == "C": A_to_C -= 1 if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "A" and end_base == "G": A_to_G -= 1 if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "T" and end_base == "A": T_to_A -= 1 if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "T" and end_base == "C": T_to_C -= 1 if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "T" and end_base == "G": T_to_G -= 1 if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "C" and end_base == "A": C_to_A -= 1 if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "C" and end_base == "T": C_to_T -= 1 if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "C" and end_base == "G": C_to_G -= 1 if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "G" and end_base == "A": G_to_A -= 1 if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "G" and end_base == "T": G_to_T -= 1 if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "G" and end_base == "C": G_to_C -= 1 A_T_outside += A_to_T A_C_outside += A_to_C A_G_outside += A_to_G T_A_outside += T_to_A T_C_outside += T_to_C T_G_outside += T_to_G C_A_outside += C_to_A C_T_outside += C_to_T C_G_outside += C_to_G G_A_outside += G_to_A G_T_outside += G_to_T G_C_outside += G_to_C AT.append(A_T) AC.append(A_C) AG.append(A_G) TA.append(T_A) TC.append(T_C) TG.append(T_G) CA.append(C_A) CT.append(C_T) CG.append(C_G) GA.append(G_A) GT.append(G_T) GC.append(G_C) AT_outside.append(A_T_outside) AC_outside.append(A_C_outside) AG_outside.append(A_G_outside) TA_outside.append(T_A_outside) TC_outside.append(T_C_outside) TG_outside.append(T_G_outside) CA_outside.append(C_A_outside) CT_outside.append(C_T_outside) CG_outside.append(C_G_outside) GA_outside.append(G_A_outside) GT_outside.append(G_T_outside) GC_outside.append(G_C_outside) edge_lengths.append(edge_length[nodeys]) node_already_tested.append(current_end_id) start_id.append(current_row) if isolate % 10 == 0: print(isolate / len(total_isolates) * 100) non_mge_mutations_out = pandas.DataFrame({'start_node': start_node, 'end_node': finish_node, 'tag': tag, 'A-T': AT, 'A-C': AC, 'A-G': AG, 'T-A': TA, 'T-C': TC, 'T-G': TG, 'C-A': CA, 'C-T': CT, 'C-G': CG, 'G-A': GA, 'G-T': GT, 'G-C': GC, 'A-T_clonal': AT_outside, 'A-C_clonal': AC_outside, 'A-G_clonal': AG_outside, 'T-A_clonal': TA_outside, 'T-C_clonal': TC_outside, 'T-G_clonal': TG_outside, 'C-A_clonal': CA_outside, 'C-T_clonal': CT_outside, 'C-G_clonal': CG_outside, 'G-A_clonal': GA_outside, 'G-T_clonal': GT_outside, 'G-C_clonal': GC_outside, 'branch_lengths': edge_lengths, 'starting_isolate': start_id}) return non_mge_mutations_out if __name__ == '__main__': start_overall = time.perf_counter() input_args = get_options() ############################################################################### ## Load up the csv and then run through each of the clusters to check through # ## the res #################################################################### ############################################################################### cluster_csv = pandas.read_csv(input_args.hit_locs_csv) base_loc = input_args.gubbins_res unique_clusters = cluster_csv['cluster_name'].unique() tot_reccy_csv = pandas.DataFrame() tot_non_reccy = pandas.DataFrame() seq_clus = 1 for cluster in unique_clusters: print("On cluster: %s, %s of %s" % (cluster, seq_clus, len(unique_clusters))) tic_cluster = time.perf_counter() current_dat = cluster_csv[cluster_csv['cluster_name'] == cluster] current_dir = base_loc + cluster try: cluster_files = os.listdir(current_dir) except: current_dir = current_dir + "_run_data" cluster_files = os.listdir(current_dir) tree_indexio = [k for k, s in enumerate(cluster_files) if "node_labelled.final_tree.tre" in s] embl_branch = [k for k, s in enumerate(cluster_files) if "_branch_base.csv" in s] embl_reccy = [k for k, s in enumerate(cluster_files) if "_recombinations.csv" in s] per_branch_file = [k for k, s in enumerate(cluster_files) if "_per_branch_mutations.csv" in s] if len(per_branch_file) > 0: print("Using already formed per_branch_mutations.csv") continue tree_loc = current_dir + "/" + cluster_files[tree_indexio[0]] embl_branch_loc = current_dir + "/" + cluster_files[embl_branch[0]] embl_rec_loc = current_dir + "/" + cluster_files[embl_reccy[0]] embl_csv = pandas.read_csv(embl_branch_loc) embl_reccy_csv = pandas.read_csv(embl_rec_loc) ## So now we've got all the files we need for this particular cluster, we'll run through ## the tree to get the nodes labelled with the inserts. Then we'll run through the ## branches and get the summary of the mutations present #tree = node_reconstruct(tree_loc, current_dat) tree = dendropy.Tree.get(path=tree_loc, schema="newick", preserve_underscores=True) branches_csv = branch_mutations(tree,embl_csv, embl_reccy_csv) branches_csv['cluster_name'] = cluster branches_out_loc = current_dir + "/" + cluster + "_per_branch_mutations.csv" branches_csv.to_csv(path_or_buf=branches_out_loc, index=False) seq_clus += 1 toc_tot = time.perf_counter() print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") print("Mutation finder took: %s (seconds)" % (toc_tot - start_overall)) print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Branch mutations found ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")