diff --git a/api/group.py b/api/group.py index d095dfd9..f32e5691 100644 --- a/api/group.py +++ b/api/group.py @@ -21,6 +21,7 @@ from core.dependencies import ( session_dependency, admin_dependency, + amp_admin_dependency, editor_dependency, viewer_dependency, ) @@ -60,7 +61,7 @@ def add_thing_to_group_route( group_id: int, thing_id: int, session: session_dependency, - user: admin_dependency, + user: amp_admin_dependency, ): """ Associate a thing (e.g. a water well) with a group (project). @@ -118,7 +119,7 @@ def remove_thing_from_group_route( group_id: int, thing_id: int, session: session_dependency, - user: admin_dependency, + user: amp_admin_dependency, ): """ Remove the association between a thing and a group. diff --git a/tests/test_group.py b/tests/test_group.py index 8ab05879..fc252f28 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -4,7 +4,12 @@ from geoalchemy2.shape import to_shape from pydantic import ValidationError -from core.dependencies import admin_function, viewer_function, editor_function +from core.dependencies import ( + admin_function, + amp_admin_function, + viewer_function, + editor_function, +) from db import Group, GroupThingAssociation, Thing from db.engine import session_ctx from main import app @@ -24,6 +29,9 @@ def override_authentication_dependency_fixture(): app.dependency_overrides[admin_function] = override_authentication( default={"name": "foobar", "sub": "1234567890"} ) + app.dependency_overrides[amp_admin_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) app.dependency_overrides[editor_function] = override_authentication( default={"name": "foobar", "sub": "1234567890"} ) diff --git a/tests/test_lazy_admin.py b/tests/test_lazy_admin.py index 5b70ed88..ac2f2244 100644 --- a/tests/test_lazy_admin.py +++ b/tests/test_lazy_admin.py @@ -1,14 +1,29 @@ import os +from collections.abc import Iterable from core.factory import create_api_app from fastapi.testclient import TestClient +def _iter_route_paths(routes: Iterable) -> Iterable[str]: + for route in routes: + path = getattr(route, "path", None) + if path: + yield path + nested = getattr(route, "routes", None) + if nested: + yield from _iter_route_paths(nested) + + +def _has_admin_route(routes: Iterable) -> bool: + return any(path.startswith("/admin") for path in _iter_route_paths(routes)) + + def test_admin_is_lazy_loaded_on_first_admin_request(): os.environ["SESSION_SECRET_KEY"] = "test-session-secret-key" app = create_api_app() - assert not any(route.path.startswith("/admin") for route in app.routes) + assert not _has_admin_route(app.routes) assert getattr(app.state, "admin_configured", False) is False with TestClient(app) as client: @@ -16,4 +31,4 @@ def test_admin_is_lazy_loaded_on_first_admin_request(): assert response.status_code in {200, 302, 307} assert app.state.admin_configured is True - assert any(route.path.startswith("/admin") for route in app.routes) + assert _has_admin_route(app.routes)