# --- # jupyter: # jupytext: # text_representation: # extension: .py # format_name: light # format_version: '1.5' # jupytext_version: 1.4.0 # kernelspec: # display_name: Python 3 # language: python # name: python3 # --- # + import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import os import sys sys.path.append('../src/') from theory import geodist_counts from plotting import GeoDistPlot # %matplotlib inline # + plt.rcParams['font.sans-serif'] = "Arial" plt.rcParams['figure.facecolor'] = "w" plt.rcParams['figure.autolayout'] = True from mpl_toolkits.axes_grid1 import make_axes_locatable # Deboxing a particular axis def debox(ax): ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) figdir = '../plots/figure4/' os.makedirs(figdir, exist_ok=True) # - # Setting up simulation parameters for divergence times ... n = 200; j=8 t1=0.05 t2=0.5 theta=1e-3 # %%time # Compute Joint SFS with a divergence time of 0.1 jsfs1 = geodist_counts(t=t1, n=n, k_rare=j, c=0.0, theta=theta) cur_geodist1 = GeoDistPlot() cur_geodist1._add_data_jsfs(jsfs=jsfs1) cur_geodist1._add_poplabels_manual(poplabels=np.array(['A','B'])) cur_geodist1.ncat = cur_geodist1.ncat -1 cur_geodist1._add_cmap(str_labels=['u','R','C']) cur_geodist1._filter_data(max_freq=0.0) # %%time # Compute Joint SFS with a divergence time of 0.5 jsfs2 = geodist_counts(t=t2, n=n, k_rare=j, c=0.0, theta=theta) cur_geodist2 = GeoDistPlot() cur_geodist2._add_data_jsfs(jsfs=jsfs2) cur_geodist2._add_poplabels_manual(poplabels=np.array(['A','B'])) cur_geodist2.ncat = cur_geodist2.ncat -1 cur_geodist2._add_cmap(str_labels=['u','R','C']) cur_geodist2._filter_data(max_freq=0.0) # %%time # Compute Joint SFS with a divergence time of 0.1 jsfs3 = geodist_counts(t=t1, n=n, k_rare=j, c=0.025, theta=theta) cur_geodist3 = GeoDistPlot() cur_geodist3._add_data_jsfs(jsfs=jsfs3) cur_geodist3._add_poplabels_manual(poplabels=np.array(['A','B'])) cur_geodist3.ncat = cur_geodist3.ncat -1 cur_geodist3._add_cmap(str_labels=['u','R','C']) cur_geodist3._filter_data(max_freq=0.0) # ## Developing Plotting Functions for Box 2 # Plotting Divergence Model def plot_div_model(ax, tdiv=0.3, admix=0.0, vbar=False, vbarlabel=None, **kwargs): # Plotting the underlying functions ax.plot([0.25,0.35], [0.0,tdiv], **kwargs) ax.plot([0.3,0.4], [0.0,tdiv], **kwargs) ax.plot([0.5,0.4], [0.0,tdiv], **kwargs) ax.plot([0.55,0.45], [0.0,tdiv], **kwargs) ax.plot([0.35,0.35], [tdiv,0.6], **kwargs) ax.plot([0.45,0.45], [tdiv,0.6], **kwargs) # Plotting the limits if vbar: x = 0.2 capsize=0.01 ax.plot([x,x], [0.0, tdiv], color='black') ax.plot([x-capsize, x+capsize], [0.0,0.0], color='black') ax.plot([x-capsize, x+capsize], [tdiv,tdiv], color='black') ax.text(0.2, tdiv/2, vbarlabel, horizontalalignment='right', verticalalignment='center', fontsize=10) if admix > 0: # Drawing an admixture bar ax.arrow(0.325, tdiv/4., 0.475-0.325, 0, head_width=0.01, head_length=0.01, linewidth=1, color='b', length_includes_head=True) ax.arrow(0.475, tdiv/4., -(0.475-0.325), 0, head_width=0.01, head_length=0.01, linewidth=1, color='b', length_includes_head=True) ax.text(0.4, -0.005, r'%0.0f %%' % (100*admix), ha='center', va='top', fontsize=10) # Defining a vertical bar for the divergence times ax.set_xlim(0.1,0.7) ax.axis('off') # + # Using a GridSpec for this plot fig = plt.figure(figsize=(3.3,6)) grid = plt.GridSpec(3, 3, hspace=0.1, wspace=0.2) top_left_ax = fig.add_subplot(grid[0, 0]) top_center_ax = fig.add_subplot(grid[0, 1]) top_right_ax = fig.add_subplot(grid[0, 2]) plot_div_model(top_right_ax, tdiv=t1, vbar=False, lw=2, admix=0.025, color='black', solid_capstyle='round') plot_div_model(top_center_ax, tdiv=t1, vbar=True, vbarlabel = '0.05 ', lw=2, color='black', solid_capstyle='round') plot_div_model(top_left_ax, tdiv=t2, vbar=True, vbarlabel = r'$T/2N= 0.5$ ', lw=2, color='black', solid_capstyle='round') top_center_ax.set_title('Recent\n Divergence', fontsize=12) top_right_ax.set_title('Recent\n Admixture', fontsize=12) top_left_ax.set_title('Deep\n Divergence', fontsize=12) # # Plotting GeoDist Results bottom_left_ax = fig.add_subplot(grid[1:, 0]) bottom_center_ax = fig.add_subplot(grid[1:, 1]) bottom_right_ax = fig.add_subplot(grid[1:, 2]) cur_geodist1.plot_geodist(bottom_center_ax); cur_geodist2.plot_geodist(bottom_left_ax); cur_geodist3.plot_geodist(bottom_right_ax); bottom_left_ax.set_ylabel(r'Cumulative fraction of variants', fontsize=14); bottom_right_ax.set_xticklabels([]); bottom_right_ax.set_xticks([]); bottom_center_ax.set_xticklabels([]); bottom_center_ax.set_xticks([]); bottom_left_ax.set_xticklabels([]); bottom_left_ax.set_xticks([]); bottom_center_ax.set_yticklabels([]); bottom_right_ax.set_yticklabels([]); plt.savefig(figdir + 'fig4A.pdf', bbox_inches='tight')