API Reference¶
This section contains the complete API reference for metalog-jax.
Table of Contents
metalog_jax¶
Main Module¶
Metalog JAX library.
Metalog Distribution¶
Metalog distribution implementations and fitting functions.
This module provides the main metalog distribution classes and fitting functions:
- Classes:
Metalog: Standard metalog distribution fitted via regression methods. SPTMetalog: Symmetric Percentile Triplet metalog with closed-form coefficients. GridResult: Results container for grid search and hyperparameter optimization.
- Functions:
fit: Fit a standard metalog distribution using configurable regression methods. fit_spt_metalog: Fit an SPT metalog distribution using closed-form formulas.
The fit function uses a dispatch table pattern to route to the appropriate regression implementation based on the MetalogFitMethod specified in MetalogParameters:
MetalogFitMethod.OLS -> metalog_jax.regression.ols.fit_ordinary_least_squares
MetalogFitMethod.Lasso -> metalog_jax.regression.lasso.fit_lasso
See also
metalog_jax.base: Base classes and parameter configurations. metalog_jax.regression: Regression implementations for fitting.
- class metalog_jax.metalog.SPTMetalog(metalog_params, a)[source]¶
Bases:
MetalogBaseSymmetric Percentile Triplet (SPT) metalog distribution.
SPTMetalog is a specialized metalog implementation that uses exactly three terms fitted from three quantiles: the alpha-quantile, median (0.5), and (1-alpha)-quantile, where alpha < 0.5. This method provides a computationally efficient way to fit metalog distributions with closed-form coefficient solutions, avoiding the need for optimization or linear regression.
The SPT approach is particularly useful when you need: - A quick approximation using only three data points - Guaranteed feasibility through explicit feasibility checks - Analytical coefficient formulas for different boundedness types - A distribution that matches specific symmetric or asymmetric quantile triplets
Unlike the standard Metalog class which uses OLS or LASSO regression to fit arbitrary numbers of terms, SPTMetalog always has exactly 3 terms and computes coefficients directly from the quantile triplet. This makes it faster for small datasets but less flexible than the full metalog approach.
- metalog_params¶
Configuration parameters specific to SPT fitting, including boundedness type, boundary values, and the alpha parameter that determines which quantiles are used (alpha, 0.5, 1-alpha).
- a¶
Coefficient vector of shape (3,) representing the three fitted metalog distribution parameters. These are computed using closed-form formulas specific to the boundedness type.
Examples
Fit an unbounded SPT metalog from data:
>>> import jax.numpy as jnp >>> from metalog_jax import fit_spt_metalog, SPTMetalogParameters >>> from metalog_jax import MetalogBoundedness >>> >>> # Generate sample data >>> data = jnp.array([1.2, 2.3, 2.8, 3.5, 4.1, 5.6, 6.2, 7.8, 9.1]) >>> >>> # Configure SPT parameters >>> params = SPTMetalogParameters( ... boundedness=MetalogBoundedness.UNBOUNDED, ... alpha=0.1, # Use 10th, 50th, and 90th percentiles ... lower_bound=0.0, ... upper_bound=0.0, ... ) >>> >>> # Fit the SPT metalog >>> spt = fit_spt_metalog(data, params) >>> >>> # Evaluate quantile function at probability 0.25 >>> q_25 = spt.ppf(0.25) >>> >>> # Compute mean and standard deviation >>> mean = spt.mean() >>> std = spt.std()
Note
SPT metalog distributions may fail feasibility checks if the data does not satisfy specific constraints on the symmetry ratio r. The feasibility constraints depend on the boundedness type and alpha parameter. When feasibility fails, fit_spt_metalog will raise an AssertionError.
The SPT method is described in Keelin (2016) as a special case of the metalog family that enables rapid approximation with minimal data.
See also
Metalog: Standard metalog implementation using regression fitting. MetalogBase: Base class defining the common interface. fit_spt_metalog: Function to fit SPT metalog distributions from data.
References
Keelin, T. W. (2016). The Metalog Distributions. Decision Analysis, 13(4), 243-277. https://doi.org/10.1287/deca.2016.0338
- Parameters:
- metalog_params: SPTMetalogParameters¶
- __eq__(other)[source]¶
Check equality between two SPTMetalog distribution instances.
Delegates to the parent MetalogBase.__eq__ method.
- property num_terms: int¶
Get the number of terms in the metalog expansion.
- Returns:
3 - an SPT Metalog always has three terms.
- __init__(metalog_params, a)¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- class metalog_jax.metalog.Metalog(metalog_params, a)[source]¶
Bases:
MetalogBaseStandard metalog distribution fitted via regression.
Metalog is the full-featured metalog implementation that extends MetalogBase and uses ordinary least squares (OLS) or LASSO regression to fit metalog distributions with arbitrary numbers of terms (typically 2-30). This approach provides maximum flexibility and accuracy for modeling complex distributions from larger datasets.
Unlike SPTMetalog which uses exactly 3 terms with closed-form coefficient formulas, the standard Metalog class: - Supports any number of terms (configurable via num_terms parameter) - Uses regression-based fitting (OLS or LASSO methods) - Can model more complex distribution shapes with higher-order terms - Requires more data for reliable fitting (recommended minimum: 3 x num_terms) - May encounter numerical issues with very high term counts
This class inherits all distribution methods from MetalogBase including ppf (quantile function), cdf, pdf, and statistical moments (mean, variance, std, skewness, kurtosis). It also provides serialization capabilities for saving and loading fitted distributions.
- metalog_params¶
Configuration parameters for the standard metalog fit, including boundedness type (UNBOUNDED, STRICTLY_LOWER_BOUND, STRICTLY_UPPER_BOUND, or BOUNDED), boundary values, fitting method (OLS or LASSO), and number of terms in the expansion.
- a¶
Coefficient vector of shape (num_terms,) representing the fitted metalog distribution parameters. Each coefficient corresponds to one basis function in the metalog expansion. These are computed via regression from the input data.
Examples
Fit an unbounded metalog with 9 terms using OLS:
>>> import jax.numpy as jnp >>> from metalog_jax import fit_metalog, MetalogParameters >>> from metalog_jax import MetalogBoundedness, MetalogFitMethod >>> >>> # Generate sample data >>> data = jnp.array([1.2, 2.3, 2.8, 3.5, 4.1, 5.6, 6.2, 7.8, 9.1, 10.5]) >>> >>> # Configure metalog parameters >>> params = MetalogParameters( ... boundedness=MetalogBoundedness.UNBOUNDED, ... method=MetalogFitMethod.OLS, ... num_terms=9, ... lower_bound=0.0, ... upper_bound=0.0, ... ) >>> >>> # Fit the metalog distribution >>> m = fit_metalog(data, params) >>> >>> # Evaluate quantile function at probability 0.75 >>> q_75 = m.ppf(0.75) >>> >>> # Compute distribution moments >>> mean = m.mean() >>> variance = m.variance() >>> skewness = m.skewness()
Note
The standard Metalog class is the recommended choice for most use cases involving sufficient data. Use SPTMetalog only when you need a quick three-term approximation or when you have exactly three representative quantiles. For datasets with at least 10-30 observations, the standard metalog typically provides superior fit quality.
Higher term counts generally improve accuracy but require more data and may cause overfitting or numerical instability. As a rule of thumb, ensure your sample size is at least 3 times the number of terms.
See also
SPTMetalog: Specialized three-term metalog using closed-form coefficients. MetalogBase: Base class defining the common distribution interface. fit_metalog: Function to fit standard metalog distributions from data.
References
Keelin, T. W. (2016). The Metalog Distributions. Decision Analysis, 13(4), 243-277. https://doi.org/10.1287/deca.2016.0338
- metalog_params: MetalogParameters¶
- __eq__(other)[source]¶
Check equality between two Metalog distribution instances.
Delegates to the parent MetalogBase.__eq__ method.
- property num_terms: int¶
Get the number of terms in the metalog expansion.
- Returns:
Integer representing the number of basis functions used in the metalog approximation. Higher values generally provide better accuracy but require more data.
- property method: MetalogFitMethod¶
Get the method used to fit the metalog distribution.
- Returns:
MetalogFitMethod enum of regression methods for fitting metalog distribution.
- __init__(metalog_params, a)¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- metalog_jax.metalog.fit(data, metalog_params, regression_hyperparams=None)[source]¶
Fit a metalog distribution to data using the specified configuration.
Estimates the metalog distribution coefficients by solving a linear regression problem that maps the basis functions (target matrix) to the transformed quantiles. The regression method specified in metalog_params determines the fitting approach, and optional regression_hyperparams allow fine-tuning of regularization settings.
This function processes a MetalogInputData instance containing validated input data (quantiles and probability levels) and fits a metalog distribution according to the specified parameters. The data must be created using MetalogInputData.from_values() which performs comprehensive validation.
- Parameters:
data (
MetalogInputData) – Validated input data created via MetalogInputData.from_values(). Contains: - x: Quantile values (precomputed or computed from raw samples) - y: Probability levels in (0, 1) - precomputed_quantiles: Flag indicating data type Direct instantiation of MetalogInputData is prevented by __post_init__ validation to ensure data integrity.metalog_params (
MetalogParameters) – Configuration parameters specifying the distribution characteristics (boundedness, bounds, number of terms) and the regression method to use for fitting. The method field determines which regression approach is used: - MetalogFitMethod.OLS: Ordinary Least Squares (no regularization) - MetalogFitMethod.Lasso: L1 regularizationregression_hyperparams (
RegularizedParameters) – Optional regularization hyperparameters for controlling the fitting process when using LASSO regression method. This parameter is ignored when method=OLS. If None, default hyperparameters are used. Must be an instance of LassoParameters for MetalogFitMethod.Lasso.
- Return type:
- Returns:
Metalog containing the fitted coefficient vector and the configuration parameters used to fit the metalog distribution. The returned object is validated via
assert_fitted()to ensure the fit produces a feasible distribution with strictly positive PDF values.- Raises:
TypeError – If metalog_params.method is not a valid MetalogFitMethod instance.
checkify.JaxRuntimeError – If the fitted distribution produces non-positive PDF values, indicating an infeasible fit. This validation is performed by
assert_fitted()before returning.
Example
Basic usage with OLS (no regularization):
>>> import jax.numpy as jnp >>> from metalog_jax.base import MetalogInputData, MetalogParameters >>> from metalog_jax.base import MetalogBoundedness, MetalogFitMethod >>> from metalog_jax.metalog import fit >>> >>> # Create validated input data using from_values() >>> raw_data = jnp.array([1.2, 2.3, 2.8, 3.5, 4.1, 5.6]) >>> data = MetalogInputData.from_values( ... x=raw_data, ... y=jnp.array([0.1, 0.25, 0.5, 0.75, 0.9, 0.95]), ... precomputed_quantiles=False # Raw samples, not precomputed quantiles ... ) >>> >>> # Configure metalog parameters with OLS >>> metalog_params = MetalogParameters( ... boundedness=MetalogBoundedness.UNBOUNDED, ... method=MetalogFitMethod.OLS, ... lower_bound=0.0, ... upper_bound=0.0, ... num_terms=5 ... ) >>> >>> # Fit the metalog distribution >>> m = fit(data, metalog_params)
Note
The regression_hyperparams parameter is only applicable for LASSO method. It is ignored when using OLS.
If regression_hyperparams is None, sensible defaults are used for each method.
For production use, consider tuning regularization hyperparameters via cross-validation to prevent overfitting while maintaining good fit quality.
See also
fit_spt_metalog: Alternative fitting method using Symmetric Percentile Triplet. MetalogInputData.from_values: Required method for creating validated input data. metalog_jax.regression.lasso.LassoParameters: Hyperparameters for LASSO regression.
- metalog_jax.metalog.fit_spt_metalog(array, spt_metalog_params)[source]¶
Fit a Symmetric Percentile Triplet (SPT) metalog distribution to data.
This function fits a 3-term metalog distribution using the Symmetric Percentile Triplet (SPT) parameterization method described in Keelin (2016). Unlike the standard metalog fit which uses many quantiles, the SPT method uses exactly three symmetric quantiles: alpha, 0.5 (median), and (1 - alpha), where 0 < alpha < 0.5.
The SPT approach provides a quick, analytical solution for fitting a metalog distribution when: - You have limited data or only three quantiles available - You want a simple, closed-form solution without regression - You need a fast approximation with 3 terms only
The method computes the 3 metalog coefficients directly from the three quantile values without requiring least squares regression. The formulas vary based on the boundedness type (unbounded, semi-bounded, or bounded).
- Parameters:
array (
Union[Array,ndarray,bool,number,float,int]) – Input data array with at least 3 elements from which to compute empirical quantiles. The array will be used to compute the alpha-th, 50th, and (1-alpha)-th percentiles.spt_metalog_params (
SPTMetalogParameters) –Configuration parameters containing: - alpha: Lower percentile parameter in (0, 0.5). Common values are
0.1 (10-50-90 percentiles) or 0.25 (25-50-75 percentiles/IQR).
boundedness: Domain constraint type (UNBOUNDED, STRICTLY_LOWER_BOUND, or BOUNDED). STRICTLY_UPPER_BOUND is not supported in SPT formulation.
lower_bound: Lower boundary value (used for semi-bounded/bounded).
upper_bound: Upper boundary value (used for bounded distributions).
- Returns:
- Fitted SPT metalog distribution containing:
metalog_params: The input configuration parameters
a: Coefficient vector of length 3 containing the fitted metalog parameters [a1, a2, a3] that define the quantile function
- Return type:
- Raises:
AssertionError – If array has fewer than 3 elements, is not rank 1, or contains non-numeric values.
AssertionError – If alpha is not positive or alpha >= 0.5.
AssertionError – If the computed quantiles violate feasibility constraints (q_alpha < median < q_complement).
AssertionError – If boundedness-specific feasibility checks fail (e.g., quantiles outside valid range for the given bounds).
NotImplementedError – If boundedness is STRICTLY_UPPER_BOUND, which is undefined in the SPT formulation.
Example
>>> import jax.numpy as jnp >>> from metalog_jax.metalog import fit_spt_metalog, SPTMetalogParameters >>> from metalog_jax.metalog import MetalogBoundedness >>> # Generate sample data >>> data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]) >>> # Configure SPT parameters using 10-50-90 percentiles >>> params = SPTMetalogParameters( ... alpha=0.1, ... boundedness=MetalogBoundedness.UNBOUNDED, ... lower_bound=0.0, ... upper_bound=1.0 ... ) >>> # Fit the SPT metalog >>> spt_metalog = fit_spt_metalog(data, params) >>> # The result contains 3 coefficients >>> assert len(spt_metalog.a) == 3
Note
The SPT method always produces exactly 3 metalog terms, regardless of data size. For more flexible fits with more terms, use the standard fit() function.
STRICTLY_UPPER_BOUND boundedness is not supported by the SPT formulation as defined in Keelin’s paper.
The method performs multiple feasibility checks to ensure the quantiles and parameters produce a valid probability distribution.
For unbounded distributions, the median position ratio r must satisfy specific constraints related to k_alpha to ensure feasibility, where r = (median - q_alpha) / (q_complement - q_alpha).
References
Keelin, T. W. (2016). The Metalog Distributions. Decision Analysis, 13(4).
- class metalog_jax.metalog.GridResult(metalog, ks_dist)[source]¶
Bases:
objectResults from metalog grid search or hyperparameter optimization.
GridResult encapsulates the outcome of fitting a metalog distribution along with a goodness-of-fit metric. This dataclass is typically used when performing grid searches over hyperparameters (e.g., number of terms, regularization penalties) or when comparing multiple metalog configurations to select the best-fitting distribution.
The primary use case is hyperparameter tuning via grid search with JAX’s vmap to fit multiple metalog configurations in parallel, then selecting the best configuration based on the Kolmogorov-Smirnov distance.
The class uses Flax’s struct.dataclass decorator to ensure immutability and compatibility with JAX transformations (jit, vmap, grad), making it suitable for use in vectorized grid search operations.
- metalog¶
The fitted metalog distribution. Can be either a standard Metalog (fitted via regression with configurable terms) or SPTMetalog (fitted using Symmetric Percentile Triplet with exactly 3 terms). Contains: - metalog_params: Configuration parameters (boundedness, method, etc.) - a: Coefficient vector representing the fitted distribution
- ks_dist¶
Scalar Kolmogorov-Smirnov distance measuring the goodness-of-fit between the fitted metalog distribution and the input random variable. This is the maximum absolute difference between the empirical CDF of the input data and the fitted metalog’s CDF: - 0: Perfect fit (empirical CDFs are identical) - 1: Worst possible fit (completely non-overlapping distributions) - Lower values indicate better fit quality Used to select the best metalog configuration from a grid search.
Note
This class is immutable due to the struct.dataclass decorator
The ks_dist attribute provides a single goodness-of-fit metric for comparing different metalog configurations
When used with jax.vmap, the metalog and ks_dist fields will be batched arrays/PyTrees, allowing vectorized comparison
Lower KS distances indicate better fit quality, but extremely low values may indicate overfitting (especially with high term counts)
For production use, consider cross-validation instead of or in addition to KS distance for hyperparameter selection
See also
ks_distance: Function to compute the Kolmogorov-Smirnov distance. fit: Standard metalog fitting function. fit_spt_metalog: SPT metalog fitting function. Metalog: Standard metalog distribution class. SPTMetalog: Symmetric Percentile Triplet metalog class.
References
Kolmogorov, A. N. (1933). “Sulla determinazione empirica di una legge di distribuzione”. Giornale dell’Istituto Italiano degli Attuari, 4: 83-91.
Keelin, T. W. (2016). The Metalog Distributions. Decision Analysis, 13(4), 243-277. https://doi.org/10.1287/deca.2016.0338
- Parameters:
metalog (
Union[Metalog,SPTMetalog])
- metalog: Metalog | SPTMetalog¶
- __init__(metalog, ks_dist)¶
- Parameters:
metalog (
Union[Metalog,SPTMetalog])
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
Grid Search¶
Grid search methods for the Metalog JAX library.
This module provides generalized vmap-based grid search functions for fitting metalog distributions across multiple datasets, hyperparameters, and term counts.
The functions are designed to work with JAX’s vectorization primitives for efficient parallel computation on CPU, GPU, or TPU devices.
- Key functions:
fit_grid: Unified grid search over any combination of axes
fit_grid_datasets: 1D grid over datasets (shared params)
fit_grid_hyperparams: 1D grid over hyperparameters for a single dataset
fit_grid_num_terms: 1D grid over num_terms for a single dataset
fit_grid_datasets_hyperparams: 2D grid over datasets x hyperparameters
fit_grid_datasets_num_terms: 2D grid over datasets x num_terms
fit_grid_full: 3D grid over datasets x hyperparameters x num_terms
find_best_config: Find best configuration in a grid
extract_best_from_grid: Extract best result from grid for a dataset
- metalog_jax.grid_search.fit_grid(x, y, params, *, num_terms=None, l1_penalties=None, precomputed_quantiles=False, lasso_params=None, learning_rate=0.01, num_iters=500, tol=1e-06, momentum=0.9)[source]¶
Unified grid search over any combination of axes.
Automatically detects which axes to search based on inputs: - If x.ndim == 2: grid over datasets (batch dimension) - If num_terms provided: grid over num_terms values - If l1_penalties provided: grid over L1 penalties (uses Lasso)
Output shape depends on active axes (ordered: datasets, l1_penalties, num_terms): - (): single dataset, no grids - (n_datasets,): only datasets batched - (n_penalties,): only l1_penalties grid - (n_terms,): only num_terms grid - (n_datasets, n_penalties): datasets x l1_penalties - (n_datasets, n_terms): datasets x num_terms - (n_penalties, n_terms): l1_penalties x num_terms - (n_datasets, n_penalties, n_terms): all three axes
- Parameters:
x (
Array) – Quantile values. Shape (n,) for single dataset or (batch, n) for batched.y (
Array) – Probability levels. Shape (n,) for single dataset or (batch, n) for batched.params (
MetalogParameters) – MetalogParameters configuration. The num_terms field is ignored if num_terms argument is provided.num_terms (
Optional[Sequence[int]]) – Optional sequence of term counts to search over.l1_penalties (
Optional[Array]) – Optional array of L1 penalties to search over (implies Lasso).precomputed_quantiles (
bool) – Whether x values are precomputed quantiles.lasso_params (
Optional[LassoParameters]) – Optional fixed LassoParameters (used when l1_penalties is None but params.method is Lasso).learning_rate (
float) – Learning rate for Lasso. Default: 0.01.num_iters (
int) – Max iterations for Lasso. Default: 500.tol (
float) – Convergence tolerance for Lasso. Default: 1e-6.momentum (
float) – Momentum for Lasso optimizer. Default: 0.9.
- Return type:
- Returns:
GridResult with metalog fits and KS distances. Shape depends on active axes.
- metalog_jax.grid_search.stack_leaves(*leaves)[source]¶
Stack leaf values from multiple PyTrees along a new batch dimension.
This helper function is used with jax.tree.map to transform a list of dataclass instances into a single batched dataclass where each field is stacked along axis 0. It enables efficient vectorized operations by converting multiple independent structures into a batched structure.
- Parameters:
*leaves – Variable number of leaf arrays from corresponding positions in multiple PyTrees (e.g., the x field from multiple MetalogInputData instances). Each leaf should be a JAX array or scalar value.
- Returns:
A single JAX array with the leaves stacked along a new first dimension, shape (n_leaves, …) where n_leaves is the number of input leaves and … represents the original shape of each leaf.
- metalog_jax.grid_search.make_batch(data)[source]¶
Convert a list of MetalogBaseData instances into a single batched instance.
This function transforms a list of individual MetalogBaseData instances into a single MetalogBaseData instance where each field is batched along the first dimension. This enables vectorized operations across multiple datasets using JAX’s vmap.
- Parameters:
data (
list[MetalogBaseData]) – List of MetalogBaseData instances to batch together. Each instance should have the same structure (same fields), but can have different values. All instances should have arrays with compatible shapes.- Return type:
MetalogBaseData- Returns:
A single MetalogBaseData instance where each field is stacked along axis 0. If the input list has n elements and each instance has a field with shape (m,), the output will have that field with shape (n, m).
- metalog_jax.grid_search.unvmap(batched_out)[source]¶
Convert batched vmap output into a list of individual outputs.
This function reverses the batching operation performed by jax.vmap by converting a single batched PyTree into a list of individual PyTrees, one per batch element. Each element in the output list has the same structure as the batched output but without the batch dimension.
- Parameters:
batched_out (
Any) – The batched output from jax.vmap(f)(…). Can be a JAX array, tuple, dict, dataclass, or any PyTree structure where each leaf has a batch dimension as its first axis.- Return type:
- Returns:
A list of length n_batch where each element is a PyTree with the same structure as batched_out but with the batch dimension removed.
- metalog_jax.grid_search.pad_metalog_coeffs(metalog, max_terms)[source]¶
Pad metalog coefficient array to max_terms length with zeros.
Since different term counts produce different coefficient array lengths, we pad shorter arrays with zeros to enable stacking into a consistent batched structure.
- Parameters:
metalog (
Union[Metalog,SPTMetalog]) – Metalog or SPTMetalog instance to pad.max_terms (
int) – Target length for the coefficient array.
- Return type:
Union[Metalog,SPTMetalog]- Returns:
Metalog instance with coefficients padded to max_terms.
- metalog_jax.grid_search.fit_grid_datasets(batched_x, batched_y, params, fit_params=None, precomputed_quantiles=False)[source]¶
Fit multiple datasets with shared parameters using vmap.
This function efficiently fits metalog distributions to multiple datasets in parallel using JAX’s vmap. All datasets share the same MetalogParameters and optional regularization settings.
- Parameters:
batched_x (
Array) – Stacked quantile values, shape (n_datasets, n_samples).batched_y (
Array) – Stacked probability levels, shape (n_datasets, n_samples).params (
MetalogParameters) – MetalogParameters shared across all fits.fit_params (
Optional[RegularizedParameters]) – Optional regularization parameters (e.g., LassoParameters).precomputed_quantiles (
bool) – Whether x values are precomputed quantiles.
- Returns:
metalog.a: Coefficients of shape (n_datasets, num_terms)
ks_dist: KS distances of shape (n_datasets,)
- Return type:
GridResult with batched results
- metalog_jax.grid_search.fit_grid_hyperparams(data, params, l1_penalties, learning_rate=0.01, num_iters=500, tol=1e-06, momentum=0.9)[source]¶
Fit a single dataset with a grid of L1 penalty values (Lasso).
This function performs a 1D grid search over L1 regularization penalties for a single dataset using JAX’s vmap for efficient parallel computation.
- Parameters:
data (
MetalogBaseData) – MetalogBaseData instance containing the dataset to fit.params (
MetalogParameters) – MetalogParameters configuration. The method should be Lasso.l1_penalties (
Array) – Array of L1 penalty values to test, shape (n_penalties,).learning_rate (
float) – Learning rate for Lasso gradient descent. Default: 0.01.num_iters (
int) – Maximum iterations for Lasso. Default: 500.tol (
float) – Convergence tolerance for Lasso. Default: 1e-6.momentum (
float) – Momentum for Lasso optimizer. Default: 0.9.
- Returns:
metalog.a: Coefficients of shape (n_penalties, num_terms)
ks_dist: KS distances of shape (n_penalties,)
- Return type:
GridResult with batched results
- metalog_jax.grid_search.fit_grid_datasets_hyperparams(batched_x, batched_y, params, l1_penalties, precomputed_quantiles=False, learning_rate=0.01, num_iters=500, tol=1e-06, momentum=0.9)[source]¶
Fit multiple datasets with a grid of L1 penalty values (Lasso).
This function performs a 2D grid search over datasets x L1 regularization penalties using nested vmap for efficient parallel computation.
- Parameters:
batched_x (
Array) – Stacked quantile values, shape (n_datasets, n_samples).batched_y (
Array) – Stacked probability levels, shape (n_datasets, n_samples).params (
MetalogParameters) – MetalogParameters configuration. The method should be Lasso.l1_penalties (
Array) – Array of L1 penalty values to test, shape (n_penalties,).precomputed_quantiles (
bool) – Whether x values are precomputed quantiles.learning_rate (
float) – Learning rate for Lasso gradient descent. Default: 0.01.num_iters (
int) – Maximum iterations for Lasso. Default: 500.tol (
float) – Convergence tolerance for Lasso. Default: 1e-6.momentum (
float) – Momentum for Lasso optimizer. Default: 0.9.
- Returns:
metalog.a: Coefficients of shape (n_datasets, n_penalties, num_terms)
ks_dist: KS distances of shape (n_datasets, n_penalties)
- Return type:
GridResult with batched results
- metalog_jax.grid_search.fit_grid_num_terms(data, params, num_terms_list, regression_params=None)[source]¶
Fit a single dataset with different num_terms values.
This function performs a 1D grid search over term counts for a single dataset using JAX’s vmap for efficient parallel computation. Supports OLS and Lasso regression methods based on the params.method setting.
- Vectorization strategy:
num_terms dimension: Fully vectorized via vmap
Uses MAX_TERMS-sized arrays with masking for uniform shapes
- Parameters:
data (
MetalogBaseData) – MetalogBaseData instance containing the dataset to fit.params (
MetalogParameters) – MetalogParameters configuration. The num_terms field is ignored and replaced by values from num_terms_list.regression_params (
Any) – Optional regression hyperparameters (e.g., LassoParameters). If None, uses defaults for the method.
- Returns:
- metalog.a: Coefficients of shape (n_terms, max_terms)
ks_dist: KS distances of shape (n_terms,)
Where max_terms = max(num_terms_list) and coefficients for smaller term counts are zero-padded.
- Return type:
GridResult with batched results
- metalog_jax.grid_search.fit_grid_datasets_num_terms(batched_x, batched_y, params, num_terms_list, precomputed_quantiles=False, regression_params=None)[source]¶
Fit multiple datasets with different num_terms values.
This function performs a 2D grid search over datasets x term counts using nested vmap for efficient parallel computation. Supports OLS and Lasso regression methods based on the params.method setting.
- Vectorization strategy:
datasets dimension: Fully vectorized via vmap
num_terms dimension: Fully vectorized via vmap
Uses MAX_TERMS-sized arrays with masking for uniform shapes
- Parameters:
batched_x (
Array) – Stacked quantile values, shape (n_datasets, n_samples).batched_y (
Array) – Stacked probability levels, shape (n_datasets, n_samples).params (
MetalogParameters) – MetalogParameters configuration. The num_terms field is ignored and replaced by values from num_terms_list.precomputed_quantiles (
bool) – Whether x values are precomputed quantiles.regression_params (
Any) – Optional regression hyperparameters (e.g., LassoParameters). If None, uses defaults for the method.
- Returns:
- metalog.a: Coefficients of shape (n_datasets, n_terms, max_terms)
ks_dist: KS distances of shape (n_datasets, n_terms)
Where max_terms = max(num_terms_list) and coefficients for smaller term counts are zero-padded.
- Return type:
GridResult with batched results
- metalog_jax.grid_search.fit_grid_full(batched_x, batched_y, params, l1_penalties, num_terms_list, precomputed_quantiles=False, learning_rate=0.01, num_iters=500, tol=1e-06, momentum=0.9)[source]¶
Fit multiple datasets with grids of L1 penalties and term counts (Lasso).
This function performs a 3D grid search over datasets x L1 penalties x num_terms.
- Vectorization strategy:
All three dimensions (datasets, penalties, num_terms) are fully vectorized using nested jax.vmap calls
Uses MAX_TERMS-sized arrays with masking to enable vmap over num_terms
No Python for loops in the computation path
- The implementation uses masked arrays throughout to handle varying num_terms:
Design matrix is always MAX_TERMS columns with unused columns masked to 0
Coefficients are always MAX_TERMS with unused terms masked to 0
PPF computation uses active_mask to only use valid terms
- Parameters:
batched_x (
Array) – Stacked quantile values, shape (n_datasets, n_samples).batched_y (
Array) – Stacked probability levels, shape (n_datasets, n_samples).params (
MetalogParameters) – Base MetalogParameters configuration. The num_terms field is ignored and replaced by values from num_terms_list.l1_penalties (
Array) – Array of L1 penalty values to test, shape (n_penalties,).precomputed_quantiles (
bool) – Whether x values are precomputed quantiles.learning_rate (
float) – Learning rate for Lasso gradient descent. Default: 0.01.num_iters (
int) – Maximum iterations for Lasso. Default: 500.tol (
float) – Convergence tolerance for Lasso. Default: 1e-6.momentum (
float) – Momentum for Lasso optimizer. Default: 0.9.
- Returns:
- metalog.a: Coefficients of shape (n_datasets, n_penalties, n_terms, max_terms)
ks_dist: KS distances of shape (n_datasets, n_penalties, n_terms)
Where max_terms = max(num_terms_list) and coefficients for smaller term counts are zero-padded.
- Return type:
GridResult with batched results
- metalog_jax.grid_search.find_best_config(ks_dists)[source]¶
Find the best configuration (minimum KS distance) in a grid.
Works with grids of any dimensionality (1D, 2D, 3D, etc.).
- Parameters:
ks_dists (
Array) – Array of KS distances with shape (d1, d2, …, dn).- Returns:
best_indices: Indices of the minimum, shape (n,) for n dimensions
best_ks: Minimum KS distance value (scalar)
- Return type:
Tuple of (best_indices, best_ks) where
- metalog_jax.grid_search.extract_best_from_grid(grid_results, dataset_idx)[source]¶
Extract the best result from a grid for a specific dataset.
Finds the configuration with minimum KS distance for the given dataset and extracts the corresponding metalog and KS distance.
- Parameters:
grid_results (
GridResult) – GridResult from fit_grid_datasets_hyperparams or fit_grid_full. - For 2D grid: ks_dist shape (n_datasets, n_penalties) - For 3D grid: ks_dist shape (n_datasets, n_penalties, n_terms)dataset_idx (
int) – Index of the dataset to extract best result for.
- Returns:
metalog: Best metalog for this dataset
ks_dist: Corresponding KS distance (scalar)
- Return type:
GridResult with
- metalog_jax.grid_search.extract_metalog(grid_result, *indices)[source]¶
Extract a usable Metalog from grid results at given indices.
After grid search, metalog fields become JAX arrays due to vmap stacking. This function extracts a single metalog at the specified indices and converts all scalar fields back to Python types, making the metalog compatible with JIT-compiled methods like ppf, pdf, cdf, etc.
- Parameters:
grid_result (
GridResult) – GridResult from fit_grid or related functions.*indices (
int) – Integer indices to extract. The number of indices should match the dimensionality of the grid: - 1D grid (e.g., L1 penalties): single index like extract_metalog(result, 0) - 2D grid (e.g., L1 x num_terms): two indices like extract_metalog(result, 0, 1) - 3D grid (e.g., datasets x L1 x num_terms): three indices
- Return type:
- Returns:
Metalog instance with Python-typed fields, ready to use with ppf, pdf, cdf, rvs, and other distribution methods.
Examples
Extract best metalog from a 1D grid search over L1 penalties:
>>> result = fit_grid(data.x, data.y, params, l1_penalties=l1_vals) >>> best_idx, best_ks = find_best_config(result.ks_dist) >>> best_metalog = extract_metalog(result, int(best_idx)) >>> median = best_metalog.ppf(jnp.array([0.5]))
Extract best metalog from a 2D grid search:
>>> result = fit_grid(data.x, data.y, params, ... l1_penalties=l1_vals, num_terms=[5, 7, 9]) >>> best_idx, best_ks = find_best_config(result.ks_dist) >>> best_l1_idx, best_terms_idx = int(best_idx[0]), int(best_idx[1]) >>> best_metalog = extract_metalog(result, best_l1_idx, best_terms_idx)
See also
find_best_config: Find indices of best configuration in a grid. extract_best_from_grid: Extract best result for a specific dataset.
Regression¶
Base Classes¶
Base classes for regression models and parameters.
This module provides the abstract base classes for regression models and their parameters used in metalog distribution fitting.
- Classes:
- RegularizedParameters: Base class for regularized regression hyperparameters.
Subclasses: LassoParameters
- RegressionModel: Base class for trained regression model weights.
Subclasses: OLSModel, LassoModel
See also
metalog_jax.regression.ols: OLS regression implementation. metalog_jax.regression.lasso: LASSO regression implementation.
- class metalog_jax.regression.base.RegularizedParameters[source]¶
Bases:
objectBase class for regularized regression hyperparameters.
This abstract base class serves as the parent for all regularized regression parameter configurations in the metalog JAX library. It provides a common type hierarchy for regression methods that incorporate regularization penalties (L1, L2, or both) to prevent overfitting and improve generalization.
Regularization adds penalty terms to the regression objective function to constrain model complexity. This base class enables polymorphic handling of different regularization strategies while maintaining type safety and consistency across the library.
The class uses Flax’s struct.dataclass decorator to ensure immutability and compatibility with JAX transformations (jit, vmap, grad, etc.).
- Subclasses:
- LassoParameters: L1 regularization parameters (LASSO regression).
Defined in metalog_jax.regression.lasso. Penalizes the absolute magnitude of coefficients: λ||w||₁
- Design Pattern:
This class follows the Template Method pattern, providing a common interface while allowing subclasses to define specific regularization configurations. It enables functions to accept any regularized regression parameters through polymorphic type hints.
Example
Using the base class for polymorphic type hints:
>>> from typing import Union >>> from metalog_jax.regression.base import RegularizedParameters >>> from metalog_jax.regression.lasso import LassoParameters >>> >>> def validate_regularization(params: RegularizedParameters): ... '''Accept any regularized regression parameters.''' ... if isinstance(params, LassoParameters): ... print(f"Lasso with L1={params.lam}") >>> >>> lasso_params = LassoParameters( ... lam=1.0, ... learning_rate=0.01, ... num_iters=500, ... tol=1e-6, ... momentum=0.9 ... ) >>> validate_regularization(lasso_params) # Valid Lasso with L1=1.0
Creating subclass instances:
>>> # Lasso regression (L1 only) >>> from metalog_jax.regression.lasso import LassoParameters >>> lasso = LassoParameters( ... lam=0.5, ... learning_rate=0.01, ... num_iters=500, ... tol=1e-6, ... momentum=0.9 ... ) >>> isinstance(lasso, RegularizedParameters) True
Note
This is an abstract base class with no attributes or methods
Direct instantiation is possible but not meaningful (use subclasses)
All instances are immutable due to struct.dataclass decorator
Subclasses must define their own specific regularization parameters
See also
metalog_jax.regression.lasso.LassoParameters: LASSO (L1) regression parameters. metalog_jax.regression.lasso.fit_lasso: Function that uses LassoParameters.
References
Hastie, T., Tibshirani, R., & Friedman, J. (2009). The Elements of Statistical Learning: Data Mining, Inference, and Prediction (2nd ed.). Springer. Chapter 3: Linear Methods for Regression.
- __init__()¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
Ordinary Least Squares¶
Ordinary Least Squares (OLS) regression implementation.
This module provides OLS regression for fitting metalog distributions when no regularization is needed.
- Classes:
OLSModel: Structure for trained OLS model weights.
- Functions:
fit_ordinary_least_squares: Fit an OLS regression model. predict_ordinary_least_squares: Make predictions using a fitted OLS model.
This module is used when MetalogFitMethod.OLS is specified in MetalogParameters.
See also
metalog_jax.regression.lasso: LASSO (L1 regularization). metalog_jax.metalog.fit: High-level fitting function that dispatches to OLS.
- class metalog_jax.regression.ols.OLSModel(weights)[source]¶
Bases:
RegressionModelStructure for trained Ordinary Least Squares model weights.
- weights¶
Coefficient vector of shape (n_features,).
- bias¶
Intercept term (scalar).
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- metalog_jax.regression.ols.fit_ordinary_least_squares(X, y)[source]¶
Fit an Ordinary Least Squares (OLS) regression model.
Computes the closed-form solution to linear regression using the normal equations via least squares. This method finds the optimal weights and bias that minimize the sum of squared residuals with no regularization.
The model solves: min ||y - (Xw + b)||^2
- Parameters:
- Return type:
- Returns:
OLSModel containing the fitted weights and bias that minimize the squared error.
Example
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) >>> y = jnp.array([1.0, 2.0, 3.0]) >>> model = fit_ordinary_least_squares(X, y)
- metalog_jax.regression.ols.predict_ordinary_least_squares(X, model)[source]¶
Make predictions using a fitted Regression model.
Computes predictions for new data using the trained weights and bias: y_pred = X @ weights + bias
- Parameters:
- Return type:
- Returns:
Predictions of shape (n_samples,).
Example
>>> X_test = jnp.array([[2.0, 3.0], [4.0, 5.0]]) >>> predictions = predict_ordinary_least_squares(X_test, model)
LASSO¶
LASSO regression implementation with L1 regularization.
This module provides LASSO (Least Absolute Shrinkage and Selection Operator) regression for fitting metalog distributions with L1 regularization using proximal gradient descent with Nesterov acceleration.
- Classes:
LassoParameters: Hyperparameters for LASSO regression. LassoModel: Structure for trained LASSO model weights.
- Functions:
fit_lasso: Fit a LASSO regression model via proximal gradient descent. soft_thresholding: Apply the soft-thresholding operator (proximal for L1).
- Constants:
DEFAULT_LASSO_LAMBDA: Default L1 regularization strength (0). DEFAULT_LASSO_LEARNING_RATE: Default learning rate (0.01). DEFAULT_LASSO_ITERATIONS: Default maximum iterations (500). DEFAULT_LASSO_TOLERANCE: Default convergence tolerance (1e-6). DEFAULT_LASSO_MOMENTUM: Default Nesterov momentum factor (0.9). DEFAULT_LASSO_PARAMETERS: Default LassoParameters instance.
This module is used when MetalogFitMethod.Lasso is specified in MetalogParameters.
See also
metalog_jax.regression.ols: OLS regression (no regularization). metalog_jax.metalog.fit: High-level fitting function that dispatches to LASSO.
- class metalog_jax.regression.lasso.LassoParameters(lam, learning_rate, num_iters, tol, momentum)[source]¶
Bases:
RegularizedParametersStructure for Lasso parameters.
- lam¶
L1 regularization strength λ ≥ 0. Controls sparsity of the solution: - λ = 0: No regularization (equivalent to OLS, but solved iteratively) - λ > 0: Promotes sparsity; larger values → more coefficients set to zero - Typical values: 0.001 to 10.0, depending on data scale
- learning_rate¶
Learning rate (step size) for gradient descent. Default: 0.01. Controls the size of parameter updates. Should be tuned based on problem: - Too large: May cause divergence or oscillation - Too small: Slow convergence - Typical values: 0.001 to 0.1
- num_iters¶
Maximum number of iterations. Default: 500. Training will stop early if convergence is reached before this limit.
- tol¶
Convergence tolerance for weight changes. Default: 1e-6. Training stops when ||w_new - w|| < tol, indicating the solution has stabilized. Smaller values require more precise convergence but may take longer.
- momentum¶
Nesterov momentum factor, must satisfy 0 < momentum < 1. Default: 0.9. Controls acceleration of convergence: - Higher values (e.g., 0.9, 0.99): Faster convergence but less stable - Lower values (e.g., 0.5, 0.7): More stable but slower convergence - Standard choice: 0.9
- __init__(lam, learning_rate, num_iters, tol, momentum)¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- class metalog_jax.regression.lasso.LassoModel(weights)[source]¶
Bases:
RegressionModelStructure for trained LASSO Regression model weights.
- weights¶
Coefficient vector of shape (n_features,).
- bias¶
Intercept term (scalar).
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- metalog_jax.regression.lasso.soft_thresholding(x, lam)[source]¶
Apply the soft-thresholding operator (proximal operator for L1 norm).
The soft-thresholding operator is the proximal operator for the L1 norm and is fundamental to LASSO regression and other sparse optimization methods. It shrinks values toward zero and sets values below the threshold to exactly zero, promoting sparsity in the solution.
- The operator is defined element-wise as:
soft_threshold(x, λ) = sign(x) * max(|x| - λ, 0)
- Equivalently:
If x > λ: return x - λ
If x < -λ: return x + λ
If |x| ≤ λ: return 0
- This function is the proximal operator for the L1 penalty:
prox_{λ||·||₁}(x) = argmin_z { (1/2)||z - x||² + λ||z||₁ }
- Parameters:
x (
Union[Array,ndarray,bool,number,float,int]) – Input value(s) to threshold. Can be a scalar or array of any shape.lam (
float) – Threshold parameter λ ≥ 0. Controls the amount of shrinkage: - λ = 0: No shrinkage (returns x unchanged) - λ > 0: Shrinks values toward zero, setting small values to 0 - Larger λ produces sparser solutions with more zeros
- Return type:
- Returns:
Thresholded value(s) with the same shape as x. Values are shrunk toward zero by amount λ, with values |x| ≤ λ set to exactly zero.
Example
Scalar inputs:
>>> import jax.numpy as jnp >>> soft_thresholding(5.0, 2.0) Array(3.0, dtype=float32) >>> soft_thresholding(-3.0, 1.0) Array(-2.0, dtype=float32) >>> soft_thresholding(1.5, 2.0) # Below threshold -> zero Array(0.0, dtype=float32)
Array inputs:
>>> x = jnp.array([-5.0, -1.0, 0.5, 2.0, 4.0]) >>> soft_thresholding(x, 1.5) Array([-3.5, 0. , 0. , 0.5, 2.5], dtype=float32)
Note
This function is JIT-compiled for efficient execution
The operation is element-wise and preserves the input shape
Setting lam=0 returns the input unchanged
This is also known as the “shrinkage operator” or “soft thresholding”
See also
fit_lasso: LASSO regression that uses this operator in proximal gradient descent.
References
Parikh, N., & Boyd, S. (2014). Proximal Algorithms. Foundations and Trends in Optimization, 1(3), 127-239.
Beck, A., & Teboulle, M. (2009). A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse Problems. SIAM Journal on Imaging Sciences, 2(1), 183-202.
- metalog_jax.regression.lasso.fit_lasso(X, y, params=LassoParameters(lam=0.0, learning_rate=0.01, num_iters=500, tol=1e-06, momentum=0.9))[source]¶
Fit a LASSO regression model using proximal gradient descent with Nesterov acceleration.
Trains a linear regression model with L1 regularization (LASSO - Least Absolute Shrinkage and Selection Operator) using an accelerated proximal gradient method. LASSO promotes sparse solutions by shrinking coefficients toward zero and setting many coefficients to exactly zero, making it useful for feature selection and interpretable models.
This implementation uses: - Proximal gradient descent: Combines gradient steps on the smooth loss (MSE) with
the proximal operator (soft-thresholding) for the non-smooth L1 penalty
Nesterov momentum: Accelerates convergence using momentum-based lookahead steps
Early stopping: Terminates when weight changes fall below tolerance threshold
JAX compatibility: Fully vectorized and compatible with JAX transformations
- The LASSO objective function is:
min_w { (1/2n)||y - Xw||² + λ||w||₁ }
where λ is the L1 regularization strength that controls sparsity.
- Algorithm:
Initialize weights w and velocity v to zero
For each iteration (until convergence or max iterations): a. Compute lookahead weights: w_lookahead = w + momentum * v b. Compute gradient at lookahead position c. Update velocity with momentum and gradient: v = momentum * v - learning_rate * grad d. Apply proximal operator: w = soft_threshold(w + v, learning_rate * λ) e. Check convergence: stop if ||w_new - w|| < tol
Return final weights
- Parameters:
X (
Union[Array,ndarray,bool,number,float,int]) – Feature matrix of shape (n_samples, n_features). The design matrix containing the independent variables for regression.y (
Union[Array,ndarray,bool,number,float,int]) – Target vector of shape (n_samples,). The dependent variable values to predict.params (
LassoParameters) – LassoParameters containing optimization hyperparameters. Defaults to DEFAULT_LASSO_PARAMETERS. The object contains: - lam: L1 regularization strength λ ≥ 0. Default: 0. - learning_rate (float): Step size for gradient descent. Default: 0.01. - num_iters (int): Maximum number of iterations. Default: 500. - tol (float): Convergence tolerance for weight changes. Default: 1e-6. - momentum (float): Nesterov momentum factor (0 < momentum < 1). Default: 0.9.
- Return type:
- Returns:
LassoModel containing the fitted weights vector of shape (n_features,). Many coefficients may be exactly zero due to L1 regularization, achieving feature selection and model sparsity.
Example
Basic LASSO regression with default parameters:
>>> import jax.numpy as jnp >>> from metalog_jax.regression.lasso import fit_lasso >>> >>> # Create sample data >>> X = jnp.array([[1.0, 2.0, 0.5], ... [3.0, 4.0, 1.5], ... [5.0, 6.0, 2.5]]) >>> y = jnp.array([1.0, 2.0, 3.0]) >>> >>> # Fit LASSO with defaults >>> model = fit_lasso(X, y) >>> print(model.weights) # Some coefficients may be exactly zero
Note
Unlike Ridge regression, LASSO does not have a closed-form solution and requires iterative optimization
The algorithm may converge before num_iters if tol is reached
Feature scaling (standardization) is recommended before fitting LASSO
See also
soft_thresholding: The proximal operator used in each iteration. fit_ordinary_least_squares: No regularization (from regression.ols). LassoModel: Return type containing fitted weights. LassoParameters: Configuration dataclass for LASSO hyperparameters.
References
Tibshirani, R. (1996). Regression Shrinkage and Selection via the Lasso. Journal of the Royal Statistical Society: Series B (Methodological), 58(1), 267-288.
Beck, A., & Teboulle, M. (2009). A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse Problems. SIAM Journal on Imaging Sciences, 2(1), 183-202.