{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# T4 - Optimization of Infection Duration" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we'll introduce the concept of parameter optimization against an optimization function -- in this case, maximizing the mean infection duration in naive infectious challenges by changing the antigenic switching rate parameter." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We'll start by defining a function to perform multi-individual challenges (similar to the last tutorial)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "import numpy as np\nimport pandas as pd\nimport xarray as xr\nimport matplotlib.pyplot as plt\n\nfrom emodlib.malaria import IntrahostComponent, create_config\n\n\ndef multiple_challenges(config, n_people, duration):\n \n asexuals = np.zeros((n_people, duration))\n gametocytes = np.zeros((n_people, duration))\n pp = [IntrahostComponent.create(config) for _ in range(n_people)]\n _ = [p.challenge() for p in pp]\n\n for t in range(duration):\n for i, p in enumerate(pp):\n p.update(dt=1)\n asexuals[i, t] = p.parasite_density\n gametocytes[i, t] = p.gametocyte_density\n \n da = xr.DataArray(dims=('individual', 'time', 'channel'),\n coords=(range(n_people), range(duration), ['parasite_density', 'gametocyte_density']))\n \n da.loc[dict(channel='parasite_density')] = asexuals\n da.loc[dict(channel='gametocyte_density')] = gametocytes\n \n return da" }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We'll also define a helper function to determine the duration of challenge infections based on the time index of the last non-zero parasite density" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def get_last_nonzero_by_row(A):\n", " \"\"\" https://stackoverflow.com/a/39959511 \"\"\"\n", " return np.arange(A.shape[0]), A.shape[1] - 1 - (A[:, ::-1]!=0).argmax(1)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": "Then we'll define our objective function:\n- log-uniform sampling of the antigen switching rate within a defined range\n- creating a model config with that parameter value\n- running the multi-individual challenge time-series\n- returning the mean infection-duration value" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "def objective(trial):\n \n n_people = 50\n duration = 500\n\n antigen_switch_rate = trial.suggest_float(\"Antigen_Switch_Rate\", 5e-10, 5e-8, log=True)\n config = create_config({'infection_params': {'Antigen_Switch_Rate': antigen_switch_rate}})\n \n da = multiple_challenges(config, duration=duration, n_people=n_people)\n infection_durations = get_last_nonzero_by_row(da.sel(channel='parasite_density').values)[1]\n \n return infection_durations.mean()" }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we'll create an optuna study and run a number of optimization trials..." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m[I 2023-05-26 18:53:14,086]\u001b[0m A new study created in memory with name: maximize_duration\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,124]\u001b[0m Trial 0 finished with value: 249.32 and parameters: {'Antigen_Switch_Rate': 1.1915470318960563e-09}. Best is trial 0 with value: 249.32.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,171]\u001b[0m Trial 1 finished with value: 399.04 and parameters: {'Antigen_Switch_Rate': 2.750348541347495e-09}. Best is trial 1 with value: 399.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,218]\u001b[0m Trial 2 finished with value: 321.86 and parameters: {'Antigen_Switch_Rate': 1.2742403528792963e-08}. Best is trial 1 with value: 399.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,257]\u001b[0m Trial 3 finished with value: 291.14 and parameters: {'Antigen_Switch_Rate': 1.4323327599950657e-09}. Best is trial 1 with value: 399.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,303]\u001b[0m Trial 4 finished with value: 351.2 and parameters: {'Antigen_Switch_Rate': 7.727469068951856e-09}. Best is trial 1 with value: 399.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,351]\u001b[0m Trial 5 finished with value: 280.38 and parameters: {'Antigen_Switch_Rate': 4.6148422747495935e-08}. Best is trial 1 with value: 399.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,383]\u001b[0m Trial 6 finished with value: 170.26 and parameters: {'Antigen_Switch_Rate': 5.363872536158119e-10}. Best is trial 1 with value: 399.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,429]\u001b[0m Trial 7 finished with value: 359.2 and parameters: {'Antigen_Switch_Rate': 6.2266808194259075e-09}. Best is trial 1 with value: 399.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,475]\u001b[0m Trial 8 finished with value: 342.46 and parameters: {'Antigen_Switch_Rate': 8.715215269434397e-09}. Best is trial 1 with value: 399.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,508]\u001b[0m Trial 9 finished with value: 184.14 and parameters: {'Antigen_Switch_Rate': 6.224321379915857e-10}. Best is trial 1 with value: 399.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,557]\u001b[0m Trial 10 finished with value: 400.04 and parameters: {'Antigen_Switch_Rate': 2.7055736726897753e-09}. Best is trial 10 with value: 400.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,603]\u001b[0m Trial 11 finished with value: 386.52 and parameters: {'Antigen_Switch_Rate': 2.7040378193028815e-09}. Best is trial 10 with value: 400.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,650]\u001b[0m Trial 12 finished with value: 399.72 and parameters: {'Antigen_Switch_Rate': 3.0708647379292074e-09}. Best is trial 10 with value: 400.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,697]\u001b[0m Trial 13 finished with value: 393.78 and parameters: {'Antigen_Switch_Rate': 3.2067748660456205e-09}. Best is trial 10 with value: 400.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,744]\u001b[0m Trial 14 finished with value: 394.92 and parameters: {'Antigen_Switch_Rate': 4.169512643087303e-09}. Best is trial 10 with value: 400.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,786]\u001b[0m Trial 15 finished with value: 332.96 and parameters: {'Antigen_Switch_Rate': 1.6758842380026993e-09}. Best is trial 10 with value: 400.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,832]\u001b[0m Trial 16 finished with value: 385.66 and parameters: {'Antigen_Switch_Rate': 4.51732112640819e-09}. Best is trial 10 with value: 400.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,876]\u001b[0m Trial 17 finished with value: 363.84 and parameters: {'Antigen_Switch_Rate': 2.1226773056736917e-09}. Best is trial 10 with value: 400.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,913]\u001b[0m Trial 18 finished with value: 236.3 and parameters: {'Antigen_Switch_Rate': 9.94876719956596e-10}. Best is trial 10 with value: 400.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:14,957]\u001b[0m Trial 19 finished with value: 370.66 and parameters: {'Antigen_Switch_Rate': 1.814725113148747e-09}. Best is trial 10 with value: 400.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:15,007]\u001b[0m Trial 20 finished with value: 396.76 and parameters: {'Antigen_Switch_Rate': 4.011620108155682e-09}. Best is trial 10 with value: 400.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:15,053]\u001b[0m Trial 21 finished with value: 391.16 and parameters: {'Antigen_Switch_Rate': 2.86932123142975e-09}. Best is trial 10 with value: 400.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:15,098]\u001b[0m Trial 22 finished with value: 375.62 and parameters: {'Antigen_Switch_Rate': 2.2297814196138632e-09}. Best is trial 10 with value: 400.04.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:15,145]\u001b[0m Trial 23 finished with value: 409.2 and parameters: {'Antigen_Switch_Rate': 2.9439327747062316e-09}. Best is trial 23 with value: 409.2.\u001b[0m\n", "\u001b[32m[I 2023-05-26 18:53:15,182]\u001b[0m Trial 24 finished with value: 238.94 and parameters: {'Antigen_Switch_Rate': 9.439553321717197e-10}. Best is trial 23 with value: 409.2.\u001b[0m\n" ] } ], "source": [ "import optuna\n", "\n", "study = optuna.create_study(study_name='maximize_duration', direction='maximize')\n", "study.optimize(objective, n_trials=25)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Antigen_Switch_Rate': 2.9439327747062316e-09}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "study.best_params # parameter for longest avg duration" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "409.2" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "study.best_value # longest avg duration" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now let's look at a few default visualizations of the optuna study:\n", "- the convergence towards maximizing the objective value over successive trials\n", "- the value of the objective as a function of our 1-d parameter range explored" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "mode": "markers", "name": "Mean infection duration (d)", "type": "scatter", "x": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 ], "y": [ 249.32, 399.04, 321.86, 291.14, 351.2, 280.38, 170.26, 359.2, 342.46, 184.14, 400.04, 386.52, 399.72, 393.78, 394.92, 332.96, 385.66, 363.84, 236.3, 370.66, 396.76, 391.16, 375.62, 409.2, 238.94 ] }, { "name": "Best Value", "type": "scatter", "x": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 ], "y": [ 249.32, 399.04, 399.04, 399.04, 399.04, 399.04, 399.04, 399.04, 399.04, 399.04, 400.04, 400.04, 400.04, 400.04, 400.04, 400.04, 400.04, 400.04, 400.04, 400.04, 400.04, 400.04, 400.04, 409.2, 409.2 ] } ], "layout": { "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "text": "Optimization History Plot" }, "xaxis": { "title": { "text": "Trial" } }, "yaxis": { "title": { "text": "Mean infection duration (d)" } } } }, "text/html": [ "