https://github.com/osmoai/publications
Revision b36b51d687c5562713267577a305d01fa87e3302 authored by Rick Gerkin on 30 March 2023, 18:35:06 UTC, committed by Rick Gerkin on 30 March 2023, 18:35:06 UTC
1 parent c03b494
Tip revision: b36b51d687c5562713267577a305d01fa87e3302 authored by Rick Gerkin on 30 March 2023, 18:35:06 UTC
Import fix
Import fix
Tip revision: b36b51d
Figure3B.ipynb
{
"cells": [
{
"cell_type": "markdown",
"id": "398cf110",
"metadata": {},
"source": [
"# GNN performance as a function of data\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8ba5c1db",
"metadata": {},
"outputs": [],
"source": [
"# Stock imports\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from lee_et_al_2023.src import analysis\n",
"from lee_et_al_2023.src import base\n",
"from lee_et_al_2023.src import data_loaders\n",
"base.set_visual_settings()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fa74a4d2",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/rick/code/ratatouille/qian_et_al_2023/src/data_loaders.py:77: FutureWarning: The default value of numeric_only in DataFrameGroupBy.mean is deprecated. In a future version, numeric_only will default to False. Either specify numeric_only or select only columns which should be valid for the function.\n",
" panel = humans.groupby('RedJade Code').mean().loc[mol_codes, base.MONELL_CLASS_LIST]\n"
]
}
],
"source": [
"# Canonical way to load (most of) the data... add lines as needs to include other pieces\n",
"models, humans, panel, subjects = data_loaders.get_clean()"
]
},
{
"cell_type": "markdown",
"id": "0ed401ab",
"metadata": {},
"source": [
"### Run the code to prepare the data to be visualized"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "3853a409",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"transpose_ott = analysis.fast_process(humans, models, axis=1)\n",
"corr_df = transpose_ott.groupby('index').agg(np.nanmedian)\n",
"corr_df['Test data counts'] = (panel > 0.7).sum(axis=0)\n",
"training_class_counts = pd.read_csv(base.DATA_PATH / \"training_class_counts.csv\")\n",
"training_class_counts['label'] = training_class_counts['label'].apply(\n",
" lambda l: {\"jasmin\": \"jasmine\"}.get(l, l).title())\n",
"training_class_counts = training_class_counts.set_index('label')\n",
"training_class_counts = training_class_counts.loc[base.MONELL_CLASS_LIST]\n",
"corr_df['Training data counts'] = training_class_counts"
]
},
{
"cell_type": "markdown",
"id": "3dde87b3",
"metadata": {},
"source": [
"### Name the figure"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f48f5f4",
"metadata": {},
"outputs": [],
"source": [
"fig_name = '3B'"
]
},
{
"cell_type": "markdown",
"id": "b0caf3ff",
"metadata": {},
"source": [
"### Run the code to make the figure and save the figure to friendly formats"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "39db34cc",
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(8,8))\n",
"plot = sns.scatterplot(data=corr_df.reset_index(), x='Training data counts', y='GNN',\n",
" size='Test data counts', sizes=(20, 400), color='#fc8d62')\n",
"plot.legend(loc='lower right', title='Test data counts')\n",
"plot.set(xscale=\"log\")\n",
"plt.xlabel('Training Data Counts', fontsize=20)\n",
"plt.xlim(10**1.5, 10**3.5)\n",
"plt.ylabel('GNN Correlation with Panel Mean', fontsize=20)\n",
"plt.xticks(10 ** np.linspace(1.5, 3.5, 5), fontsize=16)\n",
"plt.yticks(np.linspace(-0.1, 0.5, 7), fontsize=16)\n",
"for label in ('Camphoreous', 'Fishy', 'Cooling', 'Sulfurous', 'Roasted', 'Fruity', 'Floral', 'Sweet', 'Green',\n",
" 'Ozone', 'Sharp', 'Waxy', 'Medicinal', 'Musty', 'Fermented', 'Garlic', 'Alcoholic', 'Musk', 'Meaty'):\n",
" row = corr_df.loc[label]\n",
" offset = (5, 5)\n",
" if label == 'Floral':\n",
" offset = (5, -15)\n",
" if label == 'Fishy':\n",
" offset = (5, -15)\n",
" if label == 'Camphoreous':\n",
" offset = (5, -5)\n",
" plt.annotate(label, (row['Training data counts'], row['GNN']), xytext=offset, textcoords='offset points',\n",
" fontsize=20)\n",
"\n",
"for axis in ['bottom','left']:\n",
" plot.spines[axis].set_linewidth(3)\n",
" plot.spines[axis].set_edgecolor('black')\n",
"for axis in ['top','right']:\n",
" plot.spines[axis].set_visible(False)\n",
"\n",
"plot.grid(False)\n"
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,py:light"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

Computing file changes ...