Skip to content

Fix needs_fit logic for model selection with a fixed model #855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,6 @@ def train(self, is_selecting, folds, X, y, W=None):
return self
def predict(self, X, y, W=None):
return self._model.predict(X)
@property
def needs_fit(self):
return False
np.random.seed(123)
X = np.random.normal(size=(5000, 3))
y = X[:, 0] + np.random.normal(size=(5000,))
Expand Down Expand Up @@ -245,8 +242,7 @@ def needs_fit(self):
# when there is more than one model, nuisances from previous models
# come first as positional arguments
accumulated_args = accumulated_nuisances + args
if model.needs_fit:
model.train(True, fold_vals if folds is None else folds, *accumulated_args, **kwargs)
model.train(True, fold_vals if folds is None else folds, *accumulated_args, **kwargs)

calculate_scores &= hasattr(model, 'score')

Expand Down Expand Up @@ -452,9 +448,6 @@ def train(self, is_selecting, folds, Y, T, W=None):
return self
def predict(self, Y, T, W=None):
return Y - self._model_y.predict(W), T - self._model_t.predict(W)
@property
def needs_fit(self):
return False
class ModelFinal:
def __init__(self):
return
Expand Down Expand Up @@ -509,9 +502,6 @@ def train(self, is_selecting, folds, Y, T, W=None):
return self
def predict(self, Y, T, W=None):
return Y - self._model_y.predict(W), T - self._model_t.predict_proba(W)[:, 1:]
@property
def needs_fit(self):
return False
class ModelFinal:
def __init__(self):
return
Expand Down
7 changes: 0 additions & 7 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None)
T_res = T - T_pred.reshape(T.shape)
return Y_res, T_res

@property
def needs_fit(self):
return self._model_y.needs_fit or self._model_t.needs_fit


class _ModelFinal:
"""
Expand Down Expand Up @@ -234,9 +230,6 @@ def best_model(self):
@property
def best_score(self):
return 0
@property
def needs_fit(self):
return False
class ModelFinal:
def fit(self, X, T, T_res, Y_res, sample_weight=None, freq_weight=None, sample_var=None):
self.model = LinearRegression(fit_intercept=False).fit(X * T_res.reshape(-1, 1),
Expand Down
4 changes: 0 additions & 4 deletions econml/dml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ def best_model(self):
def best_score(self):
return self._model.best_score

@property
def needs_fit(self):
return self._model.needs_fit


def _make_first_stage_selector(model, is_discrete, random_state):
if model == 'auto':
Expand Down
4 changes: 0 additions & 4 deletions econml/dr/_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,6 @@ def predict(self, Y, T, X=None, W=None, *, sample_weight=None, groups=None):
propensities_weight = np.sum(propensities * T_complete, axis=1)
return Y_pred.reshape(Y.shape + (T.shape[1] + 1,)), propensities_weight.reshape((n,))

@property
def needs_fit(self):
return self._model_propensity.needs_fit or self._model_regression.needs_fit


def _make_first_stage_selector(model, is_discrete, random_state):
if model == "auto":
Expand Down
10 changes: 0 additions & 10 deletions econml/iv/dml/_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,6 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None)
Z_res = Z - Z_pred.reshape(Z.shape)
return Y_res, T_res, Z_res

@property
def needs_fit(self):
return (self._model_y_xw.needs_fit or self._model_t_xw.needs_fit or
(self._projection and self._model_t_xwz.needs_fit) or
(not self._projection and self._model_z_xw.needs_fit))


class _OrthoIVModelFinal:
def __init__(self, model_final, featurizer, fit_cate_intercept):
Expand Down Expand Up @@ -773,10 +767,6 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None)
T_res = TXZ_pred.reshape(T.shape) - TX_pred.reshape(T.shape)
return Y_res, T_res

@property
def needs_fit(self):
return self._model_y_xw.needs_fit or self._model_t_xw.needs_fit or self._model_t_xwz.needs_fit


class _BaseDMLIVModelFinal(_ModelFinal):
"""
Expand Down
14 changes: 0 additions & 14 deletions econml/iv/dr/_dr.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,6 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None)

return prel_theta, Y_res, T_res, Z_res

@property
def needs_fit(self):
return (self._model_y_xw.needs_fit or self._model_t_xw.needs_fit or
(self._projection and self._model_t_xwz.needs_fit) or
(not self._projection and self._model_z_xw.needs_fit))


class _BaseDRIVNuisanceCovarianceSelector(ModelSelector):
def __init__(self, *, model_tz_xw,
Expand Down Expand Up @@ -275,10 +269,6 @@ def predict(self, prel_theta, Y_res, T_res, Z_res, Y, T, X=None, W=None, Z=None,

return (cov,)

@property
def needs_fit(self):
return self._model_tz_xw.needs_fit


class _BaseDRIVModelFinal:
def __init__(self, model_final, featurizer, fit_cate_intercept, cov_clip, opt_reweighted):
Expand Down Expand Up @@ -2464,10 +2454,6 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None)

return prel_theta, Y_res, T_res, Z_res, beta

@property
def needs_fit(self):
return self._model_y_xw.needs_fit or self._model_t_xwz.needs_fit or self._dummy_z.needs_fit


class _DummyClassifier:
"""
Expand Down
4 changes: 0 additions & 4 deletions econml/panel/dml/_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,6 @@ def _get_shape_formatter(self, X, W):
def _index_or_None(self, X, filter_idx):
return None if X is None else X[filter_idx]

@property
def needs_fit(self):
return self._model_t.needs_fit or self._model_y.needs_fit


class _DynamicModelFinal:
"""
Expand Down
68 changes: 25 additions & 43 deletions econml/sklearn_extensions/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,6 @@ def score(self, *args, **kwargs):
"""
raise NotImplementedError("Abstract method")

@property
@abc.abstractmethod
def needs_fit(self):
"""
Whether the model selector needs to be fit before it can be used for prediction or scoring;
in many cases this is equivalent to whether the selector is choosing between multiple models
"""
raise NotImplementedError("Abstract method")


class SingleModelSelector(ModelSelector):
"""
Expand Down Expand Up @@ -392,24 +383,25 @@ class FixedModelSelector(SingleModelSelector):
Model selection class that always selects the given sklearn-compatible model
"""

def __init__(self, model):
def __init__(self, model, score_during_selection):
self.model = clone(model, safe=False)
self.score_during_selection = score_during_selection

def train(self, is_selecting, folds: Optional[List], X, y, groups=None, **kwargs):
if is_selecting:
# since needs_fit is False, is_selecting will only be true if
# the score needs to be compared to another model's
# so we don't need to fit the model itself, just get the out-of-sample score
assert hasattr(self.model, 'score'), (f"Can't select between a fixed {type(self.model)} model and others "
"because it doesn't have a score method")
scores = []
for train, test in folds:
# use _fit_with_groups instead of just fit to handle nested grouping
_fit_with_groups(self.model, X[train], y[train],
groups=None if groups is None else groups[train],
**{key: val[train] for key, val in kwargs.items()})
scores.append(self.model.score(X[test], y[test]))
self._score = np.mean(scores)
if self.score_during_selection:
# the score needs to be compared to another model's
# so we don't need to fit the model itself on all of the data, just get the out-of-sample score
assert hasattr(self.model, 'score'), (f"Can't select between a fixed {type(self.model)} model "
"and others because it doesn't have a score method")
scores = []
for train, test in folds:
# use _fit_with_groups instead of just fit to handle nested grouping
_fit_with_groups(self.model, X[train], y[train],
groups=None if groups is None else groups[train],
**{key: val[train] for key, val in kwargs.items()})
scores.append(self.model.score(X[test], y[test]))
self._score = np.mean(scores)
else:
# we need to train the model on the data
_fit_with_groups(self.model, X, y, groups=groups, **kwargs)
Expand All @@ -422,11 +414,10 @@ def best_model(self):

@property
def best_score(self):
return self._score

@property
def needs_fit(self):
return False # We have only a single model so we can skip the selection process
if hasattr(self, '_score'):
return self._score
else:
raise ValueError("No score was computed during selection")


def _copy_to(m1, m2, attrs, insert_underscore=False):
Expand Down Expand Up @@ -579,11 +570,6 @@ def best_model(self):
def best_score(self):
return self._best_score

@property
def needs_fit(self):
return True # strictly speaking, could be false if the hyperparameters are fixed
# but it would be complicated to check that


class ListSelector(SingleModelSelector):
"""
Expand Down Expand Up @@ -627,14 +613,8 @@ def best_model(self):
def best_score(self):
return self._best_score

@property
def needs_fit(self):
# technically, if there is just one model and it doesn't need to be fit we don't need to fit it,
# but that complicates the training logic so we don't bother with that case
return True


def get_selector(input, is_discrete, *, random_state=None, cv=None, wrapper=GridSearchCV):
def get_selector(input, is_discrete, *, random_state=None, cv=None, wrapper=GridSearchCV, needs_scoring=False):
named_models = {
'linear': (LogisticRegressionCV(random_state=random_state, cv=cv) if is_discrete
else WeightedLassoCVWrapper(random_state=random_state, cv=cv)),
Expand All @@ -657,19 +637,21 @@ def get_selector(input, is_discrete, *, random_state=None, cv=None, wrapper=Grid
return input
elif isinstance(input, list): # we've got a list; call get_selector on each element, then wrap in a ListSelector
models = [get_selector(model, is_discrete,
random_state=random_state, cv=cv, wrapper=wrapper)
random_state=random_state, cv=cv, wrapper=wrapper,
needs_scoring=True) # we need to score to compare outputs to each other
for model in input]
return ListSelector(models)
elif isinstance(input, str): # we've got a string; look it up
if input in named_models:
return get_selector(named_models[input], is_discrete,
random_state=random_state, cv=cv, wrapper=wrapper)
random_state=random_state, cv=cv, wrapper=wrapper,
needs_scoring=needs_scoring)
else:
raise ValueError(f"Unknown model type: {input}, must be one of {named_models.keys()}")
elif SklearnCVSelector.can_wrap(input):
return SklearnCVSelector(input)
else: # assume this is an sklearn-compatible model
return FixedModelSelector(input)
return FixedModelSelector(input, needs_scoring)


class GridSearchCVList(BaseEstimator):
Expand Down
4 changes: 0 additions & 4 deletions econml/tests/test_missing_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ def train(self, is_selecting, folds, Y, T, W=None):
def predict(self, Y, T, W=None):
return Y - self._model_y.predict(W), T - self._model_t.predict(W)

@property
def needs_fit(self):
return False


class ModelFinal:

Expand Down
16 changes: 16 additions & 0 deletions econml/tests/test_model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from sklearn.preprocessing import PolynomialFeatures
from econml.dml import LinearDML
from econml.sklearn_extensions.linear_model import WeightedLassoCVWrapper
from econml.utilities import SeparateModel
from econml.dr import LinearDRLearner


class TestModelSelection(unittest.TestCase):
Expand Down Expand Up @@ -133,3 +135,17 @@ def test_sklearn_model_selection(self):
discrete_treatment=is_discrete,
model_y=LinearRegression())
est.fit(Y, T2 if use_array else T, X=X, W=W)

def test_fixed_model_scoring(self):
Y, T, X, W = self._simple_dgp(500, 2, 3, True)

# SeparatedModel doesn't support scoring; that should be fine when not compared to other models
mdl = LinearDRLearner(model_regression=SeparateModel(LassoCV(), LassoCV()),
model_propensity=LogisticRegressionCV())
mdl.fit(Y, T, X=X, W=W)

# on the other hand, when we need to compare the score to other models, it should raise an error
with self.assertRaises(Exception):
mdl = LinearDRLearner(model_regression=[SeparateModel(LassoCV(), LassoCV()), Lasso()],
model_propensity=LogisticRegressionCV())
mdl.fit(Y, T, X=X, W=W)
28 changes: 0 additions & 28 deletions econml/tests/test_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ def predict(self, X, y, Q, W=None):
def score(self, X, y, Q, W=None):
return self._model.score(X, y)

@property
def needs_fit(self):
return False

np.random.seed(123)
X = np.random.normal(size=(5000, 3))
y = X[:, 0] + np.random.normal(size=(5000,))
Expand Down Expand Up @@ -118,10 +114,6 @@ def train(self, is_selecting, folds, X, y, W=None):
def predict(self, X, y, W=None):
return self._model.predict(X), y - self._model.predict(X), X

@property
def needs_fit(self):
return False

np.random.seed(123)
X = np.random.normal(size=(5000, 3))
y = X[:, 0] + np.random.normal(size=(5000,))
Expand Down Expand Up @@ -195,10 +187,6 @@ def predict(self, X, y, Q, W=None):
def score(self, X, y, Q, W=None):
return self._model.score(X, y)

@property
def needs_fit(self):
return False

# Generate synthetic data
X, y = make_regression(n_samples=10, n_features=5, noise=0.1, random_state=42)
folds = list(KFold(2).split(X, y))
Expand Down Expand Up @@ -237,10 +225,6 @@ def train(self, is_selecting, folds, Y, T, W=None):
def predict(self, Y, T, W=None):
return Y - self._model_y.predict(W), T - self._model_t.predict(W)

@property
def needs_fit(self):
return False

class ModelFinal:

def __init__(self):
Expand Down Expand Up @@ -353,10 +337,6 @@ def train(self, is_selecting, folds, Y, T, W=None):
def predict(self, Y, T, W=None):
return Y - self._model_y.predict(W), T - self._model_t.predict(W)

@property
def needs_fit(self):
return False

class ModelFinal:

def __init__(self):
Expand Down Expand Up @@ -408,10 +388,6 @@ def predict(self, Y, T, W=None):
def score(self, Y, T, W=None):
return (self._model_t.score(W, Y), self._model_y.score(W, T))

@property
def needs_fit(self):
return False

class ModelFinal:

def __init__(self):
Expand Down Expand Up @@ -466,10 +442,6 @@ def train(self, is_selecting, folds, Y, T, W=None):
def predict(self, Y, T, W=None):
return Y - self._model_y.predict(W), T - self._model_t.predict_proba(W)[:, 1:]

@property
def needs_fit(self):
return False

class ModelFinal:

def __init__(self):
Expand Down