Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
33887d4
Add state existence check in create_next_states
NiveditJain Aug 22, 2025
f684e86
Refactor check_state_exists function to return State or None
NiveditJain Aug 22, 2025
5aa5c5b
Refactor state existence check in check_state_exists function
NiveditJain Aug 22, 2025
f6f91c6
Add fingerprint generation and unites handling in State model
NiveditJain Aug 22, 2025
1b87b06
Update state-manager/app/models/db/state.py
NiveditJain Aug 22, 2025
fd91f21
Update state-manager/app/tasks/create_next_states.py
NiveditJain Aug 22, 2025
70a641b
Add does_unites handling in generate_next_state function
NiveditJain Aug 22, 2025
4cfe98c
Merge branch 'fixing-multiple-branch-creation' of https://github.com/…
NiveditJain Aug 22, 2025
15ceca4
Sort parents dictionary in State model before fingerprint generation
NiveditJain Aug 22, 2025
afd5b72
Rename _fingerprint to state_fingerprint in State model
NiveditJain Aug 22, 2025
a11c42a
Add does_unites check in fingerprint generation method
NiveditJain Aug 22, 2025
d9f0fa1
Enhance State model and error handling in create_next_states
NiveditJain Aug 22, 2025
b0ab37c
improving conflicts among works while picking state, no 2 works shoul…
NiveditJain Aug 22, 2025
4cea669
Refactor tests for enqueue_states to use find_one_and_update mock
NiveditJain Aug 22, 2025
d0305bc
Update state-manager/app/tasks/create_next_states.py
NiveditJain Aug 22, 2025
424a1e7
Refactor enqueue_states to use find_state function
NiveditJain Aug 22, 2025
42d933c
Refactor test_enqueue_states.py to remove AsyncMock import
NiveditJain Aug 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions state-manager/app/controller/enqueue_states.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from beanie.operators import In
import asyncio
Comment thread
NiveditJain marked this conversation as resolved.

from ..models.enqueue_request import EnqueueRequestModel
from ..models.enqueue_response import EnqueueResponseModel, StateModel
Expand All @@ -10,27 +10,38 @@
logger = LogsManager().get_logger()


async def find_state(namespace_name: str, nodes: list[str]) -> State | None:
data = await State.get_pymongo_collection().find_one_and_update(
{
"namespace_name": namespace_name,
"status": StateStatusEnum.CREATED,
"node_name": {
"$in": nodes
}
},
{
"$set": {"status": StateStatusEnum.QUEUED}
}
)
return State(**data) if data else None

async def enqueue_states(namespace_name: str, body: EnqueueRequestModel, x_exosphere_request_id: str) -> EnqueueResponseModel:

try:
logger.info(f"Enqueuing states for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)

states = await State.find(
State.namespace_name == namespace_name,
State.status == StateStatusEnum.CREATED,
In(State.node_name, body.nodes)
).limit(
body.batch_size
).to_list()

if states:
await State.find(
In(State.id, [state.id for state in states])
).set(
{
"status": StateStatusEnum.QUEUED,
}
) # type: ignore
# Create tasks for parallel execution
tasks = [find_state(namespace_name, body.nodes) for _ in range(body.batch_size)]
results = await asyncio.gather(*tasks, return_exceptions=True)

# Filter out None results and exceptions
states = []
for result in results:
if isinstance(result, Exception):
logger.error(f"Error finding state: {result}", x_exosphere_request_id=x_exosphere_request_id)
continue
if result is not None:
states.append(result)

Comment thread
NiveditJain marked this conversation as resolved.
response = EnqueueResponseModel(
count=len(states),
Expand Down
44 changes: 42 additions & 2 deletions state-manager/app/models/db/state.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from pymongo import IndexModel
from .base import BaseDatabaseModel
from ..state_status_enum import StateStatusEnum
from pydantic import Field
from beanie import PydanticObjectId
from beanie import Insert, PydanticObjectId, Replace, Save, before_event
from typing import Any, Optional
import hashlib
import json


class State(BaseDatabaseModel):
Expand All @@ -15,4 +18,41 @@ class State(BaseDatabaseModel):
inputs: dict[str, Any] = Field(..., description="Inputs of the state")
outputs: dict[str, Any] = Field(..., description="Outputs of the state")
error: Optional[str] = Field(None, description="Error message")
parents: dict[str, PydanticObjectId] = Field(default_factory=dict, description="Parents of the state")
parents: dict[str, PydanticObjectId] = Field(default_factory=dict, description="Parents of the state")
does_unites: bool = Field(default=False, description="Whether this state unites other states")
state_fingerprint: str = Field(default="", description="Fingerprint of the state")
Comment thread
NiveditJain marked this conversation as resolved.
@before_event([Insert, Replace, Save])
def _generate_fingerprint(self):
Comment thread
NiveditJain marked this conversation as resolved.
if not self.does_unites:
self.state_fingerprint = ""
return

data = {
"node_name": self.node_name,
"namespace_name": self.namespace_name,
"identifier": self.identifier,
"graph_name": self.graph_name,
"run_id": self.run_id,
"parents": {k: str(v) for k, v in self.parents.items()},
}
payload = json.dumps(
data,
sort_keys=True, # canonical key ordering at all levels
separators=(",", ":"), # no whitespace variance
ensure_ascii=True, # normalized non-ASCII escapes
).encode("utf-8")
self.state_fingerprint = hashlib.sha256(payload).hexdigest()

class Settings:
indexes = [
IndexModel(
[
("state_fingerprint", 1)
],
unique=True,
name="uniq_state_fingerprint_unites",
partialFilterExpression={
"does_unites": True
}
)
]
18 changes: 14 additions & 4 deletions state-manager/app/tasks/create_next_states.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from beanie import PydanticObjectId
from pymongo.errors import DuplicateKeyError, BulkWriteError
Comment thread
NiveditJain marked this conversation as resolved.
from beanie.operators import In, NE
from app.singletons.logs_manager import LogsManager
from app.models.db.graph_template_model import GraphTemplate
Expand Down Expand Up @@ -58,6 +59,7 @@ async def check_unites_satisfied(namespace: str, graph_name: str, node_template:
return False
return True


def get_dependents(syntax_string: str) -> DependentString:
splits = syntax_string.split("${{")
if len(splits) <= 1:
Expand Down Expand Up @@ -134,6 +136,7 @@ def generate_next_state(next_state_input_model: Type[BaseModel], next_state_node
parents=new_parents,
inputs=next_state_input_data,
outputs={},
does_unites=next_state_node_template.unites is not None,
run_id=current_state.run_id,
error=None
)
Expand Down Expand Up @@ -231,10 +234,17 @@ async def get_input_model(node_template: NodeTemplate) -> Type[BaseModel]:
parent_state = parents[next_state_node_template.unites.identifier]

new_unit_states.append(generate_next_state(next_state_input_model, next_state_node_template, parents, parent_state))

if len(new_unit_states) > 0:
await State.insert_many(new_unit_states)


try:
if len(new_unit_states) > 0:
await State.insert_many(new_unit_states)
except (DuplicateKeyError, BulkWriteError):
logger.warning(
f"Caught duplicate key error for new unit states in namespace={namespace}, "
f"graph={graph_name}, likely due to a race condition. "
f"Attempted to insert {len(new_unit_states)} states"
)

except Exception as e:
await State.find(
In(State.id, state_ids)
Expand Down
106 changes: 41 additions & 65 deletions state-manager/tests/unit/controller/test_enqueue_states.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import MagicMock, patch
from beanie import PydanticObjectId
from datetime import datetime

Expand Down Expand Up @@ -36,29 +36,19 @@ def mock_state(self):
state.created_at = datetime.now()
return state

@patch('app.controller.enqueue_states.State')
@patch('app.controller.enqueue_states.find_state')
async def test_enqueue_states_success(
self,
mock_state_class,
mock_find_state,
mock_namespace,
mock_enqueue_request,
mock_state,
mock_request_id
):
"""Test successful enqueuing of states"""
# Arrange
# Mock State.find().limit().to_list() chain
mock_query = MagicMock()
mock_query.limit = MagicMock(return_value=mock_query)
mock_query.to_list = AsyncMock(return_value=[mock_state])

# Mock State.find().set() chain for updating states
mock_update_query = MagicMock()
mock_update_query.set = AsyncMock()

# Configure State.find to return different mocks based on call
mock_state_class.find = MagicMock()
mock_state_class.find.side_effect = [mock_query, mock_update_query]
# Mock find_state to return the mock_state for all calls
mock_find_state.return_value = mock_state

# Act
result = await enqueue_states(
Expand All @@ -68,36 +58,31 @@ async def test_enqueue_states_success(
)

# Assert
assert result.count == 1
assert result.count == 10 # batch_size=10, so 10 states should be returned
assert result.namespace == mock_namespace
assert result.status == StateStatusEnum.QUEUED
assert len(result.states) == 1
assert len(result.states) == 10
assert result.states[0].state_id == str(mock_state.id)
Comment thread
NiveditJain marked this conversation as resolved.
assert result.states[0].node_name == "node1"
assert result.states[0].identifier == "test_identifier"
assert result.states[0].inputs == {"key": "value"}

# Verify the find query was called correctly
assert mock_state_class.find.call_count == 2 # Called twice: once for finding, once for updating
mock_query.limit.assert_called_once_with(10)
mock_update_query.set.assert_called_once()
# Verify find_state was called correctly
assert mock_find_state.call_count == 10 # Called batch_size times
mock_find_state.assert_called_with(mock_namespace, ["node1", "node2"])

@patch('app.controller.enqueue_states.State')
@patch('app.controller.enqueue_states.find_state')
async def test_enqueue_states_no_states_found(
self,
mock_state_class,
mock_find_state,
mock_namespace,
mock_enqueue_request,
mock_request_id
):
"""Test when no states are found to enqueue"""
# Arrange
mock_query = MagicMock()
mock_query.limit = MagicMock(return_value=mock_query)
mock_query.to_list = AsyncMock(return_value=[])

# When no states are found, the second State.find() call won't happen
mock_state_class.find = MagicMock(return_value=mock_query)
# Mock find_state to return None for all calls
mock_find_state.return_value = None

# Act
result = await enqueue_states(
Expand All @@ -112,10 +97,10 @@ async def test_enqueue_states_no_states_found(
assert result.status == StateStatusEnum.QUEUED
assert len(result.states) == 0

@patch('app.controller.enqueue_states.State')
@patch('app.controller.enqueue_states.find_state')
async def test_enqueue_states_multiple_states(
self,
mock_state_class,
mock_find_state,
mock_namespace,
mock_enqueue_request,
mock_request_id
Expand All @@ -136,17 +121,8 @@ async def test_enqueue_states_multiple_states(
state2.inputs = {"input2": "value2"}
state2.created_at = datetime.now()

mock_query = MagicMock()
mock_query.limit = MagicMock(return_value=mock_query)
mock_query.to_list = AsyncMock(return_value=[state1, state2])

# Mock State.find().set() chain for updating states
mock_update_query = MagicMock()
mock_update_query.set = AsyncMock()

# Configure State.find to return different mocks based on call
mock_state_class.find = MagicMock()
mock_state_class.find.side_effect = [mock_query, mock_update_query]
# Mock find_state to return different states
mock_find_state.side_effect = [state1, state2, None, None, None, None, None, None, None, None]

# Act
result = await enqueue_states(
Expand All @@ -161,32 +137,36 @@ async def test_enqueue_states_multiple_states(
assert result.states[0].node_name == "node1"
assert result.states[1].node_name == "node2"

@patch('app.controller.enqueue_states.State')
@patch('app.controller.enqueue_states.find_state')
async def test_enqueue_states_database_error(
self,
mock_state_class,
mock_find_state,
mock_namespace,
mock_enqueue_request,
mock_request_id
):
"""Test handling of database errors"""
# Arrange
mock_state_class.find = MagicMock(side_effect=Exception("Database error"))

# Act & Assert
with pytest.raises(Exception) as exc_info:
await enqueue_states(
mock_namespace,
mock_enqueue_request,
mock_request_id
)

assert str(exc_info.value) == "Database error"

@patch('app.controller.enqueue_states.State')
# Mock find_state to raise an exception
mock_find_state.side_effect = Exception("Database error")

# Act
result = await enqueue_states(
mock_namespace,
mock_enqueue_request,
mock_request_id
)

# Assert - the function should handle exceptions gracefully and return empty result
assert result.count == 0
assert result.namespace == mock_namespace
assert result.status == StateStatusEnum.QUEUED
assert len(result.states) == 0

@patch('app.controller.enqueue_states.find_state')
async def test_enqueue_states_with_different_batch_size(
self,
mock_state_class,
mock_find_state,
mock_namespace,
mock_request_id
):
Expand All @@ -197,12 +177,8 @@ async def test_enqueue_states_with_different_batch_size(
batch_size=5
)

mock_query = MagicMock()
mock_query.limit = MagicMock(return_value=mock_query)
mock_query.to_list = AsyncMock(return_value=[])

# When no states are found, the second State.find() call won't happen
mock_state_class.find = MagicMock(return_value=mock_query)
# Mock find_state to return None
mock_find_state.return_value = None

# Act
result = await enqueue_states(
Expand All @@ -213,4 +189,4 @@ async def test_enqueue_states_with_different_batch_size(

# Assert
assert result.count == 0
mock_query.limit.assert_called_once_with(5)
assert mock_find_state.call_count == 5 # Called batch_size times
Loading