diff --git a/api/group.py b/api/group.py index 02cb03727..d1df69f1c 100644 --- a/api/group.py +++ b/api/group.py @@ -15,7 +15,7 @@ # =============================================================================== from fastapi import Depends, APIRouter, Query -from starlette import status +from starlette.status import HTTP_201_CREATED, HTTP_204_NO_CONTENT from api.pagination import CustomPage from core.dependencies import ( @@ -25,10 +25,9 @@ viewer_function, ) from db import adder -from db.group import Group, GroupThingAssociation +from db.group import Group from schemas.group import UpdateGroup, CreateGroup, GroupResponse -from schemas.location import CreateGroupThing -from services.crud_helper import model_patcher +from services.crud_helper import model_patcher, model_deleter from services.query_helper import ( simple_get_by_id, paginated_all_getter, @@ -38,8 +37,10 @@ prefix="/group", tags=["group"], dependencies=[Depends(viewer_function)] ) +# POST ========================================================================= -@router.post("", summary="Create a new group", status_code=status.HTTP_201_CREATED) + +@router.post("", summary="Create a new group", status_code=HTTP_201_CREATED) def create_group( group_data: CreateGroup, session: session_dependency, user: admin_dependency ) -> GroupResponse: @@ -110,4 +111,14 @@ async def update_group( return model_patcher(session, Group, group_id, group_data, user=user) +# DELETE ======================================================================= +@router.delete( + "/{group_id}", summary="Delete a group by ID", status_code=HTTP_204_NO_CONTENT +) +async def delete_group( + user: admin_dependency, group_id: int, session: session_dependency +): + return model_deleter(session, Group, group_id) + + # ============= EOF ============================================= diff --git a/api/location.py b/api/location.py index 666e09fe3..3352168ae 100644 --- a/api/location.py +++ b/api/location.py @@ -13,17 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== -from fastapi import Depends, Query, Response +from fastapi import Query, Response from fastapi_pagination.ext.sqlalchemy import paginate from sqlalchemy import select, func -from sqlalchemy.orm import Session from starlette import status from api.pagination import CustomPage from constants import SRID_WGS84 -from core.dependencies import session_dependency +from core.dependencies import ( + session_dependency, + admin_dependency, + editor_dependency, + viewer_dependency, +) from db import adder from db.location import Location -from db.engine import get_db_session from schemas.location import CreateLocation, LocationResponse, UpdateLocation from services.geospatial_helper import make_within_wkt from services.query_helper import make_query, order_sort_filter, simple_get_by_id @@ -41,12 +44,12 @@ status_code=status.HTTP_201_CREATED, ) def create_location( - location_data: CreateLocation, session: Session = Depends(get_db_session) + location_data: CreateLocation, session: session_dependency, user: admin_dependency ) -> LocationResponse: """ Create a new sample location in the database. """ - return adder(session, Location, location_data) + return adder(session, Location, location_data, user=user) @router.patch( @@ -56,12 +59,13 @@ def create_location( def update_location( location_id: int, location_data: UpdateLocation, - session: Session = Depends(get_db_session), + session: session_dependency, + user: editor_dependency, ) -> LocationResponse: """ Update a sample location in the database. """ - return model_patcher(session, Location, location_id, location_data) + return model_patcher(session, Location, location_id, location_data, user=user) # @router.get("/shapefile", summary="Get location as shapefile") @@ -125,6 +129,7 @@ def update_location( ) async def get_location( session: session_dependency, + user: viewer_dependency, nearby_point: str = None, nearby_distance_km: float = 1, within: str = None, @@ -160,7 +165,7 @@ async def get_location( summary="Get location by ID", ) async def get_location_by_id( - location_id: int, session: Session = Depends(get_db_session) + location_id: int, session: session_dependency, user: viewer_dependency ) -> LocationResponse: """ Retrieve a sample location by ID from the database. @@ -171,7 +176,7 @@ async def get_location_by_id( @router.delete("/{location_id}", summary="Delete location by ID") async def delete_location( - location_id: int, session: Session = Depends(get_db_session) + location_id: int, session: session_dependency, user: admin_dependency ) -> Response: """ Delete a sample location by ID from the database. diff --git a/api/sample.py b/api/sample.py index 9a3755093..c9107237c 100644 --- a/api/sample.py +++ b/api/sample.py @@ -14,15 +14,18 @@ # limitations under the License. # =============================================================================== -from fastapi import APIRouter, Depends, Query, Response +from fastapi import APIRouter, Query, Response from sqlalchemy.exc import IntegrityError, ProgrammingError -from sqlalchemy.orm import Session from starlette.status import HTTP_201_CREATED, HTTP_409_CONFLICT from api.pagination import CustomPage -from core.dependencies import session_dependency +from core.dependencies import ( + session_dependency, + admin_dependency, + editor_dependency, + viewer_dependency, +) from db import adder -from db.engine import get_db_session from db.sample import Sample from schemas import ResourceNotFoundResponse from schemas.sample import SampleResponse, CreateSample, UpdateSample @@ -70,13 +73,13 @@ def database_error_handler( # ============= Post ============================================= @router.post("", status_code=HTTP_201_CREATED) def add_sample( - sample_data: CreateSample, session: session_dependency + sample_data: CreateSample, session: session_dependency, user: admin_dependency ) -> SampleResponse: """ Endpoint to add a sample. """ try: - return adder(session, Sample, sample_data) + return adder(session, Sample, sample_data, user=user) except (IntegrityError, ProgrammingError) as e: database_error_handler(sample_data, e) @@ -86,7 +89,8 @@ def add_sample( def update_sample( sample_id: int, sample_data: UpdateSample, - session: Session = Depends(get_db_session), + session: session_dependency, + user: editor_dependency, ) -> SampleResponse | ResourceNotFoundResponse: """ Endpoint to update a sample. @@ -105,7 +109,7 @@ def update_sample( the update. """ try: - return model_patcher(session, Sample, sample_id, sample_data) + return model_patcher(session, Sample, sample_id, sample_data, user=user) except (IntegrityError, ProgrammingError) as e: database_error_handler(sample_data, e) @@ -114,6 +118,7 @@ def update_sample( @router.get("", summary="Get Samples") def get_samples( session: session_dependency, + user: viewer_dependency, sort: str = None, order: str = None, filter_: str = Query(alias="filter", default=None), @@ -129,7 +134,7 @@ def get_samples( @router.get("/{sample_id}", summary="Get Sample by ID") def get_sample_by_id( - sample_id: int, session: session_dependency + sample_id: int, session: session_dependency, user: viewer_dependency ) -> SampleResponse | ResourceNotFoundResponse: """ Endpoint to retrieve a sample by its ID. @@ -141,7 +146,9 @@ def get_sample_by_id( @router.delete("/{sample_id}", summary="Delete Sample by ID") -def delete_sample_by_id(sample_id: int, session: session_dependency) -> Response: +def delete_sample_by_id( + sample_id: int, session: session_dependency, user: admin_dependency +) -> Response: return model_deleter(session, Sample, sample_id) diff --git a/api/sensor.py b/api/sensor.py index bb92c4232..5239f3175 100644 --- a/api/sensor.py +++ b/api/sensor.py @@ -20,7 +20,12 @@ from starlette import status from api.pagination import CustomPage -from core.dependencies import session_dependency +from core.dependencies import ( + session_dependency, + admin_dependency, + editor_dependency, + viewer_dependency, +) from db import adder, Observation, Sample from db.sensor import Sensor from schemas.sensor import SensorResponse, CreateSensor, UpdateSensor @@ -35,12 +40,12 @@ @router.post("", status_code=status.HTTP_201_CREATED) def add_sensor( - sensor_data: CreateSensor, session: session_dependency + sensor_data: CreateSensor, session: session_dependency, user: admin_dependency ) -> SensorResponse: """ Add a sensor to the system. """ - return adder(session, Sensor, sensor_data) + return adder(session, Sensor, sensor_data, user=user) # ====== PATCH ================================================================= @@ -48,7 +53,10 @@ def add_sensor( @router.patch("/{sensor_id}", status_code=status.HTTP_200_OK) def update_sensor( - sensor_id: int, sensor_data: UpdateSensor, session: session_dependency + sensor_id: int, + sensor_data: UpdateSensor, + session: session_dependency, + user: editor_dependency, ) -> SensorResponse: """ Update a sensor in the system. @@ -97,14 +105,16 @@ def update_sensor( status_code=status.HTTP_409_CONFLICT, detail=[detail] ) - return model_patcher(session, Sensor, sensor_id, sensor_data) + return model_patcher(session, Sensor, sensor_id, sensor_data, user=user) # ====== DELETE ================================================================ @router.delete("/{sensor_id}") -def delete_sensor(sensor_id: int, session: session_dependency) -> Response: +def delete_sensor( + sensor_id: int, session: session_dependency, user: admin_dependency +) -> Response: """ Delete a sensor in the system """ @@ -117,6 +127,7 @@ def delete_sensor(sensor_id: int, session: session_dependency) -> Response: @router.get("", status_code=status.HTTP_200_OK) def get_sensors( session: session_dependency, + user: viewer_dependency, thing_id: int = None, # Optional filter for thing_id observed_property: str = None, # Optional filter for observed_property sort: str | None = None, @@ -151,7 +162,9 @@ def get_sensors( @router.get("/{sensor_id}", status_code=status.HTTP_200_OK) -def get_sensor(sensor_id: int, session: session_dependency) -> SensorResponse: +def get_sensor( + sensor_id: int, session: session_dependency, user: viewer_dependency +) -> SensorResponse: """ Retrieve a sensor by its ID. """ diff --git a/schemas/group.py b/schemas/group.py index f61282331..c5f525748 100644 --- a/schemas/group.py +++ b/schemas/group.py @@ -15,7 +15,8 @@ # =============================================================================== from geoalchemy2 import WKBElement from geoalchemy2.shape import to_shape -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, model_validator +from typing_extensions import Self from schemas import ORMBaseModel from services.validation.geospatial import validate_wkt_geometry @@ -23,14 +24,16 @@ class ValidateGroup(BaseModel): project_area: str | None = None - description: str | None = None parent_group_id: int | None = None - @classmethod @field_validator("project_area") def validate_area_is_wkt(cls, wkt): - return validate_wkt_geometry(wkt) + valid_wkt = validate_wkt_geometry(wkt) + if "MULTIPOLYGON" not in valid_wkt: + raise ValueError("WKT must be a valid MULTIPOLYGON") + + return valid_wkt # -------- CREATE ---------- @@ -49,25 +52,16 @@ class GroupResponse(ORMBaseModel): This model can be extended to include additional fields as needed. """ - id: int name: str - description: str | None = None - parent_group_id: int | None = None - - @classmethod - @field_validator("project_area", mode="before") - def project_area_to_wkt(cls, value): - if not value: - return value - - if isinstance(value, WKBElement): - return to_shape(value).wkt - - # If the value is a string, assume it's already in WKT format - if isinstance(value, str): - return value - - return None + project_area: str | None + description: str | None + parent_group_id: int | None + + @model_validator(mode="before") + def project_area_to_wkt(self: Self) -> Self: + if isinstance(self.project_area, WKBElement): + self.project_area = to_shape(self.project_area).wkt + return self # -------- UPDATE ---------- diff --git a/tests/conftest.py b/tests/conftest.py index 0fc57fb1f..db72abf53 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,21 @@ def location(): session.close() +@pytest.fixture(scope="function") +def second_location(): + with session_ctx() as session: + location = Location( + name="second location", + point="POINT (10.2 10.2)", + release_status="draft", + ) + session.add(location) + session.commit() + yield location + session.delete(location) + session.commit() + + @pytest.fixture(scope="session") def thing(location): with session_ctx() as session: @@ -53,6 +68,26 @@ def sensor(): session.close() +@pytest.fixture(scope="function") +def second_sensor(): + with session_ctx() as session: + sensor = Sensor( + name="Test Sensor 2", + model="Model X", + serial_no="123456", + datetime_installed="2023-01-01T00:00:00Z", + datetime_removed="2023-01-02T00:00:00Z", + recording_interval=60, + notes="Test equipment", + ) + session.add(sensor) + session.commit() + yield sensor + session.delete(sensor) + session.commit() + session.close() + + @pytest.fixture(scope="session") def sample(thing, sensor): with session_ctx() as session: @@ -78,6 +113,32 @@ def sample(thing, sensor): session.close() +@pytest.fixture(scope="function") +def second_sample(thing, sensor): + with session_ctx() as session: + sample = Sample( + thing_id=thing.id, + sample_type="groundwater", + field_sample_id="FS-9999999", + sample_date="2025-01-01T00:00:00Z", + release_status="draft", + sampler_name="Test Sampler", + qc_sample="Duplicate", + sensor_id=sensor.id, + sample_matrix="water", + sample_method="manual", + duplicate_sample_number=3, + sample_top=2, + sample_bottom=3, + ) + session.add(sample) + session.commit() + yield sample + session.delete(sample) + session.commit() + session.close() + + @pytest.fixture(scope="session") def contact(thing): with session_ctx() as session: @@ -148,6 +209,80 @@ def phone(contact): session.close() +@pytest.fixture(scope="function") +def second_contact(): + with session_ctx() as session: + contact = Contact( + name="Test Second Contact", + role="Owner", + ) + session.add(contact) + session.commit() + session.refresh(contact) + + yield contact + + session.delete(contact) + session.commit() + session.close() + + +@pytest.fixture(scope="function") +def second_email(second_contact): + with session_ctx() as session: + email = Email( + email="testsecondcontact@gmail.com", + email_type="Primary", + contact_id=second_contact.id, + ) + session.add(email) + session.commit() + session.refresh(email) + yield email + session.delete(email) + session.commit() + session.close() + + +@pytest.fixture(scope="function") +def second_phone(second_contact): + with session_ctx() as session: + phone = Phone( + phone_number="123-456-7890", + phone_type="Primary", + contact_id=second_contact.id, + ) + session.add(phone) + session.commit() + session.refresh(phone) + yield phone + session.delete(phone) + session.commit() + session.close() + + +@pytest.fixture(scope="function") +def second_address(second_contact): + with session_ctx() as session: + address = Address( + address_line_1="456 Secondary St", + address_line_2="Apt 12A", + city="Test Metropolis", + state="NM", + postal_code="87501", + country="United States", + address_type="Primary", + contact_id=second_contact.id, + ) + session.add(address) + session.commit() + session.refresh(address) + yield address + session.delete(address) + session.commit() + session.close() + + @pytest.fixture(scope="session") def asset(): with session_ctx() as session: @@ -274,3 +409,70 @@ def geothermal_observation(sensor, sample): yield observation session.close() + + +@pytest.fixture(scope="function") +def observation_to_delete(sample, sensor): + with session_ctx() as session: + observation = Observation( + observation_datetime="2019-01-01T00:03:00Z", + sample_id=sample.id, + sensor_id=sensor.id, + observed_property="water chemistry:pH", + release_status="draft", + value=4.0, + unit="dimensionless", + ) + session.add(observation) + session.commit() + yield observation + + +@pytest.fixture(scope="session") +def group(thing): + with session_ctx() as session: + group = Group( + name="Test Group", + description="This is a test group.", + project_area="MULTIPOLYGON(((-107.2 33.6, -106.6 33.6, -106.6 34.2, -107.2 34.2, -107.2 33.6)))", + ) + + session.add(group) + session.commit() + session.refresh(group) + + group_thing_association = GroupThingAssociation( + group_id=group.id, thing_id=thing.id + ) + session.add(group_thing_association) + session.commit() + session.refresh(group_thing_association) + + yield group + + session.close() + + +@pytest.fixture(scope="function") +def second_group(thing): + with session_ctx() as session: + group = Group( + name="Second Test Group", + description="This is a second test group.", + project_area="MULTIPOLYGON(((-107.2 33.6, -106.6 33.6, -106.6 34.2, 0 0, -107.2 34.2, -107.2 33.6)))", + ) + + session.add(group) + session.commit() + session.refresh(group) + + group_thing_association = GroupThingAssociation( + group_id=group.id, thing_id=thing.id + ) + session.add(group_thing_association) + session.commit() + session.refresh(group_thing_association) + + yield group + + session.close() diff --git a/tests/test_contact.py b/tests/test_contact.py index dd80ffea6..19ec570e4 100644 --- a/tests/test_contact.py +++ b/tests/test_contact.py @@ -4,7 +4,6 @@ amp_admin_function, ) from db import Contact, Address, Email, Phone -from db.engine import session_ctx from main import app from tests import client, cleanup_post_test, cleanup_patch_test, override_authentication from schemas.contact import ValidateEmail, ValidatePhone @@ -30,98 +29,6 @@ def override_authentication_dependency_fixture(): app.dependency_overrides = {} -# ============= module & function fixtures ======================================= - - -@pytest.fixture(scope="function") -def second_contact(): - with session_ctx() as session: - contact = Contact( - name="Test Second Contact", - role="Owner", - ) - session.add(contact) - session.commit() - session.refresh(contact) - - yield contact - - session.delete(contact) - session.commit() - session.close() - - -@pytest.fixture(scope="function") -def second_email(second_contact): - with session_ctx() as session: - email = Email( - email="testsecondcontact@gmail.com", - email_type="Primary", - contact_id=second_contact.id, - ) - session.add(email) - session.commit() - session.refresh(email) - yield email - session.delete(email) - session.commit() - session.close() - - -@pytest.fixture(scope="function") -def second_phone(second_contact): - with session_ctx() as session: - phone = Phone( - phone_number="123-456-7890", - phone_type="Primary", - contact_id=second_contact.id, - ) - session.add(phone) - session.commit() - session.refresh(phone) - yield phone - session.delete(phone) - session.commit() - session.close() - - -@pytest.fixture(scope="function") -def second_address(second_contact): - with session_ctx() as session: - address = Address( - address_line_1="456 Secondary St", - address_line_2="Apt 12A", - city="Test Metropolis", - state="NM", - postal_code="87501", - country="United States", - address_type="Primary", - contact_id=second_contact.id, - ) - session.add(address) - session.commit() - session.refresh(address) - yield address - session.delete(address) - session.commit() - session.close() - - -# @pytest.fixture(scope="function") -# def second_thing_contact_association(thing, second_contact): -# with session_ctx() as session: -# association = ThingContactAssociation( -# thing_id=thing.id, contact_id=second_contact.id -# ) -# session.add(association) -# session.commit() -# session.refresh(association) -# yield association -# session.delete(association) -# session.commit() -# session.close() - - # VALIDATION tests ============================================================= diff --git a/tests/test_group.py b/tests/test_group.py index e652339a7..4b9095591 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1,8 +1,12 @@ +from geoalchemy2.shape import to_shape +from pydantic import ValidationError import pytest +from db import Group from core.dependencies import admin_function, viewer_function, editor_function from main import app -from tests import client, override_authentication +from schemas.group import ValidateGroup +from tests import client, override_authentication, cleanup_post_test, cleanup_patch_test @pytest.fixture(scope="module", autouse=True) @@ -21,82 +25,144 @@ def override_authentication_dependency_fixture(): app.dependency_overrides = {} +# VALIDATION tests ============================================================= + + +def test_project_area_not_topologically_valid(): + wkt = "MULTIPOLYGON(((0 0, 1 1, 2 2, 0 0)))" + with pytest.raises( + ValidationError, match="WKT geometry is not topologically valid" + ): + ValidateGroup(project_area=wkt) + + +def test_project_area_invalid_wkt(): + for wkt in [ + "MULTIPOLYGON((0 0, 1 1, 2 2, 0 0))", + "0 0, 1 1, 2 2, 3 3, 4 5, 0 0", + ]: + with pytest.raises(ValidationError, match=r"Invalid WKT geometry: "): + ValidateGroup(project_area=wkt) + + +def test_project_area_not_multipolygon(): + for wkt in [ + "POINT (0 0)", + "LINESTRING (0 0, 1 1, 2 2, 3 3)", + "POLYGON ((0 0, 1 1, 2 2, 1 2, 0 0))", + ]: + with pytest.raises(ValidationError, match="WKT must be a valid MULTIPOLYGON"): + ValidateGroup(project_area=wkt) + + # ADD tests ====================================================== def test_add_group(): - response = client.post("/group", json={"name": "Test Group"}) + payload = { + "name": "Test Group", + "description": "This is a test group.", + "project_area": "MULTIPOLYGON (((0 0, 1 1, 2 2, 3 3, 4 4, 1 2, 0 0)))", + } + response = client.post("/group", json=payload) assert response.status_code == 201 data = response.json() assert "id" in data - assert data["name"] == "Test Group" + assert "created_at" in data + assert data["name"] == payload["name"] + assert data["description"] == payload["description"] + assert data["project_area"] == payload["project_area"] - -def test_add_group_with_area(): - response = client.post( - "/group", - json={ - "name": "Test Group with Project Area", - "project_area": "MULTIPOLYGON(((-107.2 33.6, -106.6 33.6, -106.6 34.2, -107.2 34.2, -107.2 33.6)))", - }, - ) - assert response.status_code == 201 - data = response.json() - - -# def test_add_group_thing(location, thing): -# response = client.post( -# "/group/association", json={"group_id": 2, "thing_id": thing.id} -# ) -# assert response.status_code == 201 -# -# data = response.json() -# assert "id" in data -# assert data["group_id"] == 2 -# assert data["thing_id"] == thing.id + cleanup_post_test(Group, data["id"]) # GET tests ====================================================== -def test_get_groups(): +def test_get_groups(group): response = client.get("/group") assert response.status_code == 200 - assert len(response.json()) > 0 + data = response.json() + assert data["total"] == 1 + assert data["items"][0]["id"] == group.id + assert data["items"][0]["created_at"] == group.created_at.isoformat().replace( + "+00:00", "Z" + ) + assert data["items"][0]["name"] == group.name + assert data["items"][0]["project_area"] == to_shape(group.project_area).wkt + assert data["items"][0]["description"] == group.description + assert data["items"][0]["parent_group_id"] == group.parent_group_id -@pytest.mark.skip +def test_get_group_by_id(group): + response = client.get(f"/group/{group.id}") + assert response.status_code == 200 + data = response.json() + assert data["id"] == group.id + assert data["created_at"] == group.created_at.isoformat().replace("+00:00", "Z") + assert data["name"] == group.name + assert data["project_area"] == to_shape(group.project_area).wkt + assert data["description"] == group.description + assert data["parent_group_id"] == group.parent_group_id + + +def test_get_group_by_id_404_not_found(group): + bad_id = 99999 + response = client.get(f"/group/{bad_id}") + assert response.status_code == 404 + data = response.json() + assert data["detail"] == f"Group with ID {bad_id} not found." + + +@pytest.mark.skip("associations not yet implemented") def test_get_group_things(): response = client.get("/group/association") assert response.status_code == 200 assert len(response.json()) > 0 -# test item retrieval via filter =========================================== +# PATCH tests ================================================================== -# Test item retrieval ====================================================== -# @pytest.mark.skip -# def test_item_get_spring(): -# response = client.get("/thing/spring/1") -# assert response.status_code == 200 -# data = response.json() -# assert data["id"] == 1 -# assert data["location_id"] == 1 +def test_patch_group(group): + payload = { + "name": "Updated Group", + } + response = client.patch(f"/group/{group.id}", json=payload) + assert response.status_code == 200 + data = response.json() + assert data["id"] == group.id + assert data["name"] == payload["name"] + cleanup_patch_test(Group, payload, group) -def test_item_get_group(): - response = client.get("/group/2") - assert response.status_code == 200 + +def test_patch_group_404_not_found(group): + payload = {"name": "Failed group patch"} + bad_id = 99999 + response = client.patch(f"/group/{bad_id}", json=payload) + assert response.status_code == 404 + data = response.json() + assert data["detail"] == f"Group with ID {bad_id} not found." + + +# DELETE tests ================================================================= + + +def test_delete_group(second_group): + response = client.delete(f"/group/{second_group.id}") + assert response.status_code == 204 + + # verify deletion + response = client.get(f"/group/{second_group.id}") + assert response.status_code == 404 data = response.json() - assert data["id"] == 2 - assert data["name"] == "Test Group" + assert data["detail"] == f"Group with ID {second_group.id} not found." -# def test_item_get_group_thing(location, thing): -# response = client.get("/group/association/1") -# assert response.status_code == 200 -# data = response.json() -# assert data["id"] == 1 -# assert data["group_id"] == 2 -# assert data["thing_id"] == thing.id +def test_delete_group_404_not_found(second_group): + bad_id = 99999 + response = client.delete(f"/group/{bad_id}") + assert response.status_code == 404 + data = response.json() + assert data["detail"] == f"Group with ID {bad_id} not found." diff --git a/tests/test_location.py b/tests/test_location.py index ee98770d1..cee06ab2b 100644 --- a/tests/test_location.py +++ b/tests/test_location.py @@ -16,26 +16,25 @@ from geoalchemy2.shape import to_shape import pytest +from core.dependencies import admin_function, editor_function, viewer_function from db import Location -from db.engine import session_ctx -from tests import client +from main import app +from tests import client, override_authentication, cleanup_post_test, cleanup_patch_test -# ============= module & function fixtures ======================================= +@pytest.fixture(scope="module", autouse=True) +def override_dependencies_fixture(): + app.dependency_overrides[admin_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[editor_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[viewer_function] = override_authentication() -@pytest.fixture(scope="function") -def second_location(): - with session_ctx() as session: - location = Location( - name="second location", - point="POINT (10.2 10.2)", - release_status="draft", - ) - session.add(location) - session.commit() - yield location - session.delete(location) - session.commit() + yield + + app.dependency_overrides = {} # ============= Post tests for locations ====================================== @@ -57,35 +56,23 @@ def test_add_location(): assert data["release_status"] == payload["release_status"] # cleanup after test - with session_ctx() as session: - session.delete(session.get(Location, data["id"])) - session.commit() + cleanup_post_test(Location, data["id"]) # ============= Patch tests for locations ===================================== def test_update_location(location): - location_id = location.id - response = client.patch( - f"/location/{location_id}", - json={ - "point": "POINT (10.1 20.2)", - "release_status": "draft", - }, - ) + payload = {"point": "POINT (10.1 20.2)", "release_status": "draft"} + response = client.patch(f"/location/{location.id}", json=payload) assert response.status_code == 200 data = response.json() - assert data["id"] == location_id - assert data["point"] == "POINT (10.1 20.2)" - assert data["release_status"] == "draft" + assert data["id"] == location.id + assert data["point"] == payload["point"] + assert data["release_status"] == payload["release_status"] # cleanup after test - with session_ctx() as session: - updated_location = session.get(Location, location_id) - updated_location.point = location.point - updated_location.release_status = location.release_status - session.commit() + cleanup_patch_test(Location, payload, location) def test_patch_location_404_not_found(location): diff --git a/tests/test_observation.py b/tests/test_observation.py index f1c94483c..6dcfc85de 100644 --- a/tests/test_observation.py +++ b/tests/test_observation.py @@ -14,7 +14,6 @@ # limitations under the License. # =============================================================================== from db import Observation -from db.engine import session_ctx from core.dependencies import ( amp_admin_function, admin_function, @@ -42,23 +41,6 @@ def override_authentication_dependency_fixture(): app.dependency_overrides = {} -@pytest.fixture(scope="function") -def observation_to_delete(sample, sensor): - with session_ctx() as session: - observation = Observation( - observation_datetime="2019-01-01T00:03:00Z", - sample_id=sample.id, - sensor_id=sensor.id, - observed_property="water chemistry:pH", - release_status="draft", - value=4.0, - unit="dimensionless", - ) - session.add(observation) - session.commit() - yield observation - - # ============= Post tests ================= def test_add_water_chemistry_observation(sample, sensor): payload = { diff --git a/tests/test_sample.py b/tests/test_sample.py index 6004a60ed..d3d432413 100644 --- a/tests/test_sample.py +++ b/tests/test_sample.py @@ -16,38 +16,26 @@ import pytest from pydantic import ValidationError -from db.engine import session_ctx +from main import app +from core.dependencies import admin_function, editor_function, viewer_function from db.sample import Sample from schemas.sample import ValidateSample -from tests import client - -# ============= module & function fixtures ======================================= - - -@pytest.fixture(scope="function") -def second_sample(thing, sensor): - with session_ctx() as session: - sample = Sample( - thing_id=thing.id, - sample_type="groundwater", - field_sample_id="FS-9999999", - sample_date="2025-01-01T00:00:00Z", - release_status="draft", - sampler_name="Test Sampler", - qc_sample="Duplicate", - sensor_id=sensor.id, - sample_matrix="water", - sample_method="manual", - duplicate_sample_number=3, - sample_top=2, - sample_bottom=3, - ) - session.add(sample) - session.commit() - yield sample - session.delete(sample) - session.commit() - session.close() +from tests import client, cleanup_post_test, cleanup_patch_test, override_authentication + + +@pytest.fixture(scope="module", autouse=True) +def override_dependencies_fixture(): + app.dependency_overrides[admin_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[editor_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[viewer_function] = override_authentication() + + yield + + app.dependency_overrides = {} # ============== Custom validators ================================================= @@ -105,9 +93,7 @@ def test_add_sample(thing, sensor): assert data["sample_bottom"] == payload["sample_bottom"] # cleanup after adding the sample - with session_ctx() as session: - session.delete(session.get(Sample, data["id"])) - session.commit() + cleanup_post_test(Sample, data["id"]) def test_409_add_sample_invalid_field_sample_id(sample, thing): @@ -183,32 +169,22 @@ def test_patch_sample(sample): """ Test updating a sample. """ - new_sampler_name = "Test Sampler B" - new_sample_method = "continuous" - new_sample_date = "2025-01-02T00:00:00Z" - response = client.patch( - f"/sample/{sample.id}", - json={ - "sampler_name": new_sampler_name, - "sample_method": new_sample_method, - "sample_date": new_sample_date, - }, - ) + payload = { + "sampler_name": "test sample b", + "sample_method": "continuous", + "sample_date": "2025-01-02T00:00:00Z", + } + response = client.patch(f"/sample/{sample.id}", json=payload) assert response.status_code == 200 data = response.json() assert data["id"] == sample.id - assert data["sampler_name"] == new_sampler_name - assert data["sample_date"] == new_sample_date - assert data["sample_method"] == new_sample_method + assert data["sampler_name"] == payload["sampler_name"] + assert data["sample_date"] == payload["sample_date"] + assert data["sample_method"] == payload["sample_method"] # rollback after updating the sample - with session_ctx() as session: - updated_sample = session.get(Sample, sample.id) - updated_sample.sampler_name = sample.sampler_name - updated_sample.sample_method = sample.sample_method - updated_sample.sample_date = sample.sample_date - session.commit() + cleanup_patch_test(Sample, payload, sample) def test_patch_sample_404_not_found(sample): diff --git a/tests/test_sensor.py b/tests/test_sensor.py index e01cbb035..7b514059b 100644 --- a/tests/test_sensor.py +++ b/tests/test_sensor.py @@ -13,35 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== +from core.dependencies import admin_function, editor_function, viewer_function from db import Sensor -from db.engine import session_ctx +from main import app from schemas.sensor import ValidateSensor -from tests import client, cleanup_post_test, cleanup_patch_test +from tests import client, cleanup_post_test, cleanup_patch_test, override_authentication import pytest from pydantic import ValidationError -# ====== module functions and fixtures ========================================= +@pytest.fixture(scope="module", autouse=True) +def override_dependencies_fixture(): + app.dependency_overrides[admin_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[editor_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[viewer_function] = override_authentication() -@pytest.fixture(scope="function") -def second_sensor(): - with session_ctx() as session: - sensor = Sensor( - name="Test Sensor 2", - model="Model X", - serial_no="123456", - datetime_installed="2023-01-01T00:00:00Z", - datetime_removed="2023-01-02T00:00:00Z", - recording_interval=60, - notes="Test equipment", - ) - session.add(sensor) - session.commit() - yield sensor - session.delete(sensor) - session.commit() - session.close() + yield + + app.dependency_overrides = {} # ====== VALIDATION tests ======================================================