diff --git a/api/group.py b/api/group.py index 962870c1..d095dfd9 100644 --- a/api/group.py +++ b/api/group.py @@ -28,9 +28,11 @@ from schemas.group import UpdateGroup, CreateGroup, GroupResponse from services.crud_helper import model_patcher, model_deleter, model_adder from services.group_helper import ( + add_thing_to_group, get_well_counts_by_group_id, group_to_response, paginated_groups_getter, + remove_thing_from_group, ) from services.query_helper import simple_get_by_id @@ -49,21 +51,22 @@ def create_group( return model_adder(session, Group, group_data, user=user) -# @router.post( -# "/association", -# summary="Create a new group-thing association", -# status_code=status.HTTP_201_CREATED, -# ) -# def create_group_thing( -# group_location_data: CreateGroupThing, -# session: session_dependency, -# user: admin_dependency, -# ): -# """ -# Create a new group location association in the database. -# """ -# return adder(session, GroupThingAssociation, group_location_data, user=user) -# +@router.post( + "/{group_id}/things/{thing_id}", + summary="Add a thing to a group", + status_code=HTTP_201_CREATED, +) +def add_thing_to_group_route( + group_id: int, + thing_id: int, + session: session_dependency, + user: admin_dependency, +): + """ + Associate a thing (e.g. a water well) with a group (project). + Returns 409 if the association already exists. + """ + return add_thing_to_group(session, group_id, thing_id, user) # ============= Get ============================================= @@ -91,17 +94,6 @@ def get_group_by_id( return group_to_response(group, counts.get(group.id, 0)) -# @router.get( -# "/association/{association_id}", -# summary="Get group-thing association by ID", -# ) -# async def get_group_thing_by_id(association_id: int, session: session_dependency): -# """ -# Retrieve a group-thing association by ID from the database. -# """ -# return simple_get_by_id(session, GroupThingAssociation, association_id) - - # ============= Patch ============================================= @router.patch("/{group_id}", summary="Update a group by ID") def update_group( @@ -117,6 +109,24 @@ def update_group( # DELETE ======================================================================= +@router.delete( + "/{group_id}/things/{thing_id}", + summary="Remove a thing from a group", + status_code=HTTP_204_NO_CONTENT, +) +def remove_thing_from_group_route( + group_id: int, + thing_id: int, + session: session_dependency, + user: admin_dependency, +): + """ + Remove the association between a thing and a group. + Returns 404 if the association does not exist. + """ + remove_thing_from_group(session, group_id, thing_id, user) + + @router.delete( "/{group_id}", summary="Delete a group by ID", status_code=HTTP_204_NO_CONTENT ) diff --git a/services/audit_helper.py b/services/audit_helper.py index 425e8ca8..7efa1e7b 100644 --- a/services/audit_helper.py +++ b/services/audit_helper.py @@ -25,4 +25,10 @@ def audit_add(user: dict, obj: DeclarativeBase) -> None: obj.created_by_name = user["name"] +def audit_update(user: dict, obj: DeclarativeBase) -> None: + if user and isinstance(user, dict): + obj.updated_by_id = user["sub"] + obj.updated_by_name = user["name"] + + # ============= EOF ============================================= diff --git a/services/group_helper.py b/services/group_helper.py index b81dd81c..58de1dce 100644 --- a/services/group_helper.py +++ b/services/group_helper.py @@ -15,13 +15,16 @@ # =============================================================================== from typing import Any +from fastapi import HTTPException from fastapi_pagination.ext.sqlalchemy import paginate from sqlalchemy import func, select from sqlalchemy.orm import Session +from starlette.status import HTTP_404_NOT_FOUND, HTTP_409_CONFLICT from db.group import Group, GroupThingAssociation from db.thing import Thing from schemas.group import GroupResponse +from services.audit_helper import audit_add, audit_update from services.query_helper import order_sort_filter @@ -49,6 +52,69 @@ def group_to_response(group: Group, well_count: int = 0) -> GroupResponse: return response.model_copy(update={"well_count": well_count}) +def add_thing_to_group( + session: Session, group_id: int, thing_id: int, user: dict +) -> GroupThingAssociation: + group = session.get(Group, group_id) + if group is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"Group with ID {group_id} not found.", + ) + + thing = session.get(Thing, thing_id) + if thing is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"Thing with ID {thing_id} not found.", + ) + + existing = session.execute( + select(GroupThingAssociation).where( + GroupThingAssociation.group_id == group_id, + GroupThingAssociation.thing_id == thing_id, + ) + ).scalar_one_or_none() + + if existing is not None: + msg = f"Thing {thing_id} is already a member of group {group_id}." + raise HTTPException(status_code=HTTP_409_CONFLICT, detail=msg) + + assoc = GroupThingAssociation(group_id=group_id, thing_id=thing_id) + audit_add(user, assoc) + session.add(assoc) + session.commit() + session.refresh(assoc) + return assoc + + +def remove_thing_from_group( + session: Session, + group_id: int, + thing_id: int, + user: dict, +) -> None: + assoc = session.execute( + select(GroupThingAssociation).where( + GroupThingAssociation.group_id == group_id, + GroupThingAssociation.thing_id == thing_id, + ) + ).scalar_one_or_none() + + if assoc is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=( + f"No association found between group {group_id} " + f"and thing {thing_id}." + ), + ) + + audit_update(user, assoc) + session.delete(assoc) + session.commit() + + def paginated_groups_getter( session: Session, filter_: str | None = None, diff --git a/tests/test_group.py b/tests/test_group.py index de4c6672..8ab05879 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -230,3 +230,45 @@ def test_delete_group_404_not_found(second_group): assert response.status_code == 404 data = response.json() assert data["detail"] == f"Group with ID {bad_id} not found." + + +# GROUP-THING association tests ================================================ + + +def test_add_thing_to_group_route(spring_thing): + payload = { + "release_status": "private", + "name": "Association Test Group", + "description": "Temporary group for association test.", + } + create_response = client.post("/group", json=payload) + assert create_response.status_code == 201 + group_id = create_response.json()["id"] + + response = client.post(f"/group/{group_id}/things/{spring_thing.id}") + assert response.status_code == 201 + data = response.json() + assert data["group_id"] == group_id + assert data["thing_id"] == spring_thing.id + assert data["created_by_id"] == "1234567890" + assert data["created_by_name"] == "foobar" + + cleanup_post_test(GroupThingAssociation, data["id"]) + cleanup_post_test(Group, group_id) + + +def test_add_thing_to_group_route_409_duplicate(group, water_well_thing): + response = client.post(f"/group/{group.id}/things/{water_well_thing.id}") + assert response.status_code == 409 + + +def test_remove_thing_from_group_route(group, water_well_thing): + response = client.delete(f"/group/{group.id}/things/{water_well_thing.id}") + assert response.status_code == 204 + + # restore association for other tests using this fixture + with session_ctx() as session: + session.add( + GroupThingAssociation(group_id=group.id, thing_id=water_well_thing.id) + ) + session.commit()