Getting Started¶
This guide will help you get started with metalog-jax.
Installation¶
Install metalog-jax using pip:
pip install metalog-jax
Or with uv:
uv pip install metalog-jax
Basic Usage¶
Fitting a Metalog Distribution¶
The simplest way to use metalog-jax is to fit a distribution to your data using
the fit function from metalog_jax.metalog:
import jax.numpy as jnp
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.metalog import fit
# Create validated input data using from_values()
raw_data = jnp.array([1.2, 2.3, 2.8, 3.5, 4.1, 5.6])
data = MetalogInputData.from_values(
x=raw_data,
y=jnp.array([0.1, 0.25, 0.5, 0.75, 0.9, 0.95]),
precomputed_quantiles=False # Raw samples, not precomputed quantiles
)
# Configure metalog parameters with OLS
metalog_params = MetalogParameters(
boundedness=MetalogBoundedness.UNBOUNDED,
method=MetalogFitMethod.OLS,
lower_bound=0.0,
upper_bound=0.0,
num_terms=5
)
# Fit the metalog distribution
m = fit(data, metalog_params)
# Evaluate the quantile function (inverse CDF)
q_75 = m.ppf(jnp.array([0.75]))
# Compute distribution properties
mean_val = m.mean
median_val = m.median
variance_val = m.var
std_val = m.std
SPT Metalog (3-Term Closed-Form)¶
For a simpler 3-term metalog with closed-form solutions, use the SPT (Symmetric Percentile Triplet) variant. SPT uses exactly three quantiles: alpha, 0.5 (median), and (1 - alpha):
import jax.numpy as jnp
from metalog_jax.base import MetalogBoundedness, SPTMetalogParameters
from metalog_jax.metalog import fit_spt_metalog
# Generate sample data
data = jnp.array([1.2, 2.3, 2.8, 3.5, 4.1, 5.6, 6.2, 7.8, 9.1])
# Configure SPT parameters (uses 10th, 50th, 90th percentiles with alpha=0.1)
params = SPTMetalogParameters(
boundedness=MetalogBoundedness.UNBOUNDED,
alpha=0.1,
lower_bound=0.0,
upper_bound=0.0,
)
spt = fit_spt_metalog(data, params)
# Evaluate quantile function
q_25 = spt.ppf(jnp.array([0.25]))
# Compute properties
mean_val = spt.mean
std_val = spt.std
Using Pre-computed Quantiles¶
If you already have quantile values (e.g., from expert elicitation), use
precomputed_quantiles=True:
import jax.numpy as jnp
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.metalog import fit
# Pre-computed quantiles (must be strictly ascending)
quantiles = jnp.array([5.0, 10.0, 20.0])
probability_levels = jnp.array([0.1, 0.5, 0.9])
# Create input data with precomputed quantiles
data = MetalogInputData.from_values(
x=quantiles,
y=probability_levels,
precomputed_quantiles=True
)
params = MetalogParameters(
boundedness=MetalogBoundedness.UNBOUNDED,
method=MetalogFitMethod.OLS,
lower_bound=0.0,
upper_bound=0.0,
num_terms=3
)
m = fit(data, params)
Bounded Distributions¶
For data with known bounds, use bounded metalog variants:
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.metalog import fit
# Bounded metalog for data in [0, 100]
params = MetalogParameters(
boundedness=MetalogBoundedness.BOUNDED,
method=MetalogFitMethod.OLS,
num_terms=5,
lower_bound=0.0,
upper_bound=100.0,
)
m = fit(data, params)
Boundedness options:
MetalogBoundedness.UNBOUNDED: Full real line support (-∞, ∞)MetalogBoundedness.STRICTLY_LOWER_BOUND: Support on (lower_bound, ∞)MetalogBoundedness.STRICTLY_UPPER_BOUND: Support on (-∞, upper_bound)MetalogBoundedness.BOUNDED: Support on (lower_bound, upper_bound)
Regression Methods¶
metalog-jax supports two regression methods for fitting, organized in the
metalog_jax.regression module:
OLS (
metalog_jax.regression.ols): No regularization, closed-form solutionLASSO (
metalog_jax.regression.lasso): L1 regularization for sparse solutions
Ordinary Least Squares (OLS)¶
Default method, no regularization:
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.metalog import fit
metalog_params = MetalogParameters(
boundedness=MetalogBoundedness.UNBOUNDED,
method=MetalogFitMethod.OLS,
lower_bound=0.0,
upper_bound=0.0,
num_terms=5
)
m = fit(data, metalog_params)
LASSO Regression¶
L1 regularization for sparse solutions using proximal gradient descent:
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.metalog import fit
from metalog_jax.regression.lasso import LassoParameters
# Configure metalog with Lasso method
metalog_params = MetalogParameters(
boundedness=MetalogBoundedness.UNBOUNDED,
method=MetalogFitMethod.Lasso,
lower_bound=0.0,
upper_bound=0.0,
num_terms=9
)
# Specify LASSO hyperparameters
lasso_params = LassoParameters(
lam=0.1, # L1 regularization strength
learning_rate=0.01,
num_iters=500,
tol=1e-6,
momentum=0.9
)
# Fit with LASSO regression
m = fit(data, metalog_params, regression_hyperparams=lasso_params)
Grid Search with JAX vmap¶
Use JAX’s vmap for efficient hyperparameter search:
import jax
import jax.numpy as jnp
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.metalog import fit, GridResult
from metalog_jax.regression.lasso import LassoParameters
from metalog_jax.utils import DEFAULT_Y, ks_distance
# Create data
data = MetalogInputData.from_values(rvs, DEFAULT_Y, False)
# Define fit function for a single L1 penalty
def fit_with_lasso_penalty(l1_penalty):
lasso_params = LassoParameters(lam=l1_penalty)
lasso_metalog_params = MetalogParameters(
boundedness=MetalogBoundedness.BOUNDED,
lower_bound=0,
upper_bound=1,
method=MetalogFitMethod.Lasso,
num_terms=11,
)
metalog = fit(data, lasso_metalog_params, lasso_params)
fitted_quantiles = metalog.ppf(data.y)
ks_dist = ks_distance(data.x, fitted_quantiles)
return GridResult(metalog=metalog, ks_dist=ks_dist)
# Grid search over L1 penalties
l1_penalties = jnp.array([0.0, 0.01, 0.1, 1.0])
vmapped_lasso_fit = jax.vmap(fit_with_lasso_penalty)
lasso_results = vmapped_lasso_fit(l1_penalties)
Goodness of Fit¶
Use the Kolmogorov-Smirnov distance to evaluate fit quality:
import jax.numpy as jnp
from metalog_jax.utils import ks_distance
# Two samples from the same distribution
x = jnp.array([0.1, 0.5, 0.3, 0.7, 0.9])
y = jnp.array([0.2, 0.4, 0.6, 0.8])
distance = ks_distance(x, y)
# Comparing different distributions
uniform_sample = jnp.array([0.1, 0.3, 0.5, 0.7, 0.9])
skewed_sample = jnp.array([0.05, 0.1, 0.15, 0.2, 0.25])
distance = ks_distance(uniform_sample, skewed_sample)
Serialization¶
Save and load fitted metalog distributions:
from pathlib import Path
from metalog_jax.metalog import Metalog
# Save to JSON file
m.save(Path("my_metalog.json"))
# Load from JSON file
loaded = Metalog.load(Path("my_metalog.json"))
# Or use string serialization
json_str = m.dumps()
loaded = Metalog.loads(json_str)
Distribution Methods¶
All metalog distributions (Metalog and SPTMetalog) provide these methods
and properties:
Methods:
ppf(y): Quantile function (percent point function / inverse CDF)cdf(q): Cumulative distribution functionpdf(x): Probability density functionlogpdf(x): Natural log of the PDFsf(x): Survival function (1 - CDF)logsf(x): Natural log of the survival functionisf(q): Inverse survival functionrvs(params): Random variable generationplot(option): Visualize the distribution (PDF, CDF, or SF)save(path): Save to JSON filedumps(): Serialize to JSON string
Properties:
mean: Expected valuemedian: Median (50th percentile)var: Variancestd: Standard deviationmode: Most likely valuenum_terms: Number of terms in the metalog expansionboundedness: Boundedness typelower_bound: Lower bound valueupper_bound: Upper bound value
R Compatibility Aliases:
q(x): Alias for quantile function (R metalog compatibility)p(q): Alias for CDF (R metalog compatibility)d(q): Alias for density function (R metalog compatibility)
Module Structure¶
The library is organized into the following modules:
metalog_jax.base: Core classes and enumerationsmetalog_jax.base.core: MetalogBase classmetalog_jax.base.data: MetalogInputData and MetalogBaseDatametalog_jax.base.enums: MetalogBoundedness, MetalogFitMethod, MetalogPlotOptionsmetalog_jax.base.parameters: MetalogParameters, SPTMetalogParameters
metalog_jax.metalog: Fitting functions and distribution classesfit(): Fit a standard metalog distributionfit_spt_metalog(): Fit an SPT metalog distributionMetalog: Standard metalog classSPTMetalog: SPT metalog classGridResult: Results container for grid search
metalog_jax.regression: Regression implementationsmetalog_jax.regression.base: RegressionModel, RegularizedParametersmetalog_jax.regression.ols: OLS regressionmetalog_jax.regression.lasso: LASSO regression
Next Steps¶
Explore the API Reference for detailed API documentation
Check out the examples in the repository