Raw File
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "ZdiI7rm7X6Tl"
   },
   "outputs": [],
   "source": [
    "!pip install --q dash==2.0.0 jupyter-dash==0.4.0;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "9vwghilrXnsg",
    "outputId": "f46f1c4e-00da-4bd2-f778-01682c132362",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-12-08 18:41:20--  https://raw.githubusercontent.com/saigerutherford/3dbrainplots/main/cross_validation_10fold_evaluation.csv\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 566347 (553K) [text/plain]\n",
      "Saving to: ‘cross_validation_10fold_evaluation.csv’\n",
      "\n",
      "cross_validation_10 100%[===================>] 553.07K  --.-KB/s    in 0.03s   \n",
      "\n",
      "2021-12-08 18:41:21 (16.8 MB/s) - ‘cross_validation_10fold_evaluation.csv’ saved [566347/566347]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget -nc 'https://raw.githubusercontent.com/saigerutherford/brainviz-app/main/transfer_eval_metrics.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "HcDkI5Uh8La-"
   },
   "outputs": [],
   "source": [
    "from jupyter_dash import JupyterDash\n",
    "from dash import dcc, html, Input, Output, no_update\n",
    "import plotly.graph_objects as go\n",
    "import pandas as pd\n",
    "import plotly.io as pio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "XEIWWYjT-7Om"
   },
   "outputs": [],
   "source": [
    "data_path = 'transfer_eval_metrics.csv'\n",
    "\n",
    "df = pd.read_csv(data_path)\n",
    "df = df.sort_values(by=\"Label\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 671
    },
    "id": "o9-G8EYeXd4d",
    "outputId": "75060621-69d1-4f4c-aa77-09e8523859eb"
   },
   "outputs": [],
   "source": [
    "colors = {\n",
    "    'background': '#FFFAFA',\n",
    "    'text': '#000000'\n",
    "}\n",
    "\n",
    "fig = go.Figure(data=[\n",
    "    go.Scatter(\n",
    "        x=df[\"Label\"],\n",
    "        y=df[\"EV\"],\n",
    "        mode=\"markers\",\n",
    "        marker=dict(\n",
    "            colorscale='plasma',\n",
    "            color=df[\"EV\"],\n",
    "            size=12,\n",
    "            line=dict(width=0.7, color='Black'),\n",
    "            colorbar={\"title\": \"Explained<br>Variance\"},\n",
    "            reversescale=True,\n",
    "            opacity=0.8,\n",
    "        )\n",
    "    )\n",
    "])\n",
    "\n",
    "# turn off native plotly.js hover effects - make sure to use\n",
    "# hoverinfo=\"none\" rather than \"skip\" which also halts events.\n",
    "fig.update_traces(hoverinfo=\"none\", hovertemplate=None)\n",
    "\n",
    "fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')\n",
    "\n",
    "fig.update_layout(\n",
    "    xaxis=dict(title='ROI', showticklabels=False),\n",
    "    yaxis=dict(title='Explained Variance (EV)', showgrid=True),\n",
    "    plot_bgcolor=colors['background'],\n",
    "    paper_bgcolor=colors['background'],\n",
    "    font_color=colors['text'],\n",
    "    autosize=False,\n",
    "    width=1600,\n",
    "    height=800\n",
    ")\n",
    "\n",
    "app = JupyterDash(__name__)\n",
    "\n",
    "app.layout = html.Div([\n",
    "    dcc.Graph(id=\"graph\", figure=fig, clear_on_unhover=True),\n",
    "    dcc.Tooltip(id=\"graph-tooltip\"),\n",
    "])\n",
    "\n",
    "\n",
    "@app.callback(\n",
    "    Output(\"graph-tooltip\", \"show\"),\n",
    "    Output(\"graph-tooltip\", \"bbox\"),\n",
    "    Output(\"graph-tooltip\", \"children\"),\n",
    "    Input(\"graph\", \"hoverData\"),\n",
    ")\n",
    "def display_hover(hoverData):\n",
    "    if hoverData is None:\n",
    "        return False, no_update, no_update\n",
    "\n",
    "    # demo only shows the first point, but other points may also be available\n",
    "    pt = hoverData[\"points\"][0]\n",
    "    bbox = pt[\"bbox\"]\n",
    "    num = pt[\"pointNumber\"]\n",
    "\n",
    "    df_row = df.iloc[num]\n",
    "    img_src = df_row['IMG_URL']\n",
    "    name = df_row['Label']\n",
    "    ev = \"Explained Variance = \" + df_row['EV'].round(3).astype(str)\n",
    "\n",
    "\n",
    "    children = [\n",
    "        html.Div(children=[\n",
    "            html.Img(src=img_src, style={\"width\": \"100%\"}),\n",
    "            html.H2(f\"{name}\", style={\"color\": \"darkblue\"}),\n",
    "            html.P(f\"{ev}\"),\n",
    "        ],\n",
    "        style={'width': '350px', 'white-space': 'normal'})\n",
    "    ]\n",
    "\n",
    "    return True, bbox, children\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    app.run_server(debug=True, mode='inline')"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "cross_validation_10fold_eval_viz",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
back to top