Skip to content

Commit 4d27356

Browse files
authored
set_handlers: changed type of models_to_fetch, removed "models_download_params" (#184)
* set_handlers: `models_to_fetch` and `models_download_params` united in one more flexible parameter. Signed-off-by: Alexander Piskun <bigcat88@icloud.com>
1 parent 16f44f8 commit 4d27356

File tree

6 files changed

+11
-16
lines changed

6 files changed

+11
-16
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ All notable changes to this project will be documented in this file.
1111
### Changed
1212

1313
- set_handlers: `enabled_handler`, `heartbeat_handler`, `init_handler` now can be async(Coroutines). #175 #181
14+
- set_handlers: `models_to_fetch` and `models_download_params` united in one more flexible parameter. #184
1415
- drop Python 3.9 support. #180
1516
- internal code refactoring and clean-up #177
1617

docs/NextcloudTalkBotTransformers.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ This library also provides an additional functionality over this endpoint for ea
6060
6161
@asynccontextmanager
6262
async def lifespan(_app: FastAPI):
63-
set_handlers(APP, enabled_handler, models_to_fetch=[MODEL_NAME])
63+
set_handlers(APP, enabled_handler, models_to_fetch={MODEL_NAME:{}})
6464
yield
6565
6666
This will automatically download models specified in ``models_to_fetch`` parameter to the application persistent storage.

examples/as_app/talk_bot_ai/lib/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@asynccontextmanager
1717
async def lifespan(_app: FastAPI):
18-
set_handlers(APP, enabled_handler, models_to_fetch=[MODEL_NAME])
18+
set_handlers(APP, enabled_handler, models_to_fetch={MODEL_NAME: {}})
1919
yield
2020

2121

nc_py_api/ex_app/integration_fastapi.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ def set_handlers(
7575
enabled_handler: typing.Callable[[bool, AsyncNextcloudApp | NextcloudApp], typing.Awaitable[str] | str],
7676
heartbeat_handler: typing.Callable[[], typing.Awaitable[str] | str] | None = None,
7777
init_handler: typing.Callable[[AsyncNextcloudApp | NextcloudApp], typing.Awaitable[None] | None] | None = None,
78-
models_to_fetch: list[str] | None = None,
79-
models_download_params: dict | None = None,
78+
models_to_fetch: dict[str, dict] | None = None,
8079
map_app_static: bool = True,
8180
):
8281
"""Defines handlers for the application.
@@ -92,7 +91,6 @@ def set_handlers(
9291
9392
.. note:: ```huggingface_hub`` package should be present for automatic models fetching.
9493
95-
:param models_download_params: Parameters to pass to ``snapshot_download`` function from **huggingface_hub**.
9694
:param map_app_static: Should be folders ``js``, ``css``, ``l10n``, ``img`` automatically mounted in FastAPI or not.
9795
9896
.. note:: First, presence of these directories in the current working dir is checked, then one directory higher.
@@ -140,8 +138,7 @@ async def init_callback(
140138
background_tasks.add_task(
141139
__fetch_models_task,
142140
nc,
143-
models_to_fetch if models_to_fetch else [],
144-
models_download_params if models_download_params else {},
141+
models_to_fetch if models_to_fetch else {},
145142
)
146143
return responses.JSONResponse(content={}, status_code=200)
147144

@@ -181,8 +178,7 @@ def __map_app_static_folders(fast_api_app: FastAPI):
181178

182179
def __fetch_models_task(
183180
nc: NextcloudApp,
184-
models: list[str],
185-
params: dict[str, typing.Any],
181+
models: dict[str, dict],
186182
) -> None:
187183
if models:
188184
from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
@@ -193,10 +189,8 @@ def display(self, msg=None, pos=None):
193189
nc.set_init_status(min(int((self.n * 100 / self.total) / len(models)), 100))
194190
return super().display(msg, pos)
195191

196-
if "max_workers" not in params:
197-
params["max_workers"] = 2
198-
if "cache_dir" not in params:
199-
params["cache_dir"] = persistent_storage()
200192
for model in models:
201-
snapshot_download(model, tqdm_class=TqdmProgress, **params) # noqa
193+
workers = models[model].pop("max_workers", 2)
194+
cache = models[model].pop("cache_dir", persistent_storage())
195+
snapshot_download(model, tqdm_class=TqdmProgress, **models[model], max_workers=workers, cache_dir=cache)
202196
nc.set_init_status(100)

tests/_install_init_handler_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
@asynccontextmanager
1212
async def lifespan(_app: FastAPI):
13-
ex_app.set_handlers(APP, enabled_handler, models_to_fetch=[MODEL_NAME])
13+
ex_app.set_handlers(APP, enabled_handler, models_to_fetch={MODEL_NAME: {}})
1414
yield
1515

1616

tests/actual_tests/nc_app_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,4 @@ async def test_set_user_same_value_async(anc_app):
116116

117117
def test_set_handlers_invalid_param(nc_any):
118118
with pytest.raises(ValueError):
119-
set_handlers(None, None, init_handler=set_handlers, models_to_fetch=["some"]) # noqa
119+
set_handlers(None, None, init_handler=set_handlers, models_to_fetch={"some": {}}) # noqa

0 commit comments

Comments
 (0)