https://github.com/scw97/PiezoPaperExpressoCode
Raw File
Tip revision: bd8a58fa0e4f796e2ed0b72fe807862305b84b6b authored by scw97 on 05 February 2021, 20:18:22 UTC
changed to exact version used in paper
Tip revision: bd8a58f
compare_gui_vs_annotations.py
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 21 23:25:11 2017

@author: Fruit Flies
"""
import sys

from os import listdir
from os.path import isfile, join, splitext

import numpy as np
from scipy import signal, interpolate
from matplotlib import pyplot as plt

import csv

from load_hdf5_data import load_hdf5
from bout_analysis_func import bout_analysis
from expresso_gui_params import analysisParams
#------------------------------------------------------------------------------
data_dir = 'F:\\Dropbox\\Sam\\eBrewerQ\\annotationsdatafiles_hdf5\\'
annotations_dir = 'F:\\Dropbox\\Sam\\eBrewerQ\\observer_corrected\\'

if sys.version_info[0] < 3:
    filekeyname = unicode('XP02') 
    groupkeyname = unicode('channel_2') 
else:
    filekeyname = 'XP02' 
    groupkeyname = 'channel_2' 

min_bout_duration = analysisParams['min_bout_duration']
min_bout_volume = analysisParams['min_bout_volume']

data_filenames = [f for f in listdir(data_dir) if 
                    isfile(join(data_dir, f)) and f.endswith('.hdf5')]
annotation_filenames = [f for f in listdir(annotations_dir) if 
                    isfile(join(annotations_dir, f)) and f.endswith('.csv')]

data_filenames_int = np.empty(shape=(len(data_filenames),), dtype=int)
for kth in np.arange(len(data_filenames)):
    fname = data_filenames[kth]    
    fname_split = splitext(fname)
    fname_intstr = fname_split[0]
    data_filenames_int[kth] = int(fname_intstr)

filename_sort_ind = np.argsort(data_filenames_int)
data_filenames_sorted = [data_filenames[sort_ind] for sort_ind in filename_sort_ind]
annotation_filenames_sorted = [annotation_filenames[sort_ind] \
                                for sort_ind in filename_sort_ind]    
#------------------------------------------------------------------------------
# initialize list for data storage
bouts_list_data = [] 
bouts_list_annotations = []
bouts_list_oldcode = [] 
dset_size_list = []

# initialize arrays for comparison metrics. First column will compare my code
#   and the user annotations. Second will compare my code and old code
norm_hamm_dist = np.ndarray(shape=(len(data_filenames),2))
num_bouts_comp = np.ndarray(shape=(len(data_filenames),2))
bout_edges_comp = np.ndarray(shape=(len(data_filenames),2))
                    
for ind in np.arange(0,len(data_filenames_sorted)):
    
    #load and analyze data 
    data_file = join(data_dir,data_filenames_sorted[ind])     
    dset, t = load_hdf5(data_file,filekeyname,groupkeyname)
        
    dset_check = (dset != -1)
    if (np.sum(dset_check) == 0):
        messagestr = "Bad dataset: " + data_file
        print(messagestr)
        continue 
    
    dset_size = dset.size     
    frames = np.arange(0,dset_size)
    dset_size_list.append(dset_size) 
    
    dset = dset[dset_check]
    frames = frames[np.squeeze(dset_check)]
    t = t[dset_check]
    
    new_frames = np.arange(0,np.max(frames)+1)
    sp_raw = interpolate.InterpolatedUnivariateSpline(frames, dset)
    sp_t = interpolate.InterpolatedUnivariateSpline(frames, t)
    dset = sp_raw(new_frames)
    t = sp_t(new_frames)
    frames = new_frames
        
    _, bouts_data, _ = bout_analysis(dset,frames)
    bouts_list_data.append(bouts_data)
    
    data_binary_array = np.zeros(dset_size)
    for ith in np.arange(0,bouts_data.shape[1]):
        data_binary_array[bouts_data[0,ith]:bouts_data[1,ith]] = 1 
        
    #get annotations
    annotations_file = join(annotations_dir,annotation_filenames_sorted[ind])
    csv_rows = [] 
    with open(annotations_file, 'rb') as csvfile:
        annotations_reader = csv.reader(csvfile)
        for row in annotations_reader:
            csv_rows.append(row)
    
    #clugy, need to fix
    if csv_rows[1][3] == ' ' :
        bouts_annotation = np.empty(shape=(2L, 0L), dtype=int)
        bouts_oldcode = np.empty(shape=(2L, 0L), dtype=int)
    
        annotation_binary_array = np.zeros(dset_size)
        oldcode_binary_array = np.zeros(dset_size)   
    else:    
        bouts_annotation = np.ndarray(shape=(2,len(csv_rows)-1), dtype=int)
        bouts_oldcode = np.ndarray(shape=(2,len(csv_rows)-1), dtype=int)
        
        annotation_binary_array = np.zeros(dset_size)
        oldcode_binary_array = np.zeros(dset_size)       
        for row_ind in np.arange(1,len(csv_rows)):
            row_curr = csv_rows[row_ind]
            
            #user annotated bout timing        
            bout_start_ann_t = int(row_curr[3]) - 1
            bout_end_ann_t = int(row_curr[4]) - 1
            bout_ann_duration = int(row_curr[5])
            
            bout_start_ann = np.searchsorted(t,bout_start_ann_t,side='right')
            bout_end_ann = np.searchsorted(t,bout_end_ann_t,side='right')
            
            #old code annotated bout timing
            bout_start_oc = int(float(row_curr[6])) - 1
            bout_end_oc = int(float(row_curr[7])) - 1 
            bout_oc_duration = int(float(row_curr[8]))
            
            #remove bouts with too short duration and concatenate
            if bout_ann_duration < min_bout_duration:
                bouts_annotation[:,row_ind-1] = np.full((2,),np.nan)
            else:    
                bout_ann = np.vstack((bout_start_ann, bout_end_ann))
                bouts_annotation[:,row_ind-1] = np.squeeze(bout_ann)
                annotation_binary_array[bout_start_ann:bout_end_ann] = 1
                
            if bout_oc_duration < min_bout_duration:
                bouts_oldcode[:,row_ind-1] = np.full((2,),np.nan)
            else:    
                bout_oc = np.vstack((bout_start_oc, bout_end_oc))
                bouts_oldcode[:,row_ind-1] = np.squeeze(bout_oc)
                oldcode_binary_array[bout_start_oc:bout_end_oc] = 1    
        
        bouts_annotation = bouts_annotation[:,~np.isnan(bouts_annotation[1,:])]
        bouts_oldcode = bouts_oldcode[:,~np.isnan(bouts_oldcode[1,:])]
    
    #add to lists    
    bouts_list_annotations.append(bouts_annotation)
    bouts_list_oldcode.append(bouts_oldcode)

    #analyze differences. Want to look at:
    #   -normalized hamming distance
    #   -bout edges
    #   -number of bouts detected
    bout_num_diff_1 = bouts_data.shape[1] - bouts_annotation.shape[1] 
    bout_num_diff_2 = bouts_data.shape[1] - bouts_oldcode.shape[1]         
    num_bouts_comp[ind,0] = bout_num_diff_1
    num_bouts_comp[ind,1] = bout_num_diff_2
    
    hamming_dist_1 = np.sum(np.abs(data_binary_array - annotation_binary_array))
    hamming_dist_2 = np.sum(np.abs(data_binary_array - oldcode_binary_array))    
    norm_hamm_dist[ind,0] = hamming_dist_1/dset_size
    norm_hamm_dist[ind,1] = hamming_dist_2/dset_size
    
    bout_edges_comp[ind,0] = hamming_dist_1 
    bout_edges_comp[ind,1] = hamming_dist_2
    
#------------------------------------------------------------------------------
#print summary
    
avg_num_bout_diff_1 = np.abs(np.mean(num_bouts_comp[:,0]))
avg_num_bout_diff_2 = np.abs(np.mean(num_bouts_comp[:,1]))

print('Average difference in number of bouts detected:')
print('User vs new code:')
print(avg_num_bout_diff_1)

print('Old code vs new code:')
print(avg_num_bout_diff_2)

avg_norm_hamm_dist_1 = np.mean(norm_hamm_dist[:,0])
avg_norm_hamm_dist_2 = np.mean(norm_hamm_dist[:,1])

print('\n')
print('Average difference in normalized hamming distance:')
print('User vs new code:')
print(avg_norm_hamm_dist_1)

print('Old code vs new code:')
print(avg_norm_hamm_dist_2)

#------------------------------------------------------------------------------    
    
            
# make plots to display results 
#------------------------------------------------------------------------------  
bin_centers = np.arange(0,len(data_filenames_sorted)) + 1
width = 0.35 

#------------------------------------------------------------------------------    
fig_numbouts, ax_numbouts  = plt.subplots(figsize=(17,4.5)) 

rects_numbouts_1 = ax_numbouts.bar(bin_centers-width, num_bouts_comp[:,0], 
                                   width, color='r')
rects_numbouts_2 = ax_numbouts.bar(bin_centers, num_bouts_comp[:,1], 
                                   width, color='b')

ax_numbouts.set_ylabel('Diff. in # of Bouts')
ax_numbouts.set_xlabel('Data File')    
ax_numbouts.set_title('Difference in Number of Bouts')
plt.legend(('new code - user annotations', 'new code - old code'))

ax_numbouts.set_xlim([0,len(data_filenames_sorted)+1]) 
ax_numbouts.set_xticks(np.arange(len(data_filenames_sorted)+1))
ax_numbouts.set_xticklabels(np.arange(len(data_filenames_sorted)+1))
fig_numbouts.set_tight_layout(True)    

#------------------------------------------------------------------------------    
fig_normhamm, ax_normhamm  = plt.subplots(figsize=(17,4.5)) 

rects_normhamm_1 = ax_normhamm.bar(bin_centers-width, norm_hamm_dist[:,0], 
                                   width, color='r')
rects_normhamm_2 = ax_normhamm.bar(bin_centers, norm_hamm_dist[:,1], 
                                   width, color='b')

ax_normhamm.set_ylabel('Norm. Hamming Dist. [Idx]')
ax_normhamm.set_xlabel('Data File')    
ax_normhamm.set_title('Normalized Hamming Distance')
plt.legend(('user annotations v. new code', 'old code v. new code'))   

ax_normhamm.set_xlim([0,len(data_filenames_sorted)+1]) 
ax_normhamm.set_xticks(np.arange(len(data_filenames_sorted)+1))
ax_normhamm.set_xticklabels(np.arange(len(data_filenames_sorted)+1))
fig_normhamm.set_tight_layout(True)    
                   
#------------------------------------------------------------------------------   
"""
fig_edgecomp, ax_edgecomp  = plt.subplots() 

rects_edgecomp_1 = ax_edgecomp.bar(bin_centers-width/2, bout_edges_comp[:,0], 
                                   width, color='r')
rects_edgecomp_2 = ax_edgecomp.bar(bin_centers+width/2, bout_edges_comp[:,1], 
                                   width, color='b')

ax_edgecomp.set_ylabel('Bout Edge Comp. [Idx]')
ax_edgecomp.set_xlabel('Data File')    
ax_edgecomp.set_title('Bout Edge Comparison')
plt.legend(('user annotations v. new code', 'old code v. new code'))                               

ax_edgecomp.set_xlim([0,len(data_filenames_sorted)+1]) 
fig_edgecomp.set_tight_layout(True)    
"""
            
back to top