https://github.com/knorwood0/MNRVA
Tip revision: dc5512ff7c7ce6d5cd1de5bd7a8678193cdcf750 authored by Katrina on 07 September 2020, 19:27:37 UTC
updated user parse_arg options
updated user parse_arg options
Tip revision: dc5512f
Apical_Basal_Classification.py
import os
import XML_Parse as parse
import Euclidean_Distance as ed
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt
#######################################################################################################################
# Data Management and Organization:
def end_pt_df(root):
"""Function to create a dataframe of the coordinates of the end points within a skeleton"""
# Creating a dataframe of the each node coordinates and their ID:
node_position_df = parse.node_pos_df(root)
# Filtering the dataframe to only include endpoints and associated coordinates (the end nodes will be used to deterine if branch is apical or basal):
end_nodes_list = parse.end_points(root)
endpoints = node_position_df['Node ID'].isin(end_nodes_list)
endpoints_position_df = node_position_df[endpoints]
return endpoints_position_df
def average_list(list):
"""Finds the average value of a list"""
return sum(list)/len(list)
#######################################################################################################################
# Data Analysis:
def apical_basal_classifier(root):
"""Function to classify skeleton branches as either apical or basal based on the position of the branches end node in comparison to the soma"""
# Creating a dataframe of the each node coordinates and their ID:
endpoints_position_df = end_pt_df(root)
# Determining the Soma of each skeleton and creating a dataframe with its position:
soma_df = parse.soma_df(root)
# Creating an empty list that will contain whether each end node is apical or basal and two variables to represent the y positions of soma and end nodes:
class_list = []
endpts_y = endpoints_position_df.loc[:, (data_ori)]
soma_y = soma_df.loc[:, (data_ori)]
# Comparing the position of the end node and the soma to determine if the branch is apical or basal
for value in endpts_y:
for soma in soma_y:
if value > soma:
class_list.append('Apical')
if value < soma:
class_list.append('Basal')
if value == soma:
class_list.append('Uncategorized')
# Creating one final dataframe containing the positions and classifications:
endpoints_position_df['Classification'] = class_list
skeleton_df = pd.DataFrame(parse.skeleton_id(root), columns=['Skeleton ID'])
endpoints_position_df.reset_index(drop=True, inplace=True)
classification_df = pd.concat([endpoints_position_df, skeleton_df], axis=1)
return classification_df
def nodes_from_soma(root):
"""Function to determine the nodes coming directly off of the soma"""
# Determining the Soma of the skeleton:
soma_id = parse.find_soma(root)
# Creating a dataframe with all of the nodes and their connections, this will be used to find what the soma is connected to:
source_target_df = parse.source_target_pos_df(root)
soma_df = source_target_df.loc[source_target_df['Source ID'].isin(soma_id)]
# Determining which of the nodes are apical and basal depending on their 'y' position:
class_list = []
soma = soma_df.loc[0, 'x_x']
branch_nodes = soma_df.drop(['Source ID', 'x_x', 'y_x', 'z_x'], axis=1)
bny_columns = list(branch_nodes.loc[:, 'x_x'])
for value in bny_columns:
if value > soma:
class_list.append('Apical')
if value < soma:
class_list.append('Basal')
if value == soma:
class_list.append('Uncategorized')
# Creating one final dataframe containing the new information:
soma_df['Classification'] = class_list
skeleton_df = pd.DataFrame(parse.skeleton_id(root), columns=['Skeleton ID'])
soma_df.reset_index(drop=True, inplace=True)
classification_df = pd.concat([soma_df, skeleton_df], axis=1)
return classification_df
def nodes_from_soma_dict(root):
"""Purpose to create a dict of the soma's target node and its apical/basal classification"""
ab_class_node = nodes_from_soma(root)
ab_class_dict = dict(zip(ab_class_node.loc[:, 'Target ID'], ab_class_node.loc[:, 'Classification']))
return ab_class_dict
#######################################################################################################################
# Visualization
def ab_barchart(root):
"""Purpose of the function is to create a plotly bar chart of the distribution of apical/basal dendrites"""
# Getting the number of apical/basal branches per skeleton:
apical_list = []
basal_list = []
skeleton_list_int = parse.skeleton_id(parse.root)
skeleton_list = []
for i in skeleton_list_int:
i_string = 'ID ' + str(i)
skeleton_list.append(i_string)
for thing in parse.root.iter('thing'):
# Creating lists to count the number of time Apical and Basal appears within each Skeleton
skeleton_class_df = apical_basal_classifier(thing)
classes = skeleton_class_df['Classification'].tolist()
# Counting how many Apical / Basal branches appear:
apical_count = classes.count('Apical')
basal_count = classes.count('Basal')
# Appending the total counts:
apical_list.append(apical_count)
basal_list.append(basal_count)
# Configuring the bar chart:
classification_fig = go.Figure(data=[
go.Bar(name='Apical', x=skeleton_list, y=apical_list),
go.Bar(name='Basal', x=skeleton_list, y=basal_list)
])
# Changing labels and chart:
classification_fig.update_layout(barmode='stack')
classification_fig.update_layout(
title="Number of Apical and Basal Branches per Neuron",
xaxis_title='Skeleton ID',
yaxis_title='Total Number of Branches'
)
classification_fig.show()
def overlaid_histogram(root):
"""creates a histogram of the distribution of apical and basal branches within a data set"""
# Getting the number of apical/basal branches per skeleton:
apical_list = []
basal_list = []
for thing in parse.root.iter('thing'):
# Creating lists to count the number of time Apical and Basal appears within each Skeleton
skeleton_class_df = apical_basal_classifier(thing)
classes = skeleton_class_df['Classification'].tolist()
# Counting how many Apical / Basal branches appear:
apical_count = classes.count('Apical')
basal_count = classes.count('Basal')
# Appending the total counts:
apical_list.append(apical_count)
basal_list.append(basal_count)
class_histo = go.Figure()
class_histo.add_trace(go.Histogram(x=apical_list, name='Apical'))
class_histo.add_trace(go.Histogram(x=basal_list, name='Basal'))
# Overlay both histograms
class_histo.update_layout(barmode='overlay')
# Reduce opacity to see both histograms
class_histo.update_traces(opacity=0.50)
class_histo.update_layout(
title="Apical and Basal Branch Frequency",
xaxis_title='Number of Neurons',
yaxis_title='Frequency of Branch Type'
)
class_histo.show()
#######################################################################################################################
# Determining data set orientation:
data_ori = input("What coordinate plane is the upward direction in the original data set? (X, Y, or Z coordinate plane): ")
data_ori.islower()
# Creating a fail safe that checks the correct data orientation:
def check_orientation(root):
"""Determines the orientation of the data set based on the number of nodes in a direction"""
final_classx = []
final_classy = []
final_classz = []
for thing in parse.root.iter('thing'):
soma = parse.soma_df(thing)
endpts_df = end_pt_df(thing)
x_d_list = []
y_d_list = []
z_d_list = []
somax = list(soma.loc[:, 'x'])
somay = list(soma.loc[:, 'y'])
somaz = list(soma.loc[:, 'z'])
endptsx = list(endpts_df.loc[:, 'x'])
endptsy = list(endpts_df.loc[:, 'y'])
endptsz = list(endpts_df.loc[:, 'z'])
for soma_value in somax:
for value in endptsx:
distancex = abs(soma_value - value)
x_d_list.append(distancex)
for soma_value in somay:
for value in endptsy:
distancey = abs(soma_value - value)
y_d_list.append(distancey)
for soma_value in somaz:
for value in endptsz:
distancez = abs(soma_value - value)
z_d_list.append(distancez)
final_classx.append(max(x_d_list))
final_classy.append(max(y_d_list))
final_classz.append(max(z_d_list))
x = average_list(final_classx)
y = average_list(final_classy)
z = average_list(final_classz)
avg_dict = dict({'x': x, 'y': y, 'z': z})
# print(avg_dict)
return max(avg_dict, key=avg_dict.get)
# Should user not specify the correct orientation the following will be run:
if data_ori == '':
data_ori = check_orientation(parse.root)
print("Predicted orientation: ", data_ori)
#######################################################################################################################