https://github.com/comp-neural-circuits/adaptation-of-drosophila-larva-foraging
Tip revision: d501cd1f0df3df2fda0f286d58e98ee0f72b4898 authored by Dylan Festa on 16 December 2022, 13:50:38 UTC
Added script to plot single larvea paths
Added script to plot single larvea paths
Tip revision: d501cd1
preprocess_data.py
# preprocess_data.py
# Reads the data, applies the RDP algorithm and saves it as a single dataframe.
#%%%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pathlib # to find current path
from tqdm import tqdm
import sys
import regex
import itertools
from rdp import rdp
path_thisfile = pathlib.Path(__file__).parent.resolve()
path_data = path_thisfile / 'Data'
if not path_data.is_dir():
raise ValueError(f"\nDirectory `{path_data}` does not exist!")
arenas = ['Homogeneous', 'Two_patches', 'Nonnutrient_patches', 'Eight_patches']
substrates = ['Sucrose','Yeast','Agar','Gel','Apple_juice']
larvae = ['Rover','Sitter','Anosmic']
for dd in arenas:
dfull = path_data/dd
if not dfull.is_dir():
raise ValueError(f"\nDirectory `{dfull}` does not exist, please import the data first!")
all_combinations = itertools.product(arenas,substrates,larvae)
def path_in_data(combination):
arena,sub,lar = combination
fullpath = path_data / arena / sub / lar
if fullpath.is_dir():
return fullpath
else:
return pathlib.Path('')
all_paths =[]
path_elements = []
for combination in all_combinations:
fullpath = path_in_data(combination)
if len(str(fullpath)) > 5:
all_paths.append(fullpath)
path_elements.append(combination)
n_sessions = len(all_paths)
#%%%
def clean_format_data(
data_as_csv, frames_list, exp_ind, factor_frames, factor_frames2):
"""
Just cleans the information in the experimental spreadsheets and organizes
the data in a dataframe
Parameters
----------
data_as_csv: DataFrame
contains output of FIMTrack software
frames_list: list
just stores the info about number of frames in each experiment
exp_ind: int
index of experiment (1, 2, 3)
factor_frames: int
number to divide the spreadsheet length to keep only coords info
factor_frames2: int
depending on how the table was generated, we redefine factor_frames
Returns
-------
data_clean: DataFrame
organized dataframe with clean data
frames_list_aux: list
accumulates info about number of frames per experiment
"""
data_as_csv = data_as_csv.select_dtypes(exclude=['object'])
#
# this a trick because some tables have all the info available and others
# have only information about x,y coordinates... I know that there are no tables with less
# than 4000 frames :)
# I also check if the total number of coordinates is an even number,
# because half the coordinates are x and the other half is y.
#
frames = int(len(data_as_csv)/factor_frames)
if frames < 4000:
frames = int(len(data_as_csv)/factor_frames2)
data_clean = data_as_csv[0:frames]
# drop columns that have ALL NaN values
data_clean = data_clean.dropna(how='all',axis=1)
name_list = []
for ii in range(len(data_clean.columns)):
name_list.append('exp' + str(exp_ind) + '_larva_' + str(ii))
data_clean.columns = name_list
frame_list_aux = np.append(frames_list, len(name_list)*[frames], axis = 0)
if frames%2 != 0:
raise ValueError('Error, odd number of frames on table' + str(exp_ind))
return(data_clean, frame_list_aux)
#%%%
def get_exp_and_larva_str(k:str):
rr=regex.search('^exp([0-9]*)_larva_([0-9]*)$',k)
assert not (rr is None) , f"wrong string value! Cannot parse this: {k} "
return 'exp'+rr.group(1),'L'+rr.group(2)
def read_single_dataset(arena,substrate,larva,datapath,scale=8.0,dt=0.5):
#scale is number of pixels per mm
#dt = 0.5 # frame rate = 2 Hz
factor_frames = 15 # when the table is bigger and has more lines than necessary
factor_frames2 = 1 # when the table is small
print(f"""
############################
Now reading and processing data from:
arena : {arena}
substrate : {substrate}
larva type : {larva}
##############################
""")
if not datapath.is_dir():
raise ValueError(f"\nDirectory `{datapath}` does not exist!")
rdp_epsilon = 10/scale if 'Yeast' in substrate else 20/scale
data_total = []
frames_list = []
data_aux = pd.read_csv(datapath / 'table1.csv', header = None)
data1, frames_list = clean_format_data(data_aux, frames_list, 1, factor_frames, factor_frames2)
data_aux = pd.read_csv(datapath / 'table2.csv', header = None)
data2, frames_list = clean_format_data(data_aux, frames_list, 2, factor_frames, factor_frames2)
path_tab3 = datapath / 'table3.csv'
all_three = path_tab3.is_file()
if not all_three:
print('\n Warning! This experiment only has 2 sessions instead of 3 \n')
if all_three:
data_aux = pd.read_csv(path_tab3, header = None)
data3, frames_list = clean_format_data(data_aux, frames_list, 3, factor_frames, factor_frames2)
data_total = pd.concat([data1, data2, data3], axis=1)
else:
data_total = pd.concat([data1, data2], axis=1)
frames_info = np.transpose(pd.DataFrame(frames_list))
# read patch centers in two patches experiments
is_two_patches = (arena == 'Two_patches') or (arena == 'Nonnutrient_patches')
is_eight_patches = arena == 'Eight_patches'
if is_two_patches:
patch_files = ['ROI_coord1.csv', 'ROI_coord2.csv','ROI_coord3.csv'] if all_three \
else ['ROI_coord1.csv', 'ROI_coord2.csv']
patch_info = [pd.read_csv(datapath / pf, header=None) for pf in patch_files]
if is_eight_patches:
patch_files = ['ROI_coord1.csv', 'ROI_coord2.csv','ROI_coord3.csv'] if all_three \
else ['ROI_coord1.csv', 'ROI_coord2.csv']
patch_info = [pd.read_csv(datapath / pf, skiprows=[0],header=None) for pf in patch_files]
# experiment number and larva number
mult = pd.MultiIndex.from_tuples([get_exp_and_larva_str(s) for s in data_total.columns.values])
data_total.columns=mult
frames_info.columns=mult
dfout = pd.DataFrame(index=mult,columns=['x','y','nframes','time',\
'simple_trajectory','idx_turn_points','rdp_mask','rdp_epsilon','patch_info'])
for mm in mult:
nfra = int(frames_info.loc[0,mm]/2)
xy = data_total.loc[:,mm].values/scale # convert to millimeters here
x = xy[0:nfra]
y = xy[nfra:2*nfra] # removes extra tail of NaN added by first dataframe operation
coord_nan = np.isnan(x)
assert not np.all(np.isnan(x)) , 'All values are NaNs... why?'
time_points = np.arange(0,nfra,1)
time = time_points * dt
total_times = time_points[~coord_nan]
coord_fix = np.column_stack([x[~coord_nan],y[~coord_nan]])
simple_traj = rdp(coord_fix, epsilon=rdp_epsilon)
rdp_mask = rdp(coord_fix, epsilon=rdp_epsilon, return_mask = True)
idx_turn_points = total_times[rdp_mask]
dfout.loc[mm,'nframes'] = nfra
dfout.loc[mm,'x'] = x
dfout.loc[mm,'y'] = y
dfout.loc[mm,'time'] = time
dfout.loc[mm,'simple_trajectory'] = simple_traj
dfout.loc[mm,'idx_turn_points'] = idx_turn_points
dfout.loc[mm,'rdp_mask'] = rdp_mask
dfout.loc[mm,'rdp_epsilon'] =rdp_epsilon
# add patch coordinates and radii
if is_two_patches or is_eight_patches:
if mm[0] == 'exp1':
my_patch = patch_info[0]
elif mm[0] == 'exp2':
my_patch = patch_info[1]
elif mm[0] == 'exp3':
my_patch = patch_info[2]
else:
raise Exception('Wrong experiment parameter {} '.format(mm[0]))
# also, convert to millimiters
if is_two_patches:
dfout.loc[mm,'patch_info'] =\
[[my_patch[1][k]/scale, my_patch[2][k]/scale, my_patch[3][k]/(2*np.pi*scale)] for k in range(2)]
elif is_eight_patches:
dfout.loc[mm,'patch_info'] =\
[[my_patch[1][k]/scale, my_patch[2][k]/scale, my_patch[3][k]/(2*np.pi*scale)] for k in range(8)]
if "Nonnutrient" in arena:
arena= "Nonnutrient"
elif "Two" in arena:
arena= "Two"
elif "Eight" in arena:
arena= "Eight"
# add extra indexes for environment, substrate and larva
supermult = pd.MultiIndex.from_tuples([ \
(arena,substrate,larva) + others for others in dfout.index.values])
dfout.index = supermult
return dfout
#%%
print("""
------------------------------------------
Starting to read and preprocess all data.
*WARNING* this might take up to 30 min.
------------------------------------------
""")
df_read = []
for k in tqdm(range(n_sessions)): # n_sessions
arena, substrate, larva = path_elements[k]
fullpath = all_paths[k]
df_read.append(read_single_dataset(arena, substrate, larva, fullpath))
pass
data_all = pd.concat(df_read)
savename = path_thisfile / 'data_all_rdp.pkl.xz'
print(f"""
**READING DONE!**
Now saving as {savename}
""")
data_all.to_pickle(savename)