-
Notifications
You must be signed in to change notification settings - Fork 37
FEAT - Implement SmoothQuantileRegression #312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
21f1459
010c399
e4888d9
be991c1
575ffbb
938d842
39ced0d
4375b4c
3c9c320
693ca06
303b167
f9e2e79
2c84a44
7d7106c
34b083d
ba3d6b8
f43c190
cb0c3e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
""" | ||
================================================================================ | ||
Smooth Quantile Regression with QuantileHuber | ||
================================================================================ | ||
|
||
This example compares sklearn's standard quantile regression with skglm's smooth | ||
approximation. Skglm's quantile regression uses a smooth Huber-like approximation | ||
(quadratic near zero, linear in the tails) to replace the non-differentiable | ||
pinball loss. Progressive smoothing enables efficient gradient-based optimization, | ||
maintaining speed and accuracy also on large-scale, high-dimensional datasets. | ||
""" | ||
|
||
# Author: Florian Kozikowski | ||
import numpy as np | ||
import time | ||
import matplotlib.pyplot as plt | ||
|
||
from sklearn.datasets import make_regression | ||
from sklearn.linear_model import QuantileRegressor | ||
from skglm.experimental.quantile_huber import QuantileHuber, SmoothQuantileRegressor | ||
|
||
# Generate regression data | ||
X, y = make_regression(n_samples=1000, n_features=10, noise=0.1, random_state=0) | ||
tau = 0.8 # 80th percentile | ||
|
||
# %% | ||
# Compare standard vs smooth quantile regression | ||
# ---------------------------------------------- | ||
# Both methods solve the same problem but with different loss functions. | ||
|
||
# Standard quantile regression (sklearn) | ||
start = time.time() | ||
sk_model = QuantileRegressor(quantile=tau, alpha=0.1) | ||
sk_model.fit(X, y) | ||
sk_time = time.time() - start | ||
|
||
# Smooth quantile regression (skglm) | ||
start = time.time() | ||
smooth_model = SmoothQuantileRegressor( | ||
quantile=tau, | ||
alpha=0.1, | ||
delta_init=0.5, # Initial smoothing parameter | ||
delta_final=0.01, # Final smoothing (smaller = closer to true quantile) | ||
n_deltas=5 # Number of continuation steps | ||
) | ||
smooth_model.fit(X, y) | ||
smooth_time = time.time() - start | ||
|
||
# %% | ||
# Evaluate both methods | ||
# --------------------- | ||
# Coverage: fraction of true values below predictions (should ≈ tau) | ||
# Pinball loss: standard quantile regression evaluation metric | ||
# | ||
# Note: No robust benchmarking conducted yet. The speed advantagous likely only | ||
# shows on large-scale, high-dimensional datasets. The sklearn implementation is | ||
# likely faster on small datasets. | ||
|
||
|
||
def pinball_loss(residuals, quantile): | ||
return np.mean(residuals * (quantile - (residuals < 0))) | ||
|
||
|
||
sk_pred = sk_model.predict(X) | ||
smooth_pred = smooth_model.predict(X) | ||
|
||
print(f"{'Method':<15} {'Coverage':<10} {'Time (s)':<10} {'Pinball Loss':<12}") | ||
print("-" * 50) | ||
print(f"{'Sklearn':<15} {np.mean(y <= sk_pred):<10.3f} {sk_time:<10.3f} " | ||
f"{pinball_loss(y - sk_pred, tau):<12.4f}") | ||
print(f"{'SmoothQuantile':<15} {np.mean(y <= smooth_pred):<10.3f} {smooth_time:<10.3f} " | ||
f"{pinball_loss(y - smooth_pred, tau):<12.4f}") | ||
|
||
# %% | ||
# Visualize the smooth approximation | ||
# ---------------------------------- | ||
# The smooth loss approximates the pinball loss but with continuous gradients | ||
|
||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4)) | ||
|
||
# Show loss and gradient for different quantile levels | ||
residuals = np.linspace(-3, 3, 500) | ||
delta = 0.5 | ||
quantiles = [0.1, 0.5, 0.9] | ||
|
||
for tau_val in quantiles: | ||
qh = QuantileHuber(quantile=tau_val, delta=delta) | ||
loss = [qh._loss_sample(r) for r in residuals] | ||
grad = [qh._grad_per_sample(r) for r in residuals] | ||
|
||
# Compute pinball loss for each residual | ||
pinball_loss = [r * (tau_val - (r < 0)) for r in residuals] | ||
|
||
# Plot smooth loss and pinball loss | ||
ax1.plot(residuals, loss, label=f"τ={tau_val}", linewidth=2) | ||
ax1.plot(residuals, pinball_loss, '--', alpha=0.4, color='gray', | ||
label=f"Pinball τ={tau_val}") | ||
ax2.plot(residuals, grad, label=f"τ={tau_val}", linewidth=2) | ||
|
||
# Add vertical lines and shading showing delta boundaries | ||
for ax in [ax1, ax2]: | ||
ax.axvline(-delta, color='gray', linestyle='--', alpha=0.7, linewidth=1.5) | ||
ax.axvline(delta, color='gray', linestyle='--', alpha=0.7, linewidth=1.5) | ||
# Add shading for quadratic region | ||
ax.axvspan(-delta, delta, alpha=0.15, color='gray') | ||
|
||
# Add delta labels | ||
ax1.text(-delta, 0.1, '−δ', ha='right', va='bottom', color='gray', fontsize=10) | ||
ax1.text(delta, 0.1, '+δ', ha='left', va='bottom', color='gray', fontsize=10) | ||
|
||
ax1.set_title(f"Smooth Quantile Loss (δ={delta})", fontsize=12) | ||
ax1.set_xlabel("Residual") | ||
ax1.set_ylabel("Loss") | ||
ax1.legend(loc='upper left') | ||
ax1.grid(True, alpha=0.3) | ||
|
||
ax2.set_title("Gradient (continuous everywhere)", fontsize=12) | ||
ax2.set_xlabel("Residual") | ||
ax2.set_ylabel("Gradient") | ||
ax2.legend(loc='upper left') | ||
ax2.grid(True, alpha=0.3) | ||
|
||
plt.tight_layout() | ||
plt.show() | ||
|
||
# %% [markdown] | ||
# The left plot shows the asymmetric loss: tau=0.1 penalizes overestimation more, | ||
# while tau=0.9 penalizes underestimation. As delta decreases towards zero, the | ||
# loss function approaches the standard pinball loss. | ||
# The right plot reveals the key advantage: gradients transition smoothly through | ||
# zero, unlike standard quantile regression which has a kink. This smoothing | ||
# enables fast convergence with gradient-based solvers. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,11 +2,14 @@ | |
from .sqrt_lasso import SqrtLasso, SqrtQuadratic | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
from .pdcd_ws import PDCD_WS | ||
from .quantile_regression import Pinball | ||
from .quantile_huber import QuantileHuber, SmoothQuantileRegressor | ||
|
||
__all__ = [ | ||
IterativeReweightedL1, | ||
PDCD_WS, | ||
Pinball, | ||
QuantileHuber, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use alphabetical order |
||
SmoothQuantileRegressor, | ||
SqrtQuadratic, | ||
SqrtLasso, | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
import numpy as np | ||
from numpy.linalg import norm | ||
from numba import float64 | ||
|
||
from sklearn.base import BaseEstimator, RegressorMixin | ||
from sklearn.exceptions import NotFittedError | ||
|
||
from skglm.datafits.base import BaseDatafit | ||
from skglm.solvers import AndersonCD | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. put all skglm imports together in their own block |
||
from skglm.penalties import L1 | ||
from skglm.estimators import GeneralizedLinearEstimator | ||
|
||
|
||
class QuantileHuber(BaseDatafit): | ||
r"""Quantile Huber loss for quantile regression. | ||
|
||
Implements the smoothed pinball loss: | ||
|
||
.. math:: | ||
|
||
\rho_\tau^\delta(r) = | ||
\begin{cases} | ||
\tau\, r - \dfrac{\delta}{2}, & \text{if } r \ge \delta,\\ | ||
\dfrac{\tau r^{2}}{2\delta}, & \text{if } 0 \le r < \delta,\\ | ||
\dfrac{(1-\tau) r^{2}}{2\delta}, & \text{if } -\delta < r < 0,\\ | ||
(\tau - 1)\, r - \dfrac{\delta}{2}, & \text{if } r \le -\delta. | ||
\end{cases} | ||
|
||
Parameters | ||
---------- | ||
quantile : float, default=0.5 | ||
Desired quantile level between 0 and 1. | ||
delta : float, default=1.0 | ||
Smoothing parameter (0 mean no smoothing). | ||
""" | ||
|
||
def __init__(self, quantile=0.5, delta=1.0): | ||
if not 0 < quantile < 1: | ||
raise ValueError("quantile must be between 0 and 1") | ||
if delta <= 0: | ||
raise ValueError("delta must be positive") | ||
self.delta = float(delta) | ||
self.quantile = float(quantile) | ||
|
||
def get_spec(self): | ||
return (('delta', float64), ('quantile', float64)) | ||
|
||
def params_to_dict(self): | ||
return dict(delta=self.delta, quantile=self.quantile) | ||
|
||
def value(self, y, w, Xw): | ||
"""Compute the quantile Huber loss value.""" | ||
n_samples = len(y) | ||
res = 0.0 | ||
for i in range(n_samples): | ||
residual = y[i] - Xw[i] | ||
res += self._loss_sample(residual) | ||
return res / n_samples | ||
|
||
def _loss_sample(self, residual): | ||
"""Calculate loss for a single sample.""" | ||
tau = self.quantile | ||
delta = self.delta | ||
r = residual | ||
|
||
if r >= delta: | ||
# Upper linear tail: r >= delta | ||
return tau * (r - delta/2) | ||
elif r >= 0: | ||
# Upper quadratic: 0 <= r < delta | ||
return tau * r**2 / (2 * delta) | ||
elif r > -delta: | ||
# Lower quadratic: -delta < r < 0 | ||
return (1 - tau) * r**2 / (2 * delta) | ||
else: | ||
# Lower linear tail: r <= -delta | ||
return (1 - tau) * (-r - delta/2) | ||
|
||
def gradient_scalar(self, X, y, w, Xw, j): | ||
"""Compute gradient w.r.t. w_j - following parent class pattern.""" | ||
n_samples = len(y) | ||
grad_j = 0.0 | ||
for i in range(n_samples): | ||
residual = y[i] - Xw[i] | ||
grad_j += -X[i, j] * self._grad_per_sample(residual) | ||
return grad_j / n_samples | ||
|
||
def _grad_per_sample(self, residual): | ||
"""Calculate gradient for a single sample.""" | ||
tau = self.quantile | ||
delta = self.delta | ||
r = residual | ||
|
||
if r >= delta: | ||
# Upper linear tail: r >= delta | ||
return tau | ||
elif r >= 0: | ||
# Upper quadratic: 0 <= r < delta | ||
return tau * r / delta | ||
elif r > -delta: | ||
# Lower quadratic: -delta < r < 0 | ||
return (1 - tau) * r / delta | ||
else: | ||
# Lower linear tail: r <= -delta | ||
return tau - 1 | ||
|
||
def get_lipschitz(self, X, y): | ||
n_features = X.shape[1] | ||
|
||
lipschitz = np.zeros(n_features, dtype=X.dtype) | ||
c = max(self.quantile, 1 - self.quantile) / self.delta | ||
for j in range(n_features): | ||
lipschitz[j] = c * (X[:, j] ** 2).sum() / len(y) | ||
|
||
return lipschitz | ||
|
||
def get_global_lipschitz(self, X, y): | ||
c = max(self.quantile, 1 - self.quantile) / self.delta | ||
return c * norm(X, ord=2) ** 2 / len(y) | ||
|
||
def intercept_update_step(self, y, Xw): | ||
n_samples = len(y) | ||
|
||
# Compute gradient | ||
grad = 0.0 | ||
for i in range(n_samples): | ||
residual = y[i] - Xw[i] | ||
grad -= self._grad_per_sample(residual) | ||
grad /= n_samples | ||
|
||
# Apply step size 1/c | ||
c = max(self.quantile, 1 - self.quantile) / self.delta | ||
return grad / c | ||
|
||
|
||
class SmoothQuantileRegressor(BaseEstimator, RegressorMixin): | ||
"""Quantile regression with progressive smoothing.""" | ||
|
||
def __init__(self, quantile=0.5, alpha=0.1, delta_init=1.0, delta_final=1e-3, | ||
n_deltas=10, max_iter=1000, tol=1e-6, verbose=False, | ||
fit_intercept=True): | ||
self.quantile = quantile | ||
self.alpha = alpha | ||
self.delta_init = delta_init | ||
self.delta_final = delta_final | ||
self.n_deltas = n_deltas | ||
self.max_iter = max_iter | ||
self.tol = tol | ||
self.verbose = verbose | ||
self.fit_intercept = fit_intercept | ||
|
||
def fit(self, X, y): | ||
"""Fit using progressive smoothing: delta_init --> delta_final.""" | ||
w = np.zeros(X.shape[1]) | ||
deltas = np.geomspace(self.delta_init, self.delta_final, self.n_deltas) | ||
|
||
if self.verbose: | ||
print( | ||
f"Progressive smoothing: delta {self.delta_init:.2e} --> " | ||
f"{self.delta_final:.2e} in {self.n_deltas} steps") | ||
|
||
datafit = QuantileHuber(quantile=self.quantile, delta=self.delta_init) | ||
penalty = L1(alpha=self.alpha) | ||
|
||
# Use AndersonCD solver | ||
solver = AndersonCD(max_iter=self.max_iter, tol=self.tol, | ||
warm_start=True, fit_intercept=self.fit_intercept, | ||
verbose=max(0, self.verbose - 1)) | ||
|
||
est = GeneralizedLinearEstimator( | ||
datafit=datafit, penalty=penalty, solver=solver) | ||
|
||
for i, delta in enumerate(deltas): | ||
datafit.delta = float(delta) | ||
|
||
est.fit(X, y) | ||
w = est.coef_.copy() | ||
|
||
if self.verbose: | ||
residuals = y - X @ w | ||
if self.fit_intercept: | ||
residuals -= est.intercept_ | ||
pinball_loss = np.mean(residuals * (self.quantile - (residuals < 0))) | ||
|
||
print( | ||
f" Stage {i+1:2d}: delta={delta:.2e}, " | ||
f"pinball_loss={pinball_loss:.6f}, " | ||
f"n_iter={est.n_iter_}" | ||
) | ||
|
||
mathurinm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.est_ = est | ||
self.coef_ = est.coef_ | ||
if self.fit_intercept: | ||
self.intercept_ = est.intercept_ | ||
|
||
return self | ||
|
||
def predict(self, X): | ||
mathurinm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Predict using the fitted model.""" | ||
if not hasattr(self, "est_"): | ||
raise NotFittedError( | ||
"This SmoothQuantileRegressor instance is not fitted yet. " | ||
"Call 'fit' with appropriate arguments before using this estimator." | ||
) | ||
return self.est_.predict(X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
look at some other plotting file format
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.