Skip to content
 
metalog-jax Grid Search
Type to start searching
    metalog_jax
    • metalog-jax
    metalog_jax
    • User Guide
    • Getting Started
    • Basic Usage
    • Grid Search
      • Grid Search
        • Unified Grid Search with fit_grid
          • Grid Axes
        • Case 1: Single Dataset, No Grid
        • Case 2: Single Dataset, Num Terms Grid
        • Case 3: Single Dataset, L1 Penalties Grid
        • Case 4: Single Dataset, Both Grids (2D)
        • Case 5: Batched Datasets, No Grid
        • Case 6: Batched Datasets, Num Terms Grid (2D)
        • Case 7: Batched Datasets, L1 Penalties Grid (2D)
        • Case 8: Full 3D Grid (Batched Datasets × L1 × Num Terms)
        • Working with Different Boundedness Types
        • Using Precomputed Quantiles
        • Summary
      • Show Source
      • Unified Grid Search with fit_grid
        • Grid Axes
      • Case 1: Single Dataset, No Grid
      • Case 2: Single Dataset, Num Terms Grid
      • Case 3: Single Dataset, L1 Penalties Grid
      • Case 4: Single Dataset, Both Grids (2D)
      • Case 5: Batched Datasets, No Grid
      • Case 6: Batched Datasets, Num Terms Grid (2D)
      • Case 7: Batched Datasets, L1 Penalties Grid (2D)
      • Case 8: Full 3D Grid (Batched Datasets × L1 × Num Terms)
      • Working with Different Boundedness Types
      • Using Precomputed Quantiles
      • Summary
    • API Reference
    • API Reference
    • Grid Search
      • Unified Grid Search with fit_grid
        • Grid Axes
      • Case 1: Single Dataset, No Grid
      • Case 2: Single Dataset, Num Terms Grid
      • Case 3: Single Dataset, L1 Penalties Grid
      • Case 4: Single Dataset, Both Grids (2D)
      • Case 5: Batched Datasets, No Grid
      • Case 6: Batched Datasets, Num Terms Grid (2D)
      • Case 7: Batched Datasets, L1 Penalties Grid (2D)
      • Case 8: Full 3D Grid (Batched Datasets × L1 × Num Terms)
      • Working with Different Boundedness Types
      • Using Precomputed Quantiles
      • Summary
    • Show Source

    Grid Search¶

    This notebook demonstrates the unified fit_grid function for hyperparameter optimization.

    Unified Grid Search with fit_grid¶

    This notebook demonstrates the unified fit_grid function which provides a single interface for all grid search combinations. The function automatically detects which axes to search based on the inputs provided.

    Grid Axes¶

    The fit_grid function can search over three axes:

    1. Datasets: Multiple datasets to fit in parallel (detected by x.ndim == 2)

    2. L1 Penalties: Regularization strength for Lasso regression

    3. Num Terms: Number of terms in the metalog distribution

    This gives us 8 possible combinations (2³), all handled by a single function.

    [1]:
    
    import jax.numpy as jnp
    from scipy.stats import beta, gamma, lognorm, norm, weibull_min
    
    from metalog_jax.base import (
        MetalogBoundedness,
        MetalogFitMethod,
        MetalogInputData,
        MetalogParameters,
    )
    from metalog_jax.grid_search import find_best_config, fit_grid
    from metalog_jax.utils import DEFAULT_Y
    

    Case 1: Single Dataset, No Grid¶

    The simplest case: fit a single dataset with fixed parameters. Returns scalar KS distance and 1D coefficient array.

    [2]:
    
    # Generate sample data from a beta distribution
    dist_beta = beta(a=2, b=5)
    samples_beta = dist_beta.rvs(size=200, random_state=42)
    
    # Create input data
    data_single = MetalogInputData.from_values(samples_beta, DEFAULT_Y, False)
    
    [3]:
    
    # Configure parameters
    params_ols = MetalogParameters(
        boundedness=MetalogBoundedness.BOUNDED,
        lower_bound=0,
        upper_bound=1,
        method=MetalogFitMethod.OLS,
        num_terms=7,
    )
    
    # Fit single dataset - returns scalar results
    result_single = fit_grid(data_single.x, data_single.y, params_ols)
    
    print(f"KS Distance: {float(result_single.ks_dist):.4f}")
    print(f"Coefficients shape: {result_single.metalog.a.shape}")
    print(f"Coefficients: {result_single.metalog.a}")
    
    KS Distance: 0.0286
    Coefficients shape: (7,)
    Coefficients: [ -1.08998096  -1.48153417   0.04950236   8.30422784  -0.78536482
       6.29446539 -18.91331443]
    

    Case 2: Single Dataset, Num Terms Grid¶

    Search over different numbers of terms to find optimal complexity. Returns 1D array of results indexed by num_terms.

    [4]:
    
    # Configure base parameters
    params_num_terms = MetalogParameters(
        boundedness=MetalogBoundedness.BOUNDED,
        lower_bound=0,
        upper_bound=1,
        method=MetalogFitMethod.OLS,
        num_terms=3,  # Will be overridden
    )
    
    # Grid over num_terms
    num_terms_grid = [3, 5, 7, 9, 11]
    
    result_num_terms = fit_grid(
        data_single.x, data_single.y, params_num_terms, num_terms=num_terms_grid
    )
    
    print(f"KS Distances shape: {result_num_terms.ks_dist.shape}")
    print(f"Coefficients shape: {result_num_terms.metalog.a.shape}")
    print()
    print("Results by num_terms:")
    for _i, _nt in enumerate(num_terms_grid):
        print(f"  {_nt} terms: KS = {float(result_num_terms.ks_dist[_i]):.4f}")
    
    # Find best configuration
    best_idx_nt, best_ks_nt = find_best_config(result_num_terms.ks_dist)
    print(
        f"\nBest: {num_terms_grid[int(best_idx_nt)]} terms (KS = {float(best_ks_nt):.4f})"
    )
    
    KS Distances shape: (5,)
    Coefficients shape: (5, 11)
    
    Results by num_terms:
      3 terms: KS = 0.0667
      5 terms: KS = 0.0381
      7 terms: KS = 0.0286
      9 terms: KS = 0.0286
      11 terms: KS = 0.0286
    
    Best: 7 terms (KS = 0.0286)
    

    Case 3: Single Dataset, L1 Penalties Grid¶

    Search over L1 regularization strengths using Lasso regression. Higher penalties produce sparser coefficients, reducing overfitting.

    [5]:
    
    # Configure for Lasso
    params_lasso = MetalogParameters(
        boundedness=MetalogBoundedness.BOUNDED,
        lower_bound=0,
        upper_bound=1,
        method=MetalogFitMethod.Lasso,
        num_terms=9,
    )
    
    # Grid over L1 penalties
    l1_grid = jnp.array([0.0, 0.001, 0.01, 0.1, 1.0])
    
    result_l1 = fit_grid(
        data_single.x, data_single.y, params_lasso, l1_penalties=l1_grid
    )
    
    print(f"KS Distances shape: {result_l1.ks_dist.shape}")
    print(f"Coefficients shape: {result_l1.metalog.a.shape}")
    print()
    print("Results by L1 penalty:")
    for _i, _l1 in enumerate(l1_grid):
        coeff_norm = float(jnp.linalg.norm(result_l1.metalog.a[_i]))
        print(
            f"  L1={float(_l1):5.3f}: KS = {float(result_l1.ks_dist[_i]):.4f}, ||a|| = {coeff_norm:.4f}"
        )
    
    best_l1_idx, best_l1_ks = find_best_config(result_l1.ks_dist)
    print(
        f"\nBest: L1={float(l1_grid[int(best_l1_idx)]):.3f} (KS = {float(best_l1_ks):.4f})"
    )
    
    KS Distances shape: (5,)
    Coefficients shape: (5, 9)
    
    Results by L1 penalty:
      L1=0.000: KS = 0.0381, ||a|| = 1.3180
      L1=0.001: KS = 0.0381, ||a|| = 1.3161
      L1=0.010: KS = 0.0476, ||a|| = 1.3009
      L1=0.100: KS = 0.0571, ||a|| = 1.1992
      L1=1.000: KS = 0.1238, ||a|| = 1.0951
    
    Best: L1=0.000 (KS = 0.0381)
    

    Case 4: Single Dataset, Both Grids (2D)¶

    Search over both L1 penalties and num_terms simultaneously. Returns a 2D grid of results: (n_penalties, n_terms).

    [6]:
    
    # Configure for 2D search
    params_2d = MetalogParameters(
        boundedness=MetalogBoundedness.BOUNDED,
        lower_bound=0,
        upper_bound=1,
        method=MetalogFitMethod.Lasso,
        num_terms=3,
    )
    
    # Define both grids
    l1_grid_2d = jnp.array([0.0, 0.01, 0.1])
    num_terms_2d = [5, 7, 9]
    
    result_2d = fit_grid(
        data_single.x,
        data_single.y,
        params_2d,
        l1_penalties=l1_grid_2d,
        num_terms=num_terms_2d,
    )
    
    print(f"KS Distances shape: {result_2d.ks_dist.shape}")
    print(f"Coefficients shape: {result_2d.metalog.a.shape}")
    print()
    print("2D Grid Results (L1 penalty × num_terms):")
    print("         ", end="")
    for _nt in num_terms_2d:
        print(f"{_nt:>8} terms", end="")
    print()
    for _i, _l1 in enumerate(l1_grid_2d):
        print(f"L1={float(_l1):5.3f}", end="")
        for _j in range(len(num_terms_2d)):
            print(f"     {float(result_2d.ks_dist[_i, _j]):.4f}", end="")
        print()
    
    # Find best in 2D grid
    best_2d_idx, best_2d_ks = find_best_config(result_2d.ks_dist)
    best_l1_2d = l1_grid_2d[best_2d_idx[0]]
    best_nt_2d = num_terms_2d[best_2d_idx[1]]
    print(
        f"\nBest: L1={float(best_l1_2d):.3f}, {best_nt_2d} terms (KS = {float(best_2d_ks):.4f})"
    )
    
    KS Distances shape: (3, 3)
    Coefficients shape: (3, 3, 9)
    
    2D Grid Results (L1 penalty × num_terms):
                    5 terms       7 terms       9 terms
    L1=0.000     0.0476     0.0381     0.0381
    L1=0.010     0.0571     0.0476     0.0476
    L1=0.100     0.0571     0.0571     0.0571
    
    Best: L1=0.000, 7 terms (KS = 0.0381)
    

    Case 5: Batched Datasets, No Grid¶

    Fit multiple datasets in parallel with shared parameters. Useful when you have many distributions with similar characteristics.

    [7]:
    
    # Generate multiple distributions (all lower-bounded at 0)
    dist_lognorm = lognorm(s=0.5, loc=0, scale=1).rvs(size=200, random_state=42)
    dist_weibull = weibull_min(c=2, scale=2).rvs(size=200, random_state=43)
    dist_gamma = gamma(a=4, scale=1).rvs(size=200, random_state=44)
    
    # Create input data for each
    data1 = MetalogInputData.from_values(dist_lognorm, DEFAULT_Y, False)
    data2 = MetalogInputData.from_values(dist_weibull, DEFAULT_Y, False)
    data3 = MetalogInputData.from_values(dist_gamma, DEFAULT_Y, False)
    
    # Stack into batched arrays
    batched_x = jnp.stack([data1.x, data2.x, data3.x])
    batched_y = jnp.stack([data1.y, data2.y, data3.y])
    
    print(f"Batched x shape: {batched_x.shape}")
    print(f"Batched y shape: {batched_y.shape}")
    
    Batched x shape: (3, 105)
    Batched y shape: (3, 105)
    
    [8]:
    
    # Configure shared parameters
    params_batch = MetalogParameters(
        boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
        lower_bound=0,
        upper_bound=0,
        method=MetalogFitMethod.OLS,
        num_terms=7,
    )
    
    # Fit all datasets in parallel
    result_batch = fit_grid(batched_x, batched_y, params_batch)
    
    print(f"KS Distances shape: {result_batch.ks_dist.shape}")
    print(f"Coefficients shape: {result_batch.metalog.a.shape}")
    print()
    dist_names = ["Lognormal", "Weibull", "Gamma"]
    print("Results per dataset:")
    for _i, _name in enumerate(dist_names):
        print(f"  {_name}: KS = {float(result_batch.ks_dist[_i]):.4f}")
    
    KS Distances shape: (3,)
    Coefficients shape: (3, 7)
    
    Results per dataset:
      Lognormal: KS = 0.0286
      Weibull: KS = 0.0286
      Gamma: KS = 0.0286
    

    Case 6: Batched Datasets, Num Terms Grid (2D)¶

    Search over num_terms for each dataset in a batch. Returns shape (n_datasets, n_terms).

    [9]:
    
    params_batch_nt = MetalogParameters(
        boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
        lower_bound=0,
        upper_bound=0,
        method=MetalogFitMethod.OLS,
        num_terms=3,
    )
    
    num_terms_batch = [5, 7, 9, 11]
    
    result_batch_nt = fit_grid(
        batched_x, batched_y, params_batch_nt, num_terms=num_terms_batch
    )
    
    print(f"KS Distances shape: {result_batch_nt.ks_dist.shape}")
    print(f"Coefficients shape: {result_batch_nt.metalog.a.shape}")
    print()
    print("Results (datasets × num_terms):")
    print("              ", end="")
    for _nt in num_terms_batch:
        print(f"{_nt:>8} terms", end="")
    print()
    for _i, _name in enumerate(dist_names):
        print(f"{_name:>12}", end="")
        for _j in range(len(num_terms_batch)):
            print(f"     {float(result_batch_nt.ks_dist[_i, _j]):.4f}", end="")
        print()
    
    # Find best per dataset
    print("\nBest config per dataset:")
    for _i, _name in enumerate(dist_names):
        _best_idx, _best_ks = find_best_config(result_batch_nt.ks_dist[_i])
        print(
            f"  {_name}: {num_terms_batch[int(_best_idx)]} terms (KS = {float(_best_ks):.4f})"
        )
    
    KS Distances shape: (3, 4)
    Coefficients shape: (3, 4, 11)
    
    Results (datasets × num_terms):
                         5 terms       7 terms       9 terms      11 terms
       Lognormal     0.0476     0.0286     0.0286     0.0286
         Weibull     0.0381     0.0286     0.0286     0.0190
           Gamma     0.0190     0.0286     0.0190     0.0190
    
    Best config per dataset:
      Lognormal: 7 terms (KS = 0.0286)
      Weibull: 11 terms (KS = 0.0190)
      Gamma: 5 terms (KS = 0.0190)
    

    Case 7: Batched Datasets, L1 Penalties Grid (2D)¶

    Search over L1 penalties for each dataset in a batch. Returns shape (n_datasets, n_penalties).

    [10]:
    
    params_batch_l1 = MetalogParameters(
        boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
        lower_bound=0,
        upper_bound=0,
        method=MetalogFitMethod.Lasso,
        num_terms=9,
    )
    
    l1_batch = jnp.array([0.0, 0.01, 0.1, 1.0])
    
    result_batch_l1 = fit_grid(
        batched_x, batched_y, params_batch_l1, l1_penalties=l1_batch
    )
    
    print(f"KS Distances shape: {result_batch_l1.ks_dist.shape}")
    print(f"Coefficients shape: {result_batch_l1.metalog.a.shape}")
    print()
    print("Results (datasets × L1 penalties):")
    print("              ", end="")
    for _l1 in l1_batch:
        print(f"  L1={float(_l1):5.3f}", end="")
    print()
    for _i, _name in enumerate(dist_names):
        print(f"{_name:>12}", end="")
        for _j in range(len(l1_batch)):
            print(f"     {float(result_batch_l1.ks_dist[_i, _j]):.4f}", end="")
        print()
    
    # Find best per dataset
    print("\nBest L1 per dataset:")
    for _i, _name in enumerate(dist_names):
        _best_idx, _best_ks = find_best_config(result_batch_l1.ks_dist[_i])
        print(
            f"  {_name}: L1={float(l1_batch[int(_best_idx)]):.3f} (KS = {float(_best_ks):.4f})"
        )
    
    KS Distances shape: (3, 4)
    Coefficients shape: (3, 4, 9)
    
    Results (datasets × L1 penalties):
                    L1=0.000  L1=0.010  L1=0.100  L1=1.000
       Lognormal     0.0381     0.0381     0.0762     0.1048
         Weibull     0.0381     0.0381     0.0667     0.1810
           Gamma     0.0286     0.0381     0.0571     0.1714
    
    Best L1 per dataset:
      Lognormal: L1=0.000 (KS = 0.0381)
      Weibull: L1=0.000 (KS = 0.0381)
      Gamma: L1=0.000 (KS = 0.0286)
    

    Case 8: Full 3D Grid (Batched Datasets × L1 × Num Terms)¶

    The most comprehensive search: search over all three axes simultaneously. Returns shape (n_datasets, n_penalties, n_terms).

    [11]:
    
    params_3d = MetalogParameters(
        boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
        lower_bound=0,
        upper_bound=0,
        method=MetalogFitMethod.Lasso,
        num_terms=3,
    )
    
    l1_3d = jnp.array([0.0, 0.01, 0.1])
    num_terms_3d = [5, 7, 9]
    
    result_3d = fit_grid(
        batched_x, batched_y, params_3d, l1_penalties=l1_3d, num_terms=num_terms_3d
    )
    
    print(f"KS Distances shape: {result_3d.ks_dist.shape}")
    print(f"Coefficients shape: {result_3d.metalog.a.shape}")
    print()
    
    # Find best configuration for each dataset
    print("Best configuration per dataset:")
    for _i, _name in enumerate(dist_names):
        _best_idx, _best_ks = find_best_config(result_3d.ks_dist[_i])
        _best_l1 = l1_3d[_best_idx[0]]
        _best_nt = num_terms_3d[_best_idx[1]]
        print(
            f"  {_name}: L1={float(_best_l1):.3f}, {_best_nt} terms (KS = {float(_best_ks):.4f})"
        )
    
    print()
    print("Full 3D grid for first dataset (Lognormal):")
    print("         ", end="")
    for _nt in num_terms_3d:
        print(f"{_nt:>8} terms", end="")
    print()
    for _j, _l1 in enumerate(l1_3d):
        print(f"L1={float(_l1):5.3f}", end="")
        for _k in range(len(num_terms_3d)):
            print(f"     {float(result_3d.ks_dist[0, _j, _k]):.4f}", end="")
        print()
    
    KS Distances shape: (3, 3, 3)
    Coefficients shape: (3, 3, 3, 9)
    
    Best configuration per dataset:
      Lognormal: L1=0.000, 7 terms (KS = 0.0381)
      Weibull: L1=0.000, 5 terms (KS = 0.0381)
      Gamma: L1=0.000, 5 terms (KS = 0.0286)
    
    Full 3D grid for first dataset (Lognormal):
                    5 terms       7 terms       9 terms
    L1=0.000     0.0476     0.0381     0.0381
    L1=0.010     0.0476     0.0381     0.0381
    L1=0.100     0.0762     0.0762     0.0762
    

    Working with Different Boundedness Types¶

    The fit_grid function works with all boundedness types. Here’s an example with unbounded data.

    [12]:
    
    # Generate normally distributed data (unbounded)
    normal_samples = norm(loc=50, scale=10).rvs(size=200, random_state=42)
    data_unbounded = MetalogInputData.from_values(normal_samples, DEFAULT_Y, False)
    
    params_unbounded = MetalogParameters(
        boundedness=MetalogBoundedness.UNBOUNDED,
        lower_bound=0,  # ignored
        upper_bound=0,  # ignored
        method=MetalogFitMethod.OLS,
        num_terms=7,
    )
    
    result_unbounded = fit_grid(data_unbounded.x, data_unbounded.y, params_unbounded)
    print(f"Unbounded fit KS distance: {float(result_unbounded.ks_dist):.4f}")
    
    Unbounded fit KS distance: 0.0286
    

    Using Precomputed Quantiles¶

    If you already have precomputed quantiles (e.g., from expert elicitation), set precomputed_quantiles=True.

    [13]:
    
    # Expert-elicited quantiles for a 0-100 bounded variable
    quantiles = jnp.array([10.0, 25.0, 40.0, 50.0, 60.0, 75.0, 90.0])
    probabilities = jnp.array([0.05, 0.20, 0.40, 0.50, 0.60, 0.80, 0.95])
    
    params_quantiles = MetalogParameters(
        boundedness=MetalogBoundedness.BOUNDED,
        lower_bound=0,
        upper_bound=100,
        method=MetalogFitMethod.OLS,
        num_terms=5,
    )
    
    result_quantiles = fit_grid(
        quantiles, probabilities, params_quantiles, precomputed_quantiles=True
    )
    
    print(
        f"Precomputed quantiles fit KS distance: {float(result_quantiles.ks_dist):.4f}"
    )
    print(f"Coefficients: {result_quantiles.metalog.a}")
    
    Precomputed quantiles fit KS distance: 0.1429
    Coefficients: [ 2.11320782e-16  5.82969856e-01  2.37162298e-16  1.05904096e+00
     -6.04972876e-16]
    

    Summary¶

    The fit_grid function provides a unified interface for all grid search operations:

    Inputs

    Output Shape

    Description

    x.ndim=1, no grids

    ()

    Single fit

    x.ndim=1, num_terms

    (n_terms,)

    1D num_terms search

    x.ndim=1, l1_penalties

    (n_penalties,)

    1D L1 search

    x.ndim=1, both

    (n_penalties, n_terms)

    2D search

    x.ndim=2, no grids

    (n_datasets,)

    Batch fit

    x.ndim=2, num_terms

    (n_datasets, n_terms)

    2D batch + terms

    x.ndim=2, l1_penalties

    (n_datasets, n_penalties)

    2D batch + L1

    x.ndim=2, both

    (n_datasets, n_penalties, n_terms)

    Full 3D search

    Use find_best_config(ks_dist) to find the best configuration in any grid.

    "Previous" Basic Usage
    "Next" API Reference
    © Copyright 2026, Travis Jefferies.
    Created using Sphinx 9.0.4. and Material for Sphinx