diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index b5f51f2825..5f67e16607 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -138,7 +138,7 @@ class FeatureConfig: FeatureStage.WIP, default_on=False ), FeatureName._MCP_GRACEFUL_ERROR_HANDLING: FeatureConfig( - FeatureStage.EXPERIMENTAL, default_on=False + FeatureStage.EXPERIMENTAL, default_on=True ), FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index db897637c3..51cdd021e0 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -976,6 +976,7 @@ async def _postprocess_async( not llm_response.content and not llm_response.error_code and not llm_response.interrupted + and not llm_response.grounding_metadata ): return @@ -1040,6 +1041,7 @@ async def _postprocess_live( and not llm_response.output_transcription and not llm_response.usage_metadata and not llm_response.live_session_resumption_update + and not llm_response.grounding_metadata ): return diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 11ed8386e1..fb9a3a5163 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -293,10 +293,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: live_session_id=live_session_id, ) self._output_transcription_text = '' - # The Gemini API might not send a transcription finished signal. + # The Gemini API or Vertex AI might not send a transcription finished signal. # Instead, we rely on generation_complete, turn_complete or # interrupted signals to flush any pending transcriptions. - if self._api_backend == GoogleLLMVariant.GEMINI_API and ( + if ( message.server_content.interrupted or message.server_content.turn_complete or message.server_content.generation_complete diff --git a/src/google/adk/tools/mcp_tool/session_context.py b/src/google/adk/tools/mcp_tool/session_context.py index 0ad63044d4..5249423cd1 100644 --- a/src/google/adk/tools/mcp_tool/session_context.py +++ b/src/google/adk/tools/mcp_tool/session_context.py @@ -130,6 +130,12 @@ async def start(self) -> ClientSession: if not self._task: self._task = asyncio.create_task(self._run()) + def _retrieve_exception(t: asyncio.Task): + if not t.cancelled(): + t.exception() + + self._task.add_done_callback(_retrieve_exception) + await self._ready_event.wait() if self._task.cancelled(): diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index ce2e83b6f7..e3c1530ca3 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -1069,3 +1069,63 @@ async def mock_run_live_sub_agent(child_ctx, *args, **kwargs): assert ( invocation_context.run_config.session_resumption.handle == 'test_handle' ) + + +@pytest.mark.asyncio +async def test_postprocess_live_yields_grounding_metadata_only(): + """Test that _postprocess_live yields LlmResponse with only grounding_metadata.""" + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + flow = BaseLlmFlowForTesting() + + llm_request = LlmRequest() + grounding_metadata = types.GroundingMetadata( + web_search_queries=['test query'], + ) + llm_response = LlmResponse(grounding_metadata=grounding_metadata) + model_response_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author=agent.name, + ) + + events = [] + async for event in flow._postprocess_live( + invocation_context, llm_request, llm_response, model_response_event + ): + events.append(event) + + assert len(events) == 1 + assert events[0].grounding_metadata == grounding_metadata + + +@pytest.mark.asyncio +async def test_postprocess_async_yields_grounding_metadata_only(): + """Test that _postprocess_async yields LlmResponse with only grounding_metadata.""" + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + flow = BaseLlmFlowForTesting() + + llm_request = LlmRequest() + grounding_metadata = types.GroundingMetadata( + web_search_queries=['test query'], + ) + llm_response = LlmResponse(grounding_metadata=grounding_metadata) + model_response_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author=agent.name, + ) + + events = [] + async for event in flow._postprocess_async( + invocation_context, llm_request, llm_response, model_response_event + ): + events.append(event) + + assert len(events) == 1 + assert events[0].grounding_metadata == grounding_metadata diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 133a455738..58aace30ed 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -285,11 +285,17 @@ async def mock_receive_generator(): @pytest.mark.asyncio +@pytest.mark.parametrize( + 'conn_fixture', + ['gemini_api_connection', 'gemini_connection'], +) async def test_receive_transcript_finished_on_interrupt( - gemini_api_connection, + conn_fixture, mock_gemini_session, + request, ): """Test receive finishes transcription on interrupt signal.""" + connection = request.getfixturevalue(conn_fixture) message1 = mock.Mock() message1.usage_metadata = None @@ -345,7 +351,7 @@ async def mock_receive_generator(): receive_mock = mock.Mock(return_value=mock_receive_generator()) mock_gemini_session.receive = receive_mock - responses = [resp async for resp in gemini_api_connection.receive()] + responses = [resp async for resp in connection.receive()] assert len(responses) == 5 assert responses[4].interrupted is True @@ -365,11 +371,17 @@ async def mock_receive_generator(): @pytest.mark.asyncio +@pytest.mark.parametrize( + 'conn_fixture', + ['gemini_api_connection', 'gemini_connection'], +) async def test_receive_transcript_finished_on_generation_complete( - gemini_api_connection, + conn_fixture, mock_gemini_session, + request, ): """Test receive finishes transcription on generation_complete signal.""" + connection = request.getfixturevalue(conn_fixture) message1 = mock.Mock() message1.usage_metadata = None @@ -425,7 +437,7 @@ async def mock_receive_generator(): receive_mock = mock.Mock(return_value=mock_receive_generator()) mock_gemini_session.receive = receive_mock - responses = [resp async for resp in gemini_api_connection.receive()] + responses = [resp async for resp in connection.receive()] assert len(responses) == 4 @@ -444,11 +456,17 @@ async def mock_receive_generator(): @pytest.mark.asyncio +@pytest.mark.parametrize( + 'conn_fixture', + ['gemini_api_connection', 'gemini_connection'], +) async def test_receive_transcript_finished_on_turn_complete( - gemini_api_connection, + conn_fixture, mock_gemini_session, + request, ): """Test receive finishes transcription on interrupt or complete signals.""" + connection = request.getfixturevalue(conn_fixture) message1 = mock.Mock() message1.usage_metadata = None @@ -504,7 +522,7 @@ async def mock_receive_generator(): receive_mock = mock.Mock(return_value=mock_receive_generator()) mock_gemini_session.receive = receive_mock - responses = [resp async for resp in gemini_api_connection.receive()] + responses = [resp async for resp in connection.receive()] assert len(responses) == 5 assert responses[4].turn_complete is True @@ -867,6 +885,7 @@ async def test_receive_grounding_metadata_standalone( mock_server_content.interrupted = False mock_server_content.input_transcription = None mock_server_content.output_transcription = None + mock_server_content.generation_complete = False mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) mock_message.usage_metadata = None @@ -911,6 +930,7 @@ async def test_receive_grounding_metadata_with_content( mock_server_content.interrupted = False mock_server_content.input_transcription = None mock_server_content.output_transcription = None + mock_server_content.generation_complete = False mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) mock_message.usage_metadata = None @@ -981,6 +1001,7 @@ async def test_receive_tool_call_and_grounding_metadata_with_native_audio( mock_server_content.interrupted = False mock_server_content.input_transcription = None mock_server_content.output_transcription = None + mock_server_content.generation_complete = False mock_metadata_msg = mock.create_autospec( types.LiveServerMessage, instance=True @@ -1001,6 +1022,7 @@ async def test_receive_tool_call_and_grounding_metadata_with_native_audio( mock_turn_complete_content.interrupted = False mock_turn_complete_content.input_transcription = None mock_turn_complete_content.output_transcription = None + mock_turn_complete_content.generation_complete = False mock_turn_complete_msg = mock.create_autospec( types.LiveServerMessage, instance=True diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index a94b2eb885..f7e16014ff 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -588,6 +588,9 @@ async def test_create_session_cleans_up_without_aclose_if_loop_is_different( self, ): """Verify that sessions from different loops are cleaned up without calling aclose().""" + from google.adk.features import FeatureName + from google.adk.features._feature_registry import temporary_feature_override + manager = MCPSessionManager(self.mock_stdio_connection_params) # 1. Simulate a session created in a "different" loop @@ -617,8 +620,11 @@ async def test_create_session_cleans_up_without_aclose_if_loop_is_different( mock_wait_for.return_value = new_session mock_session_context_class.return_value = AsyncMock() - # 3. Call create_session - session = await manager.create_session() + # 3. Call create_session with flag off to hit wait_for branch + with temporary_feature_override( + FeatureName._MCP_GRACEFUL_ERROR_HANDLING, False + ): + session = await manager.create_session() # 4. Verify results assert session == new_session @@ -969,8 +975,8 @@ class TestMCPGracefulErrorHandlingFlagContract: loudly so we don't silently break GE's rollout. """ - def test_default_state_is_off_so_cl_is_a_noop(self): - """The CL must be a no-op until GE explicitly enables it.""" + def test_default_state_is_on(self): + """The fix must be enabled by default.""" import os from google.adk.features import FeatureName @@ -981,34 +987,34 @@ def test_default_state_is_off_so_cl_is_a_noop(self): saved = {k: os.environ.pop(k) for k in (enable, disable) if k in os.environ} try: assert ( - is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is False + is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is True ) finally: os.environ.update(saved) - def test_env_var_enable_flips_flag_on_at_runtime(self): - """The env var GE will set must turn the fix on without a rebuild.""" + def test_env_var_disable_flips_flag_off_at_runtime(self): + """The env var must turn the fix off without a rebuild.""" import os from google.adk.features import FeatureName from google.adk.features import is_feature_enabled - enable = "ADK_ENABLE_MCP_GRACEFUL_ERROR_HANDLING" - saved = os.environ.pop(enable, None) + disable = "ADK_DISABLE_MCP_GRACEFUL_ERROR_HANDLING" + saved = os.environ.pop(disable, None) try: - os.environ[enable] = "1" + os.environ[disable] = "1" assert ( - is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is True + is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is False ) # And once it's removed, we revert. Confirms the value is read # live from os.environ on every call (no caching, no binary push). - del os.environ[enable] + del os.environ[disable] assert ( - is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is False + is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING) is True ) finally: if saved is not None: - os.environ[enable] = saved + os.environ[disable] = saved def test_env_var_disable_acts_as_kill_switch(self): """The disable env var lets consumers turn off without a rebuild."""