From 574d9793031b5d22547851ad0ae62783794bf98c Mon Sep 17 00:00:00 2001 From: Jacob Brown Date: Fri, 22 Aug 2025 12:59:17 -0600 Subject: [PATCH 01/11] fix: use model validator convert wkbelement to str --- schemas/group.py | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/schemas/group.py b/schemas/group.py index f61282331..6013c7c5c 100644 --- a/schemas/group.py +++ b/schemas/group.py @@ -15,7 +15,8 @@ # =============================================================================== from geoalchemy2 import WKBElement from geoalchemy2.shape import to_shape -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, model_validator +from typing_extensions import Self from schemas import ORMBaseModel from services.validation.geospatial import validate_wkt_geometry @@ -23,7 +24,6 @@ class ValidateGroup(BaseModel): project_area: str | None = None - description: str | None = None parent_group_id: int | None = None @@ -49,25 +49,16 @@ class GroupResponse(ORMBaseModel): This model can be extended to include additional fields as needed. """ - id: int name: str - description: str | None = None - parent_group_id: int | None = None - - @classmethod - @field_validator("project_area", mode="before") - def project_area_to_wkt(cls, value): - if not value: - return value - - if isinstance(value, WKBElement): - return to_shape(value).wkt - - # If the value is a string, assume it's already in WKT format - if isinstance(value, str): - return value - - return None + project_area: str | None + description: str | None + parent_group_id: int | None + + @model_validator(mode="before") + def project_area_to_wkt(self: Self) -> Self: + if isinstance(self.project_area, WKBElement): + self.project_area = to_shape(self.project_area).wkt + return self # -------- UPDATE ---------- From a8422dcc2bcf45a90d2717e78dbbe4f900ab7b65 Mon Sep 17 00:00:00 2001 From: Jacob Brown Date: Fri, 22 Aug 2025 13:00:05 -0600 Subject: [PATCH 02/11] refactor: update POST group test --- tests/conftest.py | 50 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_group.py | 39 ++++++++++++----------------------- 2 files changed, 63 insertions(+), 26 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f193441aa..e0b8f680e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -246,3 +246,53 @@ def geothermal_observation(sensor, sample): yield observation session.close() + + +@pytest.fixture(scope="session") +def group(thing): + with session_ctx() as session: + group = Group( + name="Test Group", + description="This is a test group.", + project_area="MULTIPOLYGON(((-107.2 33.6, -106.6 33.6, -106.6 34.2, -107.2 34.2, -107.2 33.6)))", + ) + + session.add(group) + session.commit() + session.refresh(group) + + group_thing_association = GroupThingAssociation( + group_id=group.id, thing_id=thing.id + ) + session.add(group_thing_association) + session.commit() + session.refresh(group_thing_association) + + yield group + + session.close() + + +@pytest.fixture(scope="function") +def second_group(thing): + with session_ctx() as session: + group = Group( + name="Second Test Group", + description="This is a second test group.", + project_area="MULTIPOLYGON(((-107.2 33.6, -106.6 33.6, -106.6 34.2, 0 0, -107.2 34.2, -107.2 33.6)))", + ) + + session.add(group) + session.commit() + session.refresh(group) + + group_thing_association = GroupThingAssociation( + group_id=group.id, thing_id=thing.id + ) + session.add(group_thing_association) + session.commit() + session.refresh(group_thing_association) + + yield group + + session.close() diff --git a/tests/test_group.py b/tests/test_group.py index e652339a7..c10426f69 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1,8 +1,9 @@ import pytest +from db import Group from core.dependencies import admin_function, viewer_function, editor_function from main import app -from tests import client, override_authentication +from tests import client, override_authentication, cleanup_post_test @pytest.fixture(scope="module", autouse=True) @@ -25,35 +26,21 @@ def override_authentication_dependency_fixture(): def test_add_group(): - response = client.post("/group", json={"name": "Test Group"}) + payload = { + "name": "Test Group", + "description": "This is a test group.", + "project_area": "MULTIPOLYGON (((0 0, 1 1, 2 2, 3 3, 4 4, 1 2, 0 0)))", + } + response = client.post("/group", json=payload) assert response.status_code == 201 data = response.json() assert "id" in data - assert data["name"] == "Test Group" - - -def test_add_group_with_area(): - response = client.post( - "/group", - json={ - "name": "Test Group with Project Area", - "project_area": "MULTIPOLYGON(((-107.2 33.6, -106.6 33.6, -106.6 34.2, -107.2 34.2, -107.2 33.6)))", - }, - ) - assert response.status_code == 201 - data = response.json() - + assert "created_at" in data + assert data["name"] == payload["name"] + assert data["description"] == payload["description"] + assert data["project_area"] == payload["project_area"] -# def test_add_group_thing(location, thing): -# response = client.post( -# "/group/association", json={"group_id": 2, "thing_id": thing.id} -# ) -# assert response.status_code == 201 -# -# data = response.json() -# assert "id" in data -# assert data["group_id"] == 2 -# assert data["thing_id"] == thing.id + cleanup_post_test(Group, data["id"]) # GET tests ====================================================== From 0356c281d2de48f37730d3be34d670e4404f943d Mon Sep 17 00:00:00 2001 From: Jacob Brown Date: Fri, 22 Aug 2025 14:54:03 -0600 Subject: [PATCH 03/11] feat: implement GET group tests --- api/group.py | 5 ++-- tests/test_group.py | 66 ++++++++++++++++++++++----------------------- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/api/group.py b/api/group.py index 02cb03727..466a12495 100644 --- a/api/group.py +++ b/api/group.py @@ -25,9 +25,8 @@ viewer_function, ) from db import adder -from db.group import Group, GroupThingAssociation +from db.group import Group from schemas.group import UpdateGroup, CreateGroup, GroupResponse -from schemas.location import CreateGroupThing from services.crud_helper import model_patcher from services.query_helper import ( simple_get_by_id, @@ -38,6 +37,8 @@ prefix="/group", tags=["group"], dependencies=[Depends(viewer_function)] ) +# POST ========================================================================= + @router.post("", summary="Create a new group", status_code=status.HTTP_201_CREATED) def create_group( diff --git a/tests/test_group.py b/tests/test_group.py index c10426f69..4c224bd51 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1,3 +1,4 @@ +from geoalchemy2.shape import to_shape import pytest from db import Group @@ -46,44 +47,43 @@ def test_add_group(): # GET tests ====================================================== -def test_get_groups(): +def test_get_groups(group): response = client.get("/group") assert response.status_code == 200 - assert len(response.json()) > 0 - - -@pytest.mark.skip -def test_get_group_things(): - response = client.get("/group/association") - assert response.status_code == 200 - assert len(response.json()) > 0 - - -# test item retrieval via filter =========================================== - - -# Test item retrieval ====================================================== -# @pytest.mark.skip -# def test_item_get_spring(): -# response = client.get("/thing/spring/1") -# assert response.status_code == 200 -# data = response.json() -# assert data["id"] == 1 -# assert data["location_id"] == 1 + data = response.json() + assert data["total"] == 1 + assert data["items"][0]["id"] == group.id + assert data["items"][0]["created_at"] == group.created_at.isoformat().replace( + "+00:00", "Z" + ) + assert data["items"][0]["name"] == group.name + assert data["items"][0]["project_area"] == to_shape(group.project_area).wkt + assert data["items"][0]["description"] == group.description + assert data["items"][0]["parent_group_id"] == group.parent_group_id -def test_item_get_group(): - response = client.get("/group/2") +def test_get_group_by_id(group): + response = client.get(f"/group/{group.id}") assert response.status_code == 200 data = response.json() - assert data["id"] == 2 - assert data["name"] == "Test Group" + assert data["id"] == group.id + assert data["created_at"] == group.created_at.isoformat().replace("+00:00", "Z") + assert data["name"] == group.name + assert data["project_area"] == to_shape(group.project_area).wkt + assert data["description"] == group.description + assert data["parent_group_id"] == group.parent_group_id + + +def test_get_group_by_id_404_not_found(group): + bad_id = 99999 + response = client.get(f"/group/{bad_id}") + assert response.status_code == 404 + data = response.json() + assert data["detail"] == f"Group with ID {bad_id} not found." -# def test_item_get_group_thing(location, thing): -# response = client.get("/group/association/1") -# assert response.status_code == 200 -# data = response.json() -# assert data["id"] == 1 -# assert data["group_id"] == 2 -# assert data["thing_id"] == thing.id +@pytest.mark.skip("associations not yet implemented") +def test_get_group_things(): + response = client.get("/group/association") + assert response.status_code == 200 + assert len(response.json()) > 0 From 04d8b98250b015bfddec58500a7975c2622a5f7b Mon Sep 17 00:00:00 2001 From: Jacob Brown Date: Fri, 22 Aug 2025 15:15:45 -0600 Subject: [PATCH 04/11] feat: implement PATCH group tests --- tests/test_group.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/test_group.py b/tests/test_group.py index 4c224bd51..b9cbd5d63 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -4,7 +4,7 @@ from db import Group from core.dependencies import admin_function, viewer_function, editor_function from main import app -from tests import client, override_authentication, cleanup_post_test +from tests import client, override_authentication, cleanup_post_test, cleanup_patch_test @pytest.fixture(scope="module", autouse=True) @@ -87,3 +87,28 @@ def test_get_group_things(): response = client.get("/group/association") assert response.status_code == 200 assert len(response.json()) > 0 + + +# PATCH tests ================================================================== + + +def test_patch_group(group): + payload = { + "name": "Updated Group", + } + response = client.patch(f"/group/{group.id}", json=payload) + assert response.status_code == 200 + data = response.json() + assert data["id"] == group.id + assert data["name"] == payload["name"] + + cleanup_patch_test(Group, payload, group) + + +def test_patch_group_404_not_found(group): + payload = {"name": "Failed group patch"} + bad_id = 99999 + response = client.patch(f"/group/{bad_id}", json=payload) + assert response.status_code == 404 + data = response.json() + assert data["detail"] == f"Group with ID {bad_id} not found." From 265d92ccd8bf1e35f44595aff4e04ce1a1b21da2 Mon Sep 17 00:00:00 2001 From: Jacob Brown Date: Fri, 22 Aug 2025 15:23:15 -0600 Subject: [PATCH 05/11] feat: implement group DELETE tests and endpoint --- api/group.py | 16 +++++++++++++--- tests/test_group.py | 22 ++++++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/api/group.py b/api/group.py index 466a12495..d1df69f1c 100644 --- a/api/group.py +++ b/api/group.py @@ -15,7 +15,7 @@ # =============================================================================== from fastapi import Depends, APIRouter, Query -from starlette import status +from starlette.status import HTTP_201_CREATED, HTTP_204_NO_CONTENT from api.pagination import CustomPage from core.dependencies import ( @@ -27,7 +27,7 @@ from db import adder from db.group import Group from schemas.group import UpdateGroup, CreateGroup, GroupResponse -from services.crud_helper import model_patcher +from services.crud_helper import model_patcher, model_deleter from services.query_helper import ( simple_get_by_id, paginated_all_getter, @@ -40,7 +40,7 @@ # POST ========================================================================= -@router.post("", summary="Create a new group", status_code=status.HTTP_201_CREATED) +@router.post("", summary="Create a new group", status_code=HTTP_201_CREATED) def create_group( group_data: CreateGroup, session: session_dependency, user: admin_dependency ) -> GroupResponse: @@ -111,4 +111,14 @@ async def update_group( return model_patcher(session, Group, group_id, group_data, user=user) +# DELETE ======================================================================= +@router.delete( + "/{group_id}", summary="Delete a group by ID", status_code=HTTP_204_NO_CONTENT +) +async def delete_group( + user: admin_dependency, group_id: int, session: session_dependency +): + return model_deleter(session, Group, group_id) + + # ============= EOF ============================================= diff --git a/tests/test_group.py b/tests/test_group.py index b9cbd5d63..484e36f9f 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -112,3 +112,25 @@ def test_patch_group_404_not_found(group): assert response.status_code == 404 data = response.json() assert data["detail"] == f"Group with ID {bad_id} not found." + + +# DELETE tests ================================================================= + + +def test_delete_group(second_group): + response = client.delete(f"/group/{second_group.id}") + assert response.status_code == 204 + + # verify deletion + response = client.get(f"/group/{second_group.id}") + assert response.status_code == 404 + data = response.json() + assert data["detail"] == f"Group with ID {second_group.id} not found." + + +def test_delete_group_404_not_found(second_group): + bad_id = 99999 + response = client.delete(f"/group/{bad_id}") + assert response.status_code == 404 + data = response.json() + assert data["detail"] == f"Group with ID {bad_id} not found." From 2a594ec51bc81c09d9a6aed3670617de83720eca Mon Sep 17 00:00:00 2001 From: Jacob Brown Date: Fri, 22 Aug 2025 15:32:44 -0600 Subject: [PATCH 06/11] refactor: put all fixtures in conftest to enable use everywhere --- tests/conftest.py | 152 ++++++++++++++++++++++++++++++++++++++ tests/test_contact.py | 93 ----------------------- tests/test_location.py | 19 ----- tests/test_observation.py | 18 ----- tests/test_sample.py | 65 +++------------- tests/test_sensor.py | 23 ------ 6 files changed, 164 insertions(+), 206 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e0b8f680e..83a81ea97 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,21 @@ def location(): session.close() +@pytest.fixture(scope="function") +def second_location(): + with session_ctx() as session: + location = Location( + name="second location", + point="POINT (10.2 10.2)", + release_status="draft", + ) + session.add(location) + session.commit() + yield location + session.delete(location) + session.commit() + + @pytest.fixture(scope="session") def thing(location): with session_ctx() as session: @@ -53,6 +68,26 @@ def sensor(): session.close() +@pytest.fixture(scope="function") +def second_sensor(): + with session_ctx() as session: + sensor = Sensor( + name="Test Sensor 2", + model="Model X", + serial_no="123456", + datetime_installed="2023-01-01T00:00:00Z", + datetime_removed="2023-01-02T00:00:00Z", + recording_interval=60, + notes="Test equipment", + ) + session.add(sensor) + session.commit() + yield sensor + session.delete(sensor) + session.commit() + session.close() + + @pytest.fixture(scope="session") def sample(thing, sensor): with session_ctx() as session: @@ -78,6 +113,32 @@ def sample(thing, sensor): session.close() +@pytest.fixture(scope="function") +def second_sample(thing, sensor): + with session_ctx() as session: + sample = Sample( + thing_id=thing.id, + sample_type="groundwater", + field_sample_id="FS-9999999", + sample_date="2025-01-01T00:00:00Z", + release_status="draft", + sampler_name="Test Sampler", + qc_sample="Duplicate", + sensor_id=sensor.id, + sample_matrix="water", + sample_method="manual", + duplicate_sample_number=3, + sample_top=2, + sample_bottom=3, + ) + session.add(sample) + session.commit() + yield sample + session.delete(sample) + session.commit() + session.close() + + @pytest.fixture(scope="session") def contact(thing): with session_ctx() as session: @@ -148,6 +209,80 @@ def phone(contact): session.close() +@pytest.fixture(scope="function") +def second_contact(): + with session_ctx() as session: + contact = Contact( + name="Test Second Contact", + role="Owner", + ) + session.add(contact) + session.commit() + session.refresh(contact) + + yield contact + + session.delete(contact) + session.commit() + session.close() + + +@pytest.fixture(scope="function") +def second_email(second_contact): + with session_ctx() as session: + email = Email( + email="testsecondcontact@gmail.com", + email_type="Primary", + contact_id=second_contact.id, + ) + session.add(email) + session.commit() + session.refresh(email) + yield email + session.delete(email) + session.commit() + session.close() + + +@pytest.fixture(scope="function") +def second_phone(second_contact): + with session_ctx() as session: + phone = Phone( + phone_number="123-456-7890", + phone_type="Primary", + contact_id=second_contact.id, + ) + session.add(phone) + session.commit() + session.refresh(phone) + yield phone + session.delete(phone) + session.commit() + session.close() + + +@pytest.fixture(scope="function") +def second_address(second_contact): + with session_ctx() as session: + address = Address( + address_line_1="456 Secondary St", + address_line_2="Apt 12A", + city="Test Metropolis", + state="NM", + postal_code="87501", + country="United States", + address_type="Primary", + contact_id=second_contact.id, + ) + session.add(address) + session.commit() + session.refresh(address) + yield address + session.delete(address) + session.commit() + session.close() + + @pytest.fixture(scope="session") def asset(): with session_ctx() as session: @@ -248,6 +383,23 @@ def geothermal_observation(sensor, sample): session.close() +@pytest.fixture(scope="function") +def observation_to_delete(sample, sensor): + with session_ctx() as session: + observation = Observation( + observation_datetime="2019-01-01T00:03:00Z", + sample_id=sample.id, + sensor_id=sensor.id, + observed_property="water chemistry:pH", + release_status="draft", + value=4.0, + unit="dimensionless", + ) + session.add(observation) + session.commit() + yield observation + + @pytest.fixture(scope="session") def group(thing): with session_ctx() as session: diff --git a/tests/test_contact.py b/tests/test_contact.py index dd80ffea6..19ec570e4 100644 --- a/tests/test_contact.py +++ b/tests/test_contact.py @@ -4,7 +4,6 @@ amp_admin_function, ) from db import Contact, Address, Email, Phone -from db.engine import session_ctx from main import app from tests import client, cleanup_post_test, cleanup_patch_test, override_authentication from schemas.contact import ValidateEmail, ValidatePhone @@ -30,98 +29,6 @@ def override_authentication_dependency_fixture(): app.dependency_overrides = {} -# ============= module & function fixtures ======================================= - - -@pytest.fixture(scope="function") -def second_contact(): - with session_ctx() as session: - contact = Contact( - name="Test Second Contact", - role="Owner", - ) - session.add(contact) - session.commit() - session.refresh(contact) - - yield contact - - session.delete(contact) - session.commit() - session.close() - - -@pytest.fixture(scope="function") -def second_email(second_contact): - with session_ctx() as session: - email = Email( - email="testsecondcontact@gmail.com", - email_type="Primary", - contact_id=second_contact.id, - ) - session.add(email) - session.commit() - session.refresh(email) - yield email - session.delete(email) - session.commit() - session.close() - - -@pytest.fixture(scope="function") -def second_phone(second_contact): - with session_ctx() as session: - phone = Phone( - phone_number="123-456-7890", - phone_type="Primary", - contact_id=second_contact.id, - ) - session.add(phone) - session.commit() - session.refresh(phone) - yield phone - session.delete(phone) - session.commit() - session.close() - - -@pytest.fixture(scope="function") -def second_address(second_contact): - with session_ctx() as session: - address = Address( - address_line_1="456 Secondary St", - address_line_2="Apt 12A", - city="Test Metropolis", - state="NM", - postal_code="87501", - country="United States", - address_type="Primary", - contact_id=second_contact.id, - ) - session.add(address) - session.commit() - session.refresh(address) - yield address - session.delete(address) - session.commit() - session.close() - - -# @pytest.fixture(scope="function") -# def second_thing_contact_association(thing, second_contact): -# with session_ctx() as session: -# association = ThingContactAssociation( -# thing_id=thing.id, contact_id=second_contact.id -# ) -# session.add(association) -# session.commit() -# session.refresh(association) -# yield association -# session.delete(association) -# session.commit() -# session.close() - - # VALIDATION tests ============================================================= diff --git a/tests/test_location.py b/tests/test_location.py index ee98770d1..153b644be 100644 --- a/tests/test_location.py +++ b/tests/test_location.py @@ -14,30 +14,11 @@ # limitations under the License. # =============================================================================== from geoalchemy2.shape import to_shape -import pytest from db import Location from db.engine import session_ctx from tests import client -# ============= module & function fixtures ======================================= - - -@pytest.fixture(scope="function") -def second_location(): - with session_ctx() as session: - location = Location( - name="second location", - point="POINT (10.2 10.2)", - release_status="draft", - ) - session.add(location) - session.commit() - yield location - session.delete(location) - session.commit() - - # ============= Post tests for locations ====================================== diff --git a/tests/test_observation.py b/tests/test_observation.py index f1c94483c..6dcfc85de 100644 --- a/tests/test_observation.py +++ b/tests/test_observation.py @@ -14,7 +14,6 @@ # limitations under the License. # =============================================================================== from db import Observation -from db.engine import session_ctx from core.dependencies import ( amp_admin_function, admin_function, @@ -42,23 +41,6 @@ def override_authentication_dependency_fixture(): app.dependency_overrides = {} -@pytest.fixture(scope="function") -def observation_to_delete(sample, sensor): - with session_ctx() as session: - observation = Observation( - observation_datetime="2019-01-01T00:03:00Z", - sample_id=sample.id, - sensor_id=sensor.id, - observed_property="water chemistry:pH", - release_status="draft", - value=4.0, - unit="dimensionless", - ) - session.add(observation) - session.commit() - yield observation - - # ============= Post tests ================= def test_add_water_chemistry_observation(sample, sensor): payload = { diff --git a/tests/test_sample.py b/tests/test_sample.py index 6004a60ed..5f246d613 100644 --- a/tests/test_sample.py +++ b/tests/test_sample.py @@ -16,38 +16,9 @@ import pytest from pydantic import ValidationError -from db.engine import session_ctx from db.sample import Sample from schemas.sample import ValidateSample -from tests import client - -# ============= module & function fixtures ======================================= - - -@pytest.fixture(scope="function") -def second_sample(thing, sensor): - with session_ctx() as session: - sample = Sample( - thing_id=thing.id, - sample_type="groundwater", - field_sample_id="FS-9999999", - sample_date="2025-01-01T00:00:00Z", - release_status="draft", - sampler_name="Test Sampler", - qc_sample="Duplicate", - sensor_id=sensor.id, - sample_matrix="water", - sample_method="manual", - duplicate_sample_number=3, - sample_top=2, - sample_bottom=3, - ) - session.add(sample) - session.commit() - yield sample - session.delete(sample) - session.commit() - session.close() +from tests import client, cleanup_post_test, cleanup_patch_test # ============== Custom validators ================================================= @@ -105,9 +76,7 @@ def test_add_sample(thing, sensor): assert data["sample_bottom"] == payload["sample_bottom"] # cleanup after adding the sample - with session_ctx() as session: - session.delete(session.get(Sample, data["id"])) - session.commit() + cleanup_post_test(Sample, data["id"]) def test_409_add_sample_invalid_field_sample_id(sample, thing): @@ -183,32 +152,22 @@ def test_patch_sample(sample): """ Test updating a sample. """ - new_sampler_name = "Test Sampler B" - new_sample_method = "continuous" - new_sample_date = "2025-01-02T00:00:00Z" - response = client.patch( - f"/sample/{sample.id}", - json={ - "sampler_name": new_sampler_name, - "sample_method": new_sample_method, - "sample_date": new_sample_date, - }, - ) + payload = { + "sampler_name": "test sample b", + "sample_method": "continuous", + "sample_date": "2025-01-02T00:00:00Z", + } + response = client.patch(f"/sample/{sample.id}", json=payload) assert response.status_code == 200 data = response.json() assert data["id"] == sample.id - assert data["sampler_name"] == new_sampler_name - assert data["sample_date"] == new_sample_date - assert data["sample_method"] == new_sample_method + assert data["sampler_name"] == payload["sampler_name"] + assert data["sample_date"] == payload["sample_date"] + assert data["sample_method"] == payload["sample_method"] # rollback after updating the sample - with session_ctx() as session: - updated_sample = session.get(Sample, sample.id) - updated_sample.sampler_name = sample.sampler_name - updated_sample.sample_method = sample.sample_method - updated_sample.sample_date = sample.sample_date - session.commit() + cleanup_patch_test(Sample, payload, sample) def test_patch_sample_404_not_found(sample): diff --git a/tests/test_sensor.py b/tests/test_sensor.py index e01cbb035..d98dceee7 100644 --- a/tests/test_sensor.py +++ b/tests/test_sensor.py @@ -14,35 +14,12 @@ # limitations under the License. # =============================================================================== from db import Sensor -from db.engine import session_ctx from schemas.sensor import ValidateSensor from tests import client, cleanup_post_test, cleanup_patch_test import pytest from pydantic import ValidationError -# ====== module functions and fixtures ========================================= - - -@pytest.fixture(scope="function") -def second_sensor(): - with session_ctx() as session: - sensor = Sensor( - name="Test Sensor 2", - model="Model X", - serial_no="123456", - datetime_installed="2023-01-01T00:00:00Z", - datetime_removed="2023-01-02T00:00:00Z", - recording_interval=60, - notes="Test equipment", - ) - session.add(sensor) - session.commit() - yield sensor - session.delete(sensor) - session.commit() - session.close() - # ====== VALIDATION tests ====================================================== From 802e0fc82448c2d9d25cdd09f04a18cf5c035ce4 Mon Sep 17 00:00:00 2001 From: Jacob Brown Date: Fri, 22 Aug 2025 15:53:53 -0600 Subject: [PATCH 07/11] feat: ensure valid project_area wkts are MULTIPOLYGON --- schemas/group.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/schemas/group.py b/schemas/group.py index 6013c7c5c..c5f525748 100644 --- a/schemas/group.py +++ b/schemas/group.py @@ -27,10 +27,13 @@ class ValidateGroup(BaseModel): description: str | None = None parent_group_id: int | None = None - @classmethod @field_validator("project_area") def validate_area_is_wkt(cls, wkt): - return validate_wkt_geometry(wkt) + valid_wkt = validate_wkt_geometry(wkt) + if "MULTIPOLYGON" not in valid_wkt: + raise ValueError("WKT must be a valid MULTIPOLYGON") + + return valid_wkt # -------- CREATE ---------- From 7918fa41cd093f8e9b739b4443c609c9ae468e81 Mon Sep 17 00:00:00 2001 From: Jacob Brown Date: Fri, 22 Aug 2025 15:55:12 -0600 Subject: [PATCH 08/11] feat: test ValidateGroup validations --- tests/test_group.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_group.py b/tests/test_group.py index 484e36f9f..4b9095591 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1,9 +1,11 @@ from geoalchemy2.shape import to_shape +from pydantic import ValidationError import pytest from db import Group from core.dependencies import admin_function, viewer_function, editor_function from main import app +from schemas.group import ValidateGroup from tests import client, override_authentication, cleanup_post_test, cleanup_patch_test @@ -23,6 +25,36 @@ def override_authentication_dependency_fixture(): app.dependency_overrides = {} +# VALIDATION tests ============================================================= + + +def test_project_area_not_topologically_valid(): + wkt = "MULTIPOLYGON(((0 0, 1 1, 2 2, 0 0)))" + with pytest.raises( + ValidationError, match="WKT geometry is not topologically valid" + ): + ValidateGroup(project_area=wkt) + + +def test_project_area_invalid_wkt(): + for wkt in [ + "MULTIPOLYGON((0 0, 1 1, 2 2, 0 0))", + "0 0, 1 1, 2 2, 3 3, 4 5, 0 0", + ]: + with pytest.raises(ValidationError, match=r"Invalid WKT geometry: "): + ValidateGroup(project_area=wkt) + + +def test_project_area_not_multipolygon(): + for wkt in [ + "POINT (0 0)", + "LINESTRING (0 0, 1 1, 2 2, 3 3)", + "POLYGON ((0 0, 1 1, 2 2, 1 2, 0 0))", + ]: + with pytest.raises(ValidationError, match="WKT must be a valid MULTIPOLYGON"): + ValidateGroup(project_area=wkt) + + # ADD tests ====================================================== From be055eddf208ef5f17cde7c0b5efe13aeae3df9b Mon Sep 17 00:00:00 2001 From: Jacob Brown Date: Fri, 22 Aug 2025 16:48:11 -0600 Subject: [PATCH 09/11] feat: add auth to sample router --- api/sample.py | 27 +++++++++++++++++---------- tests/test_sample.py | 19 ++++++++++++++++++- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/api/sample.py b/api/sample.py index 9a3755093..c9107237c 100644 --- a/api/sample.py +++ b/api/sample.py @@ -14,15 +14,18 @@ # limitations under the License. # =============================================================================== -from fastapi import APIRouter, Depends, Query, Response +from fastapi import APIRouter, Query, Response from sqlalchemy.exc import IntegrityError, ProgrammingError -from sqlalchemy.orm import Session from starlette.status import HTTP_201_CREATED, HTTP_409_CONFLICT from api.pagination import CustomPage -from core.dependencies import session_dependency +from core.dependencies import ( + session_dependency, + admin_dependency, + editor_dependency, + viewer_dependency, +) from db import adder -from db.engine import get_db_session from db.sample import Sample from schemas import ResourceNotFoundResponse from schemas.sample import SampleResponse, CreateSample, UpdateSample @@ -70,13 +73,13 @@ def database_error_handler( # ============= Post ============================================= @router.post("", status_code=HTTP_201_CREATED) def add_sample( - sample_data: CreateSample, session: session_dependency + sample_data: CreateSample, session: session_dependency, user: admin_dependency ) -> SampleResponse: """ Endpoint to add a sample. """ try: - return adder(session, Sample, sample_data) + return adder(session, Sample, sample_data, user=user) except (IntegrityError, ProgrammingError) as e: database_error_handler(sample_data, e) @@ -86,7 +89,8 @@ def add_sample( def update_sample( sample_id: int, sample_data: UpdateSample, - session: Session = Depends(get_db_session), + session: session_dependency, + user: editor_dependency, ) -> SampleResponse | ResourceNotFoundResponse: """ Endpoint to update a sample. @@ -105,7 +109,7 @@ def update_sample( the update. """ try: - return model_patcher(session, Sample, sample_id, sample_data) + return model_patcher(session, Sample, sample_id, sample_data, user=user) except (IntegrityError, ProgrammingError) as e: database_error_handler(sample_data, e) @@ -114,6 +118,7 @@ def update_sample( @router.get("", summary="Get Samples") def get_samples( session: session_dependency, + user: viewer_dependency, sort: str = None, order: str = None, filter_: str = Query(alias="filter", default=None), @@ -129,7 +134,7 @@ def get_samples( @router.get("/{sample_id}", summary="Get Sample by ID") def get_sample_by_id( - sample_id: int, session: session_dependency + sample_id: int, session: session_dependency, user: viewer_dependency ) -> SampleResponse | ResourceNotFoundResponse: """ Endpoint to retrieve a sample by its ID. @@ -141,7 +146,9 @@ def get_sample_by_id( @router.delete("/{sample_id}", summary="Delete Sample by ID") -def delete_sample_by_id(sample_id: int, session: session_dependency) -> Response: +def delete_sample_by_id( + sample_id: int, session: session_dependency, user: admin_dependency +) -> Response: return model_deleter(session, Sample, sample_id) diff --git a/tests/test_sample.py b/tests/test_sample.py index 5f246d613..d3d432413 100644 --- a/tests/test_sample.py +++ b/tests/test_sample.py @@ -16,9 +16,26 @@ import pytest from pydantic import ValidationError +from main import app +from core.dependencies import admin_function, editor_function, viewer_function from db.sample import Sample from schemas.sample import ValidateSample -from tests import client, cleanup_post_test, cleanup_patch_test +from tests import client, cleanup_post_test, cleanup_patch_test, override_authentication + + +@pytest.fixture(scope="module", autouse=True) +def override_dependencies_fixture(): + app.dependency_overrides[admin_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[editor_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[viewer_function] = override_authentication() + + yield + + app.dependency_overrides = {} # ============== Custom validators ================================================= From ca6478897375aa0b776c68fd3f1ffdca192f2068 Mon Sep 17 00:00:00 2001 From: Jacob Brown Date: Fri, 22 Aug 2025 16:51:30 -0600 Subject: [PATCH 10/11] feat: add auth to sensor router --- api/sensor.py | 27 ++++++++++++++++++++------- tests/test_sensor.py | 19 ++++++++++++++++++- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/api/sensor.py b/api/sensor.py index bb92c4232..5239f3175 100644 --- a/api/sensor.py +++ b/api/sensor.py @@ -20,7 +20,12 @@ from starlette import status from api.pagination import CustomPage -from core.dependencies import session_dependency +from core.dependencies import ( + session_dependency, + admin_dependency, + editor_dependency, + viewer_dependency, +) from db import adder, Observation, Sample from db.sensor import Sensor from schemas.sensor import SensorResponse, CreateSensor, UpdateSensor @@ -35,12 +40,12 @@ @router.post("", status_code=status.HTTP_201_CREATED) def add_sensor( - sensor_data: CreateSensor, session: session_dependency + sensor_data: CreateSensor, session: session_dependency, user: admin_dependency ) -> SensorResponse: """ Add a sensor to the system. """ - return adder(session, Sensor, sensor_data) + return adder(session, Sensor, sensor_data, user=user) # ====== PATCH ================================================================= @@ -48,7 +53,10 @@ def add_sensor( @router.patch("/{sensor_id}", status_code=status.HTTP_200_OK) def update_sensor( - sensor_id: int, sensor_data: UpdateSensor, session: session_dependency + sensor_id: int, + sensor_data: UpdateSensor, + session: session_dependency, + user: editor_dependency, ) -> SensorResponse: """ Update a sensor in the system. @@ -97,14 +105,16 @@ def update_sensor( status_code=status.HTTP_409_CONFLICT, detail=[detail] ) - return model_patcher(session, Sensor, sensor_id, sensor_data) + return model_patcher(session, Sensor, sensor_id, sensor_data, user=user) # ====== DELETE ================================================================ @router.delete("/{sensor_id}") -def delete_sensor(sensor_id: int, session: session_dependency) -> Response: +def delete_sensor( + sensor_id: int, session: session_dependency, user: admin_dependency +) -> Response: """ Delete a sensor in the system """ @@ -117,6 +127,7 @@ def delete_sensor(sensor_id: int, session: session_dependency) -> Response: @router.get("", status_code=status.HTTP_200_OK) def get_sensors( session: session_dependency, + user: viewer_dependency, thing_id: int = None, # Optional filter for thing_id observed_property: str = None, # Optional filter for observed_property sort: str | None = None, @@ -151,7 +162,9 @@ def get_sensors( @router.get("/{sensor_id}", status_code=status.HTTP_200_OK) -def get_sensor(sensor_id: int, session: session_dependency) -> SensorResponse: +def get_sensor( + sensor_id: int, session: session_dependency, user: viewer_dependency +) -> SensorResponse: """ Retrieve a sensor by its ID. """ diff --git a/tests/test_sensor.py b/tests/test_sensor.py index d98dceee7..7b514059b 100644 --- a/tests/test_sensor.py +++ b/tests/test_sensor.py @@ -13,14 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== +from core.dependencies import admin_function, editor_function, viewer_function from db import Sensor +from main import app from schemas.sensor import ValidateSensor -from tests import client, cleanup_post_test, cleanup_patch_test +from tests import client, cleanup_post_test, cleanup_patch_test, override_authentication import pytest from pydantic import ValidationError +@pytest.fixture(scope="module", autouse=True) +def override_dependencies_fixture(): + app.dependency_overrides[admin_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[editor_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[viewer_function] = override_authentication() + + yield + + app.dependency_overrides = {} + + # ====== VALIDATION tests ====================================================== From bd4634e818690d9eb3ec87db1257ffdb562bbc09 Mon Sep 17 00:00:00 2001 From: Jacob Brown Date: Fri, 22 Aug 2025 17:02:32 -0600 Subject: [PATCH 11/11] feat: add auth to location router --- api/location.py | 25 +++++++++++++--------- tests/test_location.py | 48 ++++++++++++++++++++++++------------------ 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/api/location.py b/api/location.py index 666e09fe3..3352168ae 100644 --- a/api/location.py +++ b/api/location.py @@ -13,17 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== -from fastapi import Depends, Query, Response +from fastapi import Query, Response from fastapi_pagination.ext.sqlalchemy import paginate from sqlalchemy import select, func -from sqlalchemy.orm import Session from starlette import status from api.pagination import CustomPage from constants import SRID_WGS84 -from core.dependencies import session_dependency +from core.dependencies import ( + session_dependency, + admin_dependency, + editor_dependency, + viewer_dependency, +) from db import adder from db.location import Location -from db.engine import get_db_session from schemas.location import CreateLocation, LocationResponse, UpdateLocation from services.geospatial_helper import make_within_wkt from services.query_helper import make_query, order_sort_filter, simple_get_by_id @@ -41,12 +44,12 @@ status_code=status.HTTP_201_CREATED, ) def create_location( - location_data: CreateLocation, session: Session = Depends(get_db_session) + location_data: CreateLocation, session: session_dependency, user: admin_dependency ) -> LocationResponse: """ Create a new sample location in the database. """ - return adder(session, Location, location_data) + return adder(session, Location, location_data, user=user) @router.patch( @@ -56,12 +59,13 @@ def create_location( def update_location( location_id: int, location_data: UpdateLocation, - session: Session = Depends(get_db_session), + session: session_dependency, + user: editor_dependency, ) -> LocationResponse: """ Update a sample location in the database. """ - return model_patcher(session, Location, location_id, location_data) + return model_patcher(session, Location, location_id, location_data, user=user) # @router.get("/shapefile", summary="Get location as shapefile") @@ -125,6 +129,7 @@ def update_location( ) async def get_location( session: session_dependency, + user: viewer_dependency, nearby_point: str = None, nearby_distance_km: float = 1, within: str = None, @@ -160,7 +165,7 @@ async def get_location( summary="Get location by ID", ) async def get_location_by_id( - location_id: int, session: Session = Depends(get_db_session) + location_id: int, session: session_dependency, user: viewer_dependency ) -> LocationResponse: """ Retrieve a sample location by ID from the database. @@ -171,7 +176,7 @@ async def get_location_by_id( @router.delete("/{location_id}", summary="Delete location by ID") async def delete_location( - location_id: int, session: Session = Depends(get_db_session) + location_id: int, session: session_dependency, user: admin_dependency ) -> Response: """ Delete a sample location by ID from the database. diff --git a/tests/test_location.py b/tests/test_location.py index 153b644be..cee06ab2b 100644 --- a/tests/test_location.py +++ b/tests/test_location.py @@ -14,10 +14,28 @@ # limitations under the License. # =============================================================================== from geoalchemy2.shape import to_shape +import pytest +from core.dependencies import admin_function, editor_function, viewer_function from db import Location -from db.engine import session_ctx -from tests import client +from main import app +from tests import client, override_authentication, cleanup_post_test, cleanup_patch_test + + +@pytest.fixture(scope="module", autouse=True) +def override_dependencies_fixture(): + app.dependency_overrides[admin_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[editor_function] = override_authentication( + default={"name": "foobar", "sub": "1234567890"} + ) + app.dependency_overrides[viewer_function] = override_authentication() + + yield + + app.dependency_overrides = {} + # ============= Post tests for locations ====================================== @@ -38,35 +56,23 @@ def test_add_location(): assert data["release_status"] == payload["release_status"] # cleanup after test - with session_ctx() as session: - session.delete(session.get(Location, data["id"])) - session.commit() + cleanup_post_test(Location, data["id"]) # ============= Patch tests for locations ===================================== def test_update_location(location): - location_id = location.id - response = client.patch( - f"/location/{location_id}", - json={ - "point": "POINT (10.1 20.2)", - "release_status": "draft", - }, - ) + payload = {"point": "POINT (10.1 20.2)", "release_status": "draft"} + response = client.patch(f"/location/{location.id}", json=payload) assert response.status_code == 200 data = response.json() - assert data["id"] == location_id - assert data["point"] == "POINT (10.1 20.2)" - assert data["release_status"] == "draft" + assert data["id"] == location.id + assert data["point"] == payload["point"] + assert data["release_status"] == payload["release_status"] # cleanup after test - with session_ctx() as session: - updated_location = session.get(Location, location_id) - updated_location.point = location.point - updated_location.release_status = location.release_status - session.commit() + cleanup_patch_test(Location, payload, location) def test_patch_location_404_not_found(location):