Raw File
import dendropy
from dendropy.model.parsimony import fitch_down_pass
from dendropy.model.parsimony import fitch_up_pass
import pandas
import csv
import numpy
import sys
import re
from Bio import Phylo
from Bio import SeqIO
from statistics import mean
from statistics import stdev
from scipy.stats import wilcoxon
from scipy.stats import mannwhitneyu
import bisect
import operator
import argparse
import time
import os
###############################################################################
## Load up the trees ##########################################################
###############################################################################

def get_options():

    purpose = '''This is a python script to output each branch's mutation rate.
    Usage: node_balance_and_mutation_get.py <gubbins_dir> <hit_locs_csv> <out_csv> '''
    parser = argparse.ArgumentParser(description=purpose, prog='node_mutation.py')

    parser.add_argument('--gubbins_res', required=True, help='Directory where all cluster dirs of gubbins res are stored"', type=str)
    parser.add_argument('--hit_locs_csv', required=True, help='Hit csv file from hit allocator', type=str)
    parser.add_argument('--out_name', required=True, help='Prefix to append to out out_files', type=str)

    args = parser.parse_args()

    return args

def search_function(value, df, col_name):
    current_col = df[col_name]
    indexio = current_col.loc[current_col == value].index
    if len(indexio) > 1:
        indexio = indexio[0]
    if indexio.size > 0:
        row_indy = indexio

        return row_indy

    else:
        return "Not_here"


def find_nth(haystack, needle, n):
    start = haystack.find(needle)
    while start >= 0 and n > 1:
        start = haystack.find(needle, start + len(needle))
        n -= 1
    return start

def node_reconstruct(tree_loc, hit_csv):
    ## Function to reconstruct the insertion node of the individual cluster for a specific gps cluster


    tree = dendropy.Tree.get(path=tree_loc,
                             schema="newick", preserve_underscores=True)


    ## hit df ##
    cluster_csv = hit_csv.copy()
    cluster_csv = cluster_csv.reset_index(drop=True)

    # print(tree.taxon_namespace)
    ## Now we'll create the mapping of the tree isolates to the cluster nums, with 0 being no hit.
    cluster_num_col = [1]
    tree_names = ["start"]
    for taxon in tree.taxon_namespace:
        current_name = taxon.label
        current_name_string = str(current_name)
        indy = search_function(current_name_string, cluster_csv, 'id')
        if isinstance(indy, str):
            multi_hit_name = current_name_string + "_1"
            indy_2 = search_function(multi_hit_name, cluster_csv, 'id')
            if isinstance(indy_2, str):
                cluster_num_col.append(0)
            else:
                cluster_val = cluster_csv.iloc[indy_2, cluster_csv.columns.get_loc("insert_name")]
                cluster_num_col.append(int(cluster_val))
        else:

            cluster_val = cluster_csv.iloc[indy.values[0], cluster_csv.columns.get_loc("insert_name")]
            cluster_num_col.append(int(cluster_val))
        tree_names.append(current_name_string)

        node_changer = tree.find_node_with_taxon_label(label=current_name)
        node_changer.taxon.label = dendropy.Taxon(label=current_name_string)

    cluster_num_col = cluster_num_col[1:]
    tree_names = tree_names[1:]

    ## Heres the dictionary with our insertion numbers and the tree tip ids
    mega_characters = dict(zip(tree_names, cluster_num_col))

    with open("./fasta_data.tsv", 'w') as outfile:
        csv_writer = csv.writer(outfile, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL)

        for k, v in mega_characters.items():
            csv_writer.writerow([k] + [v])

    SeqIO.convert("./fasta_data.tsv",
                  "tab", "./fasta_data.fasta",
                  "fasta")

    taxa = dendropy.TaxonNamespace()

    data_mega = dendropy.StandardCharacterMatrix.get_from_path("./fasta_data.fasta",
                                                               "fasta", taxon_namespace=taxa)
    taxon_state_sets_map = data_mega.taxon_state_sets_map(gaps_as_missing=True)

    tree = dendropy.Tree.get_from_path(tree_loc,
                                       schema="newick", preserve_underscores=True, taxon_namespace=taxa)

    os.remove("./fasta_data.tsv")
    os.remove("./fasta_data.fasta")

    score = fitch_down_pass(postorder_nodes=tree.postorder_node_iter(),
                            taxon_state_sets_map=taxon_state_sets_map)
    fitch_up_pass(tree.postorder_node_iter())

    ###########################################################################
    ## Ok so now the tree has been reconstructed with the fitch up pass down ##
    ## pass algorithm, we'll return the node labelled tree    #################
    ###########################################################################


    return tree

def branch_mutations(tree, embl_csv, embl_reccy):
    ## Function to work out all the mutations and their positions along a branch.
    ## The embl csv gives you total mutations, both in and outside recombinations, so we first
    ## calculate the total, then to get just the base substitutions we check if any recombinations
    ## along the branches and remove those snps inside these events from the total to get the clonal or (outside)
    ## mutations.



    total_isolates = []

    for taxon in tree.leaf_node_iter():
        leaf_label = taxon.taxon

        # leaf_label = re.sub("\"","", leaf_label)
        total_isolates.append(leaf_label)

    start_id = []
    isolates_ids = []
    start_node = []
    finish_node = []
    tag = []
    AT = []
    AC = []
    AG = []
    TA = []
    TC = []
    TG = []
    CA = []
    CT = []
    CG = []
    GA = []
    GT = []
    GC = []
    AT_outside = []
    AC_outside = []
    AG_outside = []
    TA_outside = []
    TC_outside = []
    TG_outside = []
    CA_outside = []
    CT_outside = []
    CG_outside = []
    GA_outside = []
    GT_outside = []
    GC_outside = []
    cluster_num = []
    node_chain_len = []
    end_nodes = []
    particular_end_node = []
    node_already_tested = []
    edge_lengths = []


    for isolate in range(len(total_isolates)):
        ## This loops through all the isolates to get all the branches
        ##

        current_row = total_isolates[isolate]
        isolate_test = str(current_row)
        isolate_test = re.sub("\'", "", isolate_test)

        nodes_in_insertion = []
        edge_length = []

        ## First part of the loop establishes the chain of nodes leading to the insertion
        ## node for the element.

        isolate_id = str(current_row)

        current_row = re.sub("\'", "", isolate_id)
        nodes_in_insertion.append(isolate_test)
        tree_node = tree.find_node_with_taxon_label(label=current_row)
        # print(tree_node)
        edge_length.append(tree_node.edge.length)
        parent_node = tree_node.parent_node


        while parent_node != None:
            new_parent = parent_node.parent_node
            old_parent = parent_node
            parent_node = new_parent

            nodes_in_insertion.append(old_parent.label)
            edge_length.append(old_parent.edge.length)

            if parent_node == None:
                end_node = old_parent.label
                if end_node not in end_nodes:
                    end_nodes.append(end_node)

        for nodeys in range(len(nodes_in_insertion) - 1):
            current_end_id = nodes_in_insertion[nodeys]
            if current_end_id not in node_already_tested:
                source_node = nodes_in_insertion[nodeys + 1]
                target_node = current_end_id
                start_node.append(source_node)
                finish_node.append(target_node)
                tag.append("No")

                subset_embl_csv = embl_csv[embl_csv['end_node'] == current_end_id]

                A_T = 0
                A_C = 0
                A_G = 0
                T_A = 0
                T_C = 0
                T_G = 0
                C_A = 0
                C_T = 0
                C_G = 0
                G_A = 0
                G_T = 0
                G_C = 0
                A_T_outside = 0
                A_C_outside = 0
                A_G_outside = 0
                T_A_outside = 0
                T_C_outside = 0
                T_G_outside = 0
                C_A_outside = 0
                C_T_outside = 0
                C_G_outside = 0
                G_A_outside = 0
                G_T_outside = 0
                G_C_outside = 0

                if subset_embl_csv.empty:
                    A_T += 0
                    A_C += 0
                    A_G += 0
                    T_A += 0
                    T_C += 0
                    T_G += 0
                    C_A += 0
                    C_T += 0
                    C_G += 0
                    G_A += 0
                    G_T += 0
                    G_C += 0
                    A_T_outside += 0
                    A_C_outside += 0
                    A_G_outside += 0
                    T_A_outside += 0
                    T_C_outside += 0
                    T_G_outside += 0
                    C_A_outside += 0
                    C_T_outside += 0
                    C_G_outside += 0
                    G_A_outside += 0
                    G_T_outside += 0
                    G_C_outside += 0



                else:

                    A_to_T = len(subset_embl_csv[(subset_embl_csv['start_base'] == "A") & (
                                subset_embl_csv['end_base'] == "T")].index)
                    A_to_G = len(subset_embl_csv[(subset_embl_csv['start_base'] == "A") & (
                                subset_embl_csv['end_base'] == "G")].index)
                    A_to_C = len(subset_embl_csv[(subset_embl_csv['start_base'] == "A") & (
                                subset_embl_csv['end_base'] == "C")].index)
                    T_to_A = len(subset_embl_csv[(subset_embl_csv['start_base'] == "T") & (
                                subset_embl_csv['end_base'] == "A")].index)
                    T_to_G = len(subset_embl_csv[(subset_embl_csv['start_base'] == "T") & (
                                subset_embl_csv['end_base'] == "G")].index)
                    T_to_C = len(subset_embl_csv[(subset_embl_csv['start_base'] == "T") & (
                                subset_embl_csv['end_base'] == "C")].index)
                    G_to_A = len(subset_embl_csv[(subset_embl_csv['start_base'] == "G") & (
                                subset_embl_csv['end_base'] == "A")].index)
                    G_to_T = len(subset_embl_csv[(subset_embl_csv['start_base'] == "G") & (
                                subset_embl_csv['end_base'] == "T")].index)
                    G_to_C = len(subset_embl_csv[(subset_embl_csv['start_base'] == "G") & (
                                subset_embl_csv['end_base'] == "C")].index)
                    C_to_A = len(subset_embl_csv[(subset_embl_csv['start_base'] == "C") & (
                                subset_embl_csv['end_base'] == "A")].index)
                    C_to_T = len(subset_embl_csv[(subset_embl_csv['start_base'] == "C") & (
                                subset_embl_csv['end_base'] == "T")].index)
                    C_to_G = len(subset_embl_csv[(subset_embl_csv['start_base'] == "C") & (
                                subset_embl_csv['end_base'] == "G")].index)

                    A_T += (A_to_T)
                    A_C += (A_to_C)
                    A_G += (A_to_G)
                    T_A += (T_to_A)
                    T_C += (T_to_C)
                    T_G += (T_to_G)
                    C_A += (C_to_A)
                    C_T += (C_to_T)
                    C_G += (C_to_G)
                    G_A += (G_to_A)
                    G_T += (G_to_T)
                    G_C += (G_to_C)

                    subset_reccy_csv = embl_reccy[embl_reccy['end_node'] == current_end_id]

                    if subset_reccy_csv.empty:
                        A_T_outside += A_to_T
                        A_C_outside += A_to_C
                        A_G_outside += A_to_G
                        T_A_outside += T_to_A
                        T_C_outside += T_to_C
                        T_G_outside += T_to_G
                        C_A_outside += C_to_A
                        C_T_outside += C_to_T
                        C_G_outside += C_to_G
                        G_A_outside += G_to_A
                        G_T_outside += G_to_T
                        G_C_outside += G_to_C
                    else:
                        for reccy_row in range(len(subset_reccy_csv.index)):
                            current_start_reccy = subset_reccy_csv.iloc[reccy_row, 2]
                            current_end_reccy = subset_reccy_csv.iloc[reccy_row, 3]
                            for mut_row in range(len(subset_embl_csv.index)):
                                base_pos = subset_embl_csv.iloc[mut_row, 4]
                                start_base = subset_embl_csv.iloc[mut_row, 2]
                                end_base = subset_embl_csv.iloc[mut_row, 3]
                                if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "A" and end_base == "T":
                                    A_to_T -= 1
                                if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "A" and end_base == "C":
                                    A_to_C -= 1
                                if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "A" and end_base == "G":
                                    A_to_G -= 1
                                if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "T" and end_base == "A":
                                    T_to_A -= 1
                                if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "T" and end_base == "C":
                                    T_to_C -= 1
                                if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "T" and end_base == "G":
                                    T_to_G -= 1
                                if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "C" and end_base == "A":
                                    C_to_A -= 1
                                if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "C" and end_base == "T":
                                    C_to_T -= 1
                                if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "C" and end_base == "G":
                                    C_to_G -= 1
                                if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "G" and end_base == "A":
                                    G_to_A -= 1
                                if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "G" and end_base == "T":
                                    G_to_T -= 1
                                if base_pos >= current_start_reccy and base_pos <= current_end_reccy and start_base == "G" and end_base == "C":
                                    G_to_C -= 1

                        A_T_outside += A_to_T
                        A_C_outside += A_to_C
                        A_G_outside += A_to_G
                        T_A_outside += T_to_A
                        T_C_outside += T_to_C
                        T_G_outside += T_to_G
                        C_A_outside += C_to_A
                        C_T_outside += C_to_T
                        C_G_outside += C_to_G
                        G_A_outside += G_to_A
                        G_T_outside += G_to_T
                        G_C_outside += G_to_C

                AT.append(A_T)
                AC.append(A_C)
                AG.append(A_G)
                TA.append(T_A)
                TC.append(T_C)
                TG.append(T_G)
                CA.append(C_A)
                CT.append(C_T)
                CG.append(C_G)
                GA.append(G_A)
                GT.append(G_T)
                GC.append(G_C)
                AT_outside.append(A_T_outside)
                AC_outside.append(A_C_outside)
                AG_outside.append(A_G_outside)
                TA_outside.append(T_A_outside)
                TC_outside.append(T_C_outside)
                TG_outside.append(T_G_outside)
                CA_outside.append(C_A_outside)

                CT_outside.append(C_T_outside)
                CG_outside.append(C_G_outside)
                GA_outside.append(G_A_outside)
                GT_outside.append(G_T_outside)
                GC_outside.append(G_C_outside)
                edge_lengths.append(edge_length[nodeys])
                node_already_tested.append(current_end_id)
                start_id.append(current_row)

        if isolate % 10 == 0:
            print(isolate / len(total_isolates) * 100)

    non_mge_mutations_out = pandas.DataFrame({'start_node': start_node,
                                              'end_node': finish_node,
                                              'tag': tag,
                                              'A-T': AT,
                                              'A-C': AC,
                                              'A-G': AG,
                                              'T-A': TA,
                                              'T-C': TC,
                                              'T-G': TG,
                                              'C-A': CA,
                                              'C-T': CT,
                                              'C-G': CG,
                                              'G-A': GA,
                                              'G-T': GT,
                                              'G-C': GC,
                                              'A-T_clonal': AT_outside,
                                              'A-C_clonal': AC_outside,
                                              'A-G_clonal': AG_outside,
                                              'T-A_clonal': TA_outside,
                                              'T-C_clonal': TC_outside,
                                              'T-G_clonal': TG_outside,
                                              'C-A_clonal': CA_outside,
                                              'C-T_clonal': CT_outside,
                                              'C-G_clonal': CG_outside,
                                              'G-A_clonal': GA_outside,
                                              'G-T_clonal': GT_outside,
                                              'G-C_clonal': GC_outside,
                                              'branch_lengths': edge_lengths,
                                              'starting_isolate': start_id})

    return non_mge_mutations_out


if __name__ == '__main__':

    start_overall = time.perf_counter()

    input_args = get_options()
    ###############################################################################
    ## Load up the csv and then run through each of the clusters to check through #
    ## the res ####################################################################
    ###############################################################################

    cluster_csv = pandas.read_csv(input_args.hit_locs_csv)

    base_loc = input_args.gubbins_res
    unique_clusters = cluster_csv['cluster_name'].unique()

    tot_reccy_csv = pandas.DataFrame()
    tot_non_reccy = pandas.DataFrame()

    seq_clus = 1
    for cluster in unique_clusters:
        print("On cluster: %s, %s of %s" % (cluster, seq_clus, len(unique_clusters)))
        tic_cluster = time.perf_counter()
        current_dat = cluster_csv[cluster_csv['cluster_name'] == cluster]
        current_dir = base_loc + cluster
        try:
            cluster_files = os.listdir(current_dir)
        except:
            current_dir = current_dir + "_run_data"
            cluster_files = os.listdir(current_dir)
        tree_indexio = [k for k, s in enumerate(cluster_files) if "node_labelled.final_tree.tre" in s]
        embl_branch = [k for k, s in enumerate(cluster_files) if "_branch_base.csv" in s]
        embl_reccy = [k for k, s in enumerate(cluster_files) if "_recombinations.csv" in s]
        per_branch_file = [k for k, s in enumerate(cluster_files) if "_per_branch_mutations.csv" in s]

        if len(per_branch_file) > 0:
            print("Using already formed per_branch_mutations.csv")
            continue

        tree_loc = current_dir + "/" + cluster_files[tree_indexio[0]]
        embl_branch_loc = current_dir + "/" + cluster_files[embl_branch[0]]
        embl_rec_loc = current_dir + "/" + cluster_files[embl_reccy[0]]

        embl_csv = pandas.read_csv(embl_branch_loc)
        embl_reccy_csv = pandas.read_csv(embl_rec_loc)

        ## So now we've got all the files we need for this particular cluster, we'll run through
        ## the tree to get the nodes labelled with the inserts. Then we'll run through the
        ## branches and get the summary of the mutations present

        #tree = node_reconstruct(tree_loc, current_dat)
        tree = dendropy.Tree.get(path=tree_loc,
                                 schema="newick", preserve_underscores=True)

        branches_csv = branch_mutations(tree,embl_csv, embl_reccy_csv)

        branches_csv['cluster_name'] = cluster


        branches_out_loc = current_dir + "/" + cluster + "_per_branch_mutations.csv"
        branches_csv.to_csv(path_or_buf=branches_out_loc, index=False)
        seq_clus += 1

    toc_tot = time.perf_counter()

    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
    print("Mutation finder took: %s (seconds)" % (toc_tot - start_overall))
    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Branch mutations found ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")




back to top