diff --git a/Cargo.lock b/Cargo.lock index 66d276de..f1683137 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,9 +114,9 @@ dependencies = [ [[package]] name = "agent-client-protocol-schema" -version = "0.13.6" +version = "0.13.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c290bfa00c6b52339db66f8e9cf711d5f08530800529f7d619ff24d6cba253d0" +checksum = "1fec685c82933a27b9d0c34594749b8b47d26ab4c2fb0e7bee268798cebe1c8b" dependencies = [ "anyhow", "derive_more", diff --git a/Cargo.toml b/Cargo.toml index 480fda39..efcfc95c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ agent-client-protocol-trace-viewer = { path = "src/agent-client-protocol-trace-v yopo = { package = "agent-client-protocol-yopo", path = "src/yopo" } # Protocol -agent-client-protocol-schema = { version = "=0.13.6", features = ["tracing"] } +agent-client-protocol-schema = { version = "=0.13.8", features = ["tracing"] } # Core async runtime tokio = { version = "1.52", features = ["full"] } diff --git a/src/agent-client-protocol-conductor/tests/trace_snapshot.rs b/src/agent-client-protocol-conductor/tests/trace_snapshot.rs index 3a3310bf..002f2098 100644 --- a/src/agent-client-protocol-conductor/tests/trace_snapshot.rs +++ b/src/agent-client-protocol-conductor/tests/trace_snapshot.rs @@ -74,7 +74,7 @@ impl EventNormalizer { fn normalize_json(&mut self, value: serde_json::Value) -> serde_json::Value { match value { serde_json::Value::Object(map) => { - let normalized: serde_json::Map = map + let mut normalized: serde_json::Map = map .into_iter() .map(|(k, v)| { let v = if k == "sessionId" { @@ -89,6 +89,14 @@ impl EventNormalizer { (k, v) }) .collect(); + if matches!( + normalized.get("auth"), + Some(serde_json::Value::Object(auth)) + if auth.len() == 1 + && auth.get("terminal") == Some(&serde_json::Value::Bool(false)) + ) { + normalized.remove("auth"); + } serde_json::Value::Object(normalized) } serde_json::Value::Array(arr) => { @@ -197,9 +205,6 @@ async fn test_trace_snapshot() -> Result<(), agent_client_protocol::Error> { "writeTextFile": Bool(false), }, "terminal": Bool(false), - "auth": Object { - "terminal": Bool(false), - }, }, }, }, diff --git a/src/agent-client-protocol/Cargo.toml b/src/agent-client-protocol/Cargo.toml index 9a1ecc00..52eca366 100644 --- a/src/agent-client-protocol/Cargo.toml +++ b/src/agent-client-protocol/Cargo.toml @@ -26,6 +26,7 @@ unstable = [ "unstable_elicitation", "unstable_end_turn_token_usage", "unstable_mcp_over_acp", + "unstable_model_config_category", "unstable_session_fork", ] unstable_auth_methods = ["agent-client-protocol-schema/unstable_auth_methods"] @@ -34,6 +35,7 @@ unstable_cancel_request = ["agent-client-protocol-schema/unstable_cancel_request unstable_elicitation = ["agent-client-protocol-schema/unstable_elicitation"] unstable_end_turn_token_usage = ["agent-client-protocol-schema/unstable_end_turn_token_usage"] unstable_mcp_over_acp = ["agent-client-protocol-schema/unstable_mcp_over_acp"] +unstable_model_config_category = ["agent-client-protocol-schema/unstable_model_config_category"] unstable_session_fork = ["agent-client-protocol-schema/unstable_session_fork"] unstable_protocol_v2 = ["agent-client-protocol-schema/unstable_protocol_v2"] diff --git a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs index b41280ee..239606bc 100644 --- a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs @@ -152,17 +152,19 @@ pub(super) async fn incoming_protocol_actor( &protocol_compat, &request_cancellations, ) { - Ok(dispatch) => { - dispatch_dispatch( - counterpart.clone(), - connection, - dispatch, - &mut dynamic_handlers, - &mut handler, - &mut pending_messages, - &request_cancellations, - ) - .await?; + Ok(dispatches) => { + for dispatch in dispatches { + dispatch_dispatch( + counterpart.clone(), + connection, + dispatch, + &mut dynamic_handlers, + &mut handler, + &mut pending_messages, + &request_cancellations, + ) + .await?; + } } Err(error) => { report_handler_error( @@ -188,17 +190,19 @@ pub(super) async fn incoming_protocol_actor( &protocol_compat, &request_cancellations, ) { - Ok(dispatch) => { - dispatch_dispatch( - counterpart.clone(), - connection, - dispatch, - &mut dynamic_handlers, - &mut handler, - &mut pending_messages, - &request_cancellations, - ) - .await?; + Ok(dispatches) => { + for dispatch in dispatches { + dispatch_dispatch( + counterpart.clone(), + connection, + dispatch, + &mut dynamic_handlers, + &mut handler, + &mut pending_messages, + &request_cancellations, + ) + .await?; + } } Err(error) => { report_handler_error(connection, None, request_method, error)?; @@ -266,22 +270,28 @@ fn dispatch_from_message( id: Option, protocol_compat: &ProtocolCompat, request_cancellations: &super::RequestCancellationRegistry, -) -> Result { +) -> Result, crate::Error> { let message = UntypedMessage::new(&method, crate::jsonrpc::params_from_transport(params)) .expect("well-formed JSON"); - let message = protocol_compat.incoming_message(message)?; match id { - Some(id) => Ok(Dispatch::Request( - message, - Responder::new( - connection.message_tx.clone(), - method.to_string(), - id, - request_cancellations, - ), - )), - None => Ok(Dispatch::Notification(message)), + Some(id) => { + let message = protocol_compat.incoming_message(message)?; + Ok(vec![Dispatch::Request( + message, + Responder::new( + connection.message_tx.clone(), + method.to_string(), + id, + request_cancellations, + ), + )]) + } + None => Ok(protocol_compat + .incoming_notification(message)? + .into_iter() + .map(Dispatch::Notification) + .collect()), } } diff --git a/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs b/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs index 3b26cdb7..ee9a5712 100644 --- a/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs @@ -73,11 +73,8 @@ pub(super) async fn outgoing_protocol_actor( request } OutgoingMessage::Notification { untyped } => { - match protocol_compat - .outgoing_message(untyped) - .and_then(|untyped| untyped.into_raw_jsonrpc_message(None)) - { - Ok(msg) => msg, + let messages = match protocol_compat.outgoing_notification(untyped) { + Ok(messages) => messages, Err(error) => { tracing::warn!( ?error, @@ -85,7 +82,24 @@ pub(super) async fn outgoing_protocol_actor( ); continue; } + }; + + for untyped in messages { + let message = match untyped.into_raw_jsonrpc_message(None) { + Ok(message) => message, + Err(error) => { + tracing::warn!( + ?error, + "Dropping outgoing notification after serialization failed" + ); + continue; + } + }; + transport_tx + .unbounded_send(Ok(message)) + .map_err(crate::Error::into_internal_error)?; } + continue; } OutgoingMessage::Response { id, @@ -146,6 +160,10 @@ mod tests { crate::UntypedMessage::new("session/new", serde_json::json!({})) } + fn malformed_v2_known_notification() -> Result { + crate::UntypedMessage::new("session/update", serde_json::json!({})) + } + #[tokio::test(flavor = "current_thread")] async fn failed_request_conversion_completes_request_locally() -> Result<(), crate::Error> { let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); @@ -196,7 +214,7 @@ mod tests { outgoing_tx .unbounded_send(OutgoingMessage::Notification { - untyped: malformed_v2_known_method()?, + untyped: malformed_v2_known_notification()?, }) .map_err(crate::Error::into_internal_error)?; outgoing_tx diff --git a/src/agent-client-protocol/src/jsonrpc/protocol_compat.rs b/src/agent-client-protocol/src/jsonrpc/protocol_compat.rs index 1d131428..2e7c8f8d 100644 --- a/src/agent-client-protocol/src/jsonrpc/protocol_compat.rs +++ b/src/agent-client-protocol/src/jsonrpc/protocol_compat.rs @@ -46,6 +46,20 @@ mod imp { Ok(message) } + pub(crate) fn incoming_notification( + &self, + message: UntypedMessage, + ) -> Result, crate::Error> { + Ok(vec![message]) + } + + pub(crate) fn outgoing_notification( + &self, + message: UntypedMessage, + ) -> Result, crate::Error> { + Ok(vec![message]) + } + pub(crate) fn incoming_response( &self, _method: &str, @@ -70,7 +84,7 @@ mod imp { use agent_client_protocol_schema::v2::{ self, - conversion::{IntoV1, IntoV2, v1_to_v2, v2_to_v1}, + conversion::{IntoV1, IntoV1Many, IntoV2, v1_to_v2, v2_to_v1, v2_to_v1_many}, }; use crate::schema::{ @@ -234,6 +248,28 @@ mod imp { convert_message(message, mode.api, wire_version) } + pub(crate) fn incoming_notification( + &self, + message: UntypedMessage, + ) -> Result, crate::Error> { + let Some(mode) = self.mode else { + return Ok(vec![message]); + }; + + convert_notification(message, self.active_wire_version(), mode.api) + } + + pub(crate) fn outgoing_notification( + &self, + message: UntypedMessage, + ) -> Result, crate::Error> { + let Some(mode) = self.mode else { + return Ok(vec![message]); + }; + + convert_notification(message, mode.api, self.active_wire_version()) + } + pub(crate) fn incoming_response( &self, method: &str, @@ -453,6 +489,51 @@ mod imp { } } + fn convert_notification( + message: UntypedMessage, + from: ProtocolVersionKind, + to: ProtocolVersionKind, + ) -> Result, crate::Error> { + if message.method().starts_with('_') || from == to { + return Ok(vec![message]); + } + + match (from, to) { + (ProtocolVersionKind::V1, ProtocolVersionKind::V2) => { + public_to_v2_notification(message) + } + (ProtocolVersionKind::V2, ProtocolVersionKind::V1) => { + v2_to_public_notification(message) + } + _ => Ok(vec![message]), + } + } + + fn public_to_v2_notification( + message: UntypedMessage, + ) -> Result, crate::Error> { + public_to_v2_message(message).map(|message| vec![message]) + } + + fn v2_to_public_notification( + message: UntypedMessage, + ) -> Result, crate::Error> { + let UntypedMessage { method, params } = message; + + if let Some(message) = + try_convert_message_to_v1::(&method, ¶ms)? + { + return Ok(vec![message]); + } + if let Some(messages) = + try_convert_message_to_v1_many::(&method, ¶ms)? + { + return Ok(messages); + } + + Ok(vec![UntypedMessage { method, params }]) + } + fn public_to_v2_message(message: UntypedMessage) -> Result { let UntypedMessage { method, params } = message; @@ -486,11 +567,6 @@ mod imp { { return Ok(message); } - if let Some(message) = try_convert_message_to_v1::(&method, ¶ms)? - { - return Ok(message); - } - Ok(UntypedMessage { method, params }) } @@ -552,6 +628,24 @@ mod imp { public_message.to_untyped_message().map(Some) } + fn try_convert_message_to_v1_many( + method: &str, + params: &serde_json::Value, + ) -> Result>, crate::Error> + where + T: JsonRpcMessage + IntoV1Many, + T::Output: JsonRpcMessage, + { + let Some(message) = try_parse_message::(method, params)? else { + return Ok(None); + }; + v2_to_v1_many(message)? + .into_iter() + .map(|public_message| public_message.to_untyped_message()) + .collect::, _>>() + .map(Some) + } + fn try_parse_message( method: &str, params: &serde_json::Value, @@ -738,6 +832,61 @@ mod imp { Ok(()) } + #[test] + fn outgoing_v2_agent_notification_fans_out_for_v1_wire() -> Result<(), crate::Error> { + let compat = ProtocolCompat::new(ProtocolMode::v2_agent()); + let messages = compat.outgoing_notification(UntypedMessage::new( + "session/update", + v2::SessionNotification::new( + "sess", + v2::SessionUpdate::AgentMessage(v2::AgentMessage::new("msg_agent").content( + vec![ + v2::ContentBlock::Text(v2::TextContent::new("hello")), + v2::ContentBlock::Text(v2::TextContent::new("world")), + ], + )), + ), + )?)?; + + assert_eq!(messages.len(), 2); + let json = messages + .into_iter() + .map(|message| { + assert_eq!(message.method(), "session/update"); + Ok(message.params) + }) + .collect::, crate::Error>>()?; + assert_eq!( + json, + vec![ + serde_json::json!({ + "sessionId": "sess", + "update": { + "sessionUpdate": "agent_message_chunk", + "content": { + "type": "text", + "text": "hello" + }, + "messageId": "msg_agent" + } + }), + serde_json::json!({ + "sessionId": "sess", + "update": { + "sessionUpdate": "agent_message_chunk", + "content": { + "type": "text", + "text": "world" + }, + "messageId": "msg_agent" + } + }), + ] + ); + + Ok(()) + } + #[test] #[should_panic(expected = "cannot merge ACP builders with different API protocol versions")] fn merging_different_api_protocol_modes_panics() { diff --git a/src/agent-client-protocol/tests/protocol_v2.rs b/src/agent-client-protocol/tests/protocol_v2.rs index 1fdc7090..01e7e93a 100644 --- a/src/agent-client-protocol/tests/protocol_v2.rs +++ b/src/agent-client-protocol/tests/protocol_v2.rs @@ -49,6 +49,13 @@ fn cwd() -> Result { std::env::current_dir().map_err(Error::into_internal_error) } +fn v2_initialize_response_with_session( + protocol_version: ProtocolVersion, +) -> v2::InitializeResponse { + v2::InitializeResponse::new(protocol_version) + .capabilities(v2::AgentCapabilities::new().session(v2::SessionCapabilities::new())) +} + #[cfg(feature = "unstable_mcp_over_acp")] fn json_value(value: impl Serialize) -> Result { serde_json::to_value(value).map_err(Error::into_internal_error) @@ -215,7 +222,9 @@ async fn role_builder_v1_client_downgrades_initialize_for_v2_agent() -> Result<( let agent = Agent.v2().on_receive_request( async |initialize: v2::InitializeRequest, responder, _cx| { assert_eq!(initialize.protocol_version, ProtocolVersion::V2); - responder.respond(v2::InitializeResponse::new(initialize.protocol_version)) + responder.respond(v2_initialize_response_with_session( + initialize.protocol_version, + )) }, agent_client_protocol::on_receive_request!(), ); @@ -451,7 +460,7 @@ async fn v2_agent_serves_v1_client_with_v2_handlers() -> Result<(), Error> { async |initialize: v2::InitializeRequest, responder, _cx| { assert_eq!(initialize.protocol_version, ProtocolVersion::V2); // The compatibility layer should force this back to the negotiated v1 wire version. - responder.respond(v2::InitializeResponse::new(ProtocolVersion::V2)) + responder.respond(v2_initialize_response_with_session(ProtocolVersion::V2)) }, agent_client_protocol::on_receive_request!(), ) @@ -557,7 +566,9 @@ fn v2_agent_with_cancellable_new_session() .v2() .on_receive_request( async |initialize: v2::InitializeRequest, responder, _cx| { - responder.respond(v2::InitializeResponse::new(initialize.protocol_version)) + responder.respond(v2_initialize_response_with_session( + initialize.protocol_version, + )) }, agent_client_protocol::on_receive_request!(), ) diff --git a/src/agent-client-protocol/tests/schema_elicitation.rs b/src/agent-client-protocol/tests/schema_elicitation.rs index 49e1c33a..e47fdada 100644 --- a/src/agent-client-protocol/tests/schema_elicitation.rs +++ b/src/agent-client-protocol/tests/schema_elicitation.rs @@ -186,17 +186,24 @@ fn protocol_v2_elicitation_variants_are_jsonrpc_mapped() -> Result<(), Error> { #[cfg(feature = "unstable_protocol_v2")] #[tokio::test(flavor = "current_thread")] -async fn v2_agent_can_elicit_from_v1_client() -> Result<(), Error> { +async fn v2_agent_can_elicit_from_v1_client_before_prompt_completion() -> Result<(), Error> { use agent_client_protocol::schema::{self, ProtocolVersion, v2}; use agent_client_protocol::{Agent, Client}; use std::collections::BTreeMap; + fn v2_initialize_response_with_session( + protocol_version: ProtocolVersion, + ) -> v2::InitializeResponse { + v2::InitializeResponse::new(protocol_version) + .capabilities(v2::AgentCapabilities::new().session(v2::SessionCapabilities::new())) + } + let agent = Agent .v2() .on_receive_request( async |initialize: v2::InitializeRequest, responder, _cx| { assert_eq!(initialize.protocol_version, ProtocolVersion::V2); - responder.respond(v2::InitializeResponse::new(ProtocolVersion::V2)) + responder.respond(v2_initialize_response_with_session(ProtocolVersion::V2)) }, agent_client_protocol::on_receive_request!(), ) @@ -223,7 +230,7 @@ async fn v2_agent_can_elicit_from_v1_client() -> Result<(), Error> { content.get("name"), Some(&v2::ElicitationContentValue::String("Ada".into())) ); - responder.respond(v2::PromptResponse::new(v2::StopReason::EndTurn)) + responder.respond_with_error(Error::request_cancelled()) })?; Ok(()) @@ -255,14 +262,15 @@ async fn v2_agent_can_elicit_from_v1_client() -> Result<(), Error> { .await?; assert_eq!(initialize.protocol_version, ProtocolVersion::V1); - let prompt = cx + let error = cx .send_request(schema::PromptRequest::new( "sess_abc123", vec!["continue".into()], )) .block_task() - .await?; - assert_eq!(prompt.stop_reason, schema::StopReason::EndTurn); + .await + .expect_err("test agent cancels the prompt after the elicitation round trip"); + assert_eq!(i32::from(error.code), -32800); Ok(()) }) .await