diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 62f43d57f5d1f..eae21264ad5c9 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Allow LightningCLI to use a customized argument parser class ([#20596](https://github.com/Lightning-AI/pytorch-lightning/pull/20596)) ### Changed @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored ([#20538](https://github.com/Lightning-AI/pytorch-lightning/pull/20538)) + ### Removed - diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index c79f2481c8af4..75a6347c95356 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -314,6 +314,7 @@ def __init__( trainer_defaults: Optional[dict[str, Any]] = None, seed_everything_default: Union[bool, int] = True, parser_kwargs: Optional[Union[dict[str, Any], dict[str, dict[str, Any]]]] = None, + parser_class: type[LightningArgumentParser] = LightningArgumentParser, subclass_mode_model: bool = False, subclass_mode_data: bool = False, args: ArgsType = None, @@ -367,6 +368,7 @@ def __init__( self.trainer_defaults = trainer_defaults or {} self.seed_everything_default = seed_everything_default self.parser_kwargs = parser_kwargs or {} + self.parser_class = parser_class self.auto_configure_optimizers = auto_configure_optimizers self.model_class = model_class @@ -404,7 +406,7 @@ def _setup_parser_kwargs(self, parser_kwargs: dict[str, Any]) -> tuple[dict[str, def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" kwargs.setdefault("dump_header", [f"lightning.pytorch=={pl.__version__}"]) - parser = LightningArgumentParser(**kwargs) + parser = self.parser_class(**kwargs) parser.add_argument( "-c", "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format." )