From 33887d42e101463d2848dd77b717c971070f9b23 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Fri, 22 Aug 2025 12:11:34 +0530 Subject: [PATCH 01/16] Add state existence check in create_next_states - Introduced an asynchronous function to check if a state already exists before inserting it into the database. - Updated the create_next_states function to gather existence checks for new unit states and only insert those that do not already exist. - This enhancement improves efficiency by preventing duplicate state entries in the database. --- state-manager/app/tasks/create_next_states.py | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/state-manager/app/tasks/create_next_states.py b/state-manager/app/tasks/create_next_states.py index 556bd601..0369495f 100644 --- a/state-manager/app/tasks/create_next_states.py +++ b/state-manager/app/tasks/create_next_states.py @@ -1,3 +1,4 @@ +import asyncio from beanie import PydanticObjectId from beanie.operators import In, NE from app.singletons.logs_manager import LogsManager @@ -58,6 +59,18 @@ async def check_unites_satisfied(namespace: str, graph_name: str, node_template: return False return True + +async def check_state_exists(state: State) -> bool: + if await State.find( + State.namespace_name == state.namespace_name, + State.graph_name == state.graph_name, + State.run_id == state.run_id, + State.parents == state.parents + ).count() > 0: + return True + return False + + def get_dependents(syntax_string: str) -> DependentString: splits = syntax_string.split("${{") if len(splits) <= 1: @@ -232,9 +245,17 @@ async def get_input_model(node_template: NodeTemplate) -> Type[BaseModel]: new_unit_states.append(generate_next_state(next_state_input_model, next_state_node_template, parents, parent_state)) - if len(new_unit_states) > 0: - await State.insert_many(new_unit_states) - + existence_checks = [check_state_exists(state) for state in new_unit_states] + existence_results = await asyncio.gather(*existence_checks) + + not_inserted_new_states = [] + for state, exists in zip(new_unit_states, existence_results): + if not exists: + not_inserted_new_states.append(state) + + if len(not_inserted_new_states) > 0: + await State.insert_many(not_inserted_new_states) + except Exception as e: await State.find( In(State.id, state_ids) From f684e86bb1d9da0c25ded8ecea10e27c207d635b Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Fri, 22 Aug 2025 12:14:22 +0530 Subject: [PATCH 02/16] Refactor check_state_exists function to return State or None - Updated the check_state_exists function to return a State instance or None instead of a boolean value, enhancing type clarity. - Adjusted the create_next_states function to utilize the new return type, ensuring only valid states are appended to the not_inserted_new_states list. - These changes improve the overall type safety and readability of the state management logic. --- state-manager/app/tasks/create_next_states.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/state-manager/app/tasks/create_next_states.py b/state-manager/app/tasks/create_next_states.py index 0369495f..3d9d84c6 100644 --- a/state-manager/app/tasks/create_next_states.py +++ b/state-manager/app/tasks/create_next_states.py @@ -60,15 +60,15 @@ async def check_unites_satisfied(namespace: str, graph_name: str, node_template: return True -async def check_state_exists(state: State) -> bool: +async def check_state_exists(state: State) -> State | None: if await State.find( State.namespace_name == state.namespace_name, State.graph_name == state.graph_name, State.run_id == state.run_id, State.parents == state.parents ).count() > 0: - return True - return False + return None + return state def get_dependents(syntax_string: str) -> DependentString: @@ -249,8 +249,8 @@ async def get_input_model(node_template: NodeTemplate) -> Type[BaseModel]: existence_results = await asyncio.gather(*existence_checks) not_inserted_new_states = [] - for state, exists in zip(new_unit_states, existence_results): - if not exists: + for state in existence_results: + if state: not_inserted_new_states.append(state) if len(not_inserted_new_states) > 0: From 5aa5c5bc7059e2227e5b180e78b6661416879253 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Fri, 22 Aug 2025 12:17:22 +0530 Subject: [PATCH 03/16] Refactor state existence check in check_state_exists function - Updated the check_state_exists function to use find_one instead of find, improving clarity and efficiency in checking for existing states. - Enhanced the condition to return None if the state does not exist, aligning with the recent type safety improvements in the state management logic. --- state-manager/app/tasks/create_next_states.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/state-manager/app/tasks/create_next_states.py b/state-manager/app/tasks/create_next_states.py index 3d9d84c6..3591bf44 100644 --- a/state-manager/app/tasks/create_next_states.py +++ b/state-manager/app/tasks/create_next_states.py @@ -61,12 +61,14 @@ async def check_unites_satisfied(namespace: str, graph_name: str, node_template: async def check_state_exists(state: State) -> State | None: - if await State.find( + if await State.find_one( State.namespace_name == state.namespace_name, + State.node_name == state.node_name, + State.identifier == state.identifier, State.graph_name == state.graph_name, State.run_id == state.run_id, State.parents == state.parents - ).count() > 0: + ) is not None: return None return state From f6f91c631a7a729c5c0e3d0299df8bdd3ed2c4f7 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Fri, 22 Aug 2025 12:34:51 +0530 Subject: [PATCH 04/16] Add fingerprint generation and unites handling in State model - Introduced a new _fingerprint field in the State model, with a method to generate its value based on relevant state attributes. - Added a does_unites field to indicate if the state unites others. - Updated create_next_states function to handle DuplicateKeyError during state insertion, improving error management and preventing crashes from duplicate entries. - Removed the previous existence check for states, streamlining the state creation process and enhancing efficiency. --- state-manager/app/models/db/state.py | 38 ++++++++++++++++++- state-manager/app/tasks/create_next_states.py | 30 +++------------ 2 files changed, 42 insertions(+), 26 deletions(-) diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index 7f7c75d7..aae27644 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -1,8 +1,11 @@ +from pymongo import IndexModel from .base import BaseDatabaseModel from ..state_status_enum import StateStatusEnum from pydantic import Field -from beanie import PydanticObjectId +from beanie import Insert, PydanticObjectId, Replace, Save, before_event from typing import Any, Optional +import hashlib +import json class State(BaseDatabaseModel): @@ -15,4 +18,35 @@ class State(BaseDatabaseModel): inputs: dict[str, Any] = Field(..., description="Inputs of the state") outputs: dict[str, Any] = Field(..., description="Outputs of the state") error: Optional[str] = Field(None, description="Error message") - parents: dict[str, PydanticObjectId] = Field(default_factory=dict, description="Parents of the state") \ No newline at end of file + parents: dict[str, PydanticObjectId] = Field(default_factory=dict, description="Parents of the state") + does_unites: bool = Field(default=False, description="Whether the state is unites others") + _fingerprint: str = Field(default="", description="Fingerprint of the state") + + @before_event([Insert, Replace, Save]) + def _generate_fingerprint(self): + data = { + "node_name": self.node_name, + "namespace_name": self.namespace_name, + "identifier": self.identifier, + "graph_name": self.graph_name, + "run_id": self.run_id, + "parents": {key: str(value) for key, value in self.parents.items()} + } + self._fingerprint = hashlib.sha256(json.dumps(data).encode()).hexdigest() + + @property + def fingerprint(self): + return self._fingerprint + + class Settings: + indexes = [ + IndexModel( + [ + ("_fingerprint", 1) + ], + unique=True, + partialFilterExpression={ + "does_unites": True + } + ) + ] \ No newline at end of file diff --git a/state-manager/app/tasks/create_next_states.py b/state-manager/app/tasks/create_next_states.py index 3591bf44..3b60a841 100644 --- a/state-manager/app/tasks/create_next_states.py +++ b/state-manager/app/tasks/create_next_states.py @@ -1,5 +1,5 @@ -import asyncio from beanie import PydanticObjectId +from pymongo.errors import DuplicateKeyError from beanie.operators import In, NE from app.singletons.logs_manager import LogsManager from app.models.db.graph_template_model import GraphTemplate @@ -60,19 +60,6 @@ async def check_unites_satisfied(namespace: str, graph_name: str, node_template: return True -async def check_state_exists(state: State) -> State | None: - if await State.find_one( - State.namespace_name == state.namespace_name, - State.node_name == state.node_name, - State.identifier == state.identifier, - State.graph_name == state.graph_name, - State.run_id == state.run_id, - State.parents == state.parents - ) is not None: - return None - return state - - def get_dependents(syntax_string: str) -> DependentString: splits = syntax_string.split("${{") if len(splits) <= 1: @@ -246,17 +233,12 @@ async def get_input_model(node_template: NodeTemplate) -> Type[BaseModel]: parent_state = parents[next_state_node_template.unites.identifier] new_unit_states.append(generate_next_state(next_state_input_model, next_state_node_template, parents, parent_state)) - - existence_checks = [check_state_exists(state) for state in new_unit_states] - existence_results = await asyncio.gather(*existence_checks) - - not_inserted_new_states = [] - for state in existence_results: - if state: - not_inserted_new_states.append(state) - if len(not_inserted_new_states) > 0: - await State.insert_many(not_inserted_new_states) + try: + if len(new_unit_states) > 0: + await State.insert_many(new_unit_states) + except DuplicateKeyError: + logger.error(f"Duplicate key error for new unit states: {new_unit_states}") except Exception as e: await State.find( From 1b87b06175c63f552c6f3c6826c4e2dc36b25de8 Mon Sep 17 00:00:00 2001 From: Nivedit Jain Date: Fri, 22 Aug 2025 12:39:47 +0530 Subject: [PATCH 05/16] Update state-manager/app/models/db/state.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- state-manager/app/models/db/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index aae27644..a741e007 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -32,7 +32,7 @@ def _generate_fingerprint(self): "run_id": self.run_id, "parents": {key: str(value) for key, value in self.parents.items()} } - self._fingerprint = hashlib.sha256(json.dumps(data).encode()).hexdigest() + self._fingerprint = hashlib.sha256(json.dumps(data, sort_keys=True).encode()).hexdigest() @property def fingerprint(self): From fd91f2168fc25c0059c438f90acc40906dcf6c05 Mon Sep 17 00:00:00 2001 From: Nivedit Jain Date: Fri, 22 Aug 2025 12:39:53 +0530 Subject: [PATCH 06/16] Update state-manager/app/tasks/create_next_states.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- state-manager/app/tasks/create_next_states.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/state-manager/app/tasks/create_next_states.py b/state-manager/app/tasks/create_next_states.py index 3b60a841..42965655 100644 --- a/state-manager/app/tasks/create_next_states.py +++ b/state-manager/app/tasks/create_next_states.py @@ -238,7 +238,7 @@ async def get_input_model(node_template: NodeTemplate) -> Type[BaseModel]: if len(new_unit_states) > 0: await State.insert_many(new_unit_states) except DuplicateKeyError: - logger.error(f"Duplicate key error for new unit states: {new_unit_states}") + logger.warning(f"Caught an expected duplicate key error for new unit states, likely due to a race condition: {new_unit_states}") except Exception as e: await State.find( From 70a641b5931348a63d780856109b5fd583e42ed2 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Fri, 22 Aug 2025 12:41:19 +0530 Subject: [PATCH 07/16] Add does_unites handling in generate_next_state function - Updated the generate_next_state function to include a does_unites parameter, indicating whether the state unites others. This change enhances the clarity and functionality of state generation within the state management logic. --- state-manager/app/tasks/create_next_states.py | 1 + 1 file changed, 1 insertion(+) diff --git a/state-manager/app/tasks/create_next_states.py b/state-manager/app/tasks/create_next_states.py index 3b60a841..76fbb540 100644 --- a/state-manager/app/tasks/create_next_states.py +++ b/state-manager/app/tasks/create_next_states.py @@ -136,6 +136,7 @@ def generate_next_state(next_state_input_model: Type[BaseModel], next_state_node parents=new_parents, inputs=next_state_input_data, outputs={}, + does_unites=next_state_node_template.unites is not None, run_id=current_state.run_id, error=None ) From 15ceca4d673553b66b854c481bdcc4f563f11095 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Fri, 22 Aug 2025 12:43:31 +0530 Subject: [PATCH 08/16] Sort parents dictionary in State model before fingerprint generation - Updated the State model to sort the parents dictionary by key when generating the fingerprint. This change ensures consistent fingerprint values, enhancing the integrity of state management. --- state-manager/app/models/db/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index a741e007..d9a34478 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -30,7 +30,7 @@ def _generate_fingerprint(self): "identifier": self.identifier, "graph_name": self.graph_name, "run_id": self.run_id, - "parents": {key: str(value) for key, value in self.parents.items()} + "parents": {key: str(value) for key, value in sorted(self.parents.items(), key=lambda x: x[0])} } self._fingerprint = hashlib.sha256(json.dumps(data, sort_keys=True).encode()).hexdigest() From afd5b72cfab2d10fba306d80486adf8b48e2bd2a Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Fri, 22 Aug 2025 12:45:30 +0530 Subject: [PATCH 09/16] Rename _fingerprint to state_fingerprint in State model - Updated the State model to rename the private _fingerprint field to state_fingerprint for improved clarity and consistency. The fingerprint generation method has been adjusted accordingly to reflect this change. --- state-manager/app/models/db/state.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index d9a34478..e39f892f 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -20,7 +20,7 @@ class State(BaseDatabaseModel): error: Optional[str] = Field(None, description="Error message") parents: dict[str, PydanticObjectId] = Field(default_factory=dict, description="Parents of the state") does_unites: bool = Field(default=False, description="Whether the state is unites others") - _fingerprint: str = Field(default="", description="Fingerprint of the state") + state_fingerprint: str = Field(default="", description="Fingerprint of the state") @before_event([Insert, Replace, Save]) def _generate_fingerprint(self): @@ -32,12 +32,8 @@ def _generate_fingerprint(self): "run_id": self.run_id, "parents": {key: str(value) for key, value in sorted(self.parents.items(), key=lambda x: x[0])} } - self._fingerprint = hashlib.sha256(json.dumps(data, sort_keys=True).encode()).hexdigest() + self.state_fingerprint = hashlib.sha256(json.dumps(data, sort_keys=True).encode()).hexdigest() - @property - def fingerprint(self): - return self._fingerprint - class Settings: indexes = [ IndexModel( From a11c42a52ecede7b134760a60eddcaf1d82aa56a Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Fri, 22 Aug 2025 13:03:52 +0530 Subject: [PATCH 10/16] Add does_unites check in fingerprint generation method - Updated the _generate_fingerprint method in the State model to handle cases where does_unites is False, ensuring that the state_fingerprint is cleared appropriately. This change enhances the robustness of fingerprint generation by preventing unnecessary calculations when uniting states is not applicable. --- state-manager/app/models/db/state.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index e39f892f..8e7b4337 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -24,6 +24,10 @@ class State(BaseDatabaseModel): @before_event([Insert, Replace, Save]) def _generate_fingerprint(self): + if not self.does_unites: + self.state_fingerprint = "" + return + data = { "node_name": self.node_name, "namespace_name": self.namespace_name, @@ -38,7 +42,7 @@ class Settings: indexes = [ IndexModel( [ - ("_fingerprint", 1) + ("state_fingerprint", 1) ], unique=True, partialFilterExpression={ From d9f0fa1c905db9c4b497878d8aa21fd9c0c2e640 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Fri, 22 Aug 2025 14:07:03 +0530 Subject: [PATCH 11/16] Enhance State model and error handling in create_next_states - Updated the State model to improve the description of the does_unites field for clarity. - Refined the fingerprint generation method to ensure consistent payload formatting and added a unique index for states that unite others. - Enhanced logging in the create_next_states function to provide more context during DuplicateKeyError occurrences, improving error traceability. --- state-manager/app/models/db/state.py | 16 +++++++++++----- state-manager/app/tasks/create_next_states.py | 8 ++++++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index 8e7b4337..76709451 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -19,9 +19,8 @@ class State(BaseDatabaseModel): outputs: dict[str, Any] = Field(..., description="Outputs of the state") error: Optional[str] = Field(None, description="Error message") parents: dict[str, PydanticObjectId] = Field(default_factory=dict, description="Parents of the state") - does_unites: bool = Field(default=False, description="Whether the state is unites others") + does_unites: bool = Field(default=False, description="Whether this state unites other states") state_fingerprint: str = Field(default="", description="Fingerprint of the state") - @before_event([Insert, Replace, Save]) def _generate_fingerprint(self): if not self.does_unites: @@ -34,10 +33,16 @@ def _generate_fingerprint(self): "identifier": self.identifier, "graph_name": self.graph_name, "run_id": self.run_id, - "parents": {key: str(value) for key, value in sorted(self.parents.items(), key=lambda x: x[0])} + "parents": {k: str(v) for k, v in self.parents.items()}, } - self.state_fingerprint = hashlib.sha256(json.dumps(data, sort_keys=True).encode()).hexdigest() - + payload = json.dumps( + data, + sort_keys=True, # canonical key ordering at all levels + separators=(",", ":"), # no whitespace variance + ensure_ascii=True, # normalized non-ASCII escapes + ).encode("utf-8") + self.state_fingerprint = hashlib.sha256(payload).hexdigest() + class Settings: indexes = [ IndexModel( @@ -45,6 +50,7 @@ class Settings: ("state_fingerprint", 1) ], unique=True, + name="uniq_state_fingerprint_unites", partialFilterExpression={ "does_unites": True } diff --git a/state-manager/app/tasks/create_next_states.py b/state-manager/app/tasks/create_next_states.py index 0296da8c..8d984df8 100644 --- a/state-manager/app/tasks/create_next_states.py +++ b/state-manager/app/tasks/create_next_states.py @@ -239,8 +239,12 @@ async def get_input_model(node_template: NodeTemplate) -> Type[BaseModel]: if len(new_unit_states) > 0: await State.insert_many(new_unit_states) except DuplicateKeyError: - logger.warning(f"Caught an expected duplicate key error for new unit states, likely due to a race condition: {new_unit_states}") - + logger.warning( + f"Caught duplicate key error for new unit states in namespace={namespace}, " + f"graph={graph_name}, likely due to a race condition. " + f"Attempted to insert {len(new_unit_states)} states" + ) + except Exception as e: await State.find( In(State.id, state_ids) From b0ab37c0836d21ef9b8159e8880b0b0bfd59d443 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Fri, 22 Aug 2025 15:27:54 +0530 Subject: [PATCH 12/16] improving conflicts among works while picking state, no 2 works should pick same state --- .../app/controller/enqueue_states.py | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/state-manager/app/controller/enqueue_states.py b/state-manager/app/controller/enqueue_states.py index e7379020..46948d29 100644 --- a/state-manager/app/controller/enqueue_states.py +++ b/state-manager/app/controller/enqueue_states.py @@ -1,4 +1,4 @@ -from beanie.operators import In +import asyncio from ..models.enqueue_request import EnqueueRequestModel from ..models.enqueue_response import EnqueueResponseModel, StateModel @@ -10,27 +10,37 @@ logger = LogsManager().get_logger() +async def find_state(namespace_name: str, nodes: list[str]) -> State | None: + return await State.get_pymongo_collection().find_one_and_update( + { + "namespace_name": namespace_name, + "status": StateStatusEnum.CREATED, + "node_name": { + "$in": nodes + } + }, + { + "$set": {"status": StateStatusEnum.QUEUED} + } + ) + async def enqueue_states(namespace_name: str, body: EnqueueRequestModel, x_exosphere_request_id: str) -> EnqueueResponseModel: try: logger.info(f"Enqueuing states for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) - states = await State.find( - State.namespace_name == namespace_name, - State.status == StateStatusEnum.CREATED, - In(State.node_name, body.nodes) - ).limit( - body.batch_size - ).to_list() - - if states: - await State.find( - In(State.id, [state.id for state in states]) - ).set( - { - "status": StateStatusEnum.QUEUED, - } - ) # type: ignore + # Create tasks for parallel execution + tasks = [find_state(namespace_name, body.nodes) for _ in range(body.batch_size)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Filter out None results and exceptions + states = [] + for result in results: + if isinstance(result, Exception): + logger.error(f"Error finding state: {result}", x_exosphere_request_id=x_exosphere_request_id) + continue + if result is not None: + states.append(result) response = EnqueueResponseModel( count=len(states), From 4cea669ab842532177f45e013fd7723318c0fbd1 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Fri, 22 Aug 2025 15:40:00 +0530 Subject: [PATCH 13/16] Refactor tests for enqueue_states to use find_one_and_update mock - Updated unit tests in test_enqueue_states.py to mock State.get_pymongo_collection().find_one_and_update() instead of using find().limit().to_list() for better clarity and efficiency. - Adjusted assertions to reflect changes in the expected results, ensuring that batch sizes are correctly handled. - Enhanced error handling in tests to verify graceful handling of exceptions during database operations. --- .../unit/controller/test_enqueue_states.py | 103 +++++++++--------- 1 file changed, 50 insertions(+), 53 deletions(-) diff --git a/state-manager/tests/unit/controller/test_enqueue_states.py b/state-manager/tests/unit/controller/test_enqueue_states.py index 41f767c7..18754b38 100644 --- a/state-manager/tests/unit/controller/test_enqueue_states.py +++ b/state-manager/tests/unit/controller/test_enqueue_states.py @@ -47,18 +47,10 @@ async def test_enqueue_states_success( ): """Test successful enqueuing of states""" # Arrange - # Mock State.find().limit().to_list() chain - mock_query = MagicMock() - mock_query.limit = MagicMock(return_value=mock_query) - mock_query.to_list = AsyncMock(return_value=[mock_state]) - - # Mock State.find().set() chain for updating states - mock_update_query = MagicMock() - mock_update_query.set = AsyncMock() - - # Configure State.find to return different mocks based on call - mock_state_class.find = MagicMock() - mock_state_class.find.side_effect = [mock_query, mock_update_query] + # Mock State.get_pymongo_collection().find_one_and_update() + mock_collection = MagicMock() + mock_collection.find_one_and_update = AsyncMock(return_value=mock_state) + mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) # Act result = await enqueue_states( @@ -68,19 +60,29 @@ async def test_enqueue_states_success( ) # Assert - assert result.count == 1 + assert result.count == 10 # batch_size=10, so 10 states should be returned assert result.namespace == mock_namespace assert result.status == StateStatusEnum.QUEUED - assert len(result.states) == 1 + assert len(result.states) == 10 assert result.states[0].state_id == str(mock_state.id) assert result.states[0].node_name == "node1" assert result.states[0].identifier == "test_identifier" assert result.states[0].inputs == {"key": "value"} - # Verify the find query was called correctly - assert mock_state_class.find.call_count == 2 # Called twice: once for finding, once for updating - mock_query.limit.assert_called_once_with(10) - mock_update_query.set.assert_called_once() + # Verify the find_one_and_update was called correctly + assert mock_collection.find_one_and_update.call_count == 10 # Called batch_size times + mock_collection.find_one_and_update.assert_called_with( + { + "namespace_name": mock_namespace, + "status": StateStatusEnum.CREATED, + "node_name": { + "$in": ["node1", "node2"] + } + }, + { + "$set": {"status": StateStatusEnum.QUEUED} + } + ) @patch('app.controller.enqueue_states.State') async def test_enqueue_states_no_states_found( @@ -92,12 +94,10 @@ async def test_enqueue_states_no_states_found( ): """Test when no states are found to enqueue""" # Arrange - mock_query = MagicMock() - mock_query.limit = MagicMock(return_value=mock_query) - mock_query.to_list = AsyncMock(return_value=[]) - - # When no states are found, the second State.find() call won't happen - mock_state_class.find = MagicMock(return_value=mock_query) + # Mock State.get_pymongo_collection().find_one_and_update() returning None + mock_collection = MagicMock() + mock_collection.find_one_and_update = AsyncMock(return_value=None) + mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) # Act result = await enqueue_states( @@ -136,17 +136,10 @@ async def test_enqueue_states_multiple_states( state2.inputs = {"input2": "value2"} state2.created_at = datetime.now() - mock_query = MagicMock() - mock_query.limit = MagicMock(return_value=mock_query) - mock_query.to_list = AsyncMock(return_value=[state1, state2]) - - # Mock State.find().set() chain for updating states - mock_update_query = MagicMock() - mock_update_query.set = AsyncMock() - - # Configure State.find to return different mocks based on call - mock_state_class.find = MagicMock() - mock_state_class.find.side_effect = [mock_query, mock_update_query] + # Mock State.get_pymongo_collection().find_one_and_update() to return different states + mock_collection = MagicMock() + mock_collection.find_one_and_update = AsyncMock(side_effect=[state1, state2, None, None, None, None, None, None, None, None]) + mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) # Act result = await enqueue_states( @@ -171,17 +164,23 @@ async def test_enqueue_states_database_error( ): """Test handling of database errors""" # Arrange - mock_state_class.find = MagicMock(side_effect=Exception("Database error")) - - # Act & Assert - with pytest.raises(Exception) as exc_info: - await enqueue_states( - mock_namespace, - mock_enqueue_request, - mock_request_id - ) - - assert str(exc_info.value) == "Database error" + # Mock State.get_pymongo_collection().find_one_and_update() to raise an exception + mock_collection = MagicMock() + mock_collection.find_one_and_update = AsyncMock(side_effect=Exception("Database error")) + mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) + + # Act + result = await enqueue_states( + mock_namespace, + mock_enqueue_request, + mock_request_id + ) + + # Assert - the function should handle exceptions gracefully and return empty result + assert result.count == 0 + assert result.namespace == mock_namespace + assert result.status == StateStatusEnum.QUEUED + assert len(result.states) == 0 @patch('app.controller.enqueue_states.State') async def test_enqueue_states_with_different_batch_size( @@ -197,12 +196,10 @@ async def test_enqueue_states_with_different_batch_size( batch_size=5 ) - mock_query = MagicMock() - mock_query.limit = MagicMock(return_value=mock_query) - mock_query.to_list = AsyncMock(return_value=[]) - - # When no states are found, the second State.find() call won't happen - mock_state_class.find = MagicMock(return_value=mock_query) + # Mock State.get_pymongo_collection().find_one_and_update() returning None + mock_collection = MagicMock() + mock_collection.find_one_and_update = AsyncMock(return_value=None) + mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) # Act result = await enqueue_states( @@ -213,4 +210,4 @@ async def test_enqueue_states_with_different_batch_size( # Assert assert result.count == 0 - mock_query.limit.assert_called_once_with(5) + assert mock_collection.find_one_and_update.call_count == 5 # Called batch_size times From d0305bcf45aee0772aa8901fc898e3fbabf3fc5f Mon Sep 17 00:00:00 2001 From: Nivedit Jain Date: Fri, 22 Aug 2025 15:53:02 +0530 Subject: [PATCH 14/16] Update state-manager/app/tasks/create_next_states.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- state-manager/app/tasks/create_next_states.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/state-manager/app/tasks/create_next_states.py b/state-manager/app/tasks/create_next_states.py index 8d984df8..0ab7a549 100644 --- a/state-manager/app/tasks/create_next_states.py +++ b/state-manager/app/tasks/create_next_states.py @@ -1,5 +1,5 @@ from beanie import PydanticObjectId -from pymongo.errors import DuplicateKeyError +from pymongo.errors import DuplicateKeyError, BulkWriteError from beanie.operators import In, NE from app.singletons.logs_manager import LogsManager from app.models.db.graph_template_model import GraphTemplate From 424a1e726d6d19e68bdfa35077c1c8656846adf6 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Fri, 22 Aug 2025 16:21:00 +0530 Subject: [PATCH 15/16] Refactor enqueue_states to use find_state function - Updated the find_state function to return a State object or None, improving clarity in state retrieval. - Modified the create_next_states task to handle BulkWriteError in addition to DuplicateKeyError for better error management. - Refactored unit tests in test_enqueue_states.py to mock find_state instead of directly mocking database calls, enhancing test readability and maintainability. --- .../app/controller/enqueue_states.py | 3 +- state-manager/app/tasks/create_next_states.py | 2 +- .../unit/controller/test_enqueue_states.py | 69 +++++++------------ 3 files changed, 27 insertions(+), 47 deletions(-) diff --git a/state-manager/app/controller/enqueue_states.py b/state-manager/app/controller/enqueue_states.py index 46948d29..a7e3ade3 100644 --- a/state-manager/app/controller/enqueue_states.py +++ b/state-manager/app/controller/enqueue_states.py @@ -11,7 +11,7 @@ async def find_state(namespace_name: str, nodes: list[str]) -> State | None: - return await State.get_pymongo_collection().find_one_and_update( + data = await State.get_pymongo_collection().find_one_and_update( { "namespace_name": namespace_name, "status": StateStatusEnum.CREATED, @@ -23,6 +23,7 @@ async def find_state(namespace_name: str, nodes: list[str]) -> State | None: "$set": {"status": StateStatusEnum.QUEUED} } ) + return State(**data) if data else None async def enqueue_states(namespace_name: str, body: EnqueueRequestModel, x_exosphere_request_id: str) -> EnqueueResponseModel: diff --git a/state-manager/app/tasks/create_next_states.py b/state-manager/app/tasks/create_next_states.py index 0ab7a549..cf89134f 100644 --- a/state-manager/app/tasks/create_next_states.py +++ b/state-manager/app/tasks/create_next_states.py @@ -238,7 +238,7 @@ async def get_input_model(node_template: NodeTemplate) -> Type[BaseModel]: try: if len(new_unit_states) > 0: await State.insert_many(new_unit_states) - except DuplicateKeyError: + except (DuplicateKeyError, BulkWriteError): logger.warning( f"Caught duplicate key error for new unit states in namespace={namespace}, " f"graph={graph_name}, likely due to a race condition. " diff --git a/state-manager/tests/unit/controller/test_enqueue_states.py b/state-manager/tests/unit/controller/test_enqueue_states.py index 18754b38..67f73dc1 100644 --- a/state-manager/tests/unit/controller/test_enqueue_states.py +++ b/state-manager/tests/unit/controller/test_enqueue_states.py @@ -36,10 +36,10 @@ def mock_state(self): state.created_at = datetime.now() return state - @patch('app.controller.enqueue_states.State') + @patch('app.controller.enqueue_states.find_state') async def test_enqueue_states_success( self, - mock_state_class, + mock_find_state, mock_namespace, mock_enqueue_request, mock_state, @@ -47,10 +47,8 @@ async def test_enqueue_states_success( ): """Test successful enqueuing of states""" # Arrange - # Mock State.get_pymongo_collection().find_one_and_update() - mock_collection = MagicMock() - mock_collection.find_one_and_update = AsyncMock(return_value=mock_state) - mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) + # Mock find_state to return the mock_state for all calls + mock_find_state.return_value = mock_state # Act result = await enqueue_states( @@ -69,35 +67,22 @@ async def test_enqueue_states_success( assert result.states[0].identifier == "test_identifier" assert result.states[0].inputs == {"key": "value"} - # Verify the find_one_and_update was called correctly - assert mock_collection.find_one_and_update.call_count == 10 # Called batch_size times - mock_collection.find_one_and_update.assert_called_with( - { - "namespace_name": mock_namespace, - "status": StateStatusEnum.CREATED, - "node_name": { - "$in": ["node1", "node2"] - } - }, - { - "$set": {"status": StateStatusEnum.QUEUED} - } - ) + # Verify find_state was called correctly + assert mock_find_state.call_count == 10 # Called batch_size times + mock_find_state.assert_called_with(mock_namespace, ["node1", "node2"]) - @patch('app.controller.enqueue_states.State') + @patch('app.controller.enqueue_states.find_state') async def test_enqueue_states_no_states_found( self, - mock_state_class, + mock_find_state, mock_namespace, mock_enqueue_request, mock_request_id ): """Test when no states are found to enqueue""" # Arrange - # Mock State.get_pymongo_collection().find_one_and_update() returning None - mock_collection = MagicMock() - mock_collection.find_one_and_update = AsyncMock(return_value=None) - mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) + # Mock find_state to return None for all calls + mock_find_state.return_value = None # Act result = await enqueue_states( @@ -112,10 +97,10 @@ async def test_enqueue_states_no_states_found( assert result.status == StateStatusEnum.QUEUED assert len(result.states) == 0 - @patch('app.controller.enqueue_states.State') + @patch('app.controller.enqueue_states.find_state') async def test_enqueue_states_multiple_states( self, - mock_state_class, + mock_find_state, mock_namespace, mock_enqueue_request, mock_request_id @@ -136,10 +121,8 @@ async def test_enqueue_states_multiple_states( state2.inputs = {"input2": "value2"} state2.created_at = datetime.now() - # Mock State.get_pymongo_collection().find_one_and_update() to return different states - mock_collection = MagicMock() - mock_collection.find_one_and_update = AsyncMock(side_effect=[state1, state2, None, None, None, None, None, None, None, None]) - mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) + # Mock find_state to return different states + mock_find_state.side_effect = [state1, state2, None, None, None, None, None, None, None, None] # Act result = await enqueue_states( @@ -154,20 +137,18 @@ async def test_enqueue_states_multiple_states( assert result.states[0].node_name == "node1" assert result.states[1].node_name == "node2" - @patch('app.controller.enqueue_states.State') + @patch('app.controller.enqueue_states.find_state') async def test_enqueue_states_database_error( self, - mock_state_class, + mock_find_state, mock_namespace, mock_enqueue_request, mock_request_id ): """Test handling of database errors""" # Arrange - # Mock State.get_pymongo_collection().find_one_and_update() to raise an exception - mock_collection = MagicMock() - mock_collection.find_one_and_update = AsyncMock(side_effect=Exception("Database error")) - mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) + # Mock find_state to raise an exception + mock_find_state.side_effect = Exception("Database error") # Act result = await enqueue_states( @@ -182,10 +163,10 @@ async def test_enqueue_states_database_error( assert result.status == StateStatusEnum.QUEUED assert len(result.states) == 0 - @patch('app.controller.enqueue_states.State') + @patch('app.controller.enqueue_states.find_state') async def test_enqueue_states_with_different_batch_size( self, - mock_state_class, + mock_find_state, mock_namespace, mock_request_id ): @@ -196,10 +177,8 @@ async def test_enqueue_states_with_different_batch_size( batch_size=5 ) - # Mock State.get_pymongo_collection().find_one_and_update() returning None - mock_collection = MagicMock() - mock_collection.find_one_and_update = AsyncMock(return_value=None) - mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) + # Mock find_state to return None + mock_find_state.return_value = None # Act result = await enqueue_states( @@ -210,4 +189,4 @@ async def test_enqueue_states_with_different_batch_size( # Assert assert result.count == 0 - assert mock_collection.find_one_and_update.call_count == 5 # Called batch_size times + assert mock_find_state.call_count == 5 # Called batch_size times From 42d933c95be9879a5f3478afd77f844fcb24259f Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Fri, 22 Aug 2025 16:22:11 +0530 Subject: [PATCH 16/16] Refactor test_enqueue_states.py to remove AsyncMock import - Removed unused AsyncMock import from test_enqueue_states.py, streamlining the test code for better readability. --- state-manager/tests/unit/controller/test_enqueue_states.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/state-manager/tests/unit/controller/test_enqueue_states.py b/state-manager/tests/unit/controller/test_enqueue_states.py index 67f73dc1..683ba867 100644 --- a/state-manager/tests/unit/controller/test_enqueue_states.py +++ b/state-manager/tests/unit/controller/test_enqueue_states.py @@ -1,5 +1,5 @@ import pytest -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch from beanie import PydanticObjectId from datetime import datetime