API Reference

This section contains the complete API reference for metalog-jax.

metalog_jax

Main Module

Metalog JAX library.

Metalog Distribution

Metalog distribution implementations and fitting functions.

This module provides the main metalog distribution classes and fitting functions:

Classes:

Metalog: Standard metalog distribution fitted via regression methods. SPTMetalog: Symmetric Percentile Triplet metalog with closed-form coefficients. GridResult: Results container for grid search and hyperparameter optimization.

Functions:

fit: Fit a standard metalog distribution using configurable regression methods. fit_spt_metalog: Fit an SPT metalog distribution using closed-form formulas.

The fit function uses a dispatch table pattern to route to the appropriate regression implementation based on the MetalogFitMethod specified in MetalogParameters:

  • MetalogFitMethod.OLS -> metalog_jax.regression.ols.fit_ordinary_least_squares

  • MetalogFitMethod.Lasso -> metalog_jax.regression.lasso.fit_lasso

See also

metalog_jax.base: Base classes and parameter configurations. metalog_jax.regression: Regression implementations for fitting.

class metalog_jax.metalog.SPTMetalog(metalog_params, a)[source]

Bases: MetalogBase

Symmetric Percentile Triplet (SPT) metalog distribution.

SPTMetalog is a specialized metalog implementation that uses exactly three terms fitted from three quantiles: the alpha-quantile, median (0.5), and (1-alpha)-quantile, where alpha < 0.5. This method provides a computationally efficient way to fit metalog distributions with closed-form coefficient solutions, avoiding the need for optimization or linear regression.

The SPT approach is particularly useful when you need: - A quick approximation using only three data points - Guaranteed feasibility through explicit feasibility checks - Analytical coefficient formulas for different boundedness types - A distribution that matches specific symmetric or asymmetric quantile triplets

Unlike the standard Metalog class which uses OLS or LASSO regression to fit arbitrary numbers of terms, SPTMetalog always has exactly 3 terms and computes coefficients directly from the quantile triplet. This makes it faster for small datasets but less flexible than the full metalog approach.

metalog_params

Configuration parameters specific to SPT fitting, including boundedness type, boundary values, and the alpha parameter that determines which quantiles are used (alpha, 0.5, 1-alpha).

a

Coefficient vector of shape (3,) representing the three fitted metalog distribution parameters. These are computed using closed-form formulas specific to the boundedness type.

Examples

Fit an unbounded SPT metalog from data:

>>> import jax.numpy as jnp
>>> from metalog_jax import fit_spt_metalog, SPTMetalogParameters
>>> from metalog_jax import MetalogBoundedness
>>>
>>> # Generate sample data
>>> data = jnp.array([1.2, 2.3, 2.8, 3.5, 4.1, 5.6, 6.2, 7.8, 9.1])
>>>
>>> # Configure SPT parameters
>>> params = SPTMetalogParameters(
...     boundedness=MetalogBoundedness.UNBOUNDED,
...     alpha=0.1,  # Use 10th, 50th, and 90th percentiles
...     lower_bound=0.0,
...     upper_bound=0.0,
... )
>>>
>>> # Fit the SPT metalog
>>> spt = fit_spt_metalog(data, params)
>>>
>>> # Evaluate quantile function at probability 0.25
>>> q_25 = spt.ppf(0.25)
>>>
>>> # Compute mean and standard deviation
>>> mean = spt.mean()
>>> std = spt.std()

Note

SPT metalog distributions may fail feasibility checks if the data does not satisfy specific constraints on the symmetry ratio r. The feasibility constraints depend on the boundedness type and alpha parameter. When feasibility fails, fit_spt_metalog will raise an AssertionError.

The SPT method is described in Keelin (2016) as a special case of the metalog family that enables rapid approximation with minimal data.

See also

Metalog: Standard metalog implementation using regression fitting. MetalogBase: Base class defining the common interface. fit_spt_metalog: Function to fit SPT metalog distributions from data.

References

Keelin, T. W. (2016). The Metalog Distributions. Decision Analysis, 13(4), 243-277. https://doi.org/10.1287/deca.2016.0338

Parameters:
metalog_params: SPTMetalogParameters
a: Array | ndarray | bool | number | float | int
__eq__(other)[source]

Check equality between two SPTMetalog distribution instances.

Delegates to the parent MetalogBase.__eq__ method.

Parameters:

other (object) – The object to compare against.

Returns:

True if the two instances are equal, False otherwise.

Return type:

bool

property num_terms: int

Get the number of terms in the metalog expansion.

Returns:

3 - an SPT Metalog always has three terms.

__init__(metalog_params, a)
Parameters:
replace(**updates)

Returns a new object replacing the specified fields with new values.

class metalog_jax.metalog.Metalog(metalog_params, a)[source]

Bases: MetalogBase

Standard metalog distribution fitted via regression.

Metalog is the full-featured metalog implementation that extends MetalogBase and uses ordinary least squares (OLS) or LASSO regression to fit metalog distributions with arbitrary numbers of terms (typically 2-30). This approach provides maximum flexibility and accuracy for modeling complex distributions from larger datasets.

Unlike SPTMetalog which uses exactly 3 terms with closed-form coefficient formulas, the standard Metalog class: - Supports any number of terms (configurable via num_terms parameter) - Uses regression-based fitting (OLS or LASSO methods) - Can model more complex distribution shapes with higher-order terms - Requires more data for reliable fitting (recommended minimum: 3 x num_terms) - May encounter numerical issues with very high term counts

This class inherits all distribution methods from MetalogBase including ppf (quantile function), cdf, pdf, and statistical moments (mean, variance, std, skewness, kurtosis). It also provides serialization capabilities for saving and loading fitted distributions.

metalog_params

Configuration parameters for the standard metalog fit, including boundedness type (UNBOUNDED, STRICTLY_LOWER_BOUND, STRICTLY_UPPER_BOUND, or BOUNDED), boundary values, fitting method (OLS or LASSO), and number of terms in the expansion.

a

Coefficient vector of shape (num_terms,) representing the fitted metalog distribution parameters. Each coefficient corresponds to one basis function in the metalog expansion. These are computed via regression from the input data.

Examples

Fit an unbounded metalog with 9 terms using OLS:

>>> import jax.numpy as jnp
>>> from metalog_jax import fit_metalog, MetalogParameters
>>> from metalog_jax import MetalogBoundedness, MetalogFitMethod
>>>
>>> # Generate sample data
>>> data = jnp.array([1.2, 2.3, 2.8, 3.5, 4.1, 5.6, 6.2, 7.8, 9.1, 10.5])
>>>
>>> # Configure metalog parameters
>>> params = MetalogParameters(
...     boundedness=MetalogBoundedness.UNBOUNDED,
...     method=MetalogFitMethod.OLS,
...     num_terms=9,
...     lower_bound=0.0,
...     upper_bound=0.0,
... )
>>>
>>> # Fit the metalog distribution
>>> m = fit_metalog(data, params)
>>>
>>> # Evaluate quantile function at probability 0.75
>>> q_75 = m.ppf(0.75)
>>>
>>> # Compute distribution moments
>>> mean = m.mean()
>>> variance = m.variance()
>>> skewness = m.skewness()

Note

The standard Metalog class is the recommended choice for most use cases involving sufficient data. Use SPTMetalog only when you need a quick three-term approximation or when you have exactly three representative quantiles. For datasets with at least 10-30 observations, the standard metalog typically provides superior fit quality.

Higher term counts generally improve accuracy but require more data and may cause overfitting or numerical instability. As a rule of thumb, ensure your sample size is at least 3 times the number of terms.

See also

SPTMetalog: Specialized three-term metalog using closed-form coefficients. MetalogBase: Base class defining the common distribution interface. fit_metalog: Function to fit standard metalog distributions from data.

References

Keelin, T. W. (2016). The Metalog Distributions. Decision Analysis, 13(4), 243-277. https://doi.org/10.1287/deca.2016.0338

Parameters:
metalog_params: MetalogParameters
a: Array | ndarray | bool | number | float | int
__eq__(other)[source]

Check equality between two Metalog distribution instances.

Delegates to the parent MetalogBase.__eq__ method.

Parameters:

other (object) – The object to compare against.

Returns:

True if the two instances are equal, False otherwise.

Return type:

bool

property num_terms: int

Get the number of terms in the metalog expansion.

Returns:

Integer representing the number of basis functions used in the metalog approximation. Higher values generally provide better accuracy but require more data.

property method: MetalogFitMethod

Get the method used to fit the metalog distribution.

Returns:

MetalogFitMethod enum of regression methods for fitting metalog distribution.

__init__(metalog_params, a)
Parameters:
replace(**updates)

Returns a new object replacing the specified fields with new values.

metalog_jax.metalog.fit(data, metalog_params, regression_hyperparams=None)[source]

Fit a metalog distribution to data using the specified configuration.

Estimates the metalog distribution coefficients by solving a linear regression problem that maps the basis functions (target matrix) to the transformed quantiles. The regression method specified in metalog_params determines the fitting approach, and optional regression_hyperparams allow fine-tuning of regularization settings.

This function processes a MetalogInputData instance containing validated input data (quantiles and probability levels) and fits a metalog distribution according to the specified parameters. The data must be created using MetalogInputData.from_values() which performs comprehensive validation.

Parameters:
  • data (MetalogInputData) – Validated input data created via MetalogInputData.from_values(). Contains: - x: Quantile values (precomputed or computed from raw samples) - y: Probability levels in (0, 1) - precomputed_quantiles: Flag indicating data type Direct instantiation of MetalogInputData is prevented by __post_init__ validation to ensure data integrity.

  • metalog_params (MetalogParameters) – Configuration parameters specifying the distribution characteristics (boundedness, bounds, number of terms) and the regression method to use for fitting. The method field determines which regression approach is used: - MetalogFitMethod.OLS: Ordinary Least Squares (no regularization) - MetalogFitMethod.Lasso: L1 regularization

  • regression_hyperparams (RegularizedParameters) – Optional regularization hyperparameters for controlling the fitting process when using LASSO regression method. This parameter is ignored when method=OLS. If None, default hyperparameters are used. Must be an instance of LassoParameters for MetalogFitMethod.Lasso.

Return type:

Metalog

Returns:

Metalog containing the fitted coefficient vector and the configuration parameters used to fit the metalog distribution. The returned object is validated via assert_fitted() to ensure the fit produces a feasible distribution with strictly positive PDF values.

Raises:
  • TypeError – If metalog_params.method is not a valid MetalogFitMethod instance.

  • checkify.JaxRuntimeError – If the fitted distribution produces non-positive PDF values, indicating an infeasible fit. This validation is performed by assert_fitted() before returning.

Example

Basic usage with OLS (no regularization):

>>> import jax.numpy as jnp
>>> from metalog_jax.base import MetalogInputData, MetalogParameters
>>> from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
>>> from metalog_jax.metalog import fit
>>>
>>> # Create validated input data using from_values()
>>> raw_data = jnp.array([1.2, 2.3, 2.8, 3.5, 4.1, 5.6])
>>> data = MetalogInputData.from_values(
...     x=raw_data,
...     y=jnp.array([0.1, 0.25, 0.5, 0.75, 0.9, 0.95]),
...     precomputed_quantiles=False  # Raw samples, not precomputed quantiles
... )
>>>
>>> # Configure metalog parameters with OLS
>>> metalog_params = MetalogParameters(
...     boundedness=MetalogBoundedness.UNBOUNDED,
...     method=MetalogFitMethod.OLS,
...     lower_bound=0.0,
...     upper_bound=0.0,
...     num_terms=5
... )
>>>
>>> # Fit the metalog distribution
>>> m = fit(data, metalog_params)

Note

  • The regression_hyperparams parameter is only applicable for LASSO method. It is ignored when using OLS.

  • If regression_hyperparams is None, sensible defaults are used for each method.

  • For production use, consider tuning regularization hyperparameters via cross-validation to prevent overfitting while maintaining good fit quality.

See also

fit_spt_metalog: Alternative fitting method using Symmetric Percentile Triplet. MetalogInputData.from_values: Required method for creating validated input data. metalog_jax.regression.lasso.LassoParameters: Hyperparameters for LASSO regression.

metalog_jax.metalog.fit_spt_metalog(array, spt_metalog_params)[source]

Fit a Symmetric Percentile Triplet (SPT) metalog distribution to data.

This function fits a 3-term metalog distribution using the Symmetric Percentile Triplet (SPT) parameterization method described in Keelin (2016). Unlike the standard metalog fit which uses many quantiles, the SPT method uses exactly three symmetric quantiles: alpha, 0.5 (median), and (1 - alpha), where 0 < alpha < 0.5.

The SPT approach provides a quick, analytical solution for fitting a metalog distribution when: - You have limited data or only three quantiles available - You want a simple, closed-form solution without regression - You need a fast approximation with 3 terms only

The method computes the 3 metalog coefficients directly from the three quantile values without requiring least squares regression. The formulas vary based on the boundedness type (unbounded, semi-bounded, or bounded).

Parameters:
  • array (Union[Array, ndarray, bool, number, float, int]) – Input data array with at least 3 elements from which to compute empirical quantiles. The array will be used to compute the alpha-th, 50th, and (1-alpha)-th percentiles.

  • spt_metalog_params (SPTMetalogParameters) –

    Configuration parameters containing: - alpha: Lower percentile parameter in (0, 0.5). Common values are

    0.1 (10-50-90 percentiles) or 0.25 (25-50-75 percentiles/IQR).

    • boundedness: Domain constraint type (UNBOUNDED, STRICTLY_LOWER_BOUND, or BOUNDED). STRICTLY_UPPER_BOUND is not supported in SPT formulation.

    • lower_bound: Lower boundary value (used for semi-bounded/bounded).

    • upper_bound: Upper boundary value (used for bounded distributions).

Returns:

Fitted SPT metalog distribution containing:
  • metalog_params: The input configuration parameters

  • a: Coefficient vector of length 3 containing the fitted metalog parameters [a1, a2, a3] that define the quantile function

Return type:

SPTMetalog

Raises:
  • AssertionError – If array has fewer than 3 elements, is not rank 1, or contains non-numeric values.

  • AssertionError – If alpha is not positive or alpha >= 0.5.

  • AssertionError – If the computed quantiles violate feasibility constraints (q_alpha < median < q_complement).

  • AssertionError – If boundedness-specific feasibility checks fail (e.g., quantiles outside valid range for the given bounds).

  • NotImplementedError – If boundedness is STRICTLY_UPPER_BOUND, which is undefined in the SPT formulation.

Example

>>> import jax.numpy as jnp
>>> from metalog_jax.metalog import fit_spt_metalog, SPTMetalogParameters
>>> from metalog_jax.metalog import MetalogBoundedness
>>> # Generate sample data
>>> data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
>>> # Configure SPT parameters using 10-50-90 percentiles
>>> params = SPTMetalogParameters(
...     alpha=0.1,
...     boundedness=MetalogBoundedness.UNBOUNDED,
...     lower_bound=0.0,
...     upper_bound=1.0
... )
>>> # Fit the SPT metalog
>>> spt_metalog = fit_spt_metalog(data, params)
>>> # The result contains 3 coefficients
>>> assert len(spt_metalog.a) == 3

Note

  • The SPT method always produces exactly 3 metalog terms, regardless of data size. For more flexible fits with more terms, use the standard fit() function.

  • STRICTLY_UPPER_BOUND boundedness is not supported by the SPT formulation as defined in Keelin’s paper.

  • The method performs multiple feasibility checks to ensure the quantiles and parameters produce a valid probability distribution.

  • For unbounded distributions, the median position ratio r must satisfy specific constraints related to k_alpha to ensure feasibility, where r = (median - q_alpha) / (q_complement - q_alpha).

References

Keelin, T. W. (2016). The Metalog Distributions. Decision Analysis, 13(4).

class metalog_jax.metalog.GridResult(metalog, ks_dist)[source]

Bases: object

Results from metalog grid search or hyperparameter optimization.

GridResult encapsulates the outcome of fitting a metalog distribution along with a goodness-of-fit metric. This dataclass is typically used when performing grid searches over hyperparameters (e.g., number of terms, regularization penalties) or when comparing multiple metalog configurations to select the best-fitting distribution.

The primary use case is hyperparameter tuning via grid search with JAX’s vmap to fit multiple metalog configurations in parallel, then selecting the best configuration based on the Kolmogorov-Smirnov distance.

The class uses Flax’s struct.dataclass decorator to ensure immutability and compatibility with JAX transformations (jit, vmap, grad), making it suitable for use in vectorized grid search operations.

metalog

The fitted metalog distribution. Can be either a standard Metalog (fitted via regression with configurable terms) or SPTMetalog (fitted using Symmetric Percentile Triplet with exactly 3 terms). Contains: - metalog_params: Configuration parameters (boundedness, method, etc.) - a: Coefficient vector representing the fitted distribution

ks_dist

Scalar Kolmogorov-Smirnov distance measuring the goodness-of-fit between the fitted metalog distribution and the input random variable. This is the maximum absolute difference between the empirical CDF of the input data and the fitted metalog’s CDF: - 0: Perfect fit (empirical CDFs are identical) - 1: Worst possible fit (completely non-overlapping distributions) - Lower values indicate better fit quality Used to select the best metalog configuration from a grid search.

Note

  • This class is immutable due to the struct.dataclass decorator

  • The ks_dist attribute provides a single goodness-of-fit metric for comparing different metalog configurations

  • When used with jax.vmap, the metalog and ks_dist fields will be batched arrays/PyTrees, allowing vectorized comparison

  • Lower KS distances indicate better fit quality, but extremely low values may indicate overfitting (especially with high term counts)

  • For production use, consider cross-validation instead of or in addition to KS distance for hyperparameter selection

See also

ks_distance: Function to compute the Kolmogorov-Smirnov distance. fit: Standard metalog fitting function. fit_spt_metalog: SPT metalog fitting function. Metalog: Standard metalog distribution class. SPTMetalog: Symmetric Percentile Triplet metalog class.

References

Kolmogorov, A. N. (1933). “Sulla determinazione empirica di una legge di distribuzione”. Giornale dell’Istituto Italiano degli Attuari, 4: 83-91.

Keelin, T. W. (2016). The Metalog Distributions. Decision Analysis, 13(4), 243-277. https://doi.org/10.1287/deca.2016.0338

Parameters:
metalog: Metalog | SPTMetalog
ks_dist: float | int
__init__(metalog, ks_dist)
Parameters:
replace(**updates)

Returns a new object replacing the specified fields with new values.

Regression

Base Classes

Base classes for regression models and parameters.

This module provides the abstract base classes for regression models and their parameters used in metalog distribution fitting.

Classes:
RegularizedParameters: Base class for regularized regression hyperparameters.

Subclasses: LassoParameters

RegressionModel: Base class for trained regression model weights.

Subclasses: OLSModel, LassoModel

See also

metalog_jax.regression.ols: OLS regression implementation. metalog_jax.regression.lasso: LASSO regression implementation.

class metalog_jax.regression.base.RegularizedParameters[source]

Bases: object

Base class for regularized regression hyperparameters.

This abstract base class serves as the parent for all regularized regression parameter configurations in the metalog JAX library. It provides a common type hierarchy for regression methods that incorporate regularization penalties (L1, L2, or both) to prevent overfitting and improve generalization.

Regularization adds penalty terms to the regression objective function to constrain model complexity. This base class enables polymorphic handling of different regularization strategies while maintaining type safety and consistency across the library.

The class uses Flax’s struct.dataclass decorator to ensure immutability and compatibility with JAX transformations (jit, vmap, grad, etc.).

Subclasses:
LassoParameters: L1 regularization parameters (LASSO regression).

Defined in metalog_jax.regression.lasso. Penalizes the absolute magnitude of coefficients: λ||w||₁

Design Pattern:

This class follows the Template Method pattern, providing a common interface while allowing subclasses to define specific regularization configurations. It enables functions to accept any regularized regression parameters through polymorphic type hints.

Example

Using the base class for polymorphic type hints:

>>> from typing import Union
>>> from metalog_jax.regression.base import RegularizedParameters
>>> from metalog_jax.regression.lasso import LassoParameters
>>>
>>> def validate_regularization(params: RegularizedParameters):
...     '''Accept any regularized regression parameters.'''
...     if isinstance(params, LassoParameters):
...         print(f"Lasso with L1={params.lam}")
>>>
>>> lasso_params = LassoParameters(
...     lam=1.0,
...     learning_rate=0.01,
...     num_iters=500,
...     tol=1e-6,
...     momentum=0.9
... )
>>> validate_regularization(lasso_params)  # Valid
Lasso with L1=1.0

Creating subclass instances:

>>> # Lasso regression (L1 only)
>>> from metalog_jax.regression.lasso import LassoParameters
>>> lasso = LassoParameters(
...     lam=0.5,
...     learning_rate=0.01,
...     num_iters=500,
...     tol=1e-6,
...     momentum=0.9
... )
>>> isinstance(lasso, RegularizedParameters)
True

Note

  • This is an abstract base class with no attributes or methods

  • Direct instantiation is possible but not meaningful (use subclasses)

  • All instances are immutable due to struct.dataclass decorator

  • Subclasses must define their own specific regularization parameters

See also

metalog_jax.regression.lasso.LassoParameters: LASSO (L1) regression parameters. metalog_jax.regression.lasso.fit_lasso: Function that uses LassoParameters.

References

Hastie, T., Tibshirani, R., & Friedman, J. (2009). The Elements of Statistical Learning: Data Mining, Inference, and Prediction (2nd ed.). Springer. Chapter 3: Linear Methods for Regression.

__init__()
replace(**updates)

Returns a new object replacing the specified fields with new values.

class metalog_jax.regression.base.RegressionModel(weights)[source]

Bases: object

Structure for trained Regression model weights.

weights

Coefficient vector of shape (n_features,).

Parameters:

weights (Union[Array, ndarray, bool, number])

weights: Array | ndarray | bool | number
__init__(weights)
Parameters:

weights (Union[Array, ndarray, bool, number])

replace(**updates)

Returns a new object replacing the specified fields with new values.

Ordinary Least Squares

Ordinary Least Squares (OLS) regression implementation.

This module provides OLS regression for fitting metalog distributions when no regularization is needed.

Classes:

OLSModel: Structure for trained OLS model weights.

Functions:

fit_ordinary_least_squares: Fit an OLS regression model. predict_ordinary_least_squares: Make predictions using a fitted OLS model.

This module is used when MetalogFitMethod.OLS is specified in MetalogParameters.

See also

metalog_jax.regression.lasso: LASSO (L1 regularization). metalog_jax.metalog.fit: High-level fitting function that dispatches to OLS.

class metalog_jax.regression.ols.OLSModel(weights)[source]

Bases: RegressionModel

Structure for trained Ordinary Least Squares model weights.

weights

Coefficient vector of shape (n_features,).

bias

Intercept term (scalar).

Parameters:

weights (Union[Array, ndarray, bool, number])

__init__(weights)
Parameters:

weights (Union[Array, ndarray, bool, number])

replace(**updates)

Returns a new object replacing the specified fields with new values.

metalog_jax.regression.ols.fit_ordinary_least_squares(X, y)[source]

Fit an Ordinary Least Squares (OLS) regression model.

Computes the closed-form solution to linear regression using the normal equations via least squares. This method finds the optimal weights and bias that minimize the sum of squared residuals with no regularization.

The model solves: min ||y - (Xw + b)||^2

Parameters:
Return type:

OLSModel

Returns:

OLSModel containing the fitted weights and bias that minimize the squared error.

Example

>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> y = jnp.array([1.0, 2.0, 3.0])
>>> model = fit_ordinary_least_squares(X, y)
metalog_jax.regression.ols.predict_ordinary_least_squares(X, model)[source]

Make predictions using a fitted Regression model.

Computes predictions for new data using the trained weights and bias: y_pred = X @ weights + bias

Parameters:
  • X (Union[Array, ndarray, bool, number]) – Feature matrix of shape (n_samples, n_features).

  • model (OLSModel) – OLSModel containing fitted weights and bias.

Return type:

Union[Array, ndarray, bool, number]

Returns:

Predictions of shape (n_samples,).

Example

>>> X_test = jnp.array([[2.0, 3.0], [4.0, 5.0]])
>>> predictions = predict_ordinary_least_squares(X_test, model)

LASSO

LASSO regression implementation with L1 regularization.

This module provides LASSO (Least Absolute Shrinkage and Selection Operator) regression for fitting metalog distributions with L1 regularization using proximal gradient descent with Nesterov acceleration.

Classes:

LassoParameters: Hyperparameters for LASSO regression. LassoModel: Structure for trained LASSO model weights.

Functions:

fit_lasso: Fit a LASSO regression model via proximal gradient descent. soft_thresholding: Apply the soft-thresholding operator (proximal for L1).

Constants:

DEFAULT_LASSO_LAMBDA: Default L1 regularization strength (0). DEFAULT_LASSO_LEARNING_RATE: Default learning rate (0.01). DEFAULT_LASSO_ITERATIONS: Default maximum iterations (500). DEFAULT_LASSO_TOLERANCE: Default convergence tolerance (1e-6). DEFAULT_LASSO_MOMENTUM: Default Nesterov momentum factor (0.9). DEFAULT_LASSO_PARAMETERS: Default LassoParameters instance.

This module is used when MetalogFitMethod.Lasso is specified in MetalogParameters.

See also

metalog_jax.regression.ols: OLS regression (no regularization). metalog_jax.metalog.fit: High-level fitting function that dispatches to LASSO.

class metalog_jax.regression.lasso.LassoParameters(lam, learning_rate, num_iters, tol, momentum)[source]

Bases: RegularizedParameters

Structure for Lasso parameters.

lam

L1 regularization strength λ ≥ 0. Controls sparsity of the solution: - λ = 0: No regularization (equivalent to OLS, but solved iteratively) - λ > 0: Promotes sparsity; larger values → more coefficients set to zero - Typical values: 0.001 to 10.0, depending on data scale

learning_rate

Learning rate (step size) for gradient descent. Default: 0.01. Controls the size of parameter updates. Should be tuned based on problem: - Too large: May cause divergence or oscillation - Too small: Slow convergence - Typical values: 0.001 to 0.1

num_iters

Maximum number of iterations. Default: 500. Training will stop early if convergence is reached before this limit.

tol

Convergence tolerance for weight changes. Default: 1e-6. Training stops when ||w_new - w|| < tol, indicating the solution has stabilized. Smaller values require more precise convergence but may take longer.

momentum

Nesterov momentum factor, must satisfy 0 < momentum < 1. Default: 0.9. Controls acceleration of convergence: - Higher values (e.g., 0.9, 0.99): Faster convergence but less stable - Lower values (e.g., 0.5, 0.7): More stable but slower convergence - Standard choice: 0.9

Parameters:
lam: float
learning_rate: float
num_iters: int
tol: float
momentum: float
__init__(lam, learning_rate, num_iters, tol, momentum)
Parameters:
replace(**updates)

Returns a new object replacing the specified fields with new values.

class metalog_jax.regression.lasso.LassoModel(weights)[source]

Bases: RegressionModel

Structure for trained LASSO Regression model weights.

weights

Coefficient vector of shape (n_features,).

bias

Intercept term (scalar).

Parameters:

weights (Union[Array, ndarray, bool, number])

__init__(weights)
Parameters:

weights (Union[Array, ndarray, bool, number])

replace(**updates)

Returns a new object replacing the specified fields with new values.

metalog_jax.regression.lasso.soft_thresholding(x, lam)[source]

Apply the soft-thresholding operator (proximal operator for L1 norm).

The soft-thresholding operator is the proximal operator for the L1 norm and is fundamental to LASSO regression and other sparse optimization methods. It shrinks values toward zero and sets values below the threshold to exactly zero, promoting sparsity in the solution.

The operator is defined element-wise as:

soft_threshold(x, λ) = sign(x) * max(|x| - λ, 0)

Equivalently:
  • If x > λ: return x - λ

  • If x < -λ: return x + λ

  • If |x| ≤ λ: return 0

This function is the proximal operator for the L1 penalty:

prox_{λ||·||₁}(x) = argmin_z { (1/2)||z - x||² + λ||z||₁ }

Parameters:
  • x (Union[Array, ndarray, bool, number, float, int]) – Input value(s) to threshold. Can be a scalar or array of any shape.

  • lam (float) – Threshold parameter λ ≥ 0. Controls the amount of shrinkage: - λ = 0: No shrinkage (returns x unchanged) - λ > 0: Shrinks values toward zero, setting small values to 0 - Larger λ produces sparser solutions with more zeros

Return type:

Union[Array, ndarray, bool, number, float, int]

Returns:

Thresholded value(s) with the same shape as x. Values are shrunk toward zero by amount λ, with values |x| ≤ λ set to exactly zero.

Example

Scalar inputs:

>>> import jax.numpy as jnp
>>> soft_thresholding(5.0, 2.0)
Array(3.0, dtype=float32)
>>> soft_thresholding(-3.0, 1.0)
Array(-2.0, dtype=float32)
>>> soft_thresholding(1.5, 2.0)  # Below threshold -> zero
Array(0.0, dtype=float32)

Array inputs:

>>> x = jnp.array([-5.0, -1.0, 0.5, 2.0, 4.0])
>>> soft_thresholding(x, 1.5)
Array([-3.5,  0. ,  0. ,  0.5,  2.5], dtype=float32)

Note

  • This function is JIT-compiled for efficient execution

  • The operation is element-wise and preserves the input shape

  • Setting lam=0 returns the input unchanged

  • This is also known as the “shrinkage operator” or “soft thresholding”

See also

fit_lasso: LASSO regression that uses this operator in proximal gradient descent.

References

Parikh, N., & Boyd, S. (2014). Proximal Algorithms. Foundations and Trends in Optimization, 1(3), 127-239.

Beck, A., & Teboulle, M. (2009). A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse Problems. SIAM Journal on Imaging Sciences, 2(1), 183-202.

metalog_jax.regression.lasso.fit_lasso(X, y, params=LassoParameters(lam=0.0, learning_rate=0.01, num_iters=500, tol=1e-06, momentum=0.9))[source]

Fit a LASSO regression model using proximal gradient descent with Nesterov acceleration.

Trains a linear regression model with L1 regularization (LASSO - Least Absolute Shrinkage and Selection Operator) using an accelerated proximal gradient method. LASSO promotes sparse solutions by shrinking coefficients toward zero and setting many coefficients to exactly zero, making it useful for feature selection and interpretable models.

This implementation uses: - Proximal gradient descent: Combines gradient steps on the smooth loss (MSE) with

the proximal operator (soft-thresholding) for the non-smooth L1 penalty

  • Nesterov momentum: Accelerates convergence using momentum-based lookahead steps

  • Early stopping: Terminates when weight changes fall below tolerance threshold

  • JAX compatibility: Fully vectorized and compatible with JAX transformations

The LASSO objective function is:

min_w { (1/2n)||y - Xw||² + λ||w||₁ }

where λ is the L1 regularization strength that controls sparsity.

Algorithm:
  1. Initialize weights w and velocity v to zero

  2. For each iteration (until convergence or max iterations): a. Compute lookahead weights: w_lookahead = w + momentum * v b. Compute gradient at lookahead position c. Update velocity with momentum and gradient: v = momentum * v - learning_rate * grad d. Apply proximal operator: w = soft_threshold(w + v, learning_rate * λ) e. Check convergence: stop if ||w_new - w|| < tol

  3. Return final weights

Parameters:
  • X (Union[Array, ndarray, bool, number, float, int]) – Feature matrix of shape (n_samples, n_features). The design matrix containing the independent variables for regression.

  • y (Union[Array, ndarray, bool, number, float, int]) – Target vector of shape (n_samples,). The dependent variable values to predict.

  • params (LassoParameters) – LassoParameters containing optimization hyperparameters. Defaults to DEFAULT_LASSO_PARAMETERS. The object contains: - lam: L1 regularization strength λ ≥ 0. Default: 0. - learning_rate (float): Step size for gradient descent. Default: 0.01. - num_iters (int): Maximum number of iterations. Default: 500. - tol (float): Convergence tolerance for weight changes. Default: 1e-6. - momentum (float): Nesterov momentum factor (0 < momentum < 1). Default: 0.9.

Return type:

LassoModel

Returns:

LassoModel containing the fitted weights vector of shape (n_features,). Many coefficients may be exactly zero due to L1 regularization, achieving feature selection and model sparsity.

Example

Basic LASSO regression with default parameters:

>>> import jax.numpy as jnp
>>> from metalog_jax.regression.lasso import fit_lasso
>>>
>>> # Create sample data
>>> X = jnp.array([[1.0, 2.0, 0.5],
...                [3.0, 4.0, 1.5],
...                [5.0, 6.0, 2.5]])
>>> y = jnp.array([1.0, 2.0, 3.0])
>>>
>>> # Fit LASSO with defaults
>>> model = fit_lasso(X, y)
>>> print(model.weights)  # Some coefficients may be exactly zero

Note

  • Unlike Ridge regression, LASSO does not have a closed-form solution and requires iterative optimization

  • The algorithm may converge before num_iters if tol is reached

  • Feature scaling (standardization) is recommended before fitting LASSO

See also

soft_thresholding: The proximal operator used in each iteration. fit_ordinary_least_squares: No regularization (from regression.ols). LassoModel: Return type containing fitted weights. LassoParameters: Configuration dataclass for LASSO hyperparameters.

References

Tibshirani, R. (1996). Regression Shrinkage and Selection via the Lasso. Journal of the Royal Statistical Society: Series B (Methodological), 58(1), 267-288.

Beck, A., & Teboulle, M. (2009). A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse Problems. SIAM Journal on Imaging Sciences, 2(1), 183-202.