@@ -547,32 +547,29 @@ def _gen_model_t(self):
547
547
def _gen_model_final (self ):
548
548
return StatsModelsLinearRegression (fit_intercept = False )
549
549
550
- def _gen_ortho_learner_model_nuisance (self , n_periods ):
550
+ def _gen_ortho_learner_model_nuisance (self ):
551
551
return _DynamicModelNuisance (
552
552
model_t = self ._gen_model_t (),
553
553
model_y = self ._gen_model_y (),
554
- n_periods = n_periods )
554
+ n_periods = self . _n_periods )
555
555
556
- def _gen_ortho_learner_model_final (self , n_periods ):
556
+ def _gen_ortho_learner_model_final (self ):
557
557
wrapped_final_model = _DynamicFinalWrapper (
558
558
StatsModelsLinearRegression (fit_intercept = False ),
559
559
fit_cate_intercept = self .fit_cate_intercept ,
560
560
featurizer = self .featurizer ,
561
561
use_weight_trick = False )
562
- return _LinearDynamicModelFinal (wrapped_final_model , n_periods = n_periods )
562
+ return _LinearDynamicModelFinal (wrapped_final_model , n_periods = self . _n_periods )
563
563
564
564
def _prefit (self , Y , T , * args , groups = None , only_final = False , ** kwargs ):
565
+ # we need to set the number of periods before calling super()._prefit, since that will generate the
566
+ # final and nuisance models, which need to have self._n_periods set
565
567
u_periods = np .unique (np .unique (groups , return_counts = True )[1 ])
566
568
if len (u_periods ) > 1 :
567
569
raise AttributeError (
568
570
"Imbalanced panel. Method currently expects only panels with equal number of periods. Pad your data" )
569
571
self ._n_periods = u_periods [0 ]
570
- # generate an instance of the final model
571
- self ._ortho_learner_model_final = self ._gen_ortho_learner_model_final (self ._n_periods )
572
- if not only_final :
573
- # generate an instance of the nuisance model
574
- self ._ortho_learner_model_nuisance = self ._gen_ortho_learner_model_nuisance (self ._n_periods )
575
- TreatmentExpansionMixin ._prefit (self , Y , T , * args , ** kwargs )
572
+ super ()._prefit (self , Y , T , * args , ** kwargs )
576
573
577
574
def _postfit (self , Y , T , * args , ** kwargs ):
578
575
super ()._postfit (Y , T , * args , ** kwargs )
0 commit comments