Skip to content

FEAT - Implement SmoothQuantileRegression #312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
21f1459
first try at simple quantile huber
floriankozikowski May 23, 2025
010c399
make basic version without intercept handling and progressive smoothing
floriankozikowski May 26, 2025
e4888d9
add progressive smoothing
floriankozikowski May 26, 2025
be991c1
din't inherit from Huber
mathurinm May 28, 2025
575ffbb
implemented lipschitz for dense case, support for AndersonCD and adre…
floriankozikowski May 30, 2025
938d842
resolve merge conflict
floriankozikowski May 30, 2025
39ced0d
add intercept method (only works for AndersonCD so far)
floriankozikowski May 30, 2025
4375b4c
script debug quantile
mathurinm Jun 13, 2025
3c9c320
set inner solver to verbose for debug + pinpoint failing case
mathurinm Jun 13, 2025
693ca06
check fit_intercept
mathurinm Jun 13, 2025
303b167
remove FISTA, implement comments, add unit test, add lipschitz on int…
floriankozikowski Jun 13, 2025
f9e2e79
Merge branch 'quantilehuber' of https://github.com/floriankozikowski/…
floriankozikowski Jun 13, 2025
2c84a44
remove solver selection from plotting example, fixes CircleCI
floriankozikowski Jun 13, 2025
7d7106c
add fit_intercept=False to pytest, but failing
floriankozikowski Jun 16, 2025
34b083d
parametrize intercept works now in unit test, still warnings though
floriankozikowski Jun 16, 2025
ba3d6b8
adress remaining comments, loosen tolerance, warnings are still there
floriankozikowski Jun 16, 2025
f43c190
suppress warnings, loosen tolerance, edit whats new
floriankozikowski Jun 16, 2025
cb0c3e4
adress final comments (api, remove debug, etc.), improve example for …
floriankozikowski Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,7 @@ Experimental
IterativeReweightedL1
PDCD_WS
Pinball
QuantileHuber
SmoothQuantileRegressor
SqrtQuadratic
SqrtLasso
1 change: 1 addition & 0 deletions doc/changes/0.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
Version 0.5 (in progress)
-------------------------
- Add support for fitting an intercept in :ref:`SqrtLasso <skglm.experimental.sqrt_lasso.SqrtLasso>` (PR: :gh:`298`)
- Add experimental :ref:`QuantileHuber <skglm.experimental.quantile_huber.QuantileHuber>` and :ref:`SmoothQuantileRegressor <skglm.experimental.quantile_huber.SmoothQuantileRegressor>` for quantile regression, and an example script (PR: :gh:`312`).
132 changes: 132 additions & 0 deletions examples/plot_smooth_quantile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

look at some other plotting file format

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • write some short text to explain what each 3 or 4 blocks of code does

================================================================================
Smooth Quantile Regression with QuantileHuber
================================================================================

This example compares sklearn's standard quantile regression with skglm's smooth
approximation. Skglm's quantile regression uses a smooth Huber-like approximation
(quadratic near zero, linear in the tails) to replace the non-differentiable
pinball loss. Progressive smoothing enables efficient gradient-based optimization,
maintaining speed and accuracy also on large-scale, high-dimensional datasets.
"""

# Author: Florian Kozikowski
import numpy as np
import time
import matplotlib.pyplot as plt

from sklearn.datasets import make_regression
from sklearn.linear_model import QuantileRegressor
from skglm.experimental.quantile_huber import QuantileHuber, SmoothQuantileRegressor

# Generate regression data
X, y = make_regression(n_samples=1000, n_features=10, noise=0.1, random_state=0)
tau = 0.8 # 80th percentile

# %%
# Compare standard vs smooth quantile regression
# ----------------------------------------------
# Both methods solve the same problem but with different loss functions.

# Standard quantile regression (sklearn)
start = time.time()
sk_model = QuantileRegressor(quantile=tau, alpha=0.1)
sk_model.fit(X, y)
sk_time = time.time() - start

# Smooth quantile regression (skglm)
start = time.time()
smooth_model = SmoothQuantileRegressor(
quantile=tau,
alpha=0.1,
delta_init=0.5, # Initial smoothing parameter
delta_final=0.01, # Final smoothing (smaller = closer to true quantile)
n_deltas=5 # Number of continuation steps
)
smooth_model.fit(X, y)
smooth_time = time.time() - start

# %%
# Evaluate both methods
# ---------------------
# Coverage: fraction of true values below predictions (should ≈ tau)
# Pinball loss: standard quantile regression evaluation metric
#
# Note: No robust benchmarking conducted yet. The speed advantagous likely only
# shows on large-scale, high-dimensional datasets. The sklearn implementation is
# likely faster on small datasets.


def pinball_loss(residuals, quantile):
return np.mean(residuals * (quantile - (residuals < 0)))


sk_pred = sk_model.predict(X)
smooth_pred = smooth_model.predict(X)

print(f"{'Method':<15} {'Coverage':<10} {'Time (s)':<10} {'Pinball Loss':<12}")
print("-" * 50)
print(f"{'Sklearn':<15} {np.mean(y <= sk_pred):<10.3f} {sk_time:<10.3f} "
f"{pinball_loss(y - sk_pred, tau):<12.4f}")
print(f"{'SmoothQuantile':<15} {np.mean(y <= smooth_pred):<10.3f} {smooth_time:<10.3f} "
f"{pinball_loss(y - smooth_pred, tau):<12.4f}")

# %%
# Visualize the smooth approximation
# ----------------------------------
# The smooth loss approximates the pinball loss but with continuous gradients

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

# Show loss and gradient for different quantile levels
residuals = np.linspace(-3, 3, 500)
delta = 0.5
quantiles = [0.1, 0.5, 0.9]

for tau_val in quantiles:
qh = QuantileHuber(quantile=tau_val, delta=delta)
loss = [qh._loss_sample(r) for r in residuals]
grad = [qh._grad_per_sample(r) for r in residuals]

# Compute pinball loss for each residual
pinball_loss = [r * (tau_val - (r < 0)) for r in residuals]

# Plot smooth loss and pinball loss
ax1.plot(residuals, loss, label=f"τ={tau_val}", linewidth=2)
ax1.plot(residuals, pinball_loss, '--', alpha=0.4, color='gray',
label=f"Pinball τ={tau_val}")
ax2.plot(residuals, grad, label=f"τ={tau_val}", linewidth=2)

# Add vertical lines and shading showing delta boundaries
for ax in [ax1, ax2]:
ax.axvline(-delta, color='gray', linestyle='--', alpha=0.7, linewidth=1.5)
ax.axvline(delta, color='gray', linestyle='--', alpha=0.7, linewidth=1.5)
# Add shading for quadratic region
ax.axvspan(-delta, delta, alpha=0.15, color='gray')

# Add delta labels
ax1.text(-delta, 0.1, '−δ', ha='right', va='bottom', color='gray', fontsize=10)
ax1.text(delta, 0.1, '+δ', ha='left', va='bottom', color='gray', fontsize=10)

ax1.set_title(f"Smooth Quantile Loss (δ={delta})", fontsize=12)
ax1.set_xlabel("Residual")
ax1.set_ylabel("Loss")
ax1.legend(loc='upper left')
ax1.grid(True, alpha=0.3)

ax2.set_title("Gradient (continuous everywhere)", fontsize=12)
ax2.set_xlabel("Residual")
ax2.set_ylabel("Gradient")
ax2.legend(loc='upper left')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# %% [markdown]
# The left plot shows the asymmetric loss: tau=0.1 penalizes overestimation more,
# while tau=0.9 penalizes underestimation. As delta decreases towards zero, the
# loss function approaches the standard pinball loss.
# The right plot reveals the key advantage: gradients transition smoothly through
# zero, unlike standard quantile regression which has a kink. This smoothing
# enables fast convergence with gradient-based solvers.
3 changes: 3 additions & 0 deletions skglm/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
from .sqrt_lasso import SqrtLasso, SqrtQuadratic
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • edit api.rst

from .pdcd_ws import PDCD_WS
from .quantile_regression import Pinball
from .quantile_huber import QuantileHuber, SmoothQuantileRegressor

__all__ = [
IterativeReweightedL1,
PDCD_WS,
Pinball,
QuantileHuber,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use alphabetical order

SmoothQuantileRegressor,
SqrtQuadratic,
SqrtLasso,
]
205 changes: 205 additions & 0 deletions skglm/experimental/quantile_huber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import numpy as np
from numpy.linalg import norm
from numba import float64

from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.exceptions import NotFittedError

from skglm.datafits.base import BaseDatafit
from skglm.solvers import AndersonCD
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put all skglm imports together in their own block

from skglm.penalties import L1
from skglm.estimators import GeneralizedLinearEstimator


class QuantileHuber(BaseDatafit):
r"""Quantile Huber loss for quantile regression.

Implements the smoothed pinball loss:

.. math::

\rho_\tau^\delta(r) =
\begin{cases}
\tau\, r - \dfrac{\delta}{2}, & \text{if } r \ge \delta,\\
\dfrac{\tau r^{2}}{2\delta}, & \text{if } 0 \le r < \delta,\\
\dfrac{(1-\tau) r^{2}}{2\delta}, & \text{if } -\delta < r < 0,\\
(\tau - 1)\, r - \dfrac{\delta}{2}, & \text{if } r \le -\delta.
\end{cases}

Parameters
----------
quantile : float, default=0.5
Desired quantile level between 0 and 1.
delta : float, default=1.0
Smoothing parameter (0 mean no smoothing).
"""

def __init__(self, quantile=0.5, delta=1.0):
if not 0 < quantile < 1:
raise ValueError("quantile must be between 0 and 1")
if delta <= 0:
raise ValueError("delta must be positive")
self.delta = float(delta)
self.quantile = float(quantile)

def get_spec(self):
return (('delta', float64), ('quantile', float64))

def params_to_dict(self):
return dict(delta=self.delta, quantile=self.quantile)

def value(self, y, w, Xw):
"""Compute the quantile Huber loss value."""
n_samples = len(y)
res = 0.0
for i in range(n_samples):
residual = y[i] - Xw[i]
res += self._loss_sample(residual)
return res / n_samples

def _loss_sample(self, residual):
"""Calculate loss for a single sample."""
tau = self.quantile
delta = self.delta
r = residual

if r >= delta:
# Upper linear tail: r >= delta
return tau * (r - delta/2)
elif r >= 0:
# Upper quadratic: 0 <= r < delta
return tau * r**2 / (2 * delta)
elif r > -delta:
# Lower quadratic: -delta < r < 0
return (1 - tau) * r**2 / (2 * delta)
else:
# Lower linear tail: r <= -delta
return (1 - tau) * (-r - delta/2)

def gradient_scalar(self, X, y, w, Xw, j):
"""Compute gradient w.r.t. w_j - following parent class pattern."""
n_samples = len(y)
grad_j = 0.0
for i in range(n_samples):
residual = y[i] - Xw[i]
grad_j += -X[i, j] * self._grad_per_sample(residual)
return grad_j / n_samples

def _grad_per_sample(self, residual):
"""Calculate gradient for a single sample."""
tau = self.quantile
delta = self.delta
r = residual

if r >= delta:
# Upper linear tail: r >= delta
return tau
elif r >= 0:
# Upper quadratic: 0 <= r < delta
return tau * r / delta
elif r > -delta:
# Lower quadratic: -delta < r < 0
return (1 - tau) * r / delta
else:
# Lower linear tail: r <= -delta
return tau - 1

def get_lipschitz(self, X, y):
n_features = X.shape[1]

lipschitz = np.zeros(n_features, dtype=X.dtype)
c = max(self.quantile, 1 - self.quantile) / self.delta
for j in range(n_features):
lipschitz[j] = c * (X[:, j] ** 2).sum() / len(y)

return lipschitz

def get_global_lipschitz(self, X, y):
c = max(self.quantile, 1 - self.quantile) / self.delta
return c * norm(X, ord=2) ** 2 / len(y)

def intercept_update_step(self, y, Xw):
n_samples = len(y)

# Compute gradient
grad = 0.0
for i in range(n_samples):
residual = y[i] - Xw[i]
grad -= self._grad_per_sample(residual)
grad /= n_samples

# Apply step size 1/c
c = max(self.quantile, 1 - self.quantile) / self.delta
return grad / c


class SmoothQuantileRegressor(BaseEstimator, RegressorMixin):
"""Quantile regression with progressive smoothing."""

def __init__(self, quantile=0.5, alpha=0.1, delta_init=1.0, delta_final=1e-3,
n_deltas=10, max_iter=1000, tol=1e-6, verbose=False,
fit_intercept=True):
self.quantile = quantile
self.alpha = alpha
self.delta_init = delta_init
self.delta_final = delta_final
self.n_deltas = n_deltas
self.max_iter = max_iter
self.tol = tol
self.verbose = verbose
self.fit_intercept = fit_intercept

def fit(self, X, y):
"""Fit using progressive smoothing: delta_init --> delta_final."""
w = np.zeros(X.shape[1])
deltas = np.geomspace(self.delta_init, self.delta_final, self.n_deltas)

if self.verbose:
print(
f"Progressive smoothing: delta {self.delta_init:.2e} --> "
f"{self.delta_final:.2e} in {self.n_deltas} steps")

datafit = QuantileHuber(quantile=self.quantile, delta=self.delta_init)
penalty = L1(alpha=self.alpha)

# Use AndersonCD solver
solver = AndersonCD(max_iter=self.max_iter, tol=self.tol,
warm_start=True, fit_intercept=self.fit_intercept,
verbose=max(0, self.verbose - 1))

est = GeneralizedLinearEstimator(
datafit=datafit, penalty=penalty, solver=solver)

for i, delta in enumerate(deltas):
datafit.delta = float(delta)

est.fit(X, y)
w = est.coef_.copy()

if self.verbose:
residuals = y - X @ w
if self.fit_intercept:
residuals -= est.intercept_
pinball_loss = np.mean(residuals * (self.quantile - (residuals < 0)))

print(
f" Stage {i+1:2d}: delta={delta:.2e}, "
f"pinball_loss={pinball_loss:.6f}, "
f"n_iter={est.n_iter_}"
)

self.est_ = est
self.coef_ = est.coef_
if self.fit_intercept:
self.intercept_ = est.intercept_

return self

def predict(self, X):
"""Predict using the fitted model."""
if not hasattr(self, "est_"):
raise NotFittedError(
"This SmoothQuantileRegressor instance is not fitted yet. "
"Call 'fit' with appropriate arguments before using this estimator."
)
return self.est_.predict(X)
Loading