diff --git a/api/asset.py b/api/asset.py index 78eae02e6..6e5b8fde9 100644 --- a/api/asset.py +++ b/api/asset.py @@ -23,9 +23,8 @@ from api.pagination import CustomPage from core.dependencies import ( session_dependency, - viewer_function, + viewer_dependency, admin_dependency, - admin_function, editor_dependency, ) from db import Thing @@ -43,9 +42,7 @@ ) from services.exceptions_helper import PydanticStyleException -router = APIRouter( - prefix="/asset", tags=["asset"], dependencies=[Depends(viewer_function)] -) +router = APIRouter(prefix="/asset", tags=["asset"]) def database_error_handler(payload: CreateAsset, error: ProgrammingError) -> None: @@ -80,10 +77,11 @@ def database_error_handler(payload: CreateAsset, error: ProgrammingError) -> Non @router.post( "/upload", status_code=HTTP_201_CREATED, - dependencies=[Depends(admin_function)], ) async def upload_asset( - bucket=Depends(get_storage_bucket), file: UploadFile = File(...) + user: admin_dependency, + bucket=Depends(get_storage_bucket), + file: UploadFile = File(...), ) -> dict: uri, blob_name = gcs_upload(file, bucket) return { @@ -148,6 +146,7 @@ async def add_asset( @router.get("") async def list_assets( + user: viewer_dependency, session: session_dependency, thing_id: int = None, ) -> CustomPage[AssetResponse]: @@ -171,6 +170,7 @@ def transformer(records: list[Asset]): @router.get("/{asset_id}") async def get_asset( + user: viewer_dependency, asset_id: int, session: session_dependency, bucket=Depends(get_storage_bucket), @@ -213,9 +213,9 @@ async def delete_asset( @router.delete( "/{asset_id}/remove", status_code=HTTP_204_NO_CONTENT, - dependencies=[Depends(admin_function)], ) async def remove_asset( + user: admin_dependency, asset_id: int, session: session_dependency, bucket=Depends(get_storage_bucket), diff --git a/api/author.py b/api/author.py index 1761dec6e..a54b1139e 100644 --- a/api/author.py +++ b/api/author.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== -from fastapi import APIRouter, Depends +from fastapi import APIRouter from sqlalchemy import select -from sqlalchemy.orm import Session -from db.engine import get_db_session +from core.dependencies import viewer_dependency, session_dependency from db.publication import Author from schemas.publication import PublicationResponse @@ -32,7 +31,7 @@ response_model=list[PublicationResponse], ) async def get_author_publications( - author_id: int, session: Session = Depends(get_db_session) + user: viewer_dependency, author_id: int, session: session_dependency ): """ Retrieve all publications for a specific author. diff --git a/api/geochronology.py b/api/geochronology.py index 411071d86..af3d984cf 100644 --- a/api/geochronology.py +++ b/api/geochronology.py @@ -14,19 +14,18 @@ # limitations under the License. # =============================================================================== from db.geochronology import GeochronologyAge -from fastapi import APIRouter, Depends, status +from fastapi import APIRouter, status from services.crud_helper import model_adder -from db.engine import get_db_session from schemas.geochronology import CreateGeochronologyAge -from sqlalchemy.orm import Session from sqlalchemy import select +from core.dependencies import viewer_dependency, session_dependency router = APIRouter(prefix="/geochronology", tags=["geochronology"]) @router.post("/age", tags=["geochronology"], status_code=status.HTTP_201_CREATED) async def create_age( - age: CreateGeochronologyAge, session: Session = Depends(get_db_session) + user: viewer_dependency, age: CreateGeochronologyAge, session: session_dependency ): """ Create a new geochronology age entry. @@ -38,7 +37,7 @@ async def create_age( @router.get("/age", tags=["geochronology"]) async def get_geochronology_age( - method: str = "arar", session: Session = Depends(get_db_session) + user: viewer_dependency, session: session_dependency, method: str = "arar" ): """ Retrieve geochronology age data. diff --git a/api/geospatial.py b/api/geospatial.py index 86951601d..c3b39b3b1 100644 --- a/api/geospatial.py +++ b/api/geospatial.py @@ -23,7 +23,7 @@ # from starlette.responses import FileResponse -from core.dependencies import session_dependency +from core.dependencies import session_dependency, viewer_dependency from db import Group from schemas.thing import FeatureCollectionResponse from services.geospatial_helper import create_shapefile, get_thing_features @@ -34,6 +34,7 @@ @router.get("") async def get_geospatial( + user: viewer_dependency, session: session_dependency, thing_type: Annotated[List[str], Query(title="thing_type")] = None, group: Annotated[str | int, Query(title="group")] = None, @@ -61,7 +62,7 @@ async def get_geospatial( @router.get("/project-area/{group_id}", summary="Get project area for group") async def get_project_area( - session: session_dependency, group_id: int + user: viewer_dependency, session: session_dependency, group_id: int ) -> FeatureCollectionResponse: group = simple_get_by_id(session, Group, group_id) diff --git a/api/group.py b/api/group.py index a0983b0bf..39b53791b 100644 --- a/api/group.py +++ b/api/group.py @@ -14,7 +14,7 @@ # limitations under the License. # =============================================================================== -from fastapi import Depends, APIRouter, Query +from fastapi import APIRouter, Query from starlette.status import HTTP_201_CREATED, HTTP_204_NO_CONTENT from api.pagination import CustomPage @@ -22,7 +22,7 @@ session_dependency, admin_dependency, editor_dependency, - viewer_function, + viewer_dependency, ) from db.group import Group from schemas.group import UpdateGroup, CreateGroup, GroupResponse @@ -32,9 +32,7 @@ paginated_all_getter, ) -router = APIRouter( - prefix="/group", tags=["group"], dependencies=[Depends(viewer_function)] -) +router = APIRouter(prefix="/group", tags=["group"]) # POST ========================================================================= @@ -69,7 +67,9 @@ async def create_group( # ============= Get ============================================= @router.get("", summary="Get groups") async def get_groups( - session: session_dependency, filter_: str = Query(alias="filter", default=None) + user: viewer_dependency, + session: session_dependency, + filter_: str = Query(alias="filter", default=None), ) -> CustomPage[GroupResponse]: """ Retrieve all groups from the database. @@ -78,7 +78,9 @@ async def get_groups( @router.get("/{group_id}", summary="Get group by ID") -async def get_group_by_id(group_id: int, session: session_dependency) -> GroupResponse: +async def get_group_by_id( + user: viewer_dependency, group_id: int, session: session_dependency +) -> GroupResponse: """ Retrieve a group by ID from the database. """ diff --git a/api/observation.py b/api/observation.py index 74719912a..f4fa65ec4 100644 --- a/api/observation.py +++ b/api/observation.py @@ -21,6 +21,7 @@ from core.dependencies import ( session_dependency, amp_admin_dependency, + amp_editor_dependency, amp_viewer_dependency, ) from db import Observation @@ -101,7 +102,7 @@ async def update_groundwater_level_observation( observation_id: int, obs_data: UpdateGroundwaterLevelObservation, session: session_dependency, - user: amp_admin_dependency, + user: amp_editor_dependency, request: Request, ) -> GroundwaterLevelObservationResponse: """ @@ -115,7 +116,7 @@ async def update_water_chemistry_observation( observation_id: int, obs_data: UpdateWaterChemistryObservation, session: session_dependency, - user: amp_admin_dependency, + user: amp_editor_dependency, request: Request, ) -> WaterChemistryObservationResponse: """ diff --git a/api/publication.py b/api/publication.py index 53ffe69d6..751c0ec88 100644 --- a/api/publication.py +++ b/api/publication.py @@ -18,7 +18,7 @@ from schemas.publication import PublicationResponse, CreatePublication from services.publication_helper import add_publication from sqlalchemy.orm import Session - +from core.dependencies import admin_dependency router = APIRouter( prefix="/publication", @@ -30,6 +30,7 @@ "/add", response_model=PublicationResponse, status_code=status.HTTP_201_CREATED ) async def post_publication( + user: admin_dependency, publication_data: CreatePublication, # Replace with your actual schema session: Session = Depends( get_db_session diff --git a/api/search.py b/api/search.py index db1d9b661..852d09919 100644 --- a/api/search.py +++ b/api/search.py @@ -20,7 +20,7 @@ from fastapi_pagination import paginate from fastapi_pagination.utils import disable_installed_extensions_check -from core.dependencies import session_dependency +from core.dependencies import session_dependency, viewer_dependency from db import ( Contact, Email, @@ -158,6 +158,7 @@ def _get_asset_results(session: Session, q: str, limit: int) -> list[dict]: @router.get("") async def search_api( + user: viewer_dependency, session: session_dependency, q: str, limit: int = 25, diff --git a/api/thing.py b/api/thing.py index 4f8b6a81e..7e673de08 100644 --- a/api/thing.py +++ b/api/thing.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== -from fastapi import APIRouter, Depends, Query, Request +from fastapi import APIRouter, Query, Request from fastapi_pagination.ext.sqlalchemy import paginate from sqlalchemy import select from sqlalchemy.exc import ProgrammingError @@ -28,16 +28,9 @@ from core.app import public_route from core.dependencies import ( session_dependency, - amp_admin_dependency, admin_dependency, editor_dependency, - # amp_viewer_dependency, - # viewer_dependency, - # no_permission_dependency, - viewer_function, - amp_viewer_function, - # no_permission_function, - amp_editor_dependency, + viewer_dependency, ) from db.thing import Thing, WellScreen from db.thing import ThingIdLink @@ -73,9 +66,7 @@ ) from services.lexicon_helper import get_terms_by_category -router = APIRouter( - prefix="/thing", tags=["thing"], dependencies=[Depends(viewer_function)] -) +router = APIRouter(prefix="/thing", tags=["thing"]) def database_error_handler( @@ -150,6 +141,7 @@ def database_error_handler( @router.get("/water-well", summary="Get all water wells", status_code=HTTP_200_OK) async def get_water_wells( + user: viewer_dependency, session: session_dependency, request: Request, sort: str = None, @@ -168,7 +160,10 @@ async def get_water_wells( "/water-well/{thing_id}", summary="Get water well by ID", status_code=HTTP_200_OK ) async def get_well_by_id( - thing_id: int, session: session_dependency, request: Request + user: viewer_dependency, + thing_id: int, + session: session_dependency, + request: Request, ) -> WellResponse: """ Retrieve a water well by ID from the database. @@ -182,7 +177,10 @@ async def get_well_by_id( status_code=HTTP_200_OK, ) async def get_well_screens_by_well_id( - thing_id: int, session: session_dependency, request: Request + user: viewer_dependency, + thing_id: int, + session: session_dependency, + request: Request, ) -> CustomPage[WellScreenResponse]: """ Retrieve all well screens for a specific water well by its ID. @@ -195,9 +193,9 @@ async def get_well_screens_by_well_id( @router.get( "/well-screen", summary="Get well screens", - dependencies=[Depends(amp_viewer_function)], ) async def get_well_screens( + user: viewer_dependency, session: session_dependency, thing_id: int = None, ) -> CustomPage[WellScreenResponse]: @@ -213,10 +211,10 @@ async def get_well_screens( @router.get( "/well-screen/{wellscreen_id}", - dependencies=[Depends(amp_viewer_function)], summary="Get well screen by ID", ) async def get_well_screen_by_id( + user: viewer_dependency, session: session_dependency, wellscreen_id: int, ) -> WellScreenResponse: @@ -229,6 +227,7 @@ async def get_well_screen_by_id( @router.get("/spring", summary="Get all springs") async def get_springs( + user: viewer_dependency, session: session_dependency, request: Request, sort: str = None, @@ -245,7 +244,10 @@ async def get_springs( @router.get("/spring/{thing_id}", summary="Get spring by ID", status_code=HTTP_200_OK) async def get_spring_by_id( - thing_id: int, session: session_dependency, request: Request + user: viewer_dependency, + thing_id: int, + session: session_dependency, + request: Request, ) -> SpringResponse: """ Retrieve a spring by ID from the database. @@ -258,6 +260,7 @@ async def get_spring_by_id( summary="Get all thing links", ) async def get_thing_id_links( + user: viewer_dependency, session: session_dependency, filter_: str = Query(alias="filter", default=None), sort: str = None, @@ -275,6 +278,7 @@ async def get_thing_id_links( @public_route @router.get("/id-link/{link_id}", summary="Get thing links by link ID") async def get_thing_id_links( + user: viewer_dependency, link_id: int, session: session_dependency, ) -> ThingIdLinkResponse: @@ -287,6 +291,7 @@ async def get_thing_id_links( @public_route @router.get("", summary="Get all things", status_code=HTTP_200_OK) async def get_things( + user: viewer_dependency, session: session_dependency, # thing_id: int = None, within: str = None, @@ -314,7 +319,10 @@ async def get_things( @router.get("/{thing_id}", summary="Get thing by ID", status_code=HTTP_200_OK) async def get_thing_by_id( - thing_id: int, session: session_dependency, request: Request + user: viewer_dependency, + thing_id: int, + session: session_dependency, + request: Request, ) -> ThingResponse: """ Retrieve a thing by ID from the database. @@ -326,6 +334,7 @@ async def get_thing_by_id( @router.get("/{thing_id}/id-link", summary="Get thing links by thing ID") async def get_thing_id_links( + user: viewer_dependency, thing_id: int, session: session_dependency, ) -> CustomPage[ThingIdLinkResponse]: @@ -366,7 +375,7 @@ async def create_well( thing_data: CreateWell, session: session_dependency, request: Request, - user: amp_admin_dependency, + user: admin_dependency, ) -> WellResponse: """ Create a new water well in the database. @@ -386,7 +395,7 @@ async def create_spring( thing_data: CreateSpring, session: session_dependency, request: Request, - user: amp_admin_dependency, + user: admin_dependency, ) -> SpringResponse: """ Create a new well in the database. @@ -404,7 +413,7 @@ async def create_spring( ) async def create_wellscreen( session: session_dependency, - user: amp_admin_dependency, + user: admin_dependency, well_screen_data: CreateWellScreen, ) -> WellScreenResponse: """ @@ -430,7 +439,7 @@ async def update_water_well( thing_id: int, thing_data: UpdateWell, session: session_dependency, - user: amp_editor_dependency, + user: editor_dependency, request: Request, ) -> WellResponse: """ @@ -448,7 +457,7 @@ async def update_spring( thing_id: int, thing_data: UpdateSpring, session: session_dependency, - user: amp_editor_dependency, + user: editor_dependency, request: Request, ) -> SpringResponse: """ @@ -496,7 +505,7 @@ async def update_well_screen( async def delete_thing( thing_id: int, session: session_dependency, - user: editor_dependency, + user: admin_dependency, ) -> None: """ Delete a thing by ID. @@ -512,7 +521,7 @@ async def delete_thing( async def delete_well_screen( well_screen_id: int, session: session_dependency, - user: editor_dependency, + user: admin_dependency, ) -> None: """ Delete a well screen by ID. @@ -528,7 +537,7 @@ async def delete_well_screen( async def delete_thing_id_link( link_id: int, session: session_dependency, - user: editor_dependency, + user: admin_dependency, ) -> None: """ Delete a thing link by ID. diff --git a/tests/test_observation.py b/tests/test_observation.py index f7c206699..ec1c73060 100644 --- a/tests/test_observation.py +++ b/tests/test_observation.py @@ -18,6 +18,7 @@ amp_admin_function, admin_function, amp_viewer_function, + amp_editor_function, viewer_function, ) from main import app @@ -33,6 +34,9 @@ def override_authentication_dependency_fixture(): app.dependency_overrides[admin_function] = override_authentication( default={"name": "foobar", "sub": "1234567890"} ) + app.dependency_overrides[amp_editor_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) app.dependency_overrides[amp_viewer_function] = override_authentication() app.dependency_overrides[viewer_function] = override_authentication() diff --git a/tests/test_publication.py b/tests/test_publication.py index 5457bd9b2..8fa8faf12 100644 --- a/tests/test_publication.py +++ b/tests/test_publication.py @@ -13,7 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== -from tests import client +import pytest + +from main import app +from core.dependencies import admin_function, editor_function, viewer_function +from tests import client, override_authentication + + +@pytest.fixture(scope="module", autouse=True) +def override_dependencies_fixture(): + app.dependency_overrides[admin_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[editor_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[viewer_function] = override_authentication() + + yield + + app.dependency_overrides = {} def test_add_publication(): diff --git a/tests/test_search.py b/tests/test_search.py index e7619b0c9..6109bdf72 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -19,7 +19,24 @@ from db import search from db.contact import Contact, Phone, Email from db.engine import session_ctx -from tests import client +from main import app +from core.dependencies import admin_function, editor_function, viewer_function +from tests import client, override_authentication + + +@pytest.fixture(scope="module", autouse=True) +def override_dependencies_fixture(): + app.dependency_overrides[admin_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[editor_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[viewer_function] = override_authentication() + + yield + + app.dependency_overrides = {} def test_search_api(water_well_thing, spring_thing, contact):