Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/google/adk/features/_feature_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class FeatureConfig:
FeatureStage.WIP, default_on=False
),
FeatureName._MCP_GRACEFUL_ERROR_HANDLING: FeatureConfig(
FeatureStage.EXPERIMENTAL, default_on=False
FeatureStage.EXPERIMENTAL, default_on=True
),
FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig(
FeatureStage.EXPERIMENTAL, default_on=True
Expand Down
2 changes: 2 additions & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,7 @@ async def _postprocess_async(
not llm_response.content
and not llm_response.error_code
and not llm_response.interrupted
and not llm_response.grounding_metadata
):
return

Expand Down Expand Up @@ -1040,6 +1041,7 @@ async def _postprocess_live(
and not llm_response.output_transcription
and not llm_response.usage_metadata
and not llm_response.live_session_resumption_update
and not llm_response.grounding_metadata
):
return

Expand Down
4 changes: 2 additions & 2 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
live_session_id=live_session_id,
)
self._output_transcription_text = ''
# The Gemini API might not send a transcription finished signal.
# The Gemini API or Vertex AI might not send a transcription finished signal.
# Instead, we rely on generation_complete, turn_complete or
# interrupted signals to flush any pending transcriptions.
if self._api_backend == GoogleLLMVariant.GEMINI_API and (
if (
message.server_content.interrupted
or message.server_content.turn_complete
or message.server_content.generation_complete
Expand Down
6 changes: 6 additions & 0 deletions src/google/adk/tools/mcp_tool/session_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ async def start(self) -> ClientSession:
if not self._task:
self._task = asyncio.create_task(self._run())

def _retrieve_exception(t: asyncio.Task):
if not t.cancelled():
t.exception()

self._task.add_done_callback(_retrieve_exception)

await self._ready_event.wait()

if self._task.cancelled():
Expand Down
60 changes: 60 additions & 0 deletions tests/unittests/flows/llm_flows/test_base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,3 +1069,63 @@ async def mock_run_live_sub_agent(child_ctx, *args, **kwargs):
assert (
invocation_context.run_config.session_resumption.handle == 'test_handle'
)


@pytest.mark.asyncio
async def test_postprocess_live_yields_grounding_metadata_only():
"""Test that _postprocess_live yields LlmResponse with only grounding_metadata."""
agent = Agent(name='test_agent')
invocation_context = await testing_utils.create_invocation_context(
agent=agent
)
flow = BaseLlmFlowForTesting()

llm_request = LlmRequest()
grounding_metadata = types.GroundingMetadata(
web_search_queries=['test query'],
)
llm_response = LlmResponse(grounding_metadata=grounding_metadata)
model_response_event = Event(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
author=agent.name,
)

events = []
async for event in flow._postprocess_live(
invocation_context, llm_request, llm_response, model_response_event
):
events.append(event)

assert len(events) == 1
assert events[0].grounding_metadata == grounding_metadata


@pytest.mark.asyncio
async def test_postprocess_async_yields_grounding_metadata_only():
"""Test that _postprocess_async yields LlmResponse with only grounding_metadata."""
agent = Agent(name='test_agent')
invocation_context = await testing_utils.create_invocation_context(
agent=agent
)
flow = BaseLlmFlowForTesting()

llm_request = LlmRequest()
grounding_metadata = types.GroundingMetadata(
web_search_queries=['test query'],
)
llm_response = LlmResponse(grounding_metadata=grounding_metadata)
model_response_event = Event(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
author=agent.name,
)

events = []
async for event in flow._postprocess_async(
invocation_context, llm_request, llm_response, model_response_event
):
events.append(event)

assert len(events) == 1
assert events[0].grounding_metadata == grounding_metadata
34 changes: 28 additions & 6 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,17 @@ async def mock_receive_generator():


@pytest.mark.asyncio
@pytest.mark.parametrize(
'conn_fixture',
['gemini_api_connection', 'gemini_connection'],
)
async def test_receive_transcript_finished_on_interrupt(
gemini_api_connection,
conn_fixture,
mock_gemini_session,
request,
):
"""Test receive finishes transcription on interrupt signal."""
connection = request.getfixturevalue(conn_fixture)

message1 = mock.Mock()
message1.usage_metadata = None
Expand Down Expand Up @@ -345,7 +351,7 @@ async def mock_receive_generator():
receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock

responses = [resp async for resp in gemini_api_connection.receive()]
responses = [resp async for resp in connection.receive()]

assert len(responses) == 5
assert responses[4].interrupted is True
Expand All @@ -365,11 +371,17 @@ async def mock_receive_generator():


@pytest.mark.asyncio
@pytest.mark.parametrize(
'conn_fixture',
['gemini_api_connection', 'gemini_connection'],
)
async def test_receive_transcript_finished_on_generation_complete(
gemini_api_connection,
conn_fixture,
mock_gemini_session,
request,
):
"""Test receive finishes transcription on generation_complete signal."""
connection = request.getfixturevalue(conn_fixture)

message1 = mock.Mock()
message1.usage_metadata = None
Expand Down Expand Up @@ -425,7 +437,7 @@ async def mock_receive_generator():
receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock

responses = [resp async for resp in gemini_api_connection.receive()]
responses = [resp async for resp in connection.receive()]

assert len(responses) == 4

Expand All @@ -444,11 +456,17 @@ async def mock_receive_generator():


@pytest.mark.asyncio
@pytest.mark.parametrize(
'conn_fixture',
['gemini_api_connection', 'gemini_connection'],
)
async def test_receive_transcript_finished_on_turn_complete(
gemini_api_connection,
conn_fixture,
mock_gemini_session,
request,
):
"""Test receive finishes transcription on interrupt or complete signals."""
connection = request.getfixturevalue(conn_fixture)

message1 = mock.Mock()
message1.usage_metadata = None
Expand Down Expand Up @@ -504,7 +522,7 @@ async def mock_receive_generator():
receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock

responses = [resp async for resp in gemini_api_connection.receive()]
responses = [resp async for resp in connection.receive()]

assert len(responses) == 5
assert responses[4].turn_complete is True
Expand Down Expand Up @@ -867,6 +885,7 @@ async def test_receive_grounding_metadata_standalone(
mock_server_content.interrupted = False
mock_server_content.input_transcription = None
mock_server_content.output_transcription = None
mock_server_content.generation_complete = False

mock_message = mock.create_autospec(types.LiveServerMessage, instance=True)
mock_message.usage_metadata = None
Expand Down Expand Up @@ -911,6 +930,7 @@ async def test_receive_grounding_metadata_with_content(
mock_server_content.interrupted = False
mock_server_content.input_transcription = None
mock_server_content.output_transcription = None
mock_server_content.generation_complete = False

mock_message = mock.create_autospec(types.LiveServerMessage, instance=True)
mock_message.usage_metadata = None
Expand Down Expand Up @@ -981,6 +1001,7 @@ async def test_receive_tool_call_and_grounding_metadata_with_native_audio(
mock_server_content.interrupted = False
mock_server_content.input_transcription = None
mock_server_content.output_transcription = None
mock_server_content.generation_complete = False

mock_metadata_msg = mock.create_autospec(
types.LiveServerMessage, instance=True
Expand All @@ -1001,6 +1022,7 @@ async def test_receive_tool_call_and_grounding_metadata_with_native_audio(
mock_turn_complete_content.interrupted = False
mock_turn_complete_content.input_transcription = None
mock_turn_complete_content.output_transcription = None
mock_turn_complete_content.generation_complete = False

mock_turn_complete_msg = mock.create_autospec(
types.LiveServerMessage, instance=True
Expand Down
34 changes: 20 additions & 14 deletions tests/unittests/tools/mcp_tool/test_mcp_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,9 @@ async def test_create_session_cleans_up_without_aclose_if_loop_is_different(
self,
):
"""Verify that sessions from different loops are cleaned up without calling aclose()."""
from google.adk.features import FeatureName
from google.adk.features._feature_registry import temporary_feature_override

manager = MCPSessionManager(self.mock_stdio_connection_params)

# 1. Simulate a session created in a "different" loop
Expand Down Expand Up @@ -617,8 +620,11 @@ async def test_create_session_cleans_up_without_aclose_if_loop_is_different(
mock_wait_for.return_value = new_session
mock_session_context_class.return_value = AsyncMock()

# 3. Call create_session
session = await manager.create_session()
# 3. Call create_session with flag off to hit wait_for branch
with temporary_feature_override(
FeatureName._MCP_GRACEFUL_ERROR_HANDLING, False
):
session = await manager.create_session()

# 4. Verify results
assert session == new_session
Expand Down Expand Up @@ -969,8 +975,8 @@ class TestMCPGracefulErrorHandlingFlagContract:
loudly so we don't silently break GE's rollout.
"""

def test_default_state_is_off_so_cl_is_a_noop(self):
"""The CL must be a no-op until GE explicitly enables it."""
def test_default_state_is_on(self):
"""The fix must be enabled by default."""
import os

from google.adk.features import FeatureName
Expand All @@ -981,34 +987,34 @@ def test_default_state_is_off_so_cl_is_a_noop(self):
saved = {k: os.environ.pop(k) for k in (enable, disable) if k in os.environ}
try:
assert (
is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is False
is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is True
)
finally:
os.environ.update(saved)

def test_env_var_enable_flips_flag_on_at_runtime(self):
"""The env var GE will set must turn the fix on without a rebuild."""
def test_env_var_disable_flips_flag_off_at_runtime(self):
"""The env var must turn the fix off without a rebuild."""
import os

from google.adk.features import FeatureName
from google.adk.features import is_feature_enabled

enable = "ADK_ENABLE_MCP_GRACEFUL_ERROR_HANDLING"
saved = os.environ.pop(enable, None)
disable = "ADK_DISABLE_MCP_GRACEFUL_ERROR_HANDLING"
saved = os.environ.pop(disable, None)
try:
os.environ[enable] = "1"
os.environ[disable] = "1"
assert (
is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is True
is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is False
)
# And once it's removed, we revert. Confirms the value is read
# live from os.environ on every call (no caching, no binary push).
del os.environ[enable]
del os.environ[disable]
assert (
is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is False
is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is True
)
finally:
if saved is not None:
os.environ[enable] = saved
os.environ[disable] = saved

def test_env_var_disable_acts_as_kill_switch(self):
"""The disable env var lets consumers turn off without a rebuild."""
Expand Down