https://github.com/osmoai/publications
Revision b36b51d687c5562713267577a305d01fa87e3302 authored by Rick Gerkin on 30 March 2023, 18:35:06 UTC, committed by Rick Gerkin on 30 March 2023, 18:35:06 UTC
1 parent c03b494
Tip revision: b36b51d687c5562713267577a305d01fa87e3302 authored by Rick Gerkin on 30 March 2023, 18:35:06 UTC
Import fix
Import fix
Tip revision: b36b51d
Figure1DEF.ipynb
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "63685727",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"from lee_et_al_2023.src import base"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28367ee7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"pc_df = pd.read_csv(base.DATA_PATH / 'fig_1_pc_data.csv')\n",
"\n",
"pca_gnn = pc_df[['gnn_pc1', 'gnn_pc2']].values\n",
"pca_fp = pc_df[['fp_pc1', 'fp_pc2']].values\n",
"pca_label = pc_df[['label_pc1', 'label_pc2']].values\n",
"pc_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4949112",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def plot_islands(bg_points,\n",
" bg_name,\n",
" *fg_specs,\n",
" alpha = 0.4,\n",
" z_limit = 15,\n",
" colors=None):\n",
" \"\"\"Plot islands in a background sea of points.\n",
"\n",
" This command executes matplotlib.pyplot commands as side effects.\n",
" Use plt.figure() to control where these outputs get generated.\n",
"\n",
" Args:\n",
" bg_points: Array of shape (n_points, 2) indicating x,y coordinates\n",
" of any background points.\n",
" bg_name: Name of background points\n",
" *fg_specs: Repeated island specifications. Allowable dict keys:\n",
" data - array of shape (n_points, 2)\n",
" label - str, name of the island\n",
" color - RGB, color of the island\n",
" scatter - bool, whether to add the point scatters\n",
" scatter_size - int, radius of scatter points in pixel\n",
" filled - bool, whether to fill the island\n",
" levels_to_plot - List[float], percentiles of the KDE to plot.\n",
" alpha: Transparency of island color fill.\n",
" z_limit: Plotting boundaries of the plot.\n",
" \"\"\"\n",
" default_fg_colors = colors or sns.color_palette('Set3', len(fg_specs))\n",
"\n",
" sns.scatterplot(x=bg_points[:, 0], y=bg_points[:, 1],\n",
" s=3, color='0.60', label=bg_name)\n",
" for fg_color, fg_spec in zip(default_fg_colors, fg_specs):\n",
" fg_color = fg_spec.get('color', fg_color)\n",
" fg_scatter = fg_spec.get('scatter', False)\n",
" fg_scatter_size = fg_spec.get('scatter_size', 5)\n",
" fg_filled = fg_spec.get('filled', False)\n",
" fg_level_to_plot = fg_spec.get('level_to_plot', 0.25)\n",
" x, y = fg_spec['data'][:, 0], fg_spec['data'][:, 1]\n",
" label = fg_spec['label']\n",
" if fg_scatter:\n",
" sns.scatterplot(x, y, s=fg_scatter_size, color=fg_color)\n",
" if fg_filled:\n",
" sns.kdeplot(x=x, y=y, color=fg_color, fill=True,\n",
" thresh=fg_level_to_plot, alpha=alpha, levels=2, bw_method=0.3)\n",
" else:\n",
" sns.kdeplot(x=x, y=y, color=fg_color, fill=False,\n",
" thresh=fg_level_to_plot, levels=2, bw_method=0.3)\n",
" # Generate the legend entry\n",
" if fg_filled:\n",
" plt.scatter([], [], marker='s', c=fg_color, label=label)\n",
" else:\n",
" plt.plot([], [], c=fg_color, linewidth=3, label=label)\n",
" plt.xlim([-z_limit - 0.6, z_limit + 0.6])\n",
" plt.ylim([-z_limit - 0.6, z_limit + 0.6])\n",
" plt.gca().set_aspect('equal', adjustable='box')\n",
" \n",
"\n",
"def plot_odor_islands(pca_space, z_limit=15):\n",
" plt.figure(figsize=(12, 8))\n",
" color_palette = sns.color_palette('Set3', 15)\n",
" fg_specs = []\n",
" i = 0\n",
" for main_group, subgroups in [('floral', ['muguet', 'lavender', 'jasmin']),\n",
" ('meaty', ['savory', 'beefy', 'roasted']),\n",
" ('alcoholic', ['cognac', 'fermented', 'winey']),\n",
" ]:\n",
" main_embeddings = pca_space[pc_df[main_group]]\n",
" fg_specs.append({'data': main_embeddings,\n",
" 'filled': True,\n",
" 'scatter': False,\n",
" 'level_to_plot': 0.1,\n",
" 'label': main_group.capitalize()})\n",
"\n",
" for subgroup in subgroups:\n",
" island_embeddings = pca_space[pc_df[subgroup]]\n",
" fg_specs.append({'data': island_embeddings,\n",
" 'filled': False,\n",
" 'scatter': False,\n",
" 'level_to_plot': 0.2,\n",
" 'label': subgroup.capitalize()})\n",
" plot_islands(pca_space, None, *fg_specs, colors=color_palette[i:i+4], z_limit=z_limit)\n",
" i += 5\n",
" fg_specs = []\n",
" plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))\n",
"\n",
"plot_odor_islands(pca_gnn, z_limit=15)\n",
"plt.title('GNN Embeddings')\n",
"plot_odor_islands(pca_fp, z_limit=10)\n",
"plt.title('Fingerprints')\n",
"plot_odor_islands(pca_label, z_limit=8)\n",
"plt.title('True Labels')\n",
"plt.show()"
]
}
],
"metadata": {
"jupytext": {
"formats": "auto:light,ipynb"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

Computing file changes ...