diff --git a/README.md b/README.md index a174a4050..c420640b7 100644 --- a/README.md +++ b/README.md @@ -199,6 +199,7 @@ See [examples](examples/README.md) - `transport-sse-server`: Server SSE transport - `transport-child-process`: Client stdio transport - `transport-sse`: Client sse transport +- `transport-streamable-http-server` streamable http server transport ## Related Resources diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 088c72fb2..09d532d96 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -49,10 +49,11 @@ tower-service = { version = "0.3", optional = true } # for ws transport # tokio-tungstenite ={ version = "0.26", optional = true } -# for sse-server transport +# for http-server transport axum = { version = "0.8", features = [], optional = true } rand = { version = "0.9", optional = true } tokio-stream = { version = "0.1", optional = true } +uuid = { version = "1", features = ["v4"], optional = true } # macro rmcp-macros = { version = "0.1", workspace = true, optional = true } @@ -71,6 +72,16 @@ transport-sse-server = [ "dep:axum", "dep:rand", "dep:tokio-stream", + "uuid", +] +transport-streamable-http-server = [ + "transport-streamable-http-server-session", + "dep:axum", + "uuid", +] +transport-streamable-http-server-session = [ + "transport-async-rw", + "dep:tokio-stream", ] # transport-ws = ["transport-io", "dep:tokio-tungstenite"] tower = ["dep:tower-service"] @@ -99,7 +110,7 @@ path = "tests/test_with_python.rs" [[test]] name = "test_with_js" -required-features = ["server", "client", "transport-sse-server", "transport-child-process"] +required-features = ["server", "client", "transport-sse-server", "transport-child-process", "transport-streamable-http-server"] path = "tests/test_with_js.rs" [[test]] diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index 7ff035eea..601827227 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -186,7 +186,7 @@ impl<'de> Deserialize<'de> for NumberOrString { pub type RequestId = NumberOrString; -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Hash, Eq)] #[serde(transparent)] pub struct ProgressToken(pub NumberOrString); #[derive(Debug, Clone)] diff --git a/crates/rmcp/src/model/tool.rs b/crates/rmcp/src/model/tool.rs index 3d973eb53..ca5100d96 100644 --- a/crates/rmcp/src/model/tool.rs +++ b/crates/rmcp/src/model/tool.rs @@ -18,6 +18,7 @@ pub struct Tool { pub description: Option>, /// A JSON Schema object defining the expected parameters for the tool pub input_schema: Arc, + #[serde(skip_serializing_if = "Option::is_none")] /// Optional additional tool information. pub annotations: Option, } diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 54aa081bb..0a972475c 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -591,7 +591,7 @@ where } }; - tracing::debug!(?evt, "new event"); + tracing::trace!(?evt, "new event"); match evt { // response and error Event::ToSink(m) => { @@ -657,7 +657,7 @@ where Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest { id, request, .. })) => { - tracing::info!(%id, ?request, "received request"); + tracing::debug!(%id, ?request, "received request"); { let service = shared_service.clone(); let sink = sink_proxy_tx.clone(); @@ -675,7 +675,7 @@ where let result = service.handle_request(request, context).await; let response = match result { Ok(result) => { - tracing::info!(%id, ?result, "response message"); + tracing::debug!(%id, ?result, "response message"); JsonRpcMessage::response(result, id) } Err(error) => { diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index 9b49f7622..b601f5b0b 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -76,6 +76,13 @@ pub use auth::{AuthError, AuthorizationManager, AuthorizationSession, Authorized // #[cfg(feature = "transport-ws")] // pub mod ws; +#[cfg(feature = "transport-streamable-http-server-session")] +pub mod streamable_http_server; +#[cfg(feature = "transport-streamable-http-server")] +pub use streamable_http_server::axum::StreamableHttpServer; + +/// Common use codes +pub mod common; pub trait IntoTransport: Send + 'static where diff --git a/crates/rmcp/src/transport/common.rs b/crates/rmcp/src/transport/common.rs new file mode 100644 index 000000000..57bbf3f7f --- /dev/null +++ b/crates/rmcp/src/transport/common.rs @@ -0,0 +1,5 @@ +#[cfg(any( + feature = "transport-streamable-http-server", + feature = "transport-sse-server" +))] +pub mod axum; diff --git a/crates/rmcp/src/transport/common/axum.rs b/crates/rmcp/src/transport/common/axum.rs new file mode 100644 index 000000000..a26115752 --- /dev/null +++ b/crates/rmcp/src/transport/common/axum.rs @@ -0,0 +1,9 @@ +use std::{sync::Arc, time::Duration}; + +pub type SessionId = Arc; + +pub fn session_id() -> SessionId { + uuid::Uuid::new_v4().to_string().into() +} + +pub const DEFAULT_AUTO_PING_INTERVAL: Duration = Duration::from_secs(15); diff --git a/crates/rmcp/src/transport/sse_server.rs b/crates/rmcp/src/transport/sse_server.rs index ed04b3d0a..5389ea9da 100644 --- a/crates/rmcp/src/transport/sse_server.rs +++ b/crates/rmcp/src/transport/sse_server.rs @@ -19,14 +19,13 @@ use crate::{ RoleServer, Service, model::ClientJsonRpcMessage, service::{RxJsonRpcMessage, TxJsonRpcMessage}, + transport::common::axum::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id}, }; -type SessionId = Arc; + type TxStore = Arc>>>; pub type TransportReceiver = ReceiverStream>; -const DEFAULT_AUTO_PING_INTERVAL: Duration = Duration::from_secs(15); - #[derive(Clone)] struct App { txs: TxStore, @@ -56,11 +55,6 @@ impl App { } } -fn session_id() -> SessionId { - let id = format!("{:016x}", rand::random::()); - Arc::from(id) -} - #[derive(Debug, serde::Deserialize)] #[serde(rename_all = "camelCase")] pub struct PostEventQuery { diff --git a/crates/rmcp/src/transport/streamable_http_server.rs b/crates/rmcp/src/transport/streamable_http_server.rs new file mode 100644 index 000000000..ec128fcc9 --- /dev/null +++ b/crates/rmcp/src/transport/streamable_http_server.rs @@ -0,0 +1,3 @@ +#[cfg(feature = "transport-streamable-http-server")] +pub mod axum; +pub mod session; diff --git a/crates/rmcp/src/transport/streamable_http_server/axum.rs b/crates/rmcp/src/transport/streamable_http_server/axum.rs new file mode 100644 index 000000000..f01ecec2e --- /dev/null +++ b/crates/rmcp/src/transport/streamable_http_server/axum.rs @@ -0,0 +1,335 @@ +use std::{collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration}; + +use axum::{ + Json, Router, + extract::State, + http::{HeaderMap, HeaderValue, StatusCode}, + response::{ + IntoResponse, Response, + sse::{Event, KeepAlive, Sse}, + }, + routing::get, +}; +use futures::Stream; +use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; +use tracing::Instrument; + +use super::session::{ + EventId, HEADER_LAST_EVENT_ID, Session, SessionTransport, StreamableHttpMessageReceiver, +}; +use crate::{ + RoleServer, Service, + model::ClientJsonRpcMessage, + transport::{ + common::axum::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id}, + streamable_http_server::session::HEADER_SESSION_ID, + }, +}; +type SessionManager = Arc>>; + +#[derive(Clone)] +struct App { + session_manager: SessionManager, + transport_tx: tokio::sync::mpsc::UnboundedSender, + sse_ping_interval: Duration, +} + +impl App { + pub fn new( + sse_ping_interval: Duration, + ) -> (Self, tokio::sync::mpsc::UnboundedReceiver) { + let (transport_tx, transport_rx) = tokio::sync::mpsc::unbounded_channel(); + ( + Self { + session_manager: Default::default(), + transport_tx, + sse_ping_interval, + }, + transport_rx, + ) + } +} + +fn receiver_as_stream( + receiver: StreamableHttpMessageReceiver, +) -> impl Stream> { + use futures::StreamExt; + ReceiverStream::new(receiver.inner).map(|message| { + match serde_json::to_string(&message.message) { + Ok(bytes) => Ok(Event::default() + .event("message") + .data(&bytes) + .id(message.event_id.to_string())), + Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)), + } + }) +} + +async fn post_handler( + State(app): State, + header_map: HeaderMap, + Json(message): Json, +) -> Result { + use futures::StreamExt; + if let Some(session_id) = header_map.get(HEADER_SESSION_ID) { + let session_id = session_id + .to_str() + .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()).into_response())?; + tracing::debug!(session_id, ?message, "new client message"); + let handle = { + let sm = app.session_manager.read().await; + let session = sm + .get(session_id) + .ok_or((StatusCode::NOT_FOUND, "session not found").into_response())?; + session.handle().clone() + }; + match &message { + ClientJsonRpcMessage::Request(_) | ClientJsonRpcMessage::BatchRequest(_) => { + let receiver = handle.establish_request_wise_channel().await.map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("fail to to establish request channel: {e}"), + ) + .into_response() + })?; + let http_request_id = receiver.http_request_id; + if let Err(push_err) = handle.push_message(message, http_request_id).await { + tracing::error!(session_id, ?push_err, "push message error"); + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + format!("fail to push message: {push_err}"), + ) + .into_response()); + } + let stream = + ReceiverStream::new(receiver.inner).map(|message| match serde_json::to_string( + &message.message, + ) { + Ok(bytes) => Ok(Event::default() + .event("message") + .data(&bytes) + .id(message.event_id.to_string())), + Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)), + }); + Ok(Sse::new(stream) + .keep_alive(KeepAlive::new().interval(app.sse_ping_interval)) + .into_response()) + } + _ => { + let result = handle.push_message(message, None).await; + if result.is_err() { + Err((StatusCode::GONE, "session terminated").into_response()) + } else { + Ok(StatusCode::ACCEPTED.into_response()) + } + } + } + } else { + // expect initialize message + let session_id = session_id(); + let (session, transport) = + super::session::create_session(session_id.clone(), Default::default()); + let Ok(_) = app.transport_tx.send(transport) else { + return Err((StatusCode::GONE, "session terminated").into_response()); + }; + + let response = session.handle().initialize(message).await.map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("fail to initialize: {e}"), + ) + .into_response() + })?; + let mut response = Json(response).into_response(); + response.headers_mut().insert( + HEADER_SESSION_ID, + HeaderValue::from_bytes(session_id.as_bytes()).expect("should be valid header value"), + ); + app.session_manager + .write() + .await + .insert(session_id, session); + Ok(response) + } +} + +async fn get_handler( + State(app): State, + header_map: HeaderMap, +) -> Result>>, Response> { + let session_id = header_map + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()); + if let Some(session_id) = session_id { + let last_event_id = header_map + .get(HEADER_LAST_EVENT_ID) + .and_then(|v| v.to_str().ok()); + match last_event_id { + Some(last_event_id) => { + let last_event_id = last_event_id.parse::().map_err(|e| { + (StatusCode::BAD_REQUEST, format!("invalid event_id {e}")).into_response() + })?; + let sm = app.session_manager.read().await; + let session = sm.get(session_id).ok_or_else(|| { + ( + StatusCode::NOT_FOUND, + format!("session {session_id} not found"), + ) + .into_response() + })?; + let handle = session.handle(); + let receiver = handle.resume(last_event_id).await.map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("resume error {e}"), + ) + .into_response() + })?; + let stream = receiver_as_stream(receiver); + Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(app.sse_ping_interval))) + } + None => { + let sm = app.session_manager.read().await; + let session = sm.get(session_id).ok_or_else(|| { + ( + StatusCode::NOT_FOUND, + format!("session {session_id} not found"), + ) + .into_response() + })?; + let handle = session.handle(); + let receiver = handle.establish_common_channel().await.map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("establish common channel error {e}"), + ) + .into_response() + })?; + let stream = receiver_as_stream(receiver); + Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(app.sse_ping_interval))) + } + } + } else { + Err((StatusCode::BAD_REQUEST, "missing session id").into_response()) + } +} + +async fn delete_handler( + State(app): State, + header_map: HeaderMap, +) -> Result { + if let Some(session_id) = header_map.get(HEADER_SESSION_ID) { + let session_id = session_id + .to_str() + .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()).into_response())?; + let mut sm = app.session_manager.write().await; + let session = sm + .remove(session_id) + .ok_or((StatusCode::NOT_FOUND, "session not found").into_response())?; + let cancel_result = session.cancel().await.map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("fail to cancel session {session_id}: tokio join error: {e}"), + ) + .into_response() + })?; + tracing::info!(session_id, ?cancel_result, "session deleted"); + Ok(StatusCode::ACCEPTED) + } else { + Err((StatusCode::BAD_REQUEST, "missing session id").into_response()) + } +} + +#[derive(Debug, Clone)] +pub struct StreamableHttpServerConfig { + pub bind: SocketAddr, + pub path: String, + pub ct: CancellationToken, + pub sse_keep_alive: Option, +} + +#[derive(Debug)] +pub struct StreamableHttpServer { + transport_rx: tokio::sync::mpsc::UnboundedReceiver, + pub config: StreamableHttpServerConfig, +} + +impl StreamableHttpServer { + pub async fn serve(bind: SocketAddr) -> io::Result { + Self::serve_with_config(StreamableHttpServerConfig { + bind, + path: "/".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: None, + }) + .await + } + pub async fn serve_with_config(config: StreamableHttpServerConfig) -> io::Result { + let (streamable_http_server, service) = Self::new(config); + let listener = tokio::net::TcpListener::bind(streamable_http_server.config.bind).await?; + let ct = streamable_http_server.config.ct.child_token(); + let server = axum::serve(listener, service).with_graceful_shutdown(async move { + ct.cancelled().await; + tracing::info!("streamable http server cancelled"); + }); + tokio::spawn( + async move { + if let Err(e) = server.await { + tracing::error!(error = %e, "streamable http server shutdown with error"); + } + } + .instrument(tracing::info_span!("streamable-http-server", bind_address = %streamable_http_server.config.bind)), + ); + Ok(streamable_http_server) + } + + /// Warning: This function creates a new StreamableHttpServer instance with the provided configuration. + /// `App.post_path` may be incorrect if using `Router` as an embedded router. + pub fn new(config: StreamableHttpServerConfig) -> (StreamableHttpServer, Router) { + let (app, transport_rx) = + App::new(config.sse_keep_alive.unwrap_or(DEFAULT_AUTO_PING_INTERVAL)); + let router = Router::new() + .route( + &config.path, + get(get_handler).post(post_handler).delete(delete_handler), + ) + .with_state(app); + + let server = StreamableHttpServer { + transport_rx, + config, + }; + + (server, router) + } + + pub fn with_service(mut self, service_provider: F) -> CancellationToken + where + S: Service, + F: Fn() -> S + Send + 'static, + { + use crate::service::ServiceExt; + let ct = self.config.ct.clone(); + tokio::spawn(async move { + while let Some(transport) = self.next_transport().await { + let service = service_provider(); + let ct = self.config.ct.child_token(); + tokio::spawn(async move { + let server = service.serve_with_ct(transport, ct).await?; + server.waiting().await?; + tokio::io::Result::Ok(()) + }); + } + }); + ct + } + + pub fn cancel(&self) { + self.config.ct.cancel(); + } + + pub async fn next_transport(&mut self) -> Option { + self.transport_rx.recv().await + } +} diff --git a/crates/rmcp/src/transport/streamable_http_server/session.rs b/crates/rmcp/src/transport/streamable_http_server/session.rs new file mode 100644 index 000000000..574ad17d6 --- /dev/null +++ b/crates/rmcp/src/transport/streamable_http_server/session.rs @@ -0,0 +1,767 @@ +use std::{ + borrow::Cow, + collections::{HashMap, HashSet, VecDeque}, + num::ParseIntError, + sync::Arc, +}; + +use futures::{Sink, SinkExt, Stream}; +use thiserror::Error; +use tokio::sync::{ + mpsc::{Receiver, Sender}, + oneshot, +}; +use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::{CancellationToken, DropGuard, PollSender}; +use tracing::instrument; + +use crate::{ + RoleServer, + model::{ + CancelledNotificationParam, ClientJsonRpcMessage, ClientNotification, ClientRequest, + JsonRpcNotification, JsonRpcRequest, Notification, ProgressNotificationParam, + ProgressToken, RequestId, ServerJsonRpcMessage, ServerNotification, + }, + transport::IntoTransport, +}; + +pub const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; +pub const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; +#[derive(Debug, Clone)] +pub struct ServerSessionMessage { + pub event_id: EventId, + pub message: Arc, +} + +/// - +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct EventId { + http_request_id: Option, + index: usize, +} + +impl std::fmt::Display for EventId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.index)?; + match &self.http_request_id { + Some(http_request_id) => write!(f, "/{http_request_id}"), + None => write!(f, ""), + } + } +} + +#[derive(Debug, Clone, Error)] +pub enum EventIdParseError { + #[error("Invalid index: {0}")] + InvalidIndex(ParseIntError), + #[error("Invalid numeric request id: {0}")] + InvalidNumericRequestId(ParseIntError), + #[error("Missing request id type")] + InvalidRequestIdType, + #[error("Missing request id")] + MissingRequestId, +} + +impl std::str::FromStr for EventId { + type Err = EventIdParseError; + fn from_str(s: &str) -> Result { + if let Some((index, request_id)) = s.split_once("/") { + let index = usize::from_str(index).map_err(EventIdParseError::InvalidIndex)?; + let request_id = u64::from_str(request_id).map_err(EventIdParseError::InvalidIndex)?; + Ok(EventId { + http_request_id: Some(request_id), + index, + }) + } else { + let index = usize::from_str(s).map_err(EventIdParseError::InvalidIndex)?; + Ok(EventId { + http_request_id: None, + index, + }) + } + } +} + +pub use crate::transport::common::axum::SessionId; + +struct CachedTx { + tx: Sender, + cache: VecDeque, + http_request_id: Option, + capacity: usize, +} + +impl CachedTx { + fn new(tx: Sender, http_request_id: Option) -> Self { + Self { + cache: VecDeque::with_capacity(tx.capacity()), + capacity: tx.capacity(), + tx, + http_request_id, + } + } + fn new_common(tx: Sender) -> Self { + Self::new(tx, None) + } + + async fn send(&mut self, message: ServerJsonRpcMessage) { + let index = self.cache.back().map_or(0, |m| m.event_id.index + 1); + let event_id = EventId { + http_request_id: self.http_request_id, + index, + }; + let message = ServerSessionMessage { + event_id: event_id.clone(), + message: Arc::new(message), + }; + if self.cache.len() >= self.capacity { + self.cache.pop_front(); + self.cache.push_back(message.clone()); + } else { + self.cache.push_back(message.clone()); + } + let _ = self.tx.send(message).await.inspect_err(|e| { + let event_id = &e.0.event_id; + tracing::trace!(%event_id, "trying to send message in a closed session") + }); + } + + async fn sync(&mut self, index: usize) -> Result<(), SessionError> { + let Some(front) = self.cache.front() else { + return Ok(()); + }; + let sync_index = index.saturating_sub(front.event_id.index); + if sync_index > self.cache.len() { + // invalid index + return Err(SessionError::InvalidEventId); + } + for message in self.cache.iter().skip(sync_index) { + let send_result = self.tx.send(message.clone()).await; + if send_result.is_err() { + return Err(SessionError::ChannelClosed( + message.event_id.http_request_id, + )); + } + } + Ok(()) + } +} + +struct HttpRequestWise { + resources: HashSet, + tx: CachedTx, +} + +type HttpRequestId = u64; +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +enum ResourceKey { + McpRequestId(RequestId), + ProgressToken(ProgressToken), +} + +struct SessionContext { + id: SessionId, + next_http_request_id: HttpRequestId, + tx_router: HashMap, + resource_router: HashMap, + common: CachedTx, + to_service_tx: Sender, + event_rx: Receiver, + session_config: SessionConfig, +} + +#[derive(Debug, Error)] +pub enum SessionError { + #[error("Invalid request id: {0}")] + DuplicatedRequestId(HttpRequestId), + #[error("Channel closed: {0:?}")] + ChannelClosed(Option), + #[error("Cannot parse event id: {0}")] + EventIdParseError(Cow<'static, str>), + #[error("Session service terminated")] + SessionServiceTerminated, + #[error("Invalid event id")] + InvalidEventId, + #[error("Transport closed")] + TransportClosed, + #[error("IO error: {0}")] + Io(std::io::Error), +} + +impl From for std::io::Error { + fn from(value: SessionError) -> Self { + match value { + SessionError::Io(io) => io, + _ => std::io::Error::new(std::io::ErrorKind::Other, format!("Session error: {value}")), + } + } +} +impl From for SessionError { + fn from(value: std::io::Error) -> Self { + SessionError::Io(value) + } +} + +enum OutboundChannel { + RequestWise { id: HttpRequestId, close: bool }, + Common, +} + +pub struct StreamableHttpMessageReceiver { + pub http_request_id: Option, + pub inner: Receiver, +} + +impl SessionContext { + fn unregister_resource(&mut self, resource: &ResourceKey) { + if let Some(http_request_id) = self.resource_router.remove(resource) { + tracing::trace!(?resource, http_request_id, "unregister resource"); + if let Some(channel) = self.tx_router.get_mut(&http_request_id) { + channel.resources.remove(resource); + if channel.resources.is_empty() { + tracing::debug!(http_request_id, "close http request wise channel"); + self.tx_router.remove(&http_request_id); + } + } + } + } + fn register_resource(&mut self, resource: ResourceKey, http_request_id: HttpRequestId) { + tracing::trace!(?resource, http_request_id, "register resource"); + if let Some(channel) = self.tx_router.get_mut(&http_request_id) { + channel.resources.insert(resource.clone()); + self.resource_router.insert(resource, http_request_id); + } + } + fn register_request( + &mut self, + request: &JsonRpcRequest, + http_request_id: HttpRequestId, + ) { + use crate::model::GetMeta; + self.register_resource( + ResourceKey::McpRequestId(request.id.clone()), + http_request_id, + ); + if let Some(progress_token) = request.request.get_meta().get_progress_token() { + self.register_resource( + ResourceKey::ProgressToken(progress_token.clone()), + http_request_id, + ); + } + } + fn catch_cancellation_notification( + &mut self, + notification: &JsonRpcNotification, + ) { + if let ClientNotification::CancelledNotification(n) = ¬ification.notification { + let request_id = n.params.request_id.clone(); + let resource = ResourceKey::McpRequestId(request_id); + self.unregister_resource(&resource); + } + } + fn next_http_request_id(&mut self) -> HttpRequestId { + let id = self.next_http_request_id; + self.next_http_request_id = self.next_http_request_id.wrapping_add(1); + id + } + async fn send_to_service(&self, message: ClientJsonRpcMessage) -> Result<(), SessionError> { + if self.to_service_tx.send(message).await.is_err() { + return Err(SessionError::TransportClosed); + } + Ok(()) + } + async fn establish_request_wise_channel( + &mut self, + ) -> Result { + let http_request_id = self.next_http_request_id(); + let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity); + self.tx_router.insert( + http_request_id, + HttpRequestWise { + resources: Default::default(), + tx: CachedTx::new(tx, Some(http_request_id)), + }, + ); + tracing::debug!(http_request_id, "establish new request wise channel"); + Ok(StreamableHttpMessageReceiver { + http_request_id: Some(http_request_id), + inner: rx, + }) + } + fn resolve_outbound_channel(&self, message: &ServerJsonRpcMessage) -> OutboundChannel { + match &message { + ServerJsonRpcMessage::Request(_) => OutboundChannel::Common, + ServerJsonRpcMessage::Notification(JsonRpcNotification { + notification: + ServerNotification::ProgressNotification(Notification { + params: ProgressNotificationParam { progress_token, .. }, + .. + }), + .. + }) => { + let id = self + .resource_router + .get(&ResourceKey::ProgressToken(progress_token.clone())); + + if let Some(id) = id { + OutboundChannel::RequestWise { + id: *id, + close: false, + } + } else { + OutboundChannel::Common + } + } + ServerJsonRpcMessage::Notification(JsonRpcNotification { + notification: + ServerNotification::CancelledNotification(Notification { + params: CancelledNotificationParam { request_id, .. }, + .. + }), + .. + }) => { + if let Some(id) = self + .resource_router + .get(&ResourceKey::McpRequestId(request_id.clone())) + { + OutboundChannel::RequestWise { + id: *id, + close: false, + } + } else { + OutboundChannel::Common + } + } + ServerJsonRpcMessage::Notification(_) => OutboundChannel::Common, + ServerJsonRpcMessage::Response(json_rpc_response) => { + if let Some(id) = self + .resource_router + .get(&ResourceKey::McpRequestId(json_rpc_response.id.clone())) + { + OutboundChannel::RequestWise { + id: *id, + close: false, + } + } else { + OutboundChannel::Common + } + } + ServerJsonRpcMessage::Error(json_rpc_error) => { + if let Some(id) = self + .resource_router + .get(&ResourceKey::McpRequestId(json_rpc_error.id.clone())) + { + OutboundChannel::RequestWise { + id: *id, + close: false, + } + } else { + OutboundChannel::Common + } + } + ServerJsonRpcMessage::BatchRequest(_) | ServerJsonRpcMessage::BatchResponse(_) => { + // the server side should never yield a batch request or response now + unreachable!("server side won't yield batch request or response") + } + } + } + async fn handle_server_message( + &mut self, + message: ServerJsonRpcMessage, + ) -> Result<(), SessionError> { + let outbound_channel = self.resolve_outbound_channel(&message); + match outbound_channel { + OutboundChannel::RequestWise { id, close } => { + if let Some(request_wise) = self.tx_router.get_mut(&id) { + request_wise.tx.send(message).await; + if close { + self.tx_router.remove(&id); + } + } else { + return Err(SessionError::ChannelClosed(Some(id))); + } + } + OutboundChannel::Common => self.common.send(message).await, + } + Ok(()) + } + async fn resume( + &mut self, + last_event_id: EventId, + ) -> Result { + match last_event_id.http_request_id { + Some(http_request_id) => { + let request_wise = self + .tx_router + .get_mut(&http_request_id) + .ok_or(SessionError::ChannelClosed(Some(http_request_id)))?; + let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity); + let (tx, rx) = channel; + request_wise.tx.tx = tx; + let index = last_event_id.index; + // sync messages after index + request_wise.tx.sync(index).await?; + Ok(StreamableHttpMessageReceiver { + http_request_id: Some(http_request_id), + inner: rx, + }) + } + None => { + let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity); + let (tx, rx) = channel; + self.common.tx = tx; + let index = last_event_id.index; + // sync messages after index + self.common.sync(index).await?; + Ok(StreamableHttpMessageReceiver { + http_request_id: None, + inner: rx, + }) + } + } + } +} + +enum SessionEvent { + ServiceMessage(ServerJsonRpcMessage), + ClientMessage { + message: ClientJsonRpcMessage, + http_request_id: Option, + }, + EstablishRequestWiseChannel { + responder: oneshot::Sender>, + }, + CloseRequestWiseChannel { + id: HttpRequestId, + responder: oneshot::Sender>, + }, + Resume { + last_event_id: EventId, + responder: oneshot::Sender>, + }, + InitializeRequest { + request: ClientJsonRpcMessage, + responder: oneshot::Sender>, + }, +} + +#[derive(Debug, Clone)] +pub enum SessionQuitReason { + ServiceTerminated, + ClientTerminated, + ExpectInitializeRequest, + ExpectInitializeResponse, + Cancelled, +} + +impl SessionContext { + #[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))] + pub async fn run(mut self, ct: CancellationToken) -> SessionQuitReason { + // waiting for initialize request + let Some(evt) = self.event_rx.recv().await else { + return SessionQuitReason::ServiceTerminated; + }; + let SessionEvent::InitializeRequest { request, responder } = evt else { + return SessionQuitReason::ExpectInitializeRequest; + }; + let send_result = self.send_to_service(request).await; + if let Err(e) = send_result { + let _ = responder.send(Err(e)); + return SessionQuitReason::ServiceTerminated; + } + let Some(evt) = self.event_rx.recv().await else { + return SessionQuitReason::ServiceTerminated; + }; + let SessionEvent::ServiceMessage(response) = evt else { + return SessionQuitReason::ExpectInitializeResponse; + }; + let response_result = responder.send(Ok(response)); + if response_result.is_err() { + return SessionQuitReason::ClientTerminated; + } + let quit_reason = loop { + let event = tokio::select! { + event = self.event_rx.recv() => { + if let Some(event) = event { + event + } else { + break SessionQuitReason::ServiceTerminated; + } + }, + + _ = ct.cancelled() => { + break SessionQuitReason::Cancelled; + } + }; + match event { + SessionEvent::ServiceMessage(json_rpc_message) => { + // catch response + match &json_rpc_message { + crate::model::JsonRpcMessage::Response(json_rpc_response) => { + let request_id = json_rpc_response.id.clone(); + self.unregister_resource(&ResourceKey::McpRequestId(request_id)); + } + crate::model::JsonRpcMessage::Error(json_rpc_error) => { + let request_id = json_rpc_error.id.clone(); + self.unregister_resource(&ResourceKey::McpRequestId(request_id)); + } + // unlikely happen + crate::model::JsonRpcMessage::BatchResponse( + json_rpc_batch_response_items, + ) => { + for item in json_rpc_batch_response_items { + let request_id = match item { + crate::model::JsonRpcBatchResponseItem::Response( + json_rpc_response, + ) => json_rpc_response.id.clone(), + crate::model::JsonRpcBatchResponseItem::Error( + json_rpc_error, + ) => json_rpc_error.id.clone(), + }; + self.unregister_resource(&ResourceKey::McpRequestId(request_id)); + } + } + _ => { + // no need to unregister resource + } + } + let _handle_result = self.handle_server_message(json_rpc_message).await; + } + SessionEvent::ClientMessage { + message: json_rpc_message, + http_request_id, + } => { + match &json_rpc_message { + crate::model::JsonRpcMessage::Request(request) => { + if let Some(http_request_id) = http_request_id { + self.register_request(request, http_request_id) + } + } + crate::model::JsonRpcMessage::Notification(notification) => { + self.catch_cancellation_notification(notification) + } + crate::model::JsonRpcMessage::BatchRequest(items) => { + for r in items { + match r { + crate::model::JsonRpcBatchRequestItem::Request(request) => { + if let Some(http_request_id) = http_request_id { + self.register_request(request, http_request_id) + } + } + crate::model::JsonRpcBatchRequestItem::Notification( + notification, + ) => self.catch_cancellation_notification(notification), + } + } + } + _ => {} + } + let _handle_result = self.send_to_service(json_rpc_message).await; + } + SessionEvent::EstablishRequestWiseChannel { responder } => { + let handle_result = self.establish_request_wise_channel().await; + let _ = responder.send(handle_result); + } + SessionEvent::CloseRequestWiseChannel { id, responder } => { + let _handle_result = self.tx_router.remove(&id); + let _ = responder.send(Ok(())); + } + SessionEvent::Resume { + last_event_id, + responder, + } => { + let handle_result = self.resume(last_event_id).await; + let _ = responder.send(handle_result); + } + _ => { + // ignore + } + } + }; + tracing::debug!("session terminated: {:?}", quit_reason); + quit_reason + } +} + +pub struct Session { + handle: SessionHandle, + guard: DropGuard, + task_handle: tokio::task::JoinHandle, +} + +impl Session { + pub fn handle(&self) -> &SessionHandle { + &self.handle + } + pub async fn cancel(self) -> Result { + self.guard.disarm().cancel(); + self.task_handle.await + } +} + +#[derive(Debug, Clone)] +pub struct SessionHandle { + // after all event_tx drop, inner task will be terminated + event_tx: Sender, +} + +impl SessionHandle { + pub async fn push_message( + &self, + message: ClientJsonRpcMessage, + http_request_id: Option, + ) -> Result<(), SessionError> { + self.event_tx + .send(SessionEvent::ClientMessage { + message, + http_request_id, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + Ok(()) + } + + pub async fn establish_request_wise_channel( + &self, + ) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::EstablishRequestWiseChannel { responder: tx }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } + pub async fn close_request_wise_channel( + &self, + request_id: HttpRequestId, + ) -> Result<(), SessionError> { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::CloseRequestWiseChannel { + id: request_id, + responder: tx, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } + pub async fn establish_common_channel( + &self, + ) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::Resume { + last_event_id: EventId { + http_request_id: None, + index: 0, + }, + responder: tx, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } + + pub async fn resume( + &self, + last_event_id: EventId, + ) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::Resume { + last_event_id, + responder: tx, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } + + pub async fn initialize( + &self, + request: ClientJsonRpcMessage, + ) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::InitializeRequest { + request, + responder: tx, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } +} + +pub struct SessionTransport { + session_handle: SessionHandle, + to_service_rx: Receiver, +} + +impl IntoTransport for SessionTransport { + fn into_transport( + self, + ) -> ( + impl Sink + Send + 'static, + impl Stream + Send + 'static, + ) { + let stream = ReceiverStream::new(self.to_service_rx); + let sink = PollSender::new(self.session_handle.event_tx.clone()) + .sink_map_err(|_| SessionError::SessionServiceTerminated) + .with(async |m| Ok(SessionEvent::ServiceMessage(m))); + (sink, stream) + } +} + +#[derive(Debug, Clone)] +pub struct SessionConfig { + channel_capacity: usize, +} + +impl SessionConfig { + pub const DEFAULT_CHANNEL_CAPACITY: usize = 16; +} + +impl Default for SessionConfig { + fn default() -> Self { + Self { + channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY, + } + } +} + +pub fn create_session(id: SessionId, config: SessionConfig) -> (Session, SessionTransport) { + let (to_service_tx, to_service_rx) = tokio::sync::mpsc::channel(config.channel_capacity); + let (event_tx, event_rx) = tokio::sync::mpsc::channel(config.channel_capacity); + let (common_tx, _) = tokio::sync::mpsc::channel(config.channel_capacity); + let common = CachedTx::new_common(common_tx); + tracing::info!(session_id = ?id, "create new session"); + let session_context = SessionContext { + next_http_request_id: 0, + id, + tx_router: HashMap::new(), + resource_router: HashMap::new(), + common, + to_service_tx, + event_rx, + session_config: config.clone(), + }; + let ct = CancellationToken::new(); + let handle = SessionHandle { event_tx }; + let task_handle = tokio::spawn(session_context.run(ct.child_token())); + let session = Session { + handle: handle.clone(), + task_handle, + guard: ct.drop_guard(), + }; + let session_transport = SessionTransport { + to_service_rx, + session_handle: handle, + }; + (session, session_transport) +} diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index f5431beef..8db2c4c6e 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -1,12 +1,13 @@ use rmcp::{ ServiceExt, - transport::{SseServer, TokioChildProcess}, + transport::{SseServer, TokioChildProcess, streamable_http_server::axum::StreamableHttpServer}, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod common; use common::calculator::Calculator; -const BIND_ADDRESS: &str = "127.0.0.1:8000"; +const SSE_BIND_ADDRESS: &str = "127.0.0.1:8000"; +const STREAMABLE_HTTP_BIND_ADDRESS: &str = "127.0.0.1:8001"; #[tokio::test] async fn test_with_js_client() -> anyhow::Result<()> { @@ -24,7 +25,7 @@ async fn test_with_js_client() -> anyhow::Result<()> { .wait() .await?; - let ct = SseServer::serve(BIND_ADDRESS.parse()?) + let ct = SseServer::serve(SSE_BIND_ADDRESS.parse()?) .await? .with_service(Calculator::default); @@ -66,3 +67,33 @@ async fn test_with_js_server() -> anyhow::Result<()> { client.cancel().await?; Ok(()) } + +#[tokio::test] +async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + tokio::process::Command::new("npm") + .arg("install") + .current_dir("tests/test_with_js") + .spawn()? + .wait() + .await?; + + let ct = StreamableHttpServer::serve(STREAMABLE_HTTP_BIND_ADDRESS.parse()?) + .await? + .with_service(Calculator::default); + + let exit_status = tokio::process::Command::new("node") + .arg("tests/test_with_js/streamable_client.js") + .spawn()? + .wait() + .await?; + assert!(exit_status.success()); + ct.cancel(); + Ok(()) +} diff --git a/crates/rmcp/tests/test_with_js/package.json b/crates/rmcp/tests/test_with_js/package.json index 6dee815cb..4612a0610 100644 --- a/crates/rmcp/tests/test_with_js/package.json +++ b/crates/rmcp/tests/test_with_js/package.json @@ -1,6 +1,7 @@ { "dependencies": { - "@modelcontextprotocol/sdk": "^1.7.0" + "@modelcontextprotocol/sdk": "^1.10", + "eventsource-parser": "^3.0.1" }, "type": "module", "name": "test_with_ts", diff --git a/crates/rmcp/tests/test_with_js/streamable_client.js b/crates/rmcp/tests/test_with_js/streamable_client.js new file mode 100644 index 000000000..b22acba3b --- /dev/null +++ b/crates/rmcp/tests/test_with_js/streamable_client.js @@ -0,0 +1,28 @@ +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; + +const transport = new StreamableHTTPClientTransport(new URL(`http://127.0.0.1:8001/`)); + +const client = new Client( + { + name: "example-client", + version: "1.0.0" + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {} + } + } +); +await client.connect(transport); +const tools = await client.listTools(); +console.log(tools); +const resources = await client.listResources(); +console.log(resources); +const templates = await client.listResourceTemplates(); +console.log(templates); +const prompts = await client.listPrompts(); +console.log(prompts); +await client.close(); diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 73778c95b..08ecc20e4 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -1,3 +1,5 @@ + + [package] name = "mcp-server-examples" version = "0.1.5" @@ -5,7 +7,7 @@ edition = "2024" publish = false [dependencies] -rmcp= { path = "../../crates/rmcp", features = ["server", "transport-sse-server", "transport-io", "auth"] } +rmcp= { path = "../../crates/rmcp", features = ["server", "transport-sse-server", "transport-io", "transport-streamable-http-server", "auth"] } tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread", "io-std", "signal"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -49,6 +51,10 @@ path = "src/axum_router.rs" name = "servers_generic_server" path = "src/generic_service.rs" +[[example]] +name = "servers_axum_streamable_http" +path = "src/axum_streamable_http.rs" + [[example]] name = "servers_auth_sse" path = "src/auth_sse.rs" diff --git a/examples/servers/src/axum_streamable_http.rs b/examples/servers/src/axum_streamable_http.rs new file mode 100644 index 000000000..e65be6b11 --- /dev/null +++ b/examples/servers/src/axum_streamable_http.rs @@ -0,0 +1,29 @@ +use rmcp::transport::streamable_http_server::axum::StreamableHttpServer; +use tracing_subscriber::{ + layer::SubscriberExt, + util::SubscriberInitExt, + {self}, +}; +mod common; +use common::counter::Counter; + +const BIND_ADDRESS: &str = "127.0.0.1:8000"; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let ct = StreamableHttpServer::serve(BIND_ADDRESS.parse()?) + .await? + .with_service(Counter::new); + + tokio::signal::ctrl_c().await?; + ct.cancel(); + Ok(()) +} diff --git a/examples/servers/src/generic_service.rs b/examples/servers/src/generic_service.rs index 258ed2180..546621e34 100644 --- a/examples/servers/src/generic_service.rs +++ b/examples/servers/src/generic_service.rs @@ -13,6 +13,6 @@ async fn main() -> Result<(), Box> { let io = (tokio::io::stdin(), tokio::io::stdout()); - serve_server(generic_service, io).await?.waiting().await?; + serve_server(generic_service, io).await?; Ok(()) }