{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "ec9cdb03", "metadata": {}, "outputs": [], "source": [ "# Stock imports\n", "%load_ext autoreload\n", "%autoreload 2\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from lee_et_al_2023.src import analysis\n", "from lee_et_al_2023.src import base\n", "from lee_et_al_2023.src import data_loaders\n", "base.set_visual_settings()" ] }, { "cell_type": "code", "execution_count": null, "id": "ce4130f4", "metadata": {}, "outputs": [], "source": [ "models, humans, panel, subjects = data_loaders.get_clean()\n", "models['GNN (shuffled baseline)'] = analysis.shuffle_df(models['GNN'], shuffle='molecules')" ] }, { "cell_type": "code", "execution_count": null, "id": "efa9540f", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "ott = analysis.fast_process(humans, models)" ] }, { "cell_type": "code", "execution_count": null, "id": "755f8e64", "metadata": {}, "outputs": [], "source": [ "fig_name = '2E'" ] }, { "cell_type": "code", "execution_count": null, "id": "8cf25192", "metadata": {}, "outputs": [], "source": [ "colormap = {\n", " 'RF': '#8da0cb',\n", " 'GNN': '#fc8d62',\n", " 'Human': '#66c2a5',\n", " 'GNN (shuffled baseline)': 'grey',\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "2c21f94e", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ "plt.figure(figsize=(8, 6))\n", "for group, color in colormap.items():\n", " plot = sns.ecdfplot(ott[group], color=color, stat='proportion', label=group)\n", "#plt.axvline(x=0, color='black', linestyle='--', linewidth=2)\n", "plt.axhline(y=0.5, color='black', linestyle='--', linewidth=2)\n", "plt.xlabel('Correlation to Panel Mean', fontsize=20)\n", "plt.ylabel('Proportion of molecules', fontsize=20)\n", "plt.xlim(-0.25, 0.75)\n", "plt.xticks(np.linspace(-0.25, 0.75, 5))\n", "\n", "plt.legend()\n", "for axis in ['bottom','left']:\n", " plot.spines[axis].set_linewidth(3)\n", " plot.spines[axis].set_edgecolor('black')\n", "for axis in ['top','right']:\n", " plot.spines[axis].set_visible(False)\n", "\n", "plot.grid(False)\n", "\n", "for group in ('Human', 'GNN', 'RF', 'GNN (shuffled baseline)'):\n", " plt.axvline(x=ott[group].median(), linestyle= ':', linewidth=1, color='black')\n", " if group in ('Human', 'GNN (shuffled baseline)'): continue\n", " plt.arrow(ott['Human'].median(), 0.5 + ott[group].median(),\n", " dx=ott[group].median() - ott['Human'].median(),\n", " dy=0,\n", " width=0.03,\n", " length_includes_head=True,\n", " head_width=0.03,\n", " head_length=0.01,\n", " color=colormap[group])" ] }, { "cell_type": "code", "execution_count": null, "id": "80903989", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "formats": "auto:light,ipynb", "notebook_metadata_filter": "-all" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }