Grid Search¶
This notebook demonstrates the unified fit_grid function for hyperparameter optimization.
Unified Grid Search with fit_grid¶
This notebook demonstrates the unified fit_grid function which provides a single interface for all grid search combinations. The function automatically detects which axes to search based on the inputs provided.
Grid Axes¶
The fit_grid function can search over three axes:
Datasets: Multiple datasets to fit in parallel (detected by x.ndim == 2)
L1 Penalties: Regularization strength for Lasso regression
Num Terms: Number of terms in the metalog distribution
This gives us 8 possible combinations (2³), all handled by a single function.
[1]:
import jax.numpy as jnp
from scipy.stats import beta, gamma, lognorm, norm, weibull_min
from metalog_jax.base import (
MetalogBoundedness,
MetalogFitMethod,
MetalogInputData,
MetalogParameters,
)
from metalog_jax.grid_search import find_best_config, fit_grid
from metalog_jax.utils import DEFAULT_Y
Case 1: Single Dataset, No Grid¶
The simplest case: fit a single dataset with fixed parameters. Returns scalar KS distance and 1D coefficient array.
[2]:
# Generate sample data from a beta distribution
dist_beta = beta(a=2, b=5)
samples_beta = dist_beta.rvs(size=200, random_state=42)
# Create input data
data_single = MetalogInputData.from_values(samples_beta, DEFAULT_Y, False)
[3]:
# Configure parameters
params_ols = MetalogParameters(
boundedness=MetalogBoundedness.BOUNDED,
lower_bound=0,
upper_bound=1,
method=MetalogFitMethod.OLS,
num_terms=7,
)
# Fit single dataset - returns scalar results
result_single = fit_grid(data_single.x, data_single.y, params_ols)
print(f"KS Distance: {float(result_single.ks_dist):.4f}")
print(f"Coefficients shape: {result_single.metalog.a.shape}")
print(f"Coefficients: {result_single.metalog.a}")
KS Distance: 0.0286
Coefficients shape: (7,)
Coefficients: [ -1.08998096 -1.48153417 0.04950236 8.30422784 -0.78536482
6.29446539 -18.91331443]
Case 2: Single Dataset, Num Terms Grid¶
Search over different numbers of terms to find optimal complexity. Returns 1D array of results indexed by num_terms.
[4]:
# Configure base parameters
params_num_terms = MetalogParameters(
boundedness=MetalogBoundedness.BOUNDED,
lower_bound=0,
upper_bound=1,
method=MetalogFitMethod.OLS,
num_terms=3, # Will be overridden
)
# Grid over num_terms
num_terms_grid = [3, 5, 7, 9, 11]
result_num_terms = fit_grid(
data_single.x, data_single.y, params_num_terms, num_terms=num_terms_grid
)
print(f"KS Distances shape: {result_num_terms.ks_dist.shape}")
print(f"Coefficients shape: {result_num_terms.metalog.a.shape}")
print()
print("Results by num_terms:")
for _i, _nt in enumerate(num_terms_grid):
print(f" {_nt} terms: KS = {float(result_num_terms.ks_dist[_i]):.4f}")
# Find best configuration
best_idx_nt, best_ks_nt = find_best_config(result_num_terms.ks_dist)
print(
f"\nBest: {num_terms_grid[int(best_idx_nt)]} terms (KS = {float(best_ks_nt):.4f})"
)
KS Distances shape: (5,)
Coefficients shape: (5, 11)
Results by num_terms:
3 terms: KS = 0.0667
5 terms: KS = 0.0381
7 terms: KS = 0.0286
9 terms: KS = 0.0286
11 terms: KS = 0.0286
Best: 7 terms (KS = 0.0286)
Case 3: Single Dataset, L1 Penalties Grid¶
Search over L1 regularization strengths using Lasso regression. Higher penalties produce sparser coefficients, reducing overfitting.
[5]:
# Configure for Lasso
params_lasso = MetalogParameters(
boundedness=MetalogBoundedness.BOUNDED,
lower_bound=0,
upper_bound=1,
method=MetalogFitMethod.Lasso,
num_terms=9,
)
# Grid over L1 penalties
l1_grid = jnp.array([0.0, 0.001, 0.01, 0.1, 1.0])
result_l1 = fit_grid(
data_single.x, data_single.y, params_lasso, l1_penalties=l1_grid
)
print(f"KS Distances shape: {result_l1.ks_dist.shape}")
print(f"Coefficients shape: {result_l1.metalog.a.shape}")
print()
print("Results by L1 penalty:")
for _i, _l1 in enumerate(l1_grid):
coeff_norm = float(jnp.linalg.norm(result_l1.metalog.a[_i]))
print(
f" L1={float(_l1):5.3f}: KS = {float(result_l1.ks_dist[_i]):.4f}, ||a|| = {coeff_norm:.4f}"
)
best_l1_idx, best_l1_ks = find_best_config(result_l1.ks_dist)
print(
f"\nBest: L1={float(l1_grid[int(best_l1_idx)]):.3f} (KS = {float(best_l1_ks):.4f})"
)
KS Distances shape: (5,)
Coefficients shape: (5, 9)
Results by L1 penalty:
L1=0.000: KS = 0.0381, ||a|| = 1.3180
L1=0.001: KS = 0.0381, ||a|| = 1.3161
L1=0.010: KS = 0.0476, ||a|| = 1.3009
L1=0.100: KS = 0.0571, ||a|| = 1.1992
L1=1.000: KS = 0.1238, ||a|| = 1.0951
Best: L1=0.000 (KS = 0.0381)
Case 4: Single Dataset, Both Grids (2D)¶
Search over both L1 penalties and num_terms simultaneously. Returns a 2D grid of results: (n_penalties, n_terms).
[6]:
# Configure for 2D search
params_2d = MetalogParameters(
boundedness=MetalogBoundedness.BOUNDED,
lower_bound=0,
upper_bound=1,
method=MetalogFitMethod.Lasso,
num_terms=3,
)
# Define both grids
l1_grid_2d = jnp.array([0.0, 0.01, 0.1])
num_terms_2d = [5, 7, 9]
result_2d = fit_grid(
data_single.x,
data_single.y,
params_2d,
l1_penalties=l1_grid_2d,
num_terms=num_terms_2d,
)
print(f"KS Distances shape: {result_2d.ks_dist.shape}")
print(f"Coefficients shape: {result_2d.metalog.a.shape}")
print()
print("2D Grid Results (L1 penalty × num_terms):")
print(" ", end="")
for _nt in num_terms_2d:
print(f"{_nt:>8} terms", end="")
print()
for _i, _l1 in enumerate(l1_grid_2d):
print(f"L1={float(_l1):5.3f}", end="")
for _j in range(len(num_terms_2d)):
print(f" {float(result_2d.ks_dist[_i, _j]):.4f}", end="")
print()
# Find best in 2D grid
best_2d_idx, best_2d_ks = find_best_config(result_2d.ks_dist)
best_l1_2d = l1_grid_2d[best_2d_idx[0]]
best_nt_2d = num_terms_2d[best_2d_idx[1]]
print(
f"\nBest: L1={float(best_l1_2d):.3f}, {best_nt_2d} terms (KS = {float(best_2d_ks):.4f})"
)
KS Distances shape: (3, 3)
Coefficients shape: (3, 3, 9)
2D Grid Results (L1 penalty × num_terms):
5 terms 7 terms 9 terms
L1=0.000 0.0476 0.0381 0.0381
L1=0.010 0.0571 0.0476 0.0476
L1=0.100 0.0571 0.0571 0.0571
Best: L1=0.000, 7 terms (KS = 0.0381)
Case 5: Batched Datasets, No Grid¶
Fit multiple datasets in parallel with shared parameters. Useful when you have many distributions with similar characteristics.
[7]:
# Generate multiple distributions (all lower-bounded at 0)
dist_lognorm = lognorm(s=0.5, loc=0, scale=1).rvs(size=200, random_state=42)
dist_weibull = weibull_min(c=2, scale=2).rvs(size=200, random_state=43)
dist_gamma = gamma(a=4, scale=1).rvs(size=200, random_state=44)
# Create input data for each
data1 = MetalogInputData.from_values(dist_lognorm, DEFAULT_Y, False)
data2 = MetalogInputData.from_values(dist_weibull, DEFAULT_Y, False)
data3 = MetalogInputData.from_values(dist_gamma, DEFAULT_Y, False)
# Stack into batched arrays
batched_x = jnp.stack([data1.x, data2.x, data3.x])
batched_y = jnp.stack([data1.y, data2.y, data3.y])
print(f"Batched x shape: {batched_x.shape}")
print(f"Batched y shape: {batched_y.shape}")
Batched x shape: (3, 105)
Batched y shape: (3, 105)
[8]:
# Configure shared parameters
params_batch = MetalogParameters(
boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
lower_bound=0,
upper_bound=0,
method=MetalogFitMethod.OLS,
num_terms=7,
)
# Fit all datasets in parallel
result_batch = fit_grid(batched_x, batched_y, params_batch)
print(f"KS Distances shape: {result_batch.ks_dist.shape}")
print(f"Coefficients shape: {result_batch.metalog.a.shape}")
print()
dist_names = ["Lognormal", "Weibull", "Gamma"]
print("Results per dataset:")
for _i, _name in enumerate(dist_names):
print(f" {_name}: KS = {float(result_batch.ks_dist[_i]):.4f}")
KS Distances shape: (3,)
Coefficients shape: (3, 7)
Results per dataset:
Lognormal: KS = 0.0286
Weibull: KS = 0.0286
Gamma: KS = 0.0286
Case 6: Batched Datasets, Num Terms Grid (2D)¶
Search over num_terms for each dataset in a batch. Returns shape (n_datasets, n_terms).
[9]:
params_batch_nt = MetalogParameters(
boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
lower_bound=0,
upper_bound=0,
method=MetalogFitMethod.OLS,
num_terms=3,
)
num_terms_batch = [5, 7, 9, 11]
result_batch_nt = fit_grid(
batched_x, batched_y, params_batch_nt, num_terms=num_terms_batch
)
print(f"KS Distances shape: {result_batch_nt.ks_dist.shape}")
print(f"Coefficients shape: {result_batch_nt.metalog.a.shape}")
print()
print("Results (datasets × num_terms):")
print(" ", end="")
for _nt in num_terms_batch:
print(f"{_nt:>8} terms", end="")
print()
for _i, _name in enumerate(dist_names):
print(f"{_name:>12}", end="")
for _j in range(len(num_terms_batch)):
print(f" {float(result_batch_nt.ks_dist[_i, _j]):.4f}", end="")
print()
# Find best per dataset
print("\nBest config per dataset:")
for _i, _name in enumerate(dist_names):
_best_idx, _best_ks = find_best_config(result_batch_nt.ks_dist[_i])
print(
f" {_name}: {num_terms_batch[int(_best_idx)]} terms (KS = {float(_best_ks):.4f})"
)
KS Distances shape: (3, 4)
Coefficients shape: (3, 4, 11)
Results (datasets × num_terms):
5 terms 7 terms 9 terms 11 terms
Lognormal 0.0476 0.0286 0.0286 0.0286
Weibull 0.0381 0.0286 0.0286 0.0190
Gamma 0.0190 0.0286 0.0190 0.0190
Best config per dataset:
Lognormal: 7 terms (KS = 0.0286)
Weibull: 11 terms (KS = 0.0190)
Gamma: 5 terms (KS = 0.0190)
Case 7: Batched Datasets, L1 Penalties Grid (2D)¶
Search over L1 penalties for each dataset in a batch. Returns shape (n_datasets, n_penalties).
[10]:
params_batch_l1 = MetalogParameters(
boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
lower_bound=0,
upper_bound=0,
method=MetalogFitMethod.Lasso,
num_terms=9,
)
l1_batch = jnp.array([0.0, 0.01, 0.1, 1.0])
result_batch_l1 = fit_grid(
batched_x, batched_y, params_batch_l1, l1_penalties=l1_batch
)
print(f"KS Distances shape: {result_batch_l1.ks_dist.shape}")
print(f"Coefficients shape: {result_batch_l1.metalog.a.shape}")
print()
print("Results (datasets × L1 penalties):")
print(" ", end="")
for _l1 in l1_batch:
print(f" L1={float(_l1):5.3f}", end="")
print()
for _i, _name in enumerate(dist_names):
print(f"{_name:>12}", end="")
for _j in range(len(l1_batch)):
print(f" {float(result_batch_l1.ks_dist[_i, _j]):.4f}", end="")
print()
# Find best per dataset
print("\nBest L1 per dataset:")
for _i, _name in enumerate(dist_names):
_best_idx, _best_ks = find_best_config(result_batch_l1.ks_dist[_i])
print(
f" {_name}: L1={float(l1_batch[int(_best_idx)]):.3f} (KS = {float(_best_ks):.4f})"
)
KS Distances shape: (3, 4)
Coefficients shape: (3, 4, 9)
Results (datasets × L1 penalties):
L1=0.000 L1=0.010 L1=0.100 L1=1.000
Lognormal 0.0381 0.0381 0.0762 0.1048
Weibull 0.0381 0.0381 0.0667 0.1810
Gamma 0.0286 0.0381 0.0571 0.1714
Best L1 per dataset:
Lognormal: L1=0.000 (KS = 0.0381)
Weibull: L1=0.000 (KS = 0.0381)
Gamma: L1=0.000 (KS = 0.0286)
Case 8: Full 3D Grid (Batched Datasets × L1 × Num Terms)¶
The most comprehensive search: search over all three axes simultaneously. Returns shape (n_datasets, n_penalties, n_terms).
[11]:
params_3d = MetalogParameters(
boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
lower_bound=0,
upper_bound=0,
method=MetalogFitMethod.Lasso,
num_terms=3,
)
l1_3d = jnp.array([0.0, 0.01, 0.1])
num_terms_3d = [5, 7, 9]
result_3d = fit_grid(
batched_x, batched_y, params_3d, l1_penalties=l1_3d, num_terms=num_terms_3d
)
print(f"KS Distances shape: {result_3d.ks_dist.shape}")
print(f"Coefficients shape: {result_3d.metalog.a.shape}")
print()
# Find best configuration for each dataset
print("Best configuration per dataset:")
for _i, _name in enumerate(dist_names):
_best_idx, _best_ks = find_best_config(result_3d.ks_dist[_i])
_best_l1 = l1_3d[_best_idx[0]]
_best_nt = num_terms_3d[_best_idx[1]]
print(
f" {_name}: L1={float(_best_l1):.3f}, {_best_nt} terms (KS = {float(_best_ks):.4f})"
)
print()
print("Full 3D grid for first dataset (Lognormal):")
print(" ", end="")
for _nt in num_terms_3d:
print(f"{_nt:>8} terms", end="")
print()
for _j, _l1 in enumerate(l1_3d):
print(f"L1={float(_l1):5.3f}", end="")
for _k in range(len(num_terms_3d)):
print(f" {float(result_3d.ks_dist[0, _j, _k]):.4f}", end="")
print()
KS Distances shape: (3, 3, 3)
Coefficients shape: (3, 3, 3, 9)
Best configuration per dataset:
Lognormal: L1=0.000, 7 terms (KS = 0.0381)
Weibull: L1=0.000, 5 terms (KS = 0.0381)
Gamma: L1=0.000, 5 terms (KS = 0.0286)
Full 3D grid for first dataset (Lognormal):
5 terms 7 terms 9 terms
L1=0.000 0.0476 0.0381 0.0381
L1=0.010 0.0476 0.0381 0.0381
L1=0.100 0.0762 0.0762 0.0762
Working with Different Boundedness Types¶
The fit_grid function works with all boundedness types. Here’s an example with unbounded data.
[12]:
# Generate normally distributed data (unbounded)
normal_samples = norm(loc=50, scale=10).rvs(size=200, random_state=42)
data_unbounded = MetalogInputData.from_values(normal_samples, DEFAULT_Y, False)
params_unbounded = MetalogParameters(
boundedness=MetalogBoundedness.UNBOUNDED,
lower_bound=0, # ignored
upper_bound=0, # ignored
method=MetalogFitMethod.OLS,
num_terms=7,
)
result_unbounded = fit_grid(data_unbounded.x, data_unbounded.y, params_unbounded)
print(f"Unbounded fit KS distance: {float(result_unbounded.ks_dist):.4f}")
Unbounded fit KS distance: 0.0286
Using Precomputed Quantiles¶
If you already have precomputed quantiles (e.g., from expert elicitation), set precomputed_quantiles=True.
[13]:
# Expert-elicited quantiles for a 0-100 bounded variable
quantiles = jnp.array([10.0, 25.0, 40.0, 50.0, 60.0, 75.0, 90.0])
probabilities = jnp.array([0.05, 0.20, 0.40, 0.50, 0.60, 0.80, 0.95])
params_quantiles = MetalogParameters(
boundedness=MetalogBoundedness.BOUNDED,
lower_bound=0,
upper_bound=100,
method=MetalogFitMethod.OLS,
num_terms=5,
)
result_quantiles = fit_grid(
quantiles, probabilities, params_quantiles, precomputed_quantiles=True
)
print(
f"Precomputed quantiles fit KS distance: {float(result_quantiles.ks_dist):.4f}"
)
print(f"Coefficients: {result_quantiles.metalog.a}")
Precomputed quantiles fit KS distance: 0.1429
Coefficients: [ 2.11320782e-16 5.82969856e-01 2.37162298e-16 1.05904096e+00
-6.04972876e-16]
Summary¶
The fit_grid function provides a unified interface for all grid search operations:
Inputs |
Output Shape |
Description |
|---|---|---|
x.ndim=1, no grids |
() |
Single fit |
x.ndim=1, num_terms |
(n_terms,) |
1D num_terms search |
x.ndim=1, l1_penalties |
(n_penalties,) |
1D L1 search |
x.ndim=1, both |
(n_penalties, n_terms) |
2D search |
x.ndim=2, no grids |
(n_datasets,) |
Batch fit |
x.ndim=2, num_terms |
(n_datasets, n_terms) |
2D batch + terms |
x.ndim=2, l1_penalties |
(n_datasets, n_penalties) |
2D batch + L1 |
x.ndim=2, both |
(n_datasets, n_penalties, n_terms) |
Full 3D search |
Use find_best_config(ks_dist) to find the best configuration in any grid.