{ "cells": [ { "cell_type": "markdown", "id": "zuvs8pz2hlh", "source": "# Grid Search\n\nThis notebook demonstrates the unified `fit_grid` function for hyperparameter optimization.", "metadata": {} }, { "cell_type": "markdown", "id": "Hbol", "metadata": { "marimo": { "config": { "hide_code": true } } }, "source": "## Unified Grid Search with `fit_grid`\n\nThis notebook demonstrates the unified `fit_grid` function which provides a single\ninterface for all grid search combinations. The function automatically detects\nwhich axes to search based on the inputs provided.\n\n### Grid Axes\n\nThe `fit_grid` function can search over three axes:\n\n1. **Datasets**: Multiple datasets to fit in parallel (detected by x.ndim == 2)\n2. **L1 Penalties**: Regularization strength for Lasso regression\n3. **Num Terms**: Number of terms in the metalog distribution\n\nThis gives us 8 possible combinations (2\u00b3), all handled by a single function." }, { "cell_type": "code", "execution_count": 1, "id": "MJUe", "metadata": { "execution": { "iopub.execute_input": "2025-12-26T21:57:49.596487Z", "iopub.status.busy": "2025-12-26T21:57:49.596295Z", "iopub.status.idle": "2025-12-26T21:57:50.686207Z", "shell.execute_reply": "2025-12-26T21:57:50.685776Z" } }, "outputs": [], "source": [ "import jax.numpy as jnp\n", "from scipy.stats import beta, gamma, lognorm, norm, weibull_min\n", "\n", "from metalog_jax.base import (\n", " MetalogBoundedness,\n", " MetalogFitMethod,\n", " MetalogInputData,\n", " MetalogParameters,\n", ")\n", "from metalog_jax.grid_search import find_best_config, fit_grid\n", "from metalog_jax.utils import DEFAULT_Y" ] }, { "cell_type": "markdown", "id": "vblA", "metadata": { "marimo": { "config": { "hide_code": true } } }, "source": [ "---\n", "## Case 1: Single Dataset, No Grid\n", "\n", "The simplest case: fit a single dataset with fixed parameters.\n", "Returns scalar KS distance and 1D coefficient array." ] }, { "cell_type": "code", "execution_count": 2, "id": "bkHC", "metadata": { "execution": { "iopub.execute_input": "2025-12-26T21:57:50.687330Z", "iopub.status.busy": "2025-12-26T21:57:50.687235Z", "iopub.status.idle": "2025-12-26T21:57:50.928310Z", "shell.execute_reply": "2025-12-26T21:57:50.927788Z" } }, "outputs": [], "source": [ "# Generate sample data from a beta distribution\n", "dist_beta = beta(a=2, b=5)\n", "samples_beta = dist_beta.rvs(size=200, random_state=42)\n", "\n", "# Create input data\n", "data_single = MetalogInputData.from_values(samples_beta, DEFAULT_Y, False)" ] }, { "cell_type": "code", "execution_count": 3, "id": "lEQa", "metadata": { "execution": { "iopub.execute_input": "2025-12-26T21:57:50.929472Z", "iopub.status.busy": "2025-12-26T21:57:50.929396Z", "iopub.status.idle": "2025-12-26T21:57:51.714621Z", "shell.execute_reply": "2025-12-26T21:57:51.714238Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KS Distance: 0.0286\n", "Coefficients shape: (7,)\n", "Coefficients: [ -1.08998096 -1.48153417 0.04950236 8.30422784 -0.78536482\n", " 6.29446539 -18.91331443]\n" ] } ], "source": [ "# Configure parameters\n", "params_ols = MetalogParameters(\n", " boundedness=MetalogBoundedness.BOUNDED,\n", " lower_bound=0,\n", " upper_bound=1,\n", " method=MetalogFitMethod.OLS,\n", " num_terms=7,\n", ")\n", "\n", "# Fit single dataset - returns scalar results\n", "result_single = fit_grid(data_single.x, data_single.y, params_ols)\n", "\n", "print(f\"KS Distance: {float(result_single.ks_dist):.4f}\")\n", "print(f\"Coefficients shape: {result_single.metalog.a.shape}\")\n", "print(f\"Coefficients: {result_single.metalog.a}\")" ] }, { "cell_type": "markdown", "id": "PKri", "metadata": { "marimo": { "config": { "hide_code": true } } }, "source": [ "---\n", "## Case 2: Single Dataset, Num Terms Grid\n", "\n", "Search over different numbers of terms to find optimal complexity.\n", "Returns 1D array of results indexed by num_terms." ] }, { "cell_type": "code", "execution_count": 4, "id": "Xref", "metadata": { "execution": { "iopub.execute_input": "2025-12-26T21:57:51.716687Z", "iopub.status.busy": "2025-12-26T21:57:51.716589Z", "iopub.status.idle": "2025-12-26T21:57:52.802550Z", "shell.execute_reply": "2025-12-26T21:57:52.802150Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KS Distances shape: (5,)\n", "Coefficients shape: (5, 11)\n", "\n", "Results by num_terms:\n", " 3 terms: KS = 0.0667\n", " 5 terms: KS = 0.0381\n", " 7 terms: KS = 0.0286\n", " 9 terms: KS = 0.0286\n", " 11 terms: KS = 0.0286\n", "\n", "Best: 7 terms (KS = 0.0286)\n" ] } ], "source": [ "# Configure base parameters\n", "params_num_terms = MetalogParameters(\n", " boundedness=MetalogBoundedness.BOUNDED,\n", " lower_bound=0,\n", " upper_bound=1,\n", " method=MetalogFitMethod.OLS,\n", " num_terms=3, # Will be overridden\n", ")\n", "\n", "# Grid over num_terms\n", "num_terms_grid = [3, 5, 7, 9, 11]\n", "\n", "result_num_terms = fit_grid(\n", " data_single.x, data_single.y, params_num_terms, num_terms=num_terms_grid\n", ")\n", "\n", "print(f\"KS Distances shape: {result_num_terms.ks_dist.shape}\")\n", "print(f\"Coefficients shape: {result_num_terms.metalog.a.shape}\")\n", "print()\n", "print(\"Results by num_terms:\")\n", "for _i, _nt in enumerate(num_terms_grid):\n", " print(f\" {_nt} terms: KS = {float(result_num_terms.ks_dist[_i]):.4f}\")\n", "\n", "# Find best configuration\n", "best_idx_nt, best_ks_nt = find_best_config(result_num_terms.ks_dist)\n", "print(\n", " f\"\\nBest: {num_terms_grid[int(best_idx_nt)]} terms (KS = {float(best_ks_nt):.4f})\"\n", ")" ] }, { "cell_type": "markdown", "id": "SFPL", "metadata": { "marimo": { "config": { "hide_code": true } } }, "source": [ "---\n", "## Case 3: Single Dataset, L1 Penalties Grid\n", "\n", "Search over L1 regularization strengths using Lasso regression.\n", "Higher penalties produce sparser coefficients, reducing overfitting." ] }, { "cell_type": "code", "execution_count": 5, "id": "BYtC", "metadata": { "execution": { "iopub.execute_input": "2025-12-26T21:57:52.803830Z", "iopub.status.busy": "2025-12-26T21:57:52.803752Z", "iopub.status.idle": "2025-12-26T21:57:53.350784Z", "shell.execute_reply": "2025-12-26T21:57:53.350429Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KS Distances shape: (5,)\n", "Coefficients shape: (5, 9)\n", "\n", "Results by L1 penalty:\n", " L1=0.000: KS = 0.0381, ||a|| = 1.3180\n", " L1=0.001: KS = 0.0381, ||a|| = 1.3161\n", " L1=0.010: KS = 0.0476, ||a|| = 1.3009\n", " L1=0.100: KS = 0.0571, ||a|| = 1.1992\n", " L1=1.000: KS = 0.1238, ||a|| = 1.0951\n", "\n", "Best: L1=0.000 (KS = 0.0381)\n" ] } ], "source": [ "# Configure for Lasso\n", "params_lasso = MetalogParameters(\n", " boundedness=MetalogBoundedness.BOUNDED,\n", " lower_bound=0,\n", " upper_bound=1,\n", " method=MetalogFitMethod.Lasso,\n", " num_terms=9,\n", ")\n", "\n", "# Grid over L1 penalties\n", "l1_grid = jnp.array([0.0, 0.001, 0.01, 0.1, 1.0])\n", "\n", "result_l1 = fit_grid(\n", " data_single.x, data_single.y, params_lasso, l1_penalties=l1_grid\n", ")\n", "\n", "print(f\"KS Distances shape: {result_l1.ks_dist.shape}\")\n", "print(f\"Coefficients shape: {result_l1.metalog.a.shape}\")\n", "print()\n", "print(\"Results by L1 penalty:\")\n", "for _i, _l1 in enumerate(l1_grid):\n", " coeff_norm = float(jnp.linalg.norm(result_l1.metalog.a[_i]))\n", " print(\n", " f\" L1={float(_l1):5.3f}: KS = {float(result_l1.ks_dist[_i]):.4f}, ||a|| = {coeff_norm:.4f}\"\n", " )\n", "\n", "best_l1_idx, best_l1_ks = find_best_config(result_l1.ks_dist)\n", "print(\n", " f\"\\nBest: L1={float(l1_grid[int(best_l1_idx)]):.3f} (KS = {float(best_l1_ks):.4f})\"\n", ")" ] }, { "cell_type": "markdown", "id": "RGSE", "metadata": { "marimo": { "config": { "hide_code": true } } }, "source": [ "---\n", "## Case 4: Single Dataset, Both Grids (2D)\n", "\n", "Search over both L1 penalties and num_terms simultaneously.\n", "Returns a 2D grid of results: (n_penalties, n_terms)." ] }, { "cell_type": "code", "execution_count": 6, "id": "Kclp", "metadata": { "execution": { "iopub.execute_input": "2025-12-26T21:57:53.351801Z", "iopub.status.busy": "2025-12-26T21:57:53.351739Z", "iopub.status.idle": "2025-12-26T21:57:54.538188Z", "shell.execute_reply": "2025-12-26T21:57:54.537763Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KS Distances shape: (3, 3)\n", "Coefficients shape: (3, 3, 9)\n", "\n", "2D Grid Results (L1 penalty \u00d7 num_terms):\n", " 5 terms 7 terms 9 terms\n", "L1=0.000 0.0476 0.0381 0.0381\n", "L1=0.010 0.0571 0.0476 0.0476\n", "L1=0.100 0.0571 0.0571 0.0571\n", "\n", "Best: L1=0.000, 7 terms (KS = 0.0381)\n" ] } ], "source": [ "# Configure for 2D search\n", "params_2d = MetalogParameters(\n", " boundedness=MetalogBoundedness.BOUNDED,\n", " lower_bound=0,\n", " upper_bound=1,\n", " method=MetalogFitMethod.Lasso,\n", " num_terms=3,\n", ")\n", "\n", "# Define both grids\n", "l1_grid_2d = jnp.array([0.0, 0.01, 0.1])\n", "num_terms_2d = [5, 7, 9]\n", "\n", "result_2d = fit_grid(\n", " data_single.x,\n", " data_single.y,\n", " params_2d,\n", " l1_penalties=l1_grid_2d,\n", " num_terms=num_terms_2d,\n", ")\n", "\n", "print(f\"KS Distances shape: {result_2d.ks_dist.shape}\")\n", "print(f\"Coefficients shape: {result_2d.metalog.a.shape}\")\n", "print()\n", "print(\"2D Grid Results (L1 penalty \u00d7 num_terms):\")\n", "print(\" \", end=\"\")\n", "for _nt in num_terms_2d:\n", " print(f\"{_nt:>8} terms\", end=\"\")\n", "print()\n", "for _i, _l1 in enumerate(l1_grid_2d):\n", " print(f\"L1={float(_l1):5.3f}\", end=\"\")\n", " for _j in range(len(num_terms_2d)):\n", " print(f\" {float(result_2d.ks_dist[_i, _j]):.4f}\", end=\"\")\n", " print()\n", "\n", "# Find best in 2D grid\n", "best_2d_idx, best_2d_ks = find_best_config(result_2d.ks_dist)\n", "best_l1_2d = l1_grid_2d[best_2d_idx[0]]\n", "best_nt_2d = num_terms_2d[best_2d_idx[1]]\n", "print(\n", " f\"\\nBest: L1={float(best_l1_2d):.3f}, {best_nt_2d} terms (KS = {float(best_2d_ks):.4f})\"\n", ")" ] }, { "cell_type": "markdown", "id": "emfo", "metadata": { "marimo": { "config": { "hide_code": true } } }, "source": [ "---\n", "## Case 5: Batched Datasets, No Grid\n", "\n", "Fit multiple datasets in parallel with shared parameters.\n", "Useful when you have many distributions with similar characteristics." ] }, { "cell_type": "code", "execution_count": 7, "id": "Hstk", "metadata": { "execution": { "iopub.execute_input": "2025-12-26T21:57:54.539181Z", "iopub.status.busy": "2025-12-26T21:57:54.539124Z", "iopub.status.idle": "2025-12-26T21:57:54.565697Z", "shell.execute_reply": "2025-12-26T21:57:54.565276Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Batched x shape: (3, 105)\n", "Batched y shape: (3, 105)\n" ] } ], "source": [ "# Generate multiple distributions (all lower-bounded at 0)\n", "dist_lognorm = lognorm(s=0.5, loc=0, scale=1).rvs(size=200, random_state=42)\n", "dist_weibull = weibull_min(c=2, scale=2).rvs(size=200, random_state=43)\n", "dist_gamma = gamma(a=4, scale=1).rvs(size=200, random_state=44)\n", "\n", "# Create input data for each\n", "data1 = MetalogInputData.from_values(dist_lognorm, DEFAULT_Y, False)\n", "data2 = MetalogInputData.from_values(dist_weibull, DEFAULT_Y, False)\n", "data3 = MetalogInputData.from_values(dist_gamma, DEFAULT_Y, False)\n", "\n", "# Stack into batched arrays\n", "batched_x = jnp.stack([data1.x, data2.x, data3.x])\n", "batched_y = jnp.stack([data1.y, data2.y, data3.y])\n", "\n", "print(f\"Batched x shape: {batched_x.shape}\")\n", "print(f\"Batched y shape: {batched_y.shape}\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "nWHF", "metadata": { "execution": { "iopub.execute_input": "2025-12-26T21:57:54.566799Z", "iopub.status.busy": "2025-12-26T21:57:54.566728Z", "iopub.status.idle": "2025-12-26T21:57:55.727247Z", "shell.execute_reply": "2025-12-26T21:57:55.726940Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KS Distances shape: (3,)\n", "Coefficients shape: (3, 7)\n", "\n", "Results per dataset:\n", " Lognormal: KS = 0.0286\n", " Weibull: KS = 0.0286\n", " Gamma: KS = 0.0286\n" ] } ], "source": [ "# Configure shared parameters\n", "params_batch = MetalogParameters(\n", " boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,\n", " lower_bound=0,\n", " upper_bound=0,\n", " method=MetalogFitMethod.OLS,\n", " num_terms=7,\n", ")\n", "\n", "# Fit all datasets in parallel\n", "result_batch = fit_grid(batched_x, batched_y, params_batch)\n", "\n", "print(f\"KS Distances shape: {result_batch.ks_dist.shape}\")\n", "print(f\"Coefficients shape: {result_batch.metalog.a.shape}\")\n", "print()\n", "dist_names = [\"Lognormal\", \"Weibull\", \"Gamma\"]\n", "print(\"Results per dataset:\")\n", "for _i, _name in enumerate(dist_names):\n", " print(f\" {_name}: KS = {float(result_batch.ks_dist[_i]):.4f}\")" ] }, { "cell_type": "markdown", "id": "iLit", "metadata": { "marimo": { "config": { "hide_code": true } } }, "source": [ "---\n", "## Case 6: Batched Datasets, Num Terms Grid (2D)\n", "\n", "Search over num_terms for each dataset in a batch.\n", "Returns shape (n_datasets, n_terms)." ] }, { "cell_type": "code", "execution_count": 9, "id": "ZHCJ", "metadata": { "execution": { "iopub.execute_input": "2025-12-26T21:57:55.728448Z", "iopub.status.busy": "2025-12-26T21:57:55.728391Z", "iopub.status.idle": "2025-12-26T21:57:57.036084Z", "shell.execute_reply": "2025-12-26T21:57:57.032763Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KS Distances shape: (3, 4)\n", "Coefficients shape: (3, 4, 11)\n", "\n", "Results (datasets \u00d7 num_terms):\n", " 5 terms 7 terms 9 terms 11 terms\n", " Lognormal 0.0476 0.0286 0.0286 0.0286\n", " Weibull 0.0381 0.0286 0.0286 0.0190\n", " Gamma 0.0190 0.0286 0.0190 0.0190\n", "\n", "Best config per dataset:\n", " Lognormal: 7 terms (KS = 0.0286)\n", " Weibull: 11 terms (KS = 0.0190)\n", " Gamma: 5 terms (KS = 0.0190)\n" ] } ], "source": [ "params_batch_nt = MetalogParameters(\n", " boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,\n", " lower_bound=0,\n", " upper_bound=0,\n", " method=MetalogFitMethod.OLS,\n", " num_terms=3,\n", ")\n", "\n", "num_terms_batch = [5, 7, 9, 11]\n", "\n", "result_batch_nt = fit_grid(\n", " batched_x, batched_y, params_batch_nt, num_terms=num_terms_batch\n", ")\n", "\n", "print(f\"KS Distances shape: {result_batch_nt.ks_dist.shape}\")\n", "print(f\"Coefficients shape: {result_batch_nt.metalog.a.shape}\")\n", "print()\n", "print(\"Results (datasets \u00d7 num_terms):\")\n", "print(\" \", end=\"\")\n", "for _nt in num_terms_batch:\n", " print(f\"{_nt:>8} terms\", end=\"\")\n", "print()\n", "for _i, _name in enumerate(dist_names):\n", " print(f\"{_name:>12}\", end=\"\")\n", " for _j in range(len(num_terms_batch)):\n", " print(f\" {float(result_batch_nt.ks_dist[_i, _j]):.4f}\", end=\"\")\n", " print()\n", "\n", "# Find best per dataset\n", "print(\"\\nBest config per dataset:\")\n", "for _i, _name in enumerate(dist_names):\n", " _best_idx, _best_ks = find_best_config(result_batch_nt.ks_dist[_i])\n", " print(\n", " f\" {_name}: {num_terms_batch[int(_best_idx)]} terms (KS = {float(_best_ks):.4f})\"\n", " )" ] }, { "cell_type": "markdown", "id": "ROlb", "metadata": { "marimo": { "config": { "hide_code": true } } }, "source": [ "---\n", "## Case 7: Batched Datasets, L1 Penalties Grid (2D)\n", "\n", "Search over L1 penalties for each dataset in a batch.\n", "Returns shape (n_datasets, n_penalties)." ] }, { "cell_type": "code", "execution_count": 10, "id": "qnkX", "metadata": { "execution": { "iopub.execute_input": "2025-12-26T21:57:57.039143Z", "iopub.status.busy": "2025-12-26T21:57:57.039025Z", "iopub.status.idle": "2025-12-26T21:57:57.653543Z", "shell.execute_reply": "2025-12-26T21:57:57.653173Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KS Distances shape: (3, 4)\n", "Coefficients shape: (3, 4, 9)\n", "\n", "Results (datasets \u00d7 L1 penalties):\n", " L1=0.000 L1=0.010 L1=0.100 L1=1.000\n", " Lognormal 0.0381 0.0381 0.0762 0.1048\n", " Weibull 0.0381 0.0381 0.0667 0.1810\n", " Gamma 0.0286 0.0381 0.0571 0.1714\n", "\n", "Best L1 per dataset:\n", " Lognormal: L1=0.000 (KS = 0.0381)\n", " Weibull: L1=0.000 (KS = 0.0381)\n", " Gamma: L1=0.000 (KS = 0.0286)\n" ] } ], "source": [ "params_batch_l1 = MetalogParameters(\n", " boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,\n", " lower_bound=0,\n", " upper_bound=0,\n", " method=MetalogFitMethod.Lasso,\n", " num_terms=9,\n", ")\n", "\n", "l1_batch = jnp.array([0.0, 0.01, 0.1, 1.0])\n", "\n", "result_batch_l1 = fit_grid(\n", " batched_x, batched_y, params_batch_l1, l1_penalties=l1_batch\n", ")\n", "\n", "print(f\"KS Distances shape: {result_batch_l1.ks_dist.shape}\")\n", "print(f\"Coefficients shape: {result_batch_l1.metalog.a.shape}\")\n", "print()\n", "print(\"Results (datasets \u00d7 L1 penalties):\")\n", "print(\" \", end=\"\")\n", "for _l1 in l1_batch:\n", " print(f\" L1={float(_l1):5.3f}\", end=\"\")\n", "print()\n", "for _i, _name in enumerate(dist_names):\n", " print(f\"{_name:>12}\", end=\"\")\n", " for _j in range(len(l1_batch)):\n", " print(f\" {float(result_batch_l1.ks_dist[_i, _j]):.4f}\", end=\"\")\n", " print()\n", "\n", "# Find best per dataset\n", "print(\"\\nBest L1 per dataset:\")\n", "for _i, _name in enumerate(dist_names):\n", " _best_idx, _best_ks = find_best_config(result_batch_l1.ks_dist[_i])\n", " print(\n", " f\" {_name}: L1={float(l1_batch[int(_best_idx)]):.3f} (KS = {float(_best_ks):.4f})\"\n", " )" ] }, { "cell_type": "markdown", "id": "TqIu", "metadata": { "marimo": { "config": { "hide_code": true } } }, "source": [ "---\n", "## Case 8: Full 3D Grid (Batched Datasets \u00d7 L1 \u00d7 Num Terms)\n", "\n", "The most comprehensive search: search over all three axes simultaneously.\n", "Returns shape (n_datasets, n_penalties, n_terms)." ] }, { "cell_type": "code", "execution_count": 11, "id": "Vxnm", "metadata": { "execution": { "iopub.execute_input": "2025-12-26T21:57:57.654621Z", "iopub.status.busy": "2025-12-26T21:57:57.654554Z", "iopub.status.idle": "2025-12-26T21:57:58.901102Z", "shell.execute_reply": "2025-12-26T21:57:58.900785Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KS Distances shape: (3, 3, 3)\n", "Coefficients shape: (3, 3, 3, 9)\n", "\n", "Best configuration per dataset:\n", " Lognormal: L1=0.000, 7 terms (KS = 0.0381)\n", " Weibull: L1=0.000, 5 terms (KS = 0.0381)\n", " Gamma: L1=0.000, 5 terms (KS = 0.0286)\n", "\n", "Full 3D grid for first dataset (Lognormal):\n", " 5 terms 7 terms 9 terms\n", "L1=0.000 0.0476 0.0381 0.0381\n", "L1=0.010 0.0476 0.0381 0.0381\n", "L1=0.100 0.0762 0.0762 0.0762\n" ] } ], "source": [ "params_3d = MetalogParameters(\n", " boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,\n", " lower_bound=0,\n", " upper_bound=0,\n", " method=MetalogFitMethod.Lasso,\n", " num_terms=3,\n", ")\n", "\n", "l1_3d = jnp.array([0.0, 0.01, 0.1])\n", "num_terms_3d = [5, 7, 9]\n", "\n", "result_3d = fit_grid(\n", " batched_x, batched_y, params_3d, l1_penalties=l1_3d, num_terms=num_terms_3d\n", ")\n", "\n", "print(f\"KS Distances shape: {result_3d.ks_dist.shape}\")\n", "print(f\"Coefficients shape: {result_3d.metalog.a.shape}\")\n", "print()\n", "\n", "# Find best configuration for each dataset\n", "print(\"Best configuration per dataset:\")\n", "for _i, _name in enumerate(dist_names):\n", " _best_idx, _best_ks = find_best_config(result_3d.ks_dist[_i])\n", " _best_l1 = l1_3d[_best_idx[0]]\n", " _best_nt = num_terms_3d[_best_idx[1]]\n", " print(\n", " f\" {_name}: L1={float(_best_l1):.3f}, {_best_nt} terms (KS = {float(_best_ks):.4f})\"\n", " )\n", "\n", "print()\n", "print(\"Full 3D grid for first dataset (Lognormal):\")\n", "print(\" \", end=\"\")\n", "for _nt in num_terms_3d:\n", " print(f\"{_nt:>8} terms\", end=\"\")\n", "print()\n", "for _j, _l1 in enumerate(l1_3d):\n", " print(f\"L1={float(_l1):5.3f}\", end=\"\")\n", " for _k in range(len(num_terms_3d)):\n", " print(f\" {float(result_3d.ks_dist[0, _j, _k]):.4f}\", end=\"\")\n", " print()" ] }, { "cell_type": "markdown", "id": "DnEU", "metadata": { "marimo": { "config": { "hide_code": true } } }, "source": [ "---\n", "## Working with Different Boundedness Types\n", "\n", "The `fit_grid` function works with all boundedness types.\n", "Here's an example with unbounded data." ] }, { "cell_type": "code", "execution_count": 12, "id": "ulZA", "metadata": { "execution": { "iopub.execute_input": "2025-12-26T21:57:58.902316Z", "iopub.status.busy": "2025-12-26T21:57:58.902247Z", "iopub.status.idle": "2025-12-26T21:57:58.972673Z", "shell.execute_reply": "2025-12-26T21:57:58.972332Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Unbounded fit KS distance: 0.0286\n" ] } ], "source": [ "# Generate normally distributed data (unbounded)\n", "normal_samples = norm(loc=50, scale=10).rvs(size=200, random_state=42)\n", "data_unbounded = MetalogInputData.from_values(normal_samples, DEFAULT_Y, False)\n", "\n", "params_unbounded = MetalogParameters(\n", " boundedness=MetalogBoundedness.UNBOUNDED,\n", " lower_bound=0, # ignored\n", " upper_bound=0, # ignored\n", " method=MetalogFitMethod.OLS,\n", " num_terms=7,\n", ")\n", "\n", "result_unbounded = fit_grid(data_unbounded.x, data_unbounded.y, params_unbounded)\n", "print(f\"Unbounded fit KS distance: {float(result_unbounded.ks_dist):.4f}\")" ] }, { "cell_type": "markdown", "id": "ecfG", "metadata": { "marimo": { "config": { "hide_code": true } } }, "source": [ "---\n", "## Using Precomputed Quantiles\n", "\n", "If you already have precomputed quantiles (e.g., from expert elicitation),\n", "set `precomputed_quantiles=True`." ] }, { "cell_type": "code", "execution_count": 13, "id": "Pvdt", "metadata": { "execution": { "iopub.execute_input": "2025-12-26T21:57:58.973970Z", "iopub.status.busy": "2025-12-26T21:57:58.973901Z", "iopub.status.idle": "2025-12-26T21:57:59.368227Z", "shell.execute_reply": "2025-12-26T21:57:59.367782Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Precomputed quantiles fit KS distance: 0.1429\n", "Coefficients: [ 2.11320782e-16 5.82969856e-01 2.37162298e-16 1.05904096e+00\n", " -6.04972876e-16]\n" ] } ], "source": [ "# Expert-elicited quantiles for a 0-100 bounded variable\n", "quantiles = jnp.array([10.0, 25.0, 40.0, 50.0, 60.0, 75.0, 90.0])\n", "probabilities = jnp.array([0.05, 0.20, 0.40, 0.50, 0.60, 0.80, 0.95])\n", "\n", "params_quantiles = MetalogParameters(\n", " boundedness=MetalogBoundedness.BOUNDED,\n", " lower_bound=0,\n", " upper_bound=100,\n", " method=MetalogFitMethod.OLS,\n", " num_terms=5,\n", ")\n", "\n", "result_quantiles = fit_grid(\n", " quantiles, probabilities, params_quantiles, precomputed_quantiles=True\n", ")\n", "\n", "print(\n", " f\"Precomputed quantiles fit KS distance: {float(result_quantiles.ks_dist):.4f}\"\n", ")\n", "print(f\"Coefficients: {result_quantiles.metalog.a}\")" ] }, { "cell_type": "markdown", "id": "ZBYS", "metadata": { "marimo": { "config": { "hide_code": true } } }, "source": [ "---\n", "## Summary\n", "\n", "The `fit_grid` function provides a unified interface for all grid search operations:\n", "\n", "| Inputs | Output Shape | Description |\n", "|--------|--------------|-------------|\n", "| x.ndim=1, no grids | () | Single fit |\n", "| x.ndim=1, num_terms | (n_terms,) | 1D num_terms search |\n", "| x.ndim=1, l1_penalties | (n_penalties,) | 1D L1 search |\n", "| x.ndim=1, both | (n_penalties, n_terms) | 2D search |\n", "| x.ndim=2, no grids | (n_datasets,) | Batch fit |\n", "| x.ndim=2, num_terms | (n_datasets, n_terms) | 2D batch + terms |\n", "| x.ndim=2, l1_penalties | (n_datasets, n_penalties) | 2D batch + L1 |\n", "| x.ndim=2, both | (n_datasets, n_penalties, n_terms) | Full 3D search |\n", "\n", "Use `find_best_config(ks_dist)` to find the best configuration in any grid." ] } ], "metadata": { "kernelspec": { "display_name": "metalog_jax", "language": "python", "name": "metalog_jax" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 5 }