Skip to content

ENH - Initialize datafit in Solver.solve method #295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/changes/0.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@

Version 0.5 (in progress)
-------------------------

- jit-compile datafits and penalties inside ``Solver.solve`` method (:gh:`270`)
- Datafits are now initialized inside ``Solver.solve`` method (:gh:`295`)
- Add support for fitting an intercept in :ref:`SqrtLasso <skglm.experimental.sqrt_lasso.SqrtLasso>` (PR: :gh:`298`)
3 changes: 1 addition & 2 deletions examples/plot_survival_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@

# skglm internals: init datafit and penalty
datafit = Cox()
penalty = L1(alpha)

datafit.initialize(X, y)
penalty = L1(alpha)

# init solver
solver = ProxNewton(fit_intercept=False, max_iter=50)
Expand Down
11 changes: 0 additions & 11 deletions skglm/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,6 @@ def _glm_fit(X, y, model, datafit, penalty, solver):

n_samples, n_features = X_.shape

if issparse(X):
datafit.initialize_sparse(X_.data, X_.indptr, X_.indices, y)
else:
datafit.initialize(X_, y)

# if model.warm_start and hasattr(model, 'coef_') and model.coef_ is not None:
if solver.warm_start and hasattr(model, 'coef_') and model.coef_ is not None:
if isinstance(datafit, QuadraticSVC):
Expand Down Expand Up @@ -1374,12 +1369,6 @@ def fit(self, X, y):
fit_intercept=False,
)

# solve problem
if not issparse(X):
datafit.initialize(X, y)
else:
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)

w, _, stop_crit = solver.solve(X, y, datafit, penalty)

# save to attribute
Expand Down
2 changes: 0 additions & 2 deletions skglm/solvers/anderson_cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,8 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):

is_sparse = sparse.issparse(X)
if is_sparse:
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
lipschitz = datafit.get_lipschitz_sparse(X.data, X.indptr, X.indices, y)
else:
datafit.initialize(X, y)
lipschitz = datafit.get_lipschitz(X, y)

if len(w) != n_features + self.fit_intercept:
Expand Down
13 changes: 12 additions & 1 deletion skglm/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import abstractmethod, ABC

import numpy as np
from scipy.sparse import issparse

from skglm.utils.validation import check_attrs
from skglm.utils.jit_compilation import compiled_clone
Expand Down Expand Up @@ -40,6 +41,8 @@ class BaseSolver(ABC):
def _solve(self, X, y, datafit, penalty, w_init, Xw_init):
"""Solve an optimization problem.

This method assumes that datafit was already initialized.

Parameters
----------
X : array, shape (n_samples, n_features)
Expand Down Expand Up @@ -95,7 +98,8 @@ def custom_checks(self, X, y, datafit, penalty):
pass

def solve(
self, X, y, datafit, penalty, w_init=None, Xw_init=None, *, run_checks=True
self, X, y, datafit, penalty, w_init=None, Xw_init=None, *,
run_checks=True, initialize_datafit=True
):
"""Solve the optimization problem after validating its compatibility.

Expand Down Expand Up @@ -133,6 +137,13 @@ def solve(
if run_checks:
self._validate(X, y, datafit, penalty)

# check for None as `GramCD` solver take `None` as datafit
if datafit is not None and initialize_datafit:
if issparse(X):
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
else:
datafit.initialize(X, y)

return self._solve(X, y, datafit, penalty, w_init, Xw_init)

def _validate(self, X, y, datafit, penalty):
Expand Down
2 changes: 0 additions & 2 deletions skglm/solvers/fista.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,10 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples)

if X_is_sparse:
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
lipschitz = datafit.get_global_lipschitz_sparse(
X.data, X.indptr, X.indices, y
)
else:
datafit.initialize(X, y)
lipschitz = datafit.get_global_lipschitz(X, y)

for n_iter in range(self.max_iter):
Expand Down
2 changes: 0 additions & 2 deletions skglm/solvers/group_bcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,8 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):

is_sparse = issparse(X)
if is_sparse:
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
lipschitz = datafit.get_lipschitz_sparse(X.data, X.indptr, X.indices, y)
else:
datafit.initialize(X, y)
lipschitz = datafit.get_lipschitz(X, y)

all_groups = np.arange(n_groups)
Expand Down
7 changes: 0 additions & 7 deletions skglm/solvers/group_prox_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,6 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
stop_crit = 0.
p_objs_out = []

# TODO: to be isolated in a seperated method
is_sparse = issparse(X)
if is_sparse:
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
else:
datafit.initialize(X, y)

for iter in range(self.max_iter):
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)

Expand Down
7 changes: 0 additions & 7 deletions skglm/solvers/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@ def __init__(self, max_iter=50, tol=1e-4, verbose=False):

def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):

# TODO: to be isolated in a seperated method
is_sparse = issparse(X)
if is_sparse:
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
else:
datafit.initialize(X, y)

def objective(w):
Xw = X @ w
datafit_value = datafit.value(y, w, Xw)
Expand Down
2 changes: 0 additions & 2 deletions skglm/solvers/multitask_bcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,8 @@ def _solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None):

is_sparse = sparse.issparse(X)
if is_sparse:
datafit.initialize_sparse(X.data, X.indptr, X.indices, Y)
lipschitz = datafit.get_lipschitz_sparse(X.data, X.indptr, X.indices, Y)
else:
datafit.initialize(X, Y)
lipschitz = datafit.get_lipschitz(X, Y)

for t in range(self.max_iter):
Expand Down
6 changes: 0 additions & 6 deletions skglm/solvers/prox_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,6 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
if is_sparse:
X_bundles = (X.data, X.indptr, X.indices)

# TODO: to be isolated in a seperated method
if is_sparse:
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
else:
datafit.initialize(X, y)

if self.ws_strategy == "fixpoint":
X_square = X.multiply(X) if is_sparse else X ** 2

Expand Down
4 changes: 2 additions & 2 deletions skglm/tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def test_CoxEstimator(use_efron, use_float_32):
datafit = Cox(use_efron)
penalty = L1(alpha)

# XXX: intialize is needed here although it is done in ProxNewton
# it is used to evaluate the objective
datafit.initialize(X, y)

w, *_ = ProxNewton(
Expand Down Expand Up @@ -256,8 +258,6 @@ def test_CoxEstimator_sparse(use_efron, use_float_32):
datafit = Cox(use_efron)
penalty = L1(alpha)

datafit.initialize_sparse(X.data, X.indptr, X.indices, y)

*_, stop_crit = ProxNewton(
fit_intercept=False, tol=1e-6, max_iter=50
).solve(
Expand Down
2 changes: 1 addition & 1 deletion skglm/tests/test_lbfgs_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_L2_Cox(use_efron):
penalty = L2(alpha)

# XXX: intialize is needed here although it is done in LBFGS
# is used to evaluate the objective
# it is used to evaluate the objective
datafit.initialize(X, y)
w, *_ = LBFGS().solve(X, y, datafit, penalty)

Expand Down