Skip to content

Commit 8120102

Browse files
committed
Refactor DynamicDML to remove incompatible method signatures
1 parent e67bff7 commit 8120102

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

econml/panel/dml/_dml.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -547,32 +547,29 @@ def _gen_model_t(self):
547547
def _gen_model_final(self):
548548
return StatsModelsLinearRegression(fit_intercept=False)
549549

550-
def _gen_ortho_learner_model_nuisance(self, n_periods):
550+
def _gen_ortho_learner_model_nuisance(self):
551551
return _DynamicModelNuisance(
552552
model_t=self._gen_model_t(),
553553
model_y=self._gen_model_y(),
554-
n_periods=n_periods)
554+
n_periods=self._n_periods)
555555

556-
def _gen_ortho_learner_model_final(self, n_periods):
556+
def _gen_ortho_learner_model_final(self):
557557
wrapped_final_model = _DynamicFinalWrapper(
558558
StatsModelsLinearRegression(fit_intercept=False),
559559
fit_cate_intercept=self.fit_cate_intercept,
560560
featurizer=self.featurizer,
561561
use_weight_trick=False)
562-
return _LinearDynamicModelFinal(wrapped_final_model, n_periods=n_periods)
562+
return _LinearDynamicModelFinal(wrapped_final_model, n_periods=self._n_periods)
563563

564564
def _prefit(self, Y, T, *args, groups=None, only_final=False, **kwargs):
565+
# we need to set the number of periods before calling super()._prefit, since that will generate the
566+
# final and nuisance models, which need to have self._n_periods set
565567
u_periods = np.unique(np.unique(groups, return_counts=True)[1])
566568
if len(u_periods) > 1:
567569
raise AttributeError(
568570
"Imbalanced panel. Method currently expects only panels with equal number of periods. Pad your data")
569571
self._n_periods = u_periods[0]
570-
# generate an instance of the final model
571-
self._ortho_learner_model_final = self._gen_ortho_learner_model_final(self._n_periods)
572-
if not only_final:
573-
# generate an instance of the nuisance model
574-
self._ortho_learner_model_nuisance = self._gen_ortho_learner_model_nuisance(self._n_periods)
575-
TreatmentExpansionMixin._prefit(self, Y, T, *args, **kwargs)
572+
super()._prefit(self, Y, T, *args, **kwargs)
576573

577574
def _postfit(self, Y, T, *args, **kwargs):
578575
super()._postfit(Y, T, *args, **kwargs)

0 commit comments

Comments
 (0)