From c0c1a2ff4be696f328ff8a06e8d8847ed0c1bb1e Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 20 May 2026 13:34:08 +0200 Subject: [PATCH 01/18] feat(acp): add request cancellation support --- md/SUMMARY.md | 1 + md/request-cancellation.md | 100 +++ src/agent-client-protocol/CHANGELOG.md | 4 + src/agent-client-protocol/Cargo.toml | 2 + src/agent-client-protocol/src/jsonrpc.rs | 512 +++++++++++++- .../src/jsonrpc/incoming_actor.rs | 35 +- src/agent-client-protocol/src/lib.rs | 2 + src/agent-client-protocol/src/schema/mod.rs | 1 + .../src/schema/protocol_level.rs | 38 + .../src/schema/v2_impls.rs | 32 + .../tests/jsonrpc_request_cancellation.rs | 660 ++++++++++++++++++ 11 files changed, 1379 insertions(+), 8 deletions(-) create mode 100644 md/request-cancellation.md create mode 100644 src/agent-client-protocol/src/schema/protocol_level.rs create mode 100644 src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs diff --git a/md/SUMMARY.md b/md/SUMMARY.md index 6750364..d5cf434 100644 --- a/md/SUMMARY.md +++ b/md/SUMMARY.md @@ -6,6 +6,7 @@ - [Design Overview](./design.md) - [Protocol Reference](./protocol.md) +- [Request Cancellation](./request-cancellation.md) - [Protocol V2](./protocol-v2.md) # Conductor (agent-client-protocol-conductor) diff --git a/md/request-cancellation.md b/md/request-cancellation.md new file mode 100644 index 0000000..2d31f07 --- /dev/null +++ b/md/request-cancellation.md @@ -0,0 +1,100 @@ +# Request Cancellation + +The SDK exposes the ACP `$/cancel_request` notification behind the +`unstable_cancel_request` feature. The notification is protocol-level: either +side may send it to ask the peer to cancel one outstanding JSON-RPC request by +ID. + +Enable the feature when depending on the crate: + +```toml +agent-client-protocol = { version = "...", features = ["unstable_cancel_request"] } +``` + +To cancel a request sent through `ConnectionTo::send_request`, keep the +returned `SentRequest` and call `cancel` on it: + +```rust +# use agent_client_protocol::{ConnectionTo, Error, UntypedRole}; +# use agent_client_protocol_test::MyRequest; +# async fn example(cx: ConnectionTo) -> Result<(), Error> { +let request = cx.send_request(MyRequest {}); +request.cancel()?; +# Ok(()) +# } +``` + +The `SentRequest` remembers the peer and any proxy wrapping used for the +original request, so this also works for requests sent through +`ConnectionTo::send_request_to`. + +If you already have the JSON-RPC request ID, send the notification directly: + +```rust +# use agent_client_protocol::{ConnectionTo, Error, UntypedRole}; +# async fn example(cx: ConnectionTo) -> Result<(), Error> { +cx.send_cancel_request("request-id".to_string())?; +# Ok(()) +# } +``` + +For incoming requests, get the request-local cancellation marker from the +`Responder`. This keeps cancellation handling next to the request work it +controls: + +```rust +# use agent_client_protocol::{ConnectionTo, Error, Responder, UntypedRole}; +# use agent_client_protocol_test::{MyRequest, MyResponse}; +# async fn example(request: MyRequest, responder: Responder, cx: ConnectionTo) -> Result<(), Error> { +# async fn run_request(_request: MyRequest) -> Result { todo!() } +let cancellation = responder.cancellation(); + +cx.spawn(async move { + let response = cancellation.run_until_cancelled(run_request(request)).await; + responder.respond_with_result(response) +})?; +Ok(()) +# } +``` + +`run_until_cancelled` is the simple path for handlers that should stop work and +reply with the standard cancellation error as soon as cancellation is requested. +If the handler needs cleanup, partial results, or custom cancellation behavior, +use `cancellation.cancelled()` or `cancellation.is_cancelled()` directly inside +the request work instead. + +Cancellation markers are only updated when the connection can process the +incoming `$/cancel_request` notification. Long-running handlers should return +quickly and move work into `ConnectionTo::spawn`, `SentRequest` callbacks, or +another task. + +When proxying with `SentRequest::forward_response_to`, the SDK observes the +upstream `Responder` cancellation marker and forwards cancellation to the +downstream request automatically. + +Register `CancelRequestNotification` or `ProtocolLevelNotification` directly +only when you need low-level access to cancellation notifications, such as +custom routing or protocol tracing: + +```rust +# use agent_client_protocol::{ConnectionTo, Error, UntypedRole}; +use agent_client_protocol::schema::CancelRequestNotification; + +# fn builder() -> agent_client_protocol::Builder { +UntypedRole.builder() + .on_receive_notification( + async |cancel: CancelRequestNotification, _cx: ConnectionTo| { + let request_id = cancel.request_id; + // Mark the matching in-flight operation cancelled. + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ) +# } +``` + +Cancellation is cooperative. A peer may ignore `$/cancel_request`, may finish +with normal data, or may respond to the original request with +`Error::request_cancelled()` (`-32800`). The SDK ignores unhandled `$/...` +notifications so unsupported protocol-level notifications do not produce +method-not-found errors. diff --git a/src/agent-client-protocol/CHANGELOG.md b/src/agent-client-protocol/CHANGELOG.md index f1aadd0..d6bbf57 100644 --- a/src/agent-client-protocol/CHANGELOG.md +++ b/src/agent-client-protocol/CHANGELOG.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Added + +- *(unstable)* Add SDK support for protocol-level request cancellation, including `SentRequest::cancel`, request-local cancellation helpers on `Responder`, and forwarded cancellation propagation. + ## [0.12.1](https://github.com/agentclientprotocol/rust-sdk/compare/v0.12.0...v0.12.1) - 2026-05-17 ### Other diff --git a/src/agent-client-protocol/Cargo.toml b/src/agent-client-protocol/Cargo.toml index 7946942..4820782 100644 --- a/src/agent-client-protocol/Cargo.toml +++ b/src/agent-client-protocol/Cargo.toml @@ -18,6 +18,7 @@ default = [] unstable = [ "unstable_auth_methods", "unstable_boolean_config", + "unstable_cancel_request", "unstable_logout", "unstable_mcp_over_acp", "unstable_message_id", @@ -29,6 +30,7 @@ unstable = [ ] unstable_auth_methods = ["agent-client-protocol-schema/unstable_auth_methods"] unstable_boolean_config = ["agent-client-protocol-schema/unstable_boolean_config"] +unstable_cancel_request = ["agent-client-protocol-schema/unstable_cancel_request"] unstable_logout = ["agent-client-protocol-schema/unstable_logout"] unstable_mcp_over_acp = ["agent-client-protocol-schema/unstable_mcp_over_acp"] unstable_message_id = ["agent-client-protocol-schema/unstable_message_id"] diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index 73c553d..6b1b896 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -7,11 +7,20 @@ pub use jsonrpcmsg; // Types re-exported from crate root use serde::{Deserialize, Serialize}; use std::any::TypeId; +#[cfg(feature = "unstable_cancel_request")] +use std::collections::HashMap; use std::fmt::Debug; use std::panic::Location; use std::pin::pin; +#[cfg(feature = "unstable_cancel_request")] +use std::sync::{ + Arc, Mutex, + atomic::{AtomicBool, Ordering}, +}; use uuid::Uuid; +#[cfg(feature = "unstable_cancel_request")] +use futures::FutureExt; use futures::channel::{mpsc, oneshot}; use futures::future::{self, BoxFuture, Either}; use futures::{AsyncRead, AsyncWrite, StreamExt}; @@ -1364,6 +1373,250 @@ impl std::fmt::Debug for ReplyMessage { } } +/// A request-local marker that is set when the peer asks to cancel the request. +/// +/// Request handlers can get this handle from [`Responder::cancellation`] and +/// use it from spawned work to stop long-running request processing +/// cooperatively. +#[cfg(feature = "unstable_cancel_request")] +#[derive(Clone)] +pub struct RequestCancellation { + state: Arc, +} + +#[cfg(feature = "unstable_cancel_request")] +struct RequestCancellationState { + cancelled: AtomicBool, + signal_tx: Mutex>>, + signal_rx: future::Shared>, +} + +#[cfg(feature = "unstable_cancel_request")] +impl RequestCancellation { + fn new() -> Self { + let (signal_tx, signal_rx) = oneshot::channel(); + let signal_rx = signal_rx.map(|_| ()).boxed().shared(); + Self { + state: Arc::new(RequestCancellationState { + cancelled: AtomicBool::new(false), + signal_tx: Mutex::new(Some(signal_tx)), + signal_rx, + }), + } + } + + /// Wait until the peer sends `$/cancel_request` for this request. + /// + /// If cancellation was already requested, this returns immediately. + pub async fn cancelled(&self) { + self.state.signal_rx.clone().await; + } + + /// Run request work until it completes or the peer asks to cancel it. + /// + /// If cancellation is requested first, this returns + /// [`Error::request_cancelled`]. This is a convenience for request handlers + /// that want to respond with the normal result or the standard + /// cancellation error. + /// + /// [`Error::request_cancelled`]: crate::Error::request_cancelled + pub async fn run_until_cancelled( + &self, + future: impl std::future::Future>, + ) -> Result { + if self.is_cancelled() { + return Err(crate::Error::request_cancelled()); + } + + match future::select(Box::pin(future), Box::pin(self.cancelled())).await { + Either::Left((result, _)) => result, + Either::Right(((), _)) => Err(crate::Error::request_cancelled()), + } + } + + /// Returns whether the peer has already requested cancellation. + #[must_use] + pub fn is_cancelled(&self) -> bool { + self.state.cancelled.load(Ordering::Acquire) + } + + fn cancel(&self) { + if self.state.cancelled.swap(true, Ordering::AcqRel) { + return; + } + + if let Some(signal_tx) = self + .state + .signal_tx + .lock() + .expect("request cancellation signal mutex poisoned") + .take() + { + let _ = signal_tx.send(()); + } + } +} + +#[cfg(feature = "unstable_cancel_request")] +impl Debug for RequestCancellation { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter + .debug_struct("RequestCancellation") + .field("is_cancelled", &self.is_cancelled()) + .finish_non_exhaustive() + } +} + +#[cfg(feature = "unstable_cancel_request")] +#[derive(Clone, Debug, Default)] +struct RequestCancellationRegistry { + inner: Arc>>, +} + +#[cfg(not(feature = "unstable_cancel_request"))] +#[derive(Clone, Debug, Default)] +struct RequestCancellationRegistry; + +#[cfg(feature = "unstable_cancel_request")] +#[derive(Debug)] +struct ResponderCancellation { + id: serde_json::Value, + registry: RequestCancellationRegistry, + cancellation: RequestCancellation, +} + +#[cfg(not(feature = "unstable_cancel_request"))] +#[derive(Debug)] +struct ResponderCancellation; + +#[cfg(feature = "unstable_cancel_request")] +impl RequestCancellationRegistry { + fn register(&self, id: serde_json::Value) -> ResponderCancellation { + let cancellation = RequestCancellation::new(); + self.inner + .lock() + .expect("request cancellation registry mutex poisoned") + .insert(id.clone(), cancellation.clone()); + ResponderCancellation { + id, + registry: self.clone(), + cancellation, + } + } + + fn cancel_if_requested(&self, dispatch: &Dispatch) -> Result { + let Some(request_id) = cancellation_request_id(dispatch)? else { + return Ok(false); + }; + Ok(self.cancel(&request_id)) + } + + fn cancel(&self, request_id: &serde_json::Value) -> bool { + let cancellation = self + .inner + .lock() + .expect("request cancellation registry mutex poisoned") + .get(request_id) + .cloned(); + if let Some(cancellation) = cancellation { + cancellation.cancel(); + true + } else { + false + } + } + + fn remove(&self, request_id: &serde_json::Value) { + self.inner + .lock() + .expect("request cancellation registry mutex poisoned") + .remove(request_id); + } +} + +#[cfg(not(feature = "unstable_cancel_request"))] +impl RequestCancellationRegistry { + fn register(&self, _id: serde_json::Value) -> ResponderCancellation { + ResponderCancellation + } + + fn cancel_if_requested(&self, _dispatch: &Dispatch) -> Result { + Ok(false) + } +} + +#[cfg(feature = "unstable_cancel_request")] +impl ResponderCancellation { + fn cancellation(&self) -> RequestCancellation { + self.cancellation.clone() + } +} + +#[cfg(feature = "unstable_cancel_request")] +impl Drop for ResponderCancellation { + fn drop(&mut self) { + self.registry.remove(&self.id); + } +} + +#[cfg(feature = "unstable_cancel_request")] +fn cancellation_request_id(dispatch: &Dispatch) -> Result, crate::Error> { + let Dispatch::Notification(message) = dispatch else { + return Ok(None); + }; + cancellation_request_id_from_message(message) +} + +#[cfg(feature = "unstable_cancel_request")] +fn cancellation_request_id_from_message( + message: &UntypedMessage, +) -> Result, crate::Error> { + if crate::schema::CancelRequestNotification::matches_method(&message.method) { + let notification = crate::schema::CancelRequestNotification::parse_message( + &message.method, + &message.params, + )?; + return serde_json::to_value(notification.request_id) + .map(Some) + .map_err(crate::Error::into_internal_error); + } + + if crate::schema::SuccessorMessage::::matches_method(&message.method) { + let successor = crate::schema::SuccessorMessage::::parse_message( + &message.method, + &message.params, + )?; + return cancellation_request_id_from_message(&successor.message); + } + + Ok(None) +} + +fn is_protocol_level_notification(dispatch: &Dispatch) -> bool { + let Dispatch::Notification(message) = dispatch else { + return false; + }; + is_protocol_level_notification_message(message) +} + +fn is_protocol_level_notification_message(message: &UntypedMessage) -> bool { + if message.method.starts_with("$/") { + return true; + } + + if crate::schema::SuccessorMessage::::matches_method(&message.method) { + let Ok(successor) = crate::schema::SuccessorMessage::::parse_message( + &message.method, + &message.params, + ) else { + return false; + }; + return is_protocol_level_notification_message(&successor.message); + } + + false +} + /// Messages send to be serialized over the transport. #[derive(Debug)] enum OutgoingMessage { @@ -1721,6 +1974,9 @@ impl ConnectionTo { let (response_tx, response_rx) = oneshot::channel(); let role_id = peer.role_id(); let remote_style = self.counterpart.remote_style(peer); + #[cfg(feature = "unstable_cancel_request")] + let cancellation = + SentRequestCancellation::new(self.message_tx.clone(), &remote_style, &id); match remote_style.transform_outgoing_message(request) { Ok(untyped) => { // Transform the message for the target role @@ -1768,8 +2024,15 @@ impl ConnectionTo { } } - SentRequest::new(id, method.clone(), self.task_tx.clone(), response_rx) - .map(move |json| ::from_value(&method, json)) + SentRequest::new( + id, + method.clone(), + self.task_tx.clone(), + response_rx, + #[cfg(feature = "unstable_cancel_request")] + cancellation, + ) + .map(move |json| ::from_value(&method, json)) } /// Send an outgoing notification to the default counterpart peer (no reply expected). @@ -1833,6 +2096,50 @@ impl ConnectionTo { ) } + /// Send a `$/cancel_request` notification for an outgoing request. + /// + /// This is a convenience wrapper around [`SentRequest::cancel`]. + /// + /// Cancellation is cooperative: the peer may ignore the notification, may + /// reply to the original request with [`Error::request_cancelled`], or may + /// return a normal response with partial data. + /// + /// [`Error::request_cancelled`]: crate::Error::request_cancelled + #[cfg(feature = "unstable_cancel_request")] + pub fn cancel_request(&self, request: &SentRequest) -> Result<(), crate::Error> { + request.cancel() + } + + /// Send a `$/cancel_request` notification for an arbitrary request ID to + /// the default counterpart peer. + #[cfg(feature = "unstable_cancel_request")] + pub fn send_cancel_request( + &self, + request_id: impl Into, + ) -> Result<(), crate::Error> + where + Counterpart: HasPeer, + { + self.send_cancel_request_to(self.counterpart.clone(), request_id) + } + + /// Send a `$/cancel_request` notification for an arbitrary request ID to a + /// specific peer. + #[cfg(feature = "unstable_cancel_request")] + pub fn send_cancel_request_to( + &self, + peer: Peer, + request_id: impl Into, + ) -> Result<(), crate::Error> + where + Counterpart: HasPeer, + { + self.send_notification_to( + peer, + crate::schema::CancelRequestNotification::new(request_id), + ) + } + /// Send an error notification (no reply expected). pub fn send_error_notification(&self, error: crate::Error) -> Result<(), crate::Error> { send_raw_message(&self.message_tx, OutgoingMessage::Error { error }) @@ -1943,6 +2250,9 @@ pub struct Responder { /// The `id` of the message we are replying to. id: jsonrpcmsg::Id, + /// Request-local cancellation state. + cancellation: ResponderCancellation, + /// Function to send the response to its destination. /// /// For incoming requests: serializes to JSON and sends over the wire. @@ -1964,12 +2274,19 @@ impl Responder { /// Create a new request context for an incoming request. /// /// The response will be serialized to JSON and sent over the wire. - fn new(message_tx: OutgoingMessageTx, method: String, id: jsonrpcmsg::Id) -> Self { + fn new( + message_tx: OutgoingMessageTx, + method: String, + id: jsonrpcmsg::Id, + cancellation_registry: &RequestCancellationRegistry, + ) -> Self { let id_clone = id.clone(); let method_clone = method.clone(); + let cancellation = cancellation_registry.register(crate::util::id_to_json(&id)); Self { method, id, + cancellation, send_fn: Box::new(move |response: Result| { send_raw_message( &message_tx, @@ -2007,6 +2324,19 @@ impl Responder { crate::util::id_to_json(&self.id) } + /// Returns the cancellation marker for this request. + /// + /// The marker is set when the peer sends `$/cancel_request` for this + /// request's JSON-RPC ID. Cancellation is cooperative: handlers should use + /// the marker to stop long-running work and then decide whether to respond + /// with [`Error::request_cancelled`] or partial data. + /// + /// [`Error::request_cancelled`]: crate::Error::request_cancelled + #[cfg(feature = "unstable_cancel_request")] + pub fn cancellation(&self) -> RequestCancellation { + self.cancellation.cancellation() + } + /// Convert to a `Responder` that expects a JSON value /// and which checks (dynamically) that the JSON value it receives /// can be converted to `T`. @@ -2019,6 +2349,7 @@ impl Responder { Responder { method, id: self.id, + cancellation: self.cancellation, send_fn: self.send_fn, } } @@ -2035,6 +2366,7 @@ impl Responder { Responder { method: self.method, id: self.id, + cancellation: self.cancellation, send_fn: Box::new(move |input: Result| { let t_value = wrap_fn(&method, input); (self.send_fn)(t_value) @@ -2813,16 +3145,106 @@ pub struct SentRequest { task_tx: TaskTx, response_rx: oneshot::Receiver, to_result: Box Result + Send>, + #[cfg(feature = "unstable_cancel_request")] + cancellation: SentRequestCancellation, +} + +#[cfg(feature = "unstable_cancel_request")] +fn jsonrpc_id_to_request_id(id: &jsonrpcmsg::Id) -> Result { + match id { + jsonrpcmsg::Id::String(value) => Ok(crate::schema::RequestId::Str(value.clone())), + jsonrpcmsg::Id::Number(value) => Ok(crate::schema::RequestId::Number( + i64::try_from(*value).map_err(|_| { + crate::util::internal_error(format!( + "request ID `{value}` cannot be represented as an ACP request ID" + )) + })?, + )), + jsonrpcmsg::Id::Null => Ok(crate::schema::RequestId::Null), + } +} + +#[cfg(feature = "unstable_cancel_request")] +#[derive(Clone)] +enum SentRequestCancellation { + Send { + message_tx: OutgoingMessageTx, + notification: UntypedMessage, + }, + Failed { + error: String, + }, +} + +#[cfg(feature = "unstable_cancel_request")] +impl SentRequestCancellation { + fn new( + message_tx: OutgoingMessageTx, + remote_style: &crate::role::RemoteStyle, + request_id: &jsonrpcmsg::Id, + ) -> Self { + let notification = jsonrpc_id_to_request_id(request_id) + .and_then(|request_id| { + remote_style.transform_outgoing_message( + crate::schema::CancelRequestNotification::new(request_id), + ) + }) + .map_err(|error| error.to_string()); + + match notification { + Ok(notification) => Self::Send { + message_tx, + notification, + }, + Err(error) => Self::Failed { error }, + } + } + + fn send(&self) -> Result<(), crate::Error> { + match self { + Self::Send { + message_tx, + notification, + } => send_raw_message( + message_tx, + OutgoingMessage::Notification { + untyped: notification.clone(), + }, + ), + Self::Failed { error } => Err(crate::util::internal_error(format!( + "failed to create cancel request notification: {error}" + ))), + } + } +} + +#[cfg(feature = "unstable_cancel_request")] +impl Debug for SentRequestCancellation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Send { notification, .. } => f + .debug_struct("SentRequestCancellation") + .field("notification", notification) + .finish(), + Self::Failed { error } => f + .debug_struct("SentRequestCancellation") + .field("error", error) + .finish(), + } + } } impl Debug for SentRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SentRequest") + let mut debug = f.debug_struct("SentRequest"); + debug .field("id", &self.id) .field("method", &self.method) .field("task_tx", &self.task_tx) - .field("response_rx", &self.response_rx) - .finish_non_exhaustive() + .field("response_rx", &self.response_rx); + #[cfg(feature = "unstable_cancel_request")] + debug.field("cancellation", &self.cancellation); + debug.finish_non_exhaustive() } } @@ -2832,6 +3254,7 @@ impl SentRequest { method: String, task_tx: mpsc::UnboundedSender, response_rx: oneshot::Receiver, + #[cfg(feature = "unstable_cancel_request")] cancellation: SentRequestCancellation, ) -> Self { Self { id, @@ -2839,10 +3262,24 @@ impl SentRequest { response_rx, task_tx, to_result: Box::new(Ok), + #[cfg(feature = "unstable_cancel_request")] + cancellation, } } } +impl SentRequest { + /// Send a `$/cancel_request` notification for this outgoing request. + /// + /// This uses the same peer and message wrapping that were used to send the + /// original request, so it is the preferred way to cancel a [`SentRequest`] + /// when the request handle is still available. + #[cfg(feature = "unstable_cancel_request")] + pub fn cancel(&self) -> Result<(), crate::Error> { + self.cancellation.send() + } +} + impl SentRequest { /// The id of the outgoing request. #[must_use] @@ -2867,6 +3304,8 @@ impl SentRequest { response_rx: self.response_rx, task_tx: self.task_tx, to_result: Box::new(move |value| map_fn((self.to_result)(value)?)), + #[cfg(feature = "unstable_cancel_request")] + cancellation: self.cancellation, } } @@ -2925,7 +3364,66 @@ impl SentRequest { where T: Send, { - self.on_receiving_result(async move |result| responder.respond_with_result(result)) + #[cfg(feature = "unstable_cancel_request")] + { + self.forward_response_to_observing_cancellation(responder) + } + #[cfg(not(feature = "unstable_cancel_request"))] + { + self.on_receiving_result(async move |result| responder.respond_with_result(result)) + } + } + + #[cfg(feature = "unstable_cancel_request")] + #[track_caller] + fn forward_response_to_observing_cancellation( + self, + responder: Responder, + ) -> Result<(), crate::Error> + where + T: Send, + { + let task_tx = self.task_tx.clone(); + let method = self.method; + let response_rx = self.response_rx; + let to_result = self.to_result; + let downstream_cancellation = self.cancellation; + let upstream_cancellation = responder.cancellation(); + let location = Location::caller(); + + Task::new(location, async move { + let response = if upstream_cancellation.is_cancelled() { + downstream_cancellation.send()?; + response_rx.await + } else { + match future::select(Box::pin(upstream_cancellation.cancelled()), response_rx).await + { + Either::Left(((), response_rx)) => { + downstream_cancellation.send()?; + response_rx.await + } + Either::Right((response, _)) => response, + } + }; + + let ResponsePayload { result, ack_tx } = response.map_err(|err| { + crate::util::internal_error(format!("response to `{method}` never received: {err}")) + })?; + + let typed_result = match result { + Ok(json_value) => to_result(json_value), + Err(err) => Err(err), + }; + + let outcome = responder.respond_with_result(typed_result); + + if let Some(tx) = ack_tx { + let _ = tx.send(()); + } + + outcome + }) + .spawn(&task_tx) } /// Block the current task until the response is received. diff --git a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs index 302554f..6c5caea 100644 --- a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs @@ -62,6 +62,8 @@ pub(super) async fn incoming_protocol_actor( FxHashMap::default(); let mut pending_messages: Vec = vec![]; + let request_cancellations = super::RequestCancellationRegistry::default(); + // Map from request ID to (method, sender) for response dispatch. // Keys are JSON values because jsonrpcmsg::Id doesn't implement Eq. // The method is stored to allow routing responses through typed handlers. @@ -135,7 +137,12 @@ pub(super) async fn incoming_protocol_actor( tracing::trace!(method = %request.method, id = ?request.id, "Handling request"); let request_method = request.method.clone(); let request_id = request.id.clone(); - match dispatch_from_request(connection, request, &protocol_compat) { + match dispatch_from_request( + connection, + request, + &protocol_compat, + &request_cancellations, + ) { Ok(dispatch) => { dispatch_dispatch( counterpart.clone(), @@ -144,6 +151,7 @@ pub(super) async fn incoming_protocol_actor( &mut dynamic_handlers, &mut handler, &mut pending_messages, + &request_cancellations, ) .await?; } @@ -183,6 +191,7 @@ pub(super) async fn incoming_protocol_actor( &mut dynamic_handlers, &mut handler, &mut pending_messages, + &request_cancellations, ) .await?; } else { @@ -218,6 +227,7 @@ fn dispatch_from_request( connection: &ConnectionTo, request: jsonrpcmsg::Request, protocol_compat: &ProtocolCompat, + request_cancellations: &super::RequestCancellationRegistry, ) -> Result { let message = UntypedMessage::new(&request.method, &request.params).expect("well-formed JSON"); let message = protocol_compat.incoming_message(message)?; @@ -229,6 +239,7 @@ fn dispatch_from_request( connection.message_tx.clone(), request.method.clone(), id.clone(), + request_cancellations, ), )), None => Ok(Dispatch::Notification(message)), @@ -268,6 +279,7 @@ async fn dispatch_dispatch( dynamic_handlers: &mut FxHashMap>>, handler: &mut impl HandleDispatchFrom, pending_messages: &mut Vec, + request_cancellations: &super::RequestCancellationRegistry, ) -> Result<(), crate::Error> { tracing::trace!(?dispatch, "dispatch_dispatch"); @@ -276,6 +288,22 @@ async fn dispatch_dispatch( let id = dispatch.id(); let method = dispatch.method().to_string(); + match request_cancellations.cancel_if_requested(&dispatch) { + Ok(true) => { + tracing::debug!(?method, "Marked request as cancelled"); + } + Ok(false) => {} + Err(err) => { + tracing::warn!( + ?method, + ?id, + ?err, + "Request cancellation notification errored" + ); + return report_handler_error(connection, id, method, err); + } + } + // First, apply the handlers given by the user. tracing::trace!(handler = ?handler.describe_chain(), "Attempting handler chain"); match handler @@ -351,6 +379,11 @@ async fn dispatch_dispatch( } } + if super::is_protocol_level_notification(&dispatch) { + tracing::debug!(?method, "Ignoring unhandled protocol-level notification"); + return Ok(()); + } + // If the message was never handled, check whether the retry flag was set. // If so, enqueue it for later processing. Else, reject it. if retry_any { diff --git a/src/agent-client-protocol/src/lib.rs b/src/agent-client-protocol/src/lib.rs index 9a28f59..103666f 100644 --- a/src/agent-client-protocol/src/lib.rs +++ b/src/agent-client-protocol/src/lib.rs @@ -108,6 +108,8 @@ pub mod jsonrpcmsg { pub use jsonrpcmsg::{Error, Id, Message, Params, Request, Response}; } +#[cfg(feature = "unstable_cancel_request")] +pub use jsonrpc::RequestCancellation; pub use jsonrpc::{ Builder, ByteStreams, Channel, ConnectionTo, Dispatch, HandleDispatchFrom, Handled, IntoHandled, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Lines, diff --git a/src/agent-client-protocol/src/schema/mod.rs b/src/agent-client-protocol/src/schema/mod.rs index 6279701..8724839 100644 --- a/src/agent-client-protocol/src/schema/mod.rs +++ b/src/agent-client-protocol/src/schema/mod.rs @@ -257,6 +257,7 @@ macro_rules! impl_jsonrpc_response_enum { mod agent_to_client; mod client_to_agent; mod enum_impls; +mod protocol_level; mod proxy_protocol; #[cfg(feature = "unstable_protocol_v2")] mod v2_impls; diff --git a/src/agent-client-protocol/src/schema/protocol_level.rs b/src/agent-client-protocol/src/schema/protocol_level.rs new file mode 100644 index 0000000..81ea0da --- /dev/null +++ b/src/agent-client-protocol/src/schema/protocol_level.rs @@ -0,0 +1,38 @@ +#[cfg(feature = "unstable_cancel_request")] +use crate::{ + JsonRpcMessage, JsonRpcNotification, UntypedMessage, + schema::{CancelRequestNotification, ProtocolLevelNotification}, +}; + +#[cfg(feature = "unstable_cancel_request")] +impl_jsonrpc_notification!(CancelRequestNotification, "$/cancel_request"); + +#[cfg(feature = "unstable_cancel_request")] +impl JsonRpcMessage for ProtocolLevelNotification { + fn matches_method(method: &str) -> bool { + method == "$/cancel_request" + } + + fn method(&self) -> &str { + match self { + Self::CancelRequestNotification(_) => "$/cancel_request", + _ => "_unknown", + } + } + + fn to_untyped_message(&self) -> Result { + UntypedMessage::new(self.method(), self) + } + + fn parse_message(method: &str, params: &impl serde::Serialize) -> Result { + match method { + "$/cancel_request" => { + crate::util::json_cast_params(params).map(Self::CancelRequestNotification) + } + _ => Err(crate::Error::method_not_found()), + } + } +} + +#[cfg(feature = "unstable_cancel_request")] +impl JsonRpcNotification for ProtocolLevelNotification {} diff --git a/src/agent-client-protocol/src/schema/v2_impls.rs b/src/agent-client-protocol/src/schema/v2_impls.rs index 5b87c8a..e47117a 100644 --- a/src/agent-client-protocol/src/schema/v2_impls.rs +++ b/src/agent-client-protocol/src/schema/v2_impls.rs @@ -270,6 +270,8 @@ impl_v2_jsonrpc_request!( #[cfg(feature = "unstable_mcp_over_acp")] impl_v2_jsonrpc_request!(v2::MessageMcpRequest, v2::MessageMcpResponse, "mcp/message"); +#[cfg(feature = "unstable_cancel_request")] +impl_v2_jsonrpc_notification!(v2::CancelRequestNotification, "$/cancel_request"); impl_v2_jsonrpc_notification!(v2::CancelNotification, "session/cancel"); #[cfg(feature = "unstable_mcp_over_acp")] impl_v2_jsonrpc_notification!(v2::MessageMcpNotification, "mcp/message"); @@ -325,6 +327,36 @@ impl_v2_jsonrpc_request!( impl_v2_jsonrpc_notification!(v2::SessionNotification, "session/update"); +#[cfg(feature = "unstable_cancel_request")] +impl JsonRpcMessage for v2::ProtocolLevelNotification { + fn matches_method(method: &str) -> bool { + method == "$/cancel_request" + } + + fn method(&self) -> &str { + match self { + Self::CancelRequestNotification(_) => "$/cancel_request", + _ => "_unknown", + } + } + + fn to_untyped_message(&self) -> Result { + UntypedMessage::new(self.method(), self) + } + + fn parse_message(method: &str, params: &impl serde::Serialize) -> Result { + match method { + "$/cancel_request" => { + crate::util::json_cast_params(params).map(Self::CancelRequestNotification) + } + _ => Err(crate::Error::method_not_found()), + } + } +} + +#[cfg(feature = "unstable_cancel_request")] +impl JsonRpcNotification for v2::ProtocolLevelNotification {} + impl_v2_jsonrpc_request_enum!(v2::ClientRequest { InitializeRequest => "initialize", AuthenticateRequest => "authenticate", diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs new file mode 100644 index 0000000..1d59a69 --- /dev/null +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -0,0 +1,660 @@ +#![cfg(feature = "unstable_cancel_request")] + +use std::sync::{Arc, Mutex}; + +use agent_client_protocol::{ + Channel, ConnectionTo, Dispatch, Handled, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, + Responder, Role, RoleId, SentRequest, + role::UntypedRole, + schema::{CancelRequestNotification, ProtocolLevelNotification, RequestId}, +}; +use expect_test::expect; +use futures::{AsyncRead, AsyncWrite}; +use serde::{Deserialize, Serialize}; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + +fn setup_test_streams() -> ( + impl AsyncRead, + impl AsyncWrite, + impl AsyncRead, + impl AsyncWrite, +) { + let (client_writer, server_reader) = tokio::io::duplex(4096); + let (server_writer, client_reader) = tokio::io::duplex(4096); + + let server_reader = server_reader.compat(); + let server_writer = server_writer.compat_write(); + let client_reader = client_reader.compat(); + let client_writer = client_writer.compat_write(); + + (server_reader, server_writer, client_reader, client_writer) +} + +async fn read_jsonrpc_response_line( + reader: &mut tokio::io::BufReader, +) -> serde_json::Value { + use tokio::io::AsyncBufReadExt as _; + + let mut line = String::new(); + match tokio::time::timeout( + tokio::time::Duration::from_secs(1), + reader.read_line(&mut line), + ) + .await + { + Ok(Ok(0)) | Err(_) => panic!("timed out waiting for JSON-RPC response"), + Ok(Ok(_)) => serde_json::from_str(line.trim()).expect("response should be valid JSON"), + Ok(Err(error)) => panic!("failed to read JSON-RPC response line: {error}"), + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleRequest { + message: String, +} + +impl JsonRpcMessage for SimpleRequest { + fn matches_method(method: &str) -> bool { + method == "simple_method" + } + + fn method(&self) -> &str { + "simple_method" + } + + fn to_untyped_message( + &self, + ) -> Result { + agent_client_protocol::UntypedMessage::new(self.method(), self) + } + + fn parse_message( + method: &str, + params: &impl Serialize, + ) -> Result { + if !Self::matches_method(method) { + return Err(agent_client_protocol::Error::method_not_found()); + } + agent_client_protocol::util::json_cast_params(params) + } +} + +impl JsonRpcRequest for SimpleRequest { + type Response = SimpleResponse; +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleResponse { + result: String, +} + +impl JsonRpcResponse for SimpleResponse { + fn into_json(self, _method: &str) -> Result { + serde_json::to_value(self).map_err(agent_client_protocol::Error::into_internal_error) + } + + fn from_value( + _method: &str, + value: serde_json::Value, + ) -> Result { + agent_client_protocol::util::json_cast(&value) + } +} + +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct WrappedHost; + +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct WrappedCounterpart; + +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct WrappedSuccessor; + +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct WrappedSuccessorCounterpart; + +impl Role for WrappedHost { + type Counterpart = WrappedCounterpart; + + fn role_id(&self) -> RoleId { + RoleId::from_singleton(self) + } + + async fn default_handle_dispatch_from( + &self, + message: Dispatch, + _connection: ConnectionTo, + ) -> Result, agent_client_protocol::Error> { + Ok(Handled::No { + message, + retry: false, + }) + } + + fn counterpart(&self) -> Self::Counterpart { + WrappedCounterpart + } +} + +impl Role for WrappedCounterpart { + type Counterpart = WrappedHost; + + fn role_id(&self) -> RoleId { + RoleId::from_singleton(self) + } + + async fn default_handle_dispatch_from( + &self, + message: Dispatch, + _connection: ConnectionTo, + ) -> Result, agent_client_protocol::Error> { + Ok(Handled::No { + message, + retry: false, + }) + } + + fn counterpart(&self) -> Self::Counterpart { + WrappedHost + } +} + +impl Role for WrappedSuccessor { + type Counterpart = WrappedSuccessorCounterpart; + + fn role_id(&self) -> RoleId { + RoleId::from_singleton(self) + } + + async fn default_handle_dispatch_from( + &self, + message: Dispatch, + _connection: ConnectionTo, + ) -> Result, agent_client_protocol::Error> { + Ok(Handled::No { + message, + retry: false, + }) + } + + fn counterpart(&self) -> Self::Counterpart { + WrappedSuccessorCounterpart + } +} + +impl Role for WrappedSuccessorCounterpart { + type Counterpart = WrappedSuccessor; + + fn role_id(&self) -> RoleId { + RoleId::from_singleton(self) + } + + async fn default_handle_dispatch_from( + &self, + message: Dispatch, + _connection: ConnectionTo, + ) -> Result, agent_client_protocol::Error> { + Ok(Handled::No { + message, + retry: false, + }) + } + + fn counterpart(&self) -> Self::Counterpart { + WrappedSuccessor + } +} + +impl agent_client_protocol::role::HasPeer for WrappedCounterpart { + fn remote_style(&self, _peer: WrappedCounterpart) -> agent_client_protocol::role::RemoteStyle { + agent_client_protocol::role::RemoteStyle::Counterpart + } +} + +impl agent_client_protocol::role::HasPeer for WrappedCounterpart { + fn remote_style(&self, _peer: WrappedSuccessor) -> agent_client_protocol::role::RemoteStyle { + agent_client_protocol::role::RemoteStyle::Successor + } +} + +#[tokio::test(flavor = "current_thread")] +async fn unhandled_protocol_level_notifications_are_ignored() { + use tokio::io::{AsyncWriteExt, BufReader}; + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (mut client_writer, server_reader) = tokio::io::duplex(4096); + let (server_writer, client_reader) = tokio::io::duplex(4096); + + let server_transport = agent_client_protocol::ByteStreams::new( + server_writer.compat_write(), + server_reader.compat(), + ); + let server = UntypedRole.builder().on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let mut client_reader = BufReader::new(client_reader); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","method":"$/cancel_request","params":{"requestId":"req-1"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":2,"method":"simple_method","params":{"message":"after cancel"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "id": 2, + "jsonrpc": "2.0", + "result": { + "result": "echo: after cancel" + } + }"#]] + .assert_eq(&serde_json::to_string_pretty(&response).unwrap()); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn unhandled_wrapped_protocol_level_notifications_are_ignored() { + use tokio::io::{AsyncWriteExt, BufReader}; + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (mut client_writer, server_reader) = tokio::io::duplex(4096); + let (server_writer, client_reader) = tokio::io::duplex(4096); + + let server_transport = agent_client_protocol::ByteStreams::new( + server_writer.compat_write(), + server_reader.compat(), + ); + let server = WrappedHost + .builder() + .on_receive_notification_from( + WrappedSuccessor, + async |cancel: CancelRequestNotification, + cx: ConnectionTo| { + Ok::<_, agent_client_protocol::Error>(Handled::No { + message: (cancel, cx), + retry: false, + }) + }, + agent_client_protocol::on_receive_notification!(), + ) + .on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let mut client_reader = BufReader::new(client_reader); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","method":"_proxy/successor","params":{"method":"$/cancel_request","params":{"requestId":"req-1"}}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":2,"method":"simple_method","params":{"message":"after wrapped cancel"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "id": 2, + "jsonrpc": "2.0", + "result": { + "result": "echo: after wrapped cancel" + } + }"#]] + .assert_eq(&serde_json::to_string_pretty(&response).unwrap()); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn cancel_request_notification_can_be_sent_and_handled() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let received = Arc::new(Mutex::new(Vec::new())); + let received_for_handler = received.clone(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + received_for_handler + .lock() + .unwrap() + .push(notification.request_id); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + cx.send_cancel_request("request-42".to_string())?; + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + Ok(()) + }) + .await + .unwrap(); + + assert_eq!( + *received.lock().unwrap(), + vec![RequestId::Str("request-42".into())] + ); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn sent_request_can_send_cancellation_for_its_id() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let received = Arc::new(Mutex::new(Vec::new())); + let received_for_handler = received.clone(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async |_request: SimpleRequest, + _responder: Responder, + _connection: ConnectionTo| { Ok(()) }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + received_for_handler + .lock() + .unwrap() + .push(notification.request_id); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let expected_id = UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + let request: SentRequest = cx.send_request(SimpleRequest { + message: "slow".into(), + }); + let expected_id = request.id(); + request.cancel()?; + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + Ok(expected_id) + }) + .await + .unwrap(); + + let received = received.lock().unwrap(); + assert_eq!(received.len(), 1); + assert_eq!(serde_json::to_value(&received[0]).unwrap(), expected_id); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn forward_response_to_propagates_cancellation_to_downstream_request() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let backend_cancellations = Arc::new(Mutex::new(Vec::new())); + let backend_cancellations_for_handler = backend_cancellations.clone(); + + let (backend_for_proxy, backend_for_server) = Channel::duplex(); + let (backend_connection_tx, backend_connection_rx) = + futures::channel::oneshot::channel(); + + tokio::task::spawn_local(async move { + let result = UntypedRole + .builder() + .connect_with(backend_for_proxy, async |connection| { + drop(backend_connection_tx.send(connection.clone())); + std::future::pending::>().await + }) + .await; + if let Err(error) = result { + panic!("proxy-to-backend connection should stay alive: {error:?}"); + } + }); + + let backend_server = UntypedRole + .builder() + .on_receive_request( + async |_request: SimpleRequest, + _responder: Responder, + _connection: ConnectionTo| { Ok(()) }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + backend_cancellations_for_handler + .lock() + .unwrap() + .push(notification.request_id); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = backend_server.connect_to(backend_for_server).await { + panic!("backend server should stay alive: {error:?}"); + } + }); + + let backend_connection = backend_connection_rx + .await + .expect("backend connection should start"); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let proxy_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let proxy = UntypedRole.builder().on_receive_request( + { + let backend_connection = backend_connection.clone(); + async move |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + backend_connection + .send_request(request) + .forward_response_to(responder)?; + Ok(()) + } + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = proxy.connect_to(proxy_transport).await { + panic!("proxy should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + UntypedRole + .builder() + .connect_with(client_transport, async |connection| { + let request: SentRequest = + connection.send_request(SimpleRequest { + message: "cancel downstream".into(), + }); + request.cancel()?; + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + Ok(()) + }) + .await + .unwrap(); + + let backend_cancellations = backend_cancellations.lock().unwrap(); + assert_eq!(backend_cancellations.len(), 1); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn request_handler_can_observe_cancellation_from_responder() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async |_request: SimpleRequest, + responder: Responder, + connection: ConnectionTo| { + let cancellation = responder.cancellation(); + assert!(!cancellation.is_cancelled()); + + connection.spawn(async move { + let response = cancellation + .run_until_cancelled(futures::future::pending::< + Result, + >()) + .await; + assert!(cancellation.is_cancelled()); + responder.respond_with_result(response) + })?; + + Ok(()) + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let error = UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + let request: SentRequest = cx.send_request(SimpleRequest { + message: "cancel me".into(), + }); + cx.cancel_request(&request)?; + Ok(request + .block_task() + .await + .expect_err("request should be cancelled")) + }) + .await + .unwrap(); + + assert_eq!(i32::from(error.code), -32800); + assert_eq!(error.message, "Request cancelled"); + }) + .await; +} + +#[test] +fn protocol_level_notification_and_cancelled_error_code_are_typed() { + let notification = ProtocolLevelNotification::parse_message( + "$/cancel_request", + &serde_json::json!({ "requestId": "req-1" }), + ) + .unwrap(); + assert_eq!(notification.method(), "$/cancel_request"); + + let error = agent_client_protocol::Error::request_cancelled(); + assert_eq!(i32::from(error.code), -32800); + assert_eq!(error.message, "Request cancelled"); +} From 7bfa50ab43f66936d584a368345a0a7e3c0df6a2 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 20 May 2026 13:36:59 +0200 Subject: [PATCH 02/18] refactor(acp): remove cancel request wrapper --- src/agent-client-protocol/src/jsonrpc.rs | 14 -------------- .../tests/jsonrpc_request_cancellation.rs | 2 +- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index 6b1b896..409ead2 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -2096,20 +2096,6 @@ impl ConnectionTo { ) } - /// Send a `$/cancel_request` notification for an outgoing request. - /// - /// This is a convenience wrapper around [`SentRequest::cancel`]. - /// - /// Cancellation is cooperative: the peer may ignore the notification, may - /// reply to the original request with [`Error::request_cancelled`], or may - /// return a normal response with partial data. - /// - /// [`Error::request_cancelled`]: crate::Error::request_cancelled - #[cfg(feature = "unstable_cancel_request")] - pub fn cancel_request(&self, request: &SentRequest) -> Result<(), crate::Error> { - request.cancel() - } - /// Send a `$/cancel_request` notification for an arbitrary request ID to /// the default counterpart peer. #[cfg(feature = "unstable_cancel_request")] diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs index 1d59a69..ce35a06 100644 --- a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -630,7 +630,7 @@ async fn request_handler_can_observe_cancellation_from_responder() { let request: SentRequest = cx.send_request(SimpleRequest { message: "cancel me".into(), }); - cx.cancel_request(&request)?; + request.cancel()?; Ok(request .block_task() .await From 6461dadb24faae00e0fce00076daedb8aa8ea2e8 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Thu, 21 May 2026 18:51:18 +0200 Subject: [PATCH 03/18] Clippy --- src/agent-client-protocol/src/jsonrpc.rs | 18 ++++++++++++++++++ .../src/jsonrpc/incoming_actor.rs | 2 +- .../tests/jsonrpc_request_cancellation.rs | 2 +- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index 409ead2..ee321c5 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -1491,6 +1491,10 @@ struct ResponderCancellation; #[cfg(feature = "unstable_cancel_request")] impl RequestCancellationRegistry { + fn new() -> Self { + Self::default() + } + fn register(&self, id: serde_json::Value) -> ResponderCancellation { let cancellation = RequestCancellation::new(); self.inner @@ -1536,10 +1540,23 @@ impl RequestCancellationRegistry { #[cfg(not(feature = "unstable_cancel_request"))] impl RequestCancellationRegistry { + fn new() -> Self { + Self + } + + #[expect( + clippy::unused_self, + reason = "feature-disabled stub mirrors the real registry API" + )] fn register(&self, _id: serde_json::Value) -> ResponderCancellation { ResponderCancellation } + #[expect( + clippy::unused_self, + clippy::unnecessary_wraps, + reason = "feature-disabled stub mirrors the real registry API" + )] fn cancel_if_requested(&self, _dispatch: &Dispatch) -> Result { Ok(false) } @@ -2319,6 +2336,7 @@ impl Responder { /// /// [`Error::request_cancelled`]: crate::Error::request_cancelled #[cfg(feature = "unstable_cancel_request")] + #[must_use] pub fn cancellation(&self) -> RequestCancellation { self.cancellation.cancellation() } diff --git a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs index 6c5caea..21a815b 100644 --- a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs @@ -62,7 +62,7 @@ pub(super) async fn incoming_protocol_actor( FxHashMap::default(); let mut pending_messages: Vec = vec![]; - let request_cancellations = super::RequestCancellationRegistry::default(); + let request_cancellations = super::RequestCancellationRegistry::new(); // Map from request ID to (method, sender) for response dispatch. // Keys are JSON values because jsonrpcmsg::Id doesn't implement Eq. diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs index ce35a06..24833f9 100644 --- a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -58,7 +58,7 @@ impl JsonRpcMessage for SimpleRequest { method == "simple_method" } - fn method(&self) -> &str { + fn method(&self) -> &'static str { "simple_method" } From ddb19ba1f94124e5503601583997f3e68602fef6 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Thu, 21 May 2026 19:27:23 +0200 Subject: [PATCH 04/18] feat(acp): Auto-cancel dropped sent requests --- md/request-cancellation.md | 5 + src/agent-client-protocol/CHANGELOG.md | 2 +- src/agent-client-protocol/src/jsonrpc.rs | 128 +++++++++-- .../tests/jsonrpc_request_cancellation.rs | 216 ++++++++++++++++++ 4 files changed, 326 insertions(+), 25 deletions(-) diff --git a/md/request-cancellation.md b/md/request-cancellation.md index 2d31f07..1044fe0 100644 --- a/md/request-cancellation.md +++ b/md/request-cancellation.md @@ -28,6 +28,11 @@ The `SentRequest` remembers the peer and any proxy wrapping used for the original request, so this also works for requests sent through `ConnectionTo::send_request_to`. +Dropping a `SentRequest` before a response is received also sends +`$/cancel_request`. This covers abandoned request handles and futures: once a +response is received by `block_task`, `on_receiving_result`, or +`forward_response_to`, the SDK disarms the automatic cancellation. + If you already have the JSON-RPC request ID, send the notification directly: ```rust diff --git a/src/agent-client-protocol/CHANGELOG.md b/src/agent-client-protocol/CHANGELOG.md index d6bbf57..b20dc90 100644 --- a/src/agent-client-protocol/CHANGELOG.md +++ b/src/agent-client-protocol/CHANGELOG.md @@ -4,7 +4,7 @@ ### Added -- *(unstable)* Add SDK support for protocol-level request cancellation, including `SentRequest::cancel`, request-local cancellation helpers on `Responder`, and forwarded cancellation propagation. +- *(unstable)* Add SDK support for protocol-level request cancellation, including `SentRequest::cancel`, automatic cancellation when a `SentRequest` is dropped before receiving a response, request-local cancellation helpers on `Responder`, and forwarded cancellation propagation. ## [0.12.1](https://github.com/agentclientprotocol/rust-sdk/compare/v0.12.0...v0.12.1) - 2026-05-17 diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index ee321c5..5788cce 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -2008,6 +2008,9 @@ impl ConnectionTo { match self.message_tx.unbounded_send(message) { Ok(()) => (), Err(error) => { + #[cfg(feature = "unstable_cancel_request")] + cancellation.disarm(); + let OutgoingMessage::Request { method, response_tx, @@ -2030,6 +2033,9 @@ impl ConnectionTo { } Err(err) => { + #[cfg(feature = "unstable_cancel_request")] + cancellation.disarm(); + response_tx .send(ResponsePayload { result: Err(crate::util::internal_error(format!( @@ -2442,26 +2448,35 @@ impl ResponseRouter { /// Create a new response context for routing a response to a local awaiter. /// /// When `respond_with_result` is called, the response is sent through the oneshot - /// channel to the code that originally sent the request. + /// channel to the code that originally sent the request. If that receiver was + /// dropped, the response is discarded because there is no local awaiter left. pub(crate) fn new( method: String, id: jsonrpcmsg::Id, role_id: RoleId, sender: oneshot::Sender, ) -> Self { + let response_method = method.clone(); + let response_id = id.clone(); Self { method, id, role_id, send_fn: Box::new(move |response: Result| { - sender + if sender .send(ResponsePayload { result: response, ack_tx: None, }) - .map_err(|_| { - crate::util::internal_error("failed to send response, receiver dropped") - }) + .is_err() + { + tracing::debug!( + method = %response_method, + id = ?response_id, + "dropped response because local receiver was gone" + ); + } + Ok(()) }), } } @@ -3174,9 +3189,11 @@ enum SentRequestCancellation { Send { message_tx: OutgoingMessageTx, notification: UntypedMessage, + armed: Arc, }, Failed { error: String, + armed: Arc, }, } @@ -3199,8 +3216,20 @@ impl SentRequestCancellation { Ok(notification) => Self::Send { message_tx, notification, + armed: Arc::new(AtomicBool::new(true)), }, - Err(error) => Self::Failed { error }, + Err(error) => Self::Failed { + error, + armed: Arc::new(AtomicBool::new(true)), + }, + } + } + + fn disarm(&self) { + match self { + Self::Send { armed, .. } | Self::Failed { armed, .. } => { + armed.store(false, Ordering::Release); + } } } @@ -3209,15 +3238,37 @@ impl SentRequestCancellation { Self::Send { message_tx, notification, - } => send_raw_message( - message_tx, - OutgoingMessage::Notification { - untyped: notification.clone(), - }, - ), - Self::Failed { error } => Err(crate::util::internal_error(format!( - "failed to create cancel request notification: {error}" - ))), + armed, + } => { + if !armed.swap(false, Ordering::AcqRel) { + return Ok(()); + } + + send_raw_message( + message_tx, + OutgoingMessage::Notification { + untyped: notification.clone(), + }, + ) + } + Self::Failed { error, armed } => { + if !armed.swap(false, Ordering::AcqRel) { + return Ok(()); + } + + Err(crate::util::internal_error(format!( + "failed to create cancel request notification: {error}" + ))) + } + } + } +} + +#[cfg(feature = "unstable_cancel_request")] +impl Drop for SentRequestCancellation { + fn drop(&mut self) { + if let Err(error) = self.send() { + tracing::debug!(?error, "failed to auto-cancel dropped request"); } } } @@ -3226,13 +3277,19 @@ impl SentRequestCancellation { impl Debug for SentRequestCancellation { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Send { notification, .. } => f + Self::Send { + notification, + armed, + .. + } => f .debug_struct("SentRequestCancellation") .field("notification", notification) + .field("armed", &armed.load(Ordering::Acquire)) .finish(), - Self::Failed { error } => f + Self::Failed { error, armed } => f .debug_struct("SentRequestCancellation") .field("error", error) + .field("armed", &armed.load(Ordering::Acquire)) .finish(), } } @@ -3410,6 +3467,8 @@ impl SentRequest { } }; + downstream_cancellation.disarm(); + let ResponsePayload { result, ack_tx } = response.map_err(|err| { crate::util::internal_error(format!("response to `{method}` never received: {err}")) })?; @@ -3501,6 +3560,9 @@ impl SentRequest { result: Ok(json_value), ack_tx, }) => { + #[cfg(feature = "unstable_cancel_request")] + self.cancellation.disarm(); + // Ack immediately - we're in a spawned task, so the dispatch loop // can continue while we process the value. if let Some(tx) = ack_tx { @@ -3515,15 +3577,23 @@ impl SentRequest { result: Err(err), ack_tx, }) => { + #[cfg(feature = "unstable_cancel_request")] + self.cancellation.disarm(); + if let Some(tx) = ack_tx { let _ = tx.send(()); } Err(err) } - Err(err) => Err(crate::util::internal_error(format!( - "response to `{}` never received: {}", - self.method, err - ))), + Err(err) => { + #[cfg(feature = "unstable_cancel_request")] + self.cancellation.disarm(); + + Err(crate::util::internal_error(format!( + "response to `{}` never received: {}", + self.method, err + ))) + } } } @@ -3673,11 +3743,16 @@ impl SentRequest { let method = self.method; let response_rx = self.response_rx; let to_result = self.to_result; + #[cfg(feature = "unstable_cancel_request")] + let cancellation = self.cancellation; let location = Location::caller(); Task::new(location, async move { match response_rx.await { Ok(ResponsePayload { result, ack_tx }) => { + #[cfg(feature = "unstable_cancel_request")] + cancellation.disarm(); + // Convert the result using to_result for Ok values let typed_result = match result { Ok(json_value) => to_result(json_value), @@ -3695,9 +3770,14 @@ impl SentRequest { outcome } - Err(err) => Err(crate::util::internal_error(format!( - "response to `{method}` never received: {err}" - ))), + Err(err) => { + #[cfg(feature = "unstable_cancel_request")] + cancellation.disarm(); + + Err(crate::util::internal_error(format!( + "response to `{method}` never received: {err}" + ))) + } } }) .spawn(&task_tx) diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs index 24833f9..79f34af 100644 --- a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -479,6 +479,222 @@ async fn sent_request_can_send_cancellation_for_its_id() { .await; } +#[tokio::test(flavor = "current_thread")] +async fn dropped_sent_request_sends_cancellation_for_its_id() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let received = Arc::new(Mutex::new(Vec::new())); + let received_for_handler = received.clone(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async |_request: SimpleRequest, + _responder: Responder, + _connection: ConnectionTo| { Ok(()) }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + received_for_handler + .lock() + .unwrap() + .push(notification.request_id); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let expected_id = UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + let request: SentRequest = cx.send_request(SimpleRequest { + message: "abandoned".into(), + }); + let expected_id = request.id(); + drop(request); + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + Ok(expected_id) + }) + .await + .unwrap(); + + let received = received.lock().unwrap(); + assert_eq!(received.len(), 1); + assert_eq!(serde_json::to_value(&received[0]).unwrap(), expected_id); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn late_response_after_dropped_sent_request_does_not_close_connection() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let received = Arc::new(Mutex::new(Vec::new())); + let received_for_handler = received.clone(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async |request: SimpleRequest, + responder: Responder, + connection: ConnectionTo| { + if request.message == "late" { + connection.spawn(async move { + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + responder.respond(SimpleResponse { + result: "late response".into(), + }) + })?; + return Ok(()); + } + + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + received_for_handler + .lock() + .unwrap() + .push(notification.request_id); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let (expected_id, response) = UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + let request: SentRequest = cx.send_request(SimpleRequest { + message: "late".into(), + }); + let expected_id = request.id(); + drop(request); + + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + let response = cx + .send_request(SimpleRequest { + message: "after late".into(), + }) + .block_task() + .await?; + Ok((expected_id, response)) + }) + .await + .unwrap(); + + assert_eq!(response.result, "echo: after late"); + let received = received.lock().unwrap(); + assert_eq!(received.len(), 1); + assert_eq!(serde_json::to_value(&received[0]).unwrap(), expected_id); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn completed_sent_request_does_not_send_cancellation_on_drop() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let received = Arc::new(Mutex::new(Vec::new())); + let received_for_handler = received.clone(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + received_for_handler + .lock() + .unwrap() + .push(notification.request_id); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let response = UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + let response = cx + .send_request(SimpleRequest { + message: "complete".into(), + }) + .block_task() + .await?; + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + Ok(response) + }) + .await + .unwrap(); + + assert_eq!(response.result, "echo: complete"); + assert!(received.lock().unwrap().is_empty()); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn forward_response_to_propagates_cancellation_to_downstream_request() { use tokio::task::LocalSet; From 25bea0070e6cd38514a46c2756c14dd779106543 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Thu, 21 May 2026 20:14:52 +0200 Subject: [PATCH 05/18] fix(acp): disarm cancellation for buffered responses --- src/agent-client-protocol/src/jsonrpc.rs | 69 +++++++++++++----- .../src/jsonrpc/incoming_actor.rs | 17 ++++- .../src/jsonrpc/outgoing_actor.rs | 8 +++ .../tests/jsonrpc_request_cancellation.rs | 72 +++++++++++++++++++ 4 files changed, 146 insertions(+), 20 deletions(-) diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index 5788cce..a8e8ca4 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -1358,6 +1358,9 @@ enum ReplyMessage { method: String, sender: oneshot::Sender, + + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm: SentRequestCancellationDisarm, }, } @@ -1654,6 +1657,9 @@ enum OutgoingMessage { /// where to send the response when it arrives (includes ack channel) response_tx: oneshot::Sender, + + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm: SentRequestCancellationDisarm, }, /// Send a notification to the server. @@ -2003,6 +2009,8 @@ impl ConnectionTo { role_id, untyped, response_tx, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm: cancellation.disarm_handle(), }; match self.message_tx.unbounded_send(message) { @@ -2455,6 +2463,8 @@ impl ResponseRouter { id: jsonrpcmsg::Id, role_id: RoleId, sender: oneshot::Sender, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm: SentRequestCancellationDisarm, ) -> Self { let response_method = method.clone(); let response_id = id.clone(); @@ -2475,6 +2485,9 @@ impl ResponseRouter { id = ?response_id, "dropped response because local receiver was gone" ); + } else { + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm.disarm(); } Ok(()) }), @@ -3184,16 +3197,34 @@ fn jsonrpc_id_to_request_id(id: &jsonrpcmsg::Id) -> Result, +} + +#[cfg(feature = "unstable_cancel_request")] +impl SentRequestCancellationDisarm { + fn new() -> Self { + Self { + armed: Arc::new(AtomicBool::new(true)), + } + } + + fn disarm(&self) { + self.armed.store(false, Ordering::Release); + } +} + +#[cfg(feature = "unstable_cancel_request")] enum SentRequestCancellation { Send { message_tx: OutgoingMessageTx, notification: UntypedMessage, - armed: Arc, + disarm: SentRequestCancellationDisarm, }, Failed { error: String, - armed: Arc, + disarm: SentRequestCancellationDisarm, }, } @@ -3211,25 +3242,25 @@ impl SentRequestCancellation { ) }) .map_err(|error| error.to_string()); + let disarm = SentRequestCancellationDisarm::new(); match notification { Ok(notification) => Self::Send { message_tx, notification, - armed: Arc::new(AtomicBool::new(true)), - }, - Err(error) => Self::Failed { - error, - armed: Arc::new(AtomicBool::new(true)), + disarm, }, + Err(error) => Self::Failed { error, disarm }, } } fn disarm(&self) { + self.disarm_handle().disarm(); + } + + fn disarm_handle(&self) -> SentRequestCancellationDisarm { match self { - Self::Send { armed, .. } | Self::Failed { armed, .. } => { - armed.store(false, Ordering::Release); - } + Self::Send { disarm, .. } | Self::Failed { disarm, .. } => disarm.clone(), } } @@ -3238,9 +3269,9 @@ impl SentRequestCancellation { Self::Send { message_tx, notification, - armed, + disarm, } => { - if !armed.swap(false, Ordering::AcqRel) { + if !disarm.armed.swap(false, Ordering::AcqRel) { return Ok(()); } @@ -3251,8 +3282,8 @@ impl SentRequestCancellation { }, ) } - Self::Failed { error, armed } => { - if !armed.swap(false, Ordering::AcqRel) { + Self::Failed { error, disarm } => { + if !disarm.armed.swap(false, Ordering::AcqRel) { return Ok(()); } @@ -3279,17 +3310,17 @@ impl Debug for SentRequestCancellation { match self { Self::Send { notification, - armed, + disarm, .. } => f .debug_struct("SentRequestCancellation") .field("notification", notification) - .field("armed", &armed.load(Ordering::Acquire)) + .field("armed", &disarm.armed.load(Ordering::Acquire)) .finish(), - Self::Failed { error, armed } => f + Self::Failed { error, disarm } => f .debug_struct("SentRequestCancellation") .field("error", error) - .field("armed", &armed.load(Ordering::Acquire)) + .field("armed", &disarm.armed.load(Ordering::Acquire)) .finish(), } } diff --git a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs index 21a815b..1a5b80d 100644 --- a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs @@ -30,6 +30,8 @@ struct PendingReply { method: String, role_id: RoleId, sender: oneshot::Sender, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm: super::SentRequestCancellationDisarm, } /// Incoming protocol actor: The central dispatch loop for a connection. @@ -78,6 +80,8 @@ pub(super) async fn incoming_protocol_actor( role_id, method, sender, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm, } => { tracing::trace!(?id, %method, "incoming_actor: subscribing to response"); let id = serde_json::to_value(&id).unwrap(); @@ -87,6 +91,8 @@ pub(super) async fn incoming_protocol_actor( method, role_id, sender, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm, }, ); } @@ -260,10 +266,19 @@ fn dispatch_from_response( method, role_id, sender, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm, } = pending_reply; // Create a Dispatch::Response with a ResponseRouter that routes to the oneshot - let router = ResponseRouter::new(method.clone(), id.clone(), role_id, sender); + let router = ResponseRouter::new( + method.clone(), + id.clone(), + role_id, + sender, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm, + ); Dispatch::Response(result, router) } diff --git a/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs b/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs index 0b54ff7..65a5611 100644 --- a/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs @@ -41,6 +41,8 @@ pub(super) async fn outgoing_protocol_actor( method, untyped, response_tx, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm, } => { let request = match protocol_compat .outgoing_message(untyped) @@ -49,6 +51,8 @@ pub(super) async fn outgoing_protocol_actor( Ok(request) => request, Err(error) => { tracing::warn!(?id, %method, ?error, "Failed to convert outgoing request"); + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm.disarm(); complete_request_with_error(response_tx, error); continue; } @@ -61,6 +65,8 @@ pub(super) async fn outgoing_protocol_actor( role_id, method, sender: response_tx, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm, }) .map_err(crate::Error::into_internal_error)?; @@ -167,6 +173,8 @@ mod tests { method: "session/new".into(), untyped: malformed_v2_known_method()?, response_tx, + #[cfg(feature = "unstable_cancel_request")] + cancellation_disarm: crate::jsonrpc::SentRequestCancellationDisarm::new(), }) .map_err(crate::Error::into_internal_error)?; drop(outgoing_tx); diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs index 79f34af..9383fd0 100644 --- a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -628,6 +628,78 @@ async fn late_response_after_dropped_sent_request_does_not_close_connection() { .await; } +#[tokio::test(flavor = "current_thread")] +async fn response_buffered_before_drop_disarms_auto_cancellation() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let received = Arc::new(Mutex::new(Vec::new())); + let received_for_handler = received.clone(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + received_for_handler + .lock() + .unwrap() + .push(notification.request_id); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let response = UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + let request: SentRequest = cx.send_request(SimpleRequest { + message: "buffered".into(), + }); + + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + drop(request); + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + cx.send_request(SimpleRequest { + message: "after buffered".into(), + }) + .block_task() + .await + }) + .await + .unwrap(); + + assert_eq!(response.result, "echo: after buffered"); + assert!(received.lock().unwrap().is_empty()); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn completed_sent_request_does_not_send_cancellation_on_drop() { use tokio::task::LocalSet; From 506ec3fd2b53b8f41dd3083dd6765a060ae47366 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Thu, 21 May 2026 20:21:57 +0200 Subject: [PATCH 06/18] docs: clarify automatic request cancellation behavior --- md/request-cancellation.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/md/request-cancellation.md b/md/request-cancellation.md index 1044fe0..4e8dc64 100644 --- a/md/request-cancellation.md +++ b/md/request-cancellation.md @@ -28,10 +28,11 @@ The `SentRequest` remembers the peer and any proxy wrapping used for the original request, so this also works for requests sent through `ConnectionTo::send_request_to`. -Dropping a `SentRequest` before a response is received also sends -`$/cancel_request`. This covers abandoned request handles and futures: once a -response is received by `block_task`, `on_receiving_result`, or -`forward_response_to`, the SDK disarms the automatic cancellation. +Dropping a `SentRequest` before the SDK receives a response also sends +`$/cancel_request`. This covers abandoned request handles and futures. Once the +SDK routes a response to the waiting request handle, automatic cancellation is +disarmed, even if caller code has not yet consumed it with `block_task`, +`on_receiving_result`, or `forward_response_to`. If you already have the JSON-RPC request ID, send the notification directly: From 0065bd334f0da39d6fa9137cfa26abcce584f3f5 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 10 Jun 2026 11:30:48 +0200 Subject: [PATCH 07/18] chore(acp): acknowledge large session enum variant --- src/agent-client-protocol/src/session.rs | 4 ++++ .../tests/jsonrpc_request_cancellation.rs | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/agent-client-protocol/src/session.rs b/src/agent-client-protocol/src/session.rs index 97f70d1..087dab4 100644 --- a/src/agent-client-protocol/src/session.rs +++ b/src/agent-client-protocol/src/session.rs @@ -514,6 +514,10 @@ where /// Incoming message from the agent #[non_exhaustive] #[derive(Debug)] +#[expect( + clippy::large_enum_variant, + reason = "Dispatch messages vastly outnumber StopReason; boxing would add a heap allocation" +)] pub enum SessionMessage { /// Periodic updates with new content, tool requests, etc. /// Use [`MatchDispatch`] to match on the message type. diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs index 9383fd0..4a11938 100644 --- a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -273,11 +273,11 @@ async fn unhandled_protocol_level_notifications_are_ignored() { let response = read_jsonrpc_response_line(&mut client_reader).await; expect![[r#" { - "id": 2, "jsonrpc": "2.0", "result": { "result": "echo: after cancel" - } + }, + "id": 2 }"#]] .assert_eq(&serde_json::to_string_pretty(&response).unwrap()); }) @@ -353,11 +353,11 @@ async fn unhandled_wrapped_protocol_level_notifications_are_ignored() { let response = read_jsonrpc_response_line(&mut client_reader).await; expect![[r#" { - "id": 2, "jsonrpc": "2.0", "result": { "result": "echo: after wrapped cancel" - } + }, + "id": 2 }"#]] .assert_eq(&serde_json::to_string_pretty(&response).unwrap()); }) From ca4d3dd80ca5835f7c8b265d0f9a4b9e65b602ac Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 10 Jun 2026 12:11:14 +0200 Subject: [PATCH 08/18] revert changelog changes --- src/agent-client-protocol/CHANGELOG.md | 46 +++++++++++++------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/agent-client-protocol/CHANGELOG.md b/src/agent-client-protocol/CHANGELOG.md index 1f85e35..b27cf30 100644 --- a/src/agent-client-protocol/CHANGELOG.md +++ b/src/agent-client-protocol/CHANGELOG.md @@ -7,11 +7,11 @@ ### Added - Stabilize session/delete, message ids, and context usage ([#199](https://github.com/agentclientprotocol/rust-sdk/pull/199)) -- _(acp)_ add unstable elicitation support ([#197](https://github.com/agentclientprotocol/rust-sdk/pull/197)) +- *(acp)* add unstable elicitation support ([#197](https://github.com/agentclientprotocol/rust-sdk/pull/197)) ### Fixed -- _(acp)_ Serialize proxy metadata as \_meta ([#198](https://github.com/agentclientprotocol/rust-sdk/pull/198)) +- *(acp)* Serialize proxy metadata as _meta ([#198](https://github.com/agentclientprotocol/rust-sdk/pull/198)) ### Other @@ -19,21 +19,21 @@ ### Added -- _(unstable)_ Add JSON-RPC support for elicitation requests and notifications. +- *(unstable)* Add JSON-RPC support for elicitation requests and notifications. ## [0.13.1](https://github.com/agentclientprotocol/rust-sdk/compare/v0.13.0...v0.13.1) - 2026-06-01 ### Added -- _(deps)_ bump schema to 0.13.5 ([#188](https://github.com/agentclientprotocol/rust-sdk/pull/188)) +- *(deps)* bump schema to 0.13.5 ([#188](https://github.com/agentclientprotocol/rust-sdk/pull/188)) ## [0.13.0](https://github.com/agentclientprotocol/rust-sdk/compare/v0.12.1...v0.13.0) - 2026-06-01 ### Added -- _(acp)_ stabilize logout support ([#185](https://github.com/agentclientprotocol/rust-sdk/pull/185)) -- _(acp)_ Extract all rmcp logic to the rmcp crate ([#180](https://github.com/agentclientprotocol/rust-sdk/pull/180)) -- _(acp)_ Add unstable (very experimental!) protocol v2 support ([#170](https://github.com/agentclientprotocol/rust-sdk/pull/170)) +- *(acp)* stabilize logout support ([#185](https://github.com/agentclientprotocol/rust-sdk/pull/185)) +- *(acp)* Extract all rmcp logic to the rmcp crate ([#180](https://github.com/agentclientprotocol/rust-sdk/pull/180)) +- *(acp)* Add unstable (very experimental!) protocol v2 support ([#170](https://github.com/agentclientprotocol/rust-sdk/pull/170)) ### Changed @@ -49,10 +49,10 @@ ### Added -- _(acp)_ add unstable session delete support ([#165](https://github.com/agentclientprotocol/rust-sdk/pull/165)) +- *(acp)* add unstable session delete support ([#165](https://github.com/agentclientprotocol/rust-sdk/pull/165)) - extract mcp-over-acp proxy ([#146](https://github.com/agentclientprotocol/rust-sdk/pull/146)) - Stabilize session/close and session/resume ([#147](https://github.com/agentclientprotocol/rust-sdk/pull/147)) -- remove direct dependency on tokio ([#145](https://github.com/agentclientprotocol/rust-sdk/pull/145)) +- remove direct dependency on tokio ([#145](https://github.com/agentclientprotocol/rust-sdk/pull/145)) ### Fixed @@ -70,7 +70,7 @@ ### Added -- _(unstable)_ Add support for `session/delete` method. +- *(unstable)* Add support for `session/delete` method. - `McpConnectionTo::acp_id()` method. ### Deprecated @@ -81,7 +81,7 @@ ### Fixed -- _(acp)_ remove `boxfnonce` dependency in favor of `Box` ([#137](https://github.com/agentclientprotocol/rust-sdk/pull/137)) +- *(acp)* remove `boxfnonce` dependency in favor of `Box` ([#137](https://github.com/agentclientprotocol/rust-sdk/pull/137)) ## [0.11.0](https://github.com/agentclientprotocol/rust-sdk/compare/v0.10.4...v0.11.0) - 2026-04-20 @@ -92,9 +92,9 @@ ### Fixed -- _(rpc)_ log errors when sending response to peer fails ([#101](https://github.com/agentclientprotocol/rust-sdk/pull/101)) -- _(rpc)_ handle write failures in handle_io loop ([#99](https://github.com/agentclientprotocol/rust-sdk/pull/99)) -- _(rpc)_ use RawValue::NULL constant instead of from_string().unwrap() ([#96](https://github.com/agentclientprotocol/rust-sdk/pull/96)) +- *(rpc)* log errors when sending response to peer fails ([#101](https://github.com/agentclientprotocol/rust-sdk/pull/101)) +- *(rpc)* handle write failures in handle_io loop ([#99](https://github.com/agentclientprotocol/rust-sdk/pull/99)) +- *(rpc)* use RawValue::NULL constant instead of from_string().unwrap() ([#96](https://github.com/agentclientprotocol/rust-sdk/pull/96)) ### Other @@ -102,13 +102,13 @@ - Add mdbook build ([#120](https://github.com/agentclientprotocol/rust-sdk/pull/120)) - Add migration guide for next release ([#111](https://github.com/agentclientprotocol/rust-sdk/pull/111)) - remove debug code from rpc_tests ([#100](https://github.com/agentclientprotocol/rust-sdk/pull/100)) -- _(test)_ add conditional compilation ([#98](https://github.com/agentclientprotocol/rust-sdk/pull/98)) +- *(test)* add conditional compilation ([#98](https://github.com/agentclientprotocol/rust-sdk/pull/98)) ## [0.10.4](https://github.com/agentclientprotocol/rust-sdk/compare/v0.10.3...v0.10.4) - 2026-03-31 ### Added -- _(schema)_ Update schema to 0.11.4 ([#95](https://github.com/agentclientprotocol/rust-sdk/pull/95)) +- *(schema)* Update schema to 0.11.4 ([#95](https://github.com/agentclientprotocol/rust-sdk/pull/95)) ### Fixed @@ -123,14 +123,14 @@ ### Added -- _(unstable)_ Add logout support ([#84](https://github.com/agentclientprotocol/rust-sdk/pull/84)) -- _(schema)_ Update schema to 0.11.3 ([#82](https://github.com/agentclientprotocol/rust-sdk/pull/82)) +- *(unstable)* Add logout support ([#84](https://github.com/agentclientprotocol/rust-sdk/pull/84)) +- *(schema)* Update schema to 0.11.3 ([#82](https://github.com/agentclientprotocol/rust-sdk/pull/82)) ## [0.10.2](https://github.com/agentclientprotocol/rust-sdk/compare/v0.10.1...v0.10.2) - 2026-03-11 ### Added -- _(unstable)_ Add support for session/close methods ([#77](https://github.com/agentclientprotocol/rust-sdk/pull/77)) +- *(unstable)* Add support for session/close methods ([#77](https://github.com/agentclientprotocol/rust-sdk/pull/77)) ## [0.10.1](https://github.com/agentclientprotocol/rust-sdk/compare/v0.10.0...v0.10.1) - 2026-03-10 @@ -171,15 +171,15 @@ ### Added -- _(unstable)_ Add initial support for session config options ([#36](https://github.com/agentclientprotocol/rust-sdk/pull/36)) +- *(unstable)* Add initial support for session config options ([#36](https://github.com/agentclientprotocol/rust-sdk/pull/36)) ## [0.9.1](https://github.com/agentclientprotocol/rust-sdk/compare/v0.9.0...v0.9.1) - 2025-12-17 ### Added -- _(unstable)_ Add initial support for resuming sessions ([#34](https://github.com/agentclientprotocol/rust-sdk/pull/34)) -- _(unstable)_ Add initial support for forking sessions ([#33](https://github.com/agentclientprotocol/rust-sdk/pull/33)) -- _(unstable)_ Add initial support for listing sessions ([#31](https://github.com/agentclientprotocol/rust-sdk/pull/31)) +- *(unstable)* Add initial support for resuming sessions ([#34](https://github.com/agentclientprotocol/rust-sdk/pull/34)) +- *(unstable)* Add initial support for forking sessions ([#33](https://github.com/agentclientprotocol/rust-sdk/pull/33)) +- *(unstable)* Add initial support for listing sessions ([#31](https://github.com/agentclientprotocol/rust-sdk/pull/31)) ### Other From 5df098d033fa3ef21a81088fc4127a21b8713c38 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 10 Jun 2026 12:35:13 +0200 Subject: [PATCH 09/18] docs(acp): expand request cancellation guidance and various fixes --- md/request-cancellation.md | 60 +++-- .../src/concepts/cancellation.rs | 136 ++++++++++ .../src/concepts/connections.rs | 10 + src/agent-client-protocol/src/concepts/mod.rs | 6 + src/agent-client-protocol/src/jsonrpc.rs | 117 ++++++--- src/agent-client-protocol/src/schema/mod.rs | 57 +++++ .../src/schema/protocol_level.rs | 38 +-- .../src/schema/v2_impls.rs | 31 +-- src/agent-client-protocol/src/session.rs | 9 +- .../tests/jsonrpc_request_cancellation.rs | 232 +++++++++++------- .../tests/protocol_v2.rs | 78 ++++++ 11 files changed, 559 insertions(+), 215 deletions(-) create mode 100644 src/agent-client-protocol/src/concepts/cancellation.rs diff --git a/md/request-cancellation.md b/md/request-cancellation.md index 4e8dc64..0d58f73 100644 --- a/md/request-cancellation.md +++ b/md/request-cancellation.md @@ -11,6 +11,16 @@ Enable the feature when depending on the crate: agent-client-protocol = { version = "...", features = ["unstable_cancel_request"] } ``` +Cancellation is cooperative. A peer may ignore `$/cancel_request`, may finish +with normal data, or may respond to the original request with +`Error::request_cancelled()` (`-32800`). The requesting side always receives a +response to the original request; cancellation only changes _which_ response +that is. The SDK ignores unhandled `$/...` notifications (even when the +feature is disabled) so unsupported protocol-level notifications do not +produce method-not-found errors. + +## Cancelling outgoing requests + To cancel a request sent through `ConnectionTo::send_request`, keep the returned `SentRequest` and call `cancel` on it: @@ -20,6 +30,11 @@ returned `SentRequest` and call `cancel` on it: # async fn example(cx: ConnectionTo) -> Result<(), Error> { let request = cx.send_request(MyRequest {}); request.cancel()?; + +// The peer still responds to the request: with normal data if it raced +// ahead, or with the standard cancellation error. +let result = request.block_task().await; +# let _ = result; # Ok(()) # } ``` @@ -44,6 +59,8 @@ cx.send_cancel_request("request-id".to_string())?; # } ``` +## Handling cancellation of incoming requests + For incoming requests, get the request-local cancellation marker from the `Responder`. This keeps cancellation handling next to the request work it controls: @@ -59,14 +76,16 @@ cx.spawn(async move { let response = cancellation.run_until_cancelled(run_request(request)).await; responder.respond_with_result(response) })?; -Ok(()) +# Ok(()) # } ``` `run_until_cancelled` is the simple path for handlers that should stop work and -reply with the standard cancellation error as soon as cancellation is requested. -If the handler needs cleanup, partial results, or custom cancellation behavior, -use `cancellation.cancelled()` or `cancellation.is_cancelled()` directly inside +reply with the standard cancellation error as soon as cancellation is +requested; it drops the work future when cancellation wins, so cleanup must +happen in `Drop` implementations and partial results are lost. If the handler +needs cleanup, partial results, or custom cancellation behavior, use +`cancellation.cancelled()` or `cancellation.is_cancelled()` directly inside the request work instead. Cancellation markers are only updated when the connection can process the @@ -74,9 +93,14 @@ incoming `$/cancel_request` notification. Long-running handlers should return quickly and move work into `ConnectionTo::spawn`, `SentRequest` callbacks, or another task. +## Proxies + When proxying with `SentRequest::forward_response_to`, the SDK observes the upstream `Responder` cancellation marker and forwards cancellation to the -downstream request automatically. +downstream request automatically. The downstream response (normal data or a +cancellation error) is still forwarded back upstream. + +## Low-level access Register `CancelRequestNotification` or `ProtocolLevelNotification` directly only when you need low-level access to cancellation notifications, such as @@ -86,21 +110,15 @@ custom routing or protocol tracing: # use agent_client_protocol::{ConnectionTo, Error, UntypedRole}; use agent_client_protocol::schema::CancelRequestNotification; -# fn builder() -> agent_client_protocol::Builder { -UntypedRole.builder() - .on_receive_notification( - async |cancel: CancelRequestNotification, _cx: ConnectionTo| { - let request_id = cancel.request_id; - // Mark the matching in-flight operation cancelled. - Ok(()) - }, - agent_client_protocol::on_receive_notification!(), - ) +# fn example() { +let builder = UntypedRole.builder().on_receive_notification( + async |cancel: CancelRequestNotification, _cx: ConnectionTo| { + // Mark the matching in-flight operation cancelled. + let _request_id = cancel.request_id; + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), +); +# let _ = builder; # } ``` - -Cancellation is cooperative. A peer may ignore `$/cancel_request`, may finish -with normal data, or may respond to the original request with -`Error::request_cancelled()` (`-32800`). The SDK ignores unhandled `$/...` -notifications so unsupported protocol-level notifications do not produce -method-not-found errors. diff --git a/src/agent-client-protocol/src/concepts/cancellation.rs b/src/agent-client-protocol/src/concepts/cancellation.rs new file mode 100644 index 0000000..1de7c0a --- /dev/null +++ b/src/agent-client-protocol/src/concepts/cancellation.rs @@ -0,0 +1,136 @@ +//! Request cancellation with `$/cancel_request`. +//! +//! The SDK exposes the ACP `$/cancel_request` notification behind the +//! `unstable_cancel_request` feature. The notification is protocol-level: +//! either side may send it to ask the peer to cancel one outstanding JSON-RPC +//! request by ID. +//! +//! Cancellation is **cooperative**. A peer may ignore `$/cancel_request`, may +//! finish with normal data, or may respond to the original request with +//! [`Error::request_cancelled`] (`-32800`). The requesting side always +//! receives a response to the original request; cancellation only changes +//! *which* response that is. Unhandled `$/`-prefixed notifications are ignored +//! by the SDK (even without this feature), so peers that do not support +//! cancellation simply will not act on it. +//! +//! # Cancelling outgoing requests +//! +//! To cancel a request sent through [`ConnectionTo::send_request`], keep the +//! returned [`SentRequest`] and call [`cancel`][`SentRequest::cancel`] on it: +//! +//! ``` +//! # use agent_client_protocol::{ConnectionTo, Error, UntypedRole}; +//! # use agent_client_protocol_test::MyRequest; +//! # async fn example(cx: ConnectionTo) -> Result<(), Error> { +//! let request = cx.send_request(MyRequest {}); +//! request.cancel()?; +//! +//! // The peer still responds to the request: with normal data if it raced +//! // ahead, or with the standard cancellation error. +//! let result = request.block_task().await; +//! # let _ = result; +//! # Ok(()) +//! # } +//! ``` +//! +//! The [`SentRequest`] remembers the peer and any proxy wrapping used for the +//! original request, so this also works for requests sent through +//! [`ConnectionTo::send_request_to`]. +//! +//! Dropping a [`SentRequest`] before the SDK receives a response also sends +//! `$/cancel_request`. This covers abandoned request handles and futures. Once +//! the SDK routes a response to the waiting request handle, automatic +//! cancellation is disarmed, even if caller code has not yet consumed it with +//! [`block_task`], [`on_receiving_result`], or [`forward_response_to`]. +//! +//! If you already have the JSON-RPC request ID, send the notification +//! directly: +//! +//! ``` +//! # use agent_client_protocol::{ConnectionTo, Error, UntypedRole}; +//! # async fn example(cx: ConnectionTo) -> Result<(), Error> { +//! cx.send_cancel_request("request-id".to_string())?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Handling cancellation of incoming requests +//! +//! For incoming requests, get the request-local cancellation marker from the +//! [`Responder`]. This keeps cancellation handling next to the request work it +//! controls: +//! +//! ``` +//! # use agent_client_protocol::{ConnectionTo, Error, Responder, UntypedRole}; +//! # use agent_client_protocol_test::{MyRequest, MyResponse}; +//! # async fn example(request: MyRequest, responder: Responder, cx: ConnectionTo) -> Result<(), Error> { +//! # async fn run_request(_request: MyRequest) -> Result { todo!() } +//! let cancellation = responder.cancellation(); +//! +//! cx.spawn(async move { +//! let response = cancellation.run_until_cancelled(run_request(request)).await; +//! responder.respond_with_result(response) +//! })?; +//! # Ok(()) +//! # } +//! ``` +//! +//! [`run_until_cancelled`] is the simple path for handlers that should stop +//! work and reply with the standard cancellation error as soon as cancellation +//! is requested; it drops the work future when cancellation wins. If the +//! handler needs cleanup, partial results, or custom cancellation behavior, +//! use [`cancelled`][`RequestCancellation::cancelled`] or +//! [`is_cancelled`][`RequestCancellation::is_cancelled`] directly inside the +//! request work instead. +//! +//! Cancellation markers are only updated when the connection can process the +//! incoming `$/cancel_request` notification. Long-running handlers should +//! return quickly and move work into [`ConnectionTo::spawn`], [`SentRequest`] +//! callbacks, or another task; see the [ordering](super::ordering) chapter. +//! +//! # Proxies +//! +//! When proxying with [`forward_response_to`], the SDK observes the upstream +//! [`Responder`] cancellation marker and forwards cancellation to the +//! downstream request automatically. The downstream response (normal data or a +//! cancellation error) is still forwarded back upstream. +//! +//! # Low-level access +//! +//! Register [`CancelRequestNotification`] (or [`ProtocolLevelNotification`]) +//! directly only when you need low-level access to cancellation notifications, +//! such as custom routing or protocol tracing: +//! +//! ``` +//! # use agent_client_protocol::{ConnectionTo, Error, UntypedRole}; +//! use agent_client_protocol::schema::CancelRequestNotification; +//! +//! # fn example() { +//! let builder = UntypedRole.builder().on_receive_notification( +//! async |cancel: CancelRequestNotification, _cx: ConnectionTo| { +//! // Mark the matching in-flight operation cancelled. +//! let _request_id = cancel.request_id; +//! Ok(()) +//! }, +//! agent_client_protocol::on_receive_notification!(), +//! ); +//! # let _ = builder; +//! # } +//! ``` +//! +//! [`block_task`]: crate::SentRequest::block_task +//! [`on_receiving_result`]: crate::SentRequest::on_receiving_result +//! [`forward_response_to`]: crate::SentRequest::forward_response_to +//! [`run_until_cancelled`]: crate::RequestCancellation::run_until_cancelled +//! [`RequestCancellation`]: crate::RequestCancellation +//! [`RequestCancellation::cancelled`]: crate::RequestCancellation::cancelled +//! [`RequestCancellation::is_cancelled`]: crate::RequestCancellation::is_cancelled +//! [`ConnectionTo::send_request`]: crate::ConnectionTo::send_request +//! [`ConnectionTo::send_request_to`]: crate::ConnectionTo::send_request_to +//! [`ConnectionTo::spawn`]: crate::ConnectionTo::spawn +//! [`SentRequest`]: crate::SentRequest +//! [`SentRequest::cancel`]: crate::SentRequest::cancel +//! [`Responder`]: crate::Responder +//! [`Error::request_cancelled`]: crate::Error::request_cancelled +//! [`CancelRequestNotification`]: crate::schema::CancelRequestNotification +//! [`ProtocolLevelNotification`]: crate::schema::ProtocolLevelNotification diff --git a/src/agent-client-protocol/src/concepts/connections.rs b/src/agent-client-protocol/src/concepts/connections.rs index 8c05fe7..305649a 100644 --- a/src/agent-client-protocol/src/concepts/connections.rs +++ b/src/agent-client-protocol/src/concepts/connections.rs @@ -108,6 +108,16 @@ //! //! See [Ordering](super::ordering) for important details about how these differ. //! +//! ## Dropping a `SentRequest` +//! +//! By default, dropping a [`SentRequest`] without consuming it simply discards +//! the response when it arrives. When the `unstable_cancel_request` feature is +//! enabled, dropping an unconsumed [`SentRequest`] additionally sends a +//! `$/cancel_request` notification asking the peer to cancel the request, so +//! fire-and-forget requests should consume their handle (for example with +//! `on_receiving_result`). See the request cancellation chapter +//! (`concepts::cancellation`, feature-gated) for details. +//! //! # Next Steps //! //! - [Sessions](super::sessions) - Create multi-turn conversations diff --git a/src/agent-client-protocol/src/concepts/mod.rs b/src/agent-client-protocol/src/concepts/mod.rs index 01b77ef..1aee662 100644 --- a/src/agent-client-protocol/src/concepts/mod.rs +++ b/src/agent-client-protocol/src/concepts/mod.rs @@ -29,9 +29,15 @@ //! //! 8. [Error Handling][`crate::concepts::error_handling`] - Protocol errors vs //! connection errors, and how to handle them. +//! +//! When the `unstable_cancel_request` feature is enabled, there is also a +//! chapter on request cancellation (`crate::concepts::cancellation`). pub mod acp_basics; pub mod callbacks; +#[cfg(feature = "unstable_cancel_request")] +#[cfg_attr(docsrs, doc(cfg(feature = "unstable_cancel_request")))] +pub mod cancellation; pub mod connections; pub mod error_handling; pub mod ordering; diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index a8e8ca4..4bd7ecd 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -1422,6 +1422,12 @@ impl RequestCancellation { /// that want to respond with the normal result or the standard /// cancellation error. /// + /// When cancellation wins, `future` is dropped: work stops at its next + /// await point, partial results are lost, and any cleanup must happen in + /// `Drop` implementations. Handlers that need to flush partial results or + /// run async cleanup should instead watch [`cancelled`](Self::cancelled) + /// or poll [`is_cancelled`](Self::is_cancelled) from inside the work. + /// /// [`Error::request_cancelled`]: crate::Error::request_cancelled pub async fn run_until_cancelled( &self, @@ -1612,6 +1618,16 @@ fn cancellation_request_id_from_message( Ok(None) } +/// Whether the dispatch is a protocol-level (`$/`-prefixed) notification, +/// possibly wrapped in a [`SuccessorMessage`] envelope. +/// +/// Unhandled protocol-level notifications are ignored rather than rejected +/// with a method-not-found error. This is deliberately *not* feature-gated: +/// protocol-level notifications are optional by design, so a peer that sends +/// `$/cancel_request` must be able to interoperate with an SDK built without +/// `unstable_cancel_request` (which simply won't act on it). +/// +/// [`SuccessorMessage`]: crate::schema::SuccessorMessage fn is_protocol_level_notification(dispatch: &Dispatch) -> bool { let Dispatch::Notification(message) = dispatch else { return false; @@ -3171,6 +3187,15 @@ impl JsonRpcNotification for UntypedMessage {} /// If you block the event loop while waiting for a response, the connection cannot process /// the incoming response message, creating a deadlock. This API design prevents that footgun /// by making blocking explicit and encouraging non-blocking patterns. +/// +/// # Drop Behavior +/// +/// By default, dropping a `SentRequest` without consuming it discards the +/// response when it arrives. When the `unstable_cancel_request` feature is +/// enabled, dropping a `SentRequest` before the SDK has received the response +/// additionally sends a `$/cancel_request` notification asking the peer to +/// cancel the request; fire-and-forget requests should consume their handle +/// (for example with [`on_receiving_result`](Self::on_receiving_result)). pub struct SentRequest { id: jsonrpcmsg::Id, method: String, @@ -3451,27 +3476,17 @@ impl SentRequest { /// - The response types match between the outgoing request and incoming request /// /// This is equivalent to calling `on_receiving_result` and manually forwarding - /// the result, but more concise. - pub fn forward_response_to(self, responder: Responder) -> Result<(), crate::Error> - where - T: Send, - { - #[cfg(feature = "unstable_cancel_request")] - { - self.forward_response_to_observing_cancellation(responder) - } - #[cfg(not(feature = "unstable_cancel_request"))] - { - self.on_receiving_result(async move |result| responder.respond_with_result(result)) - } - } - - #[cfg(feature = "unstable_cancel_request")] + /// the result, with two proxy-specific additions: + /// + /// - If the pending response is dropped without ever being delivered (for + /// example, the downstream connection closed), the incoming request is + /// answered with an internal error instead of being left unanswered. + /// - When the `unstable_cancel_request` feature is enabled and the peer + /// cancels the incoming request, the cancellation is forwarded to the + /// outgoing request, and the downstream response (normal data or a + /// cancellation error) is still forwarded back. #[track_caller] - fn forward_response_to_observing_cancellation( - self, - responder: Responder, - ) -> Result<(), crate::Error> + pub fn forward_response_to(self, responder: Responder) -> Result<(), crate::Error> where T: Send, { @@ -3479,30 +3494,60 @@ impl SentRequest { let method = self.method; let response_rx = self.response_rx; let to_result = self.to_result; + #[cfg(feature = "unstable_cancel_request")] let downstream_cancellation = self.cancellation; + #[cfg(feature = "unstable_cancel_request")] let upstream_cancellation = responder.cancellation(); let location = Location::caller(); Task::new(location, async move { - let response = if upstream_cancellation.is_cancelled() { - downstream_cancellation.send()?; - response_rx.await - } else { - match future::select(Box::pin(upstream_cancellation.cancelled()), response_rx).await - { - Either::Left(((), response_rx)) => { - downstream_cancellation.send()?; - response_rx.await + #[cfg(feature = "unstable_cancel_request")] + let response = { + // Failing to forward the cancellation must not abort this + // task: the downstream response (normal data or a + // cancellation error) may still arrive and must still be + // forwarded upstream. + let forward_cancellation = |cancellation: &SentRequestCancellation| { + if let Err(error) = cancellation.send() { + tracing::debug!( + ?error, + "failed to forward cancellation to downstream request" + ); } - Either::Right((response, _)) => response, - } - }; + }; - downstream_cancellation.disarm(); + let response = if upstream_cancellation.is_cancelled() { + forward_cancellation(&downstream_cancellation); + response_rx.await + } else { + match future::select(Box::pin(upstream_cancellation.cancelled()), response_rx) + .await + { + Either::Left(((), response_rx)) => { + forward_cancellation(&downstream_cancellation); + response_rx.await + } + Either::Right((response, _)) => response, + } + }; - let ResponsePayload { result, ack_tx } = response.map_err(|err| { - crate::util::internal_error(format!("response to `{method}` never received: {err}")) - })?; + downstream_cancellation.disarm(); + response + }; + #[cfg(not(feature = "unstable_cancel_request"))] + let response = response_rx.await; + + let ResponsePayload { result, ack_tx } = match response { + Ok(payload) => payload, + Err(err) => { + // The pending response was dropped (e.g. the downstream + // connection closed). Answer the incoming request instead + // of leaving the peer waiting forever. + return responder.respond_with_result(Err(crate::util::internal_error( + format!("response to `{method}` never received: {err}"), + ))); + } + }; let typed_result = match result { Ok(json_value) => to_result(json_value), diff --git a/src/agent-client-protocol/src/schema/mod.rs b/src/agent-client-protocol/src/schema/mod.rs index efc22a5..b9131e1 100644 --- a/src/agent-client-protocol/src/schema/mod.rs +++ b/src/agent-client-protocol/src/schema/mod.rs @@ -217,6 +217,63 @@ macro_rules! impl_jsonrpc_notification_enum { }; } +/// Implement `JsonRpcMessage` and `JsonRpcNotification` for a protocol-level +/// notification enum (`$/`-prefixed methods), shared between the v1 and v2 +/// schema namespaces. +/// +/// The schema enums are `#[non_exhaustive]`, so the matches need wildcard +/// arms: when the schema crate adds a protocol-level notification, list it +/// here as well. Unknown variants fail to serialize rather than producing a +/// bogus method name on the wire. +/// +/// ```ignore +/// impl_jsonrpc_protocol_level_notification_enum!(ProtocolLevelNotification { +/// CancelRequestNotification => "$/cancel_request", +/// }); +/// ``` +#[cfg(feature = "unstable_cancel_request")] +macro_rules! impl_jsonrpc_protocol_level_notification_enum { + ($enum:ty { + $( $variant:ident => $method:literal, )* + }) => { + impl $crate::JsonRpcMessage for $enum { + fn matches_method(method: &str) -> bool { + matches!(method, $( $method )|*) + } + + fn method(&self) -> &str { + match self { + $( Self::$variant(_) => $method, )* + _ => "_unknown", + } + } + + fn to_untyped_message(&self) -> Result<$crate::UntypedMessage, $crate::Error> { + match self { + $( Self::$variant(notification) => { + $crate::UntypedMessage::new($method, notification) + } )* + _ => Err($crate::util::internal_error( + "protocol-level notification variant is not supported by this SDK version", + )), + } + } + + fn parse_message( + method: &str, + params: &impl serde::Serialize, + ) -> Result { + match method { + $( $method => $crate::util::json_cast_params(params).map(Self::$variant), )* + _ => Err($crate::Error::method_not_found()), + } + } + } + + impl $crate::JsonRpcNotification for $enum {} + }; +} + /// Implement `JsonRpcResponse` for an enum that dispatches across multiple /// response types, with an extension method fallback. macro_rules! impl_jsonrpc_response_enum { diff --git a/src/agent-client-protocol/src/schema/protocol_level.rs b/src/agent-client-protocol/src/schema/protocol_level.rs index 81ea0da..61aebd9 100644 --- a/src/agent-client-protocol/src/schema/protocol_level.rs +++ b/src/agent-client-protocol/src/schema/protocol_level.rs @@ -1,38 +1,12 @@ -#[cfg(feature = "unstable_cancel_request")] -use crate::{ - JsonRpcMessage, JsonRpcNotification, UntypedMessage, - schema::{CancelRequestNotification, ProtocolLevelNotification}, -}; +//! JSON-RPC trait implementations for protocol-level (`$/`-prefixed) messages. #[cfg(feature = "unstable_cancel_request")] -impl_jsonrpc_notification!(CancelRequestNotification, "$/cancel_request"); +use crate::schema::{CancelRequestNotification, ProtocolLevelNotification}; #[cfg(feature = "unstable_cancel_request")] -impl JsonRpcMessage for ProtocolLevelNotification { - fn matches_method(method: &str) -> bool { - method == "$/cancel_request" - } - - fn method(&self) -> &str { - match self { - Self::CancelRequestNotification(_) => "$/cancel_request", - _ => "_unknown", - } - } - - fn to_untyped_message(&self) -> Result { - UntypedMessage::new(self.method(), self) - } - - fn parse_message(method: &str, params: &impl serde::Serialize) -> Result { - match method { - "$/cancel_request" => { - crate::util::json_cast_params(params).map(Self::CancelRequestNotification) - } - _ => Err(crate::Error::method_not_found()), - } - } -} +impl_jsonrpc_notification!(CancelRequestNotification, "$/cancel_request"); #[cfg(feature = "unstable_cancel_request")] -impl JsonRpcNotification for ProtocolLevelNotification {} +impl_jsonrpc_protocol_level_notification_enum!(ProtocolLevelNotification { + CancelRequestNotification => "$/cancel_request", +}); diff --git a/src/agent-client-protocol/src/schema/v2_impls.rs b/src/agent-client-protocol/src/schema/v2_impls.rs index 7f9d59d..d87ad71 100644 --- a/src/agent-client-protocol/src/schema/v2_impls.rs +++ b/src/agent-client-protocol/src/schema/v2_impls.rs @@ -288,34 +288,9 @@ impl_v2_jsonrpc_notification!(v2::SessionNotification, "session/update"); impl_v2_jsonrpc_notification!(v2::CompleteElicitationNotification, "elicitation/complete"); #[cfg(feature = "unstable_cancel_request")] -impl JsonRpcMessage for v2::ProtocolLevelNotification { - fn matches_method(method: &str) -> bool { - method == "$/cancel_request" - } - - fn method(&self) -> &str { - match self { - Self::CancelRequestNotification(_) => "$/cancel_request", - _ => "_unknown", - } - } - - fn to_untyped_message(&self) -> Result { - UntypedMessage::new(self.method(), self) - } - - fn parse_message(method: &str, params: &impl serde::Serialize) -> Result { - match method { - "$/cancel_request" => { - crate::util::json_cast_params(params).map(Self::CancelRequestNotification) - } - _ => Err(crate::Error::method_not_found()), - } - } -} - -#[cfg(feature = "unstable_cancel_request")] -impl JsonRpcNotification for v2::ProtocolLevelNotification {} +impl_jsonrpc_protocol_level_notification_enum!(v2::ProtocolLevelNotification { + CancelRequestNotification => "$/cancel_request", +}); impl_v2_jsonrpc_request_enum!(v2::ClientRequest { InitializeRequest => "initialize", diff --git a/src/agent-client-protocol/src/session.rs b/src/agent-client-protocol/src/session.rs index 087dab4..c12f1b5 100644 --- a/src/agent-client-protocol/src/session.rs +++ b/src/agent-client-protocol/src/session.rs @@ -514,9 +514,12 @@ where /// Incoming message from the agent #[non_exhaustive] #[derive(Debug)] -#[expect( - clippy::large_enum_variant, - reason = "Dispatch messages vastly outnumber StopReason; boxing would add a heap allocation" +#[cfg_attr( + feature = "unstable_cancel_request", + expect( + clippy::large_enum_variant, + reason = "Dispatch messages vastly outnumber StopReason; boxing would add a heap allocation" + ) )] pub enum SessionMessage { /// Periodic updates with new content, tool requests, etc. diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs index 4a11938..d3ebb6d 100644 --- a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -1,5 +1,17 @@ #![cfg(feature = "unstable_cancel_request")] +//! Integration tests for `$/cancel_request` support. +//! +//! These tests avoid sleeps by relying on two ordering guarantees: +//! +//! - Messages are delivered in the order they were sent, and each side's +//! dispatch loop processes incoming messages sequentially. A request/response +//! round trip therefore acts as a barrier: by the time the response arrives, +//! every message sent before the request (including any `$/cancel_request`) +//! has been fully processed by the peer. +//! - Test handlers report observed cancellations through in-process channels, +//! which the test awaits (with a timeout) instead of sleeping. + use std::sync::{Arc, Mutex}; use agent_client_protocol::{ @@ -9,7 +21,8 @@ use agent_client_protocol::{ schema::{CancelRequestNotification, ProtocolLevelNotification, RequestId}, }; use expect_test::expect; -use futures::{AsyncRead, AsyncWrite}; +use futures::channel::mpsc; +use futures::{AsyncRead, AsyncWrite, StreamExt as _}; use serde::{Deserialize, Serialize}; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; @@ -30,6 +43,26 @@ fn setup_test_streams() -> ( (server_reader, server_writer, client_reader, client_writer) } +/// Await the next item on `rx`, panicking instead of hanging if it never +/// arrives. +async fn next_with_timeout(rx: &mut mpsc::UnboundedReceiver) -> T { + tokio::time::timeout(tokio::time::Duration::from_secs(10), rx.next()) + .await + .expect("timed out waiting for channel event") + .expect("channel closed before expected event") +} + +/// Assert that no item is currently buffered on `rx`. +/// +/// Callers must first establish an ordering barrier (such as a +/// request/response round trip) that guarantees any erroneously sent +/// notification would already have been observed. +fn assert_no_event(rx: &mut mpsc::UnboundedReceiver) { + if let Ok(event) = rx.try_recv() { + panic!("unexpected event: {event:?}"); + } +} + async fn read_jsonrpc_response_line( reader: &mut tokio::io::BufReader, ) -> serde_json::Value { @@ -37,7 +70,7 @@ async fn read_jsonrpc_response_line( let mut line = String::new(); match tokio::time::timeout( - tokio::time::Duration::from_secs(1), + tokio::time::Duration::from_secs(10), reader.read_line(&mut line), ) .await @@ -261,6 +294,9 @@ async fn unhandled_protocol_level_notifications_are_ignored() { .unwrap(); client_writer.flush().await.unwrap(); + // The server processes messages in order: a response to this + // request proves the unknown `$/` notification before it was + // ignored without erroring or closing the connection. client_writer .write_all( br#"{"jsonrpc":"2.0","id":2,"method":"simple_method","params":{"message":"after cancel"}} @@ -372,8 +408,7 @@ async fn cancel_request_notification_can_be_sent_and_handled() { local .run_until(async { - let received = Arc::new(Mutex::new(Vec::new())); - let received_for_handler = received.clone(); + let (cancel_tx, mut cancel_rx) = mpsc::unbounded(); let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); let server_transport = @@ -381,10 +416,7 @@ async fn cancel_request_notification_can_be_sent_and_handled() { let server = UntypedRole.builder().on_receive_notification( async move |notification: CancelRequestNotification, _connection: ConnectionTo| { - received_for_handler - .lock() - .unwrap() - .push(notification.request_id); + cancel_tx.unbounded_send(notification.request_id).unwrap(); Ok(()) }, agent_client_protocol::on_receive_notification!(), @@ -398,20 +430,16 @@ async fn cancel_request_notification_can_be_sent_and_handled() { let client_transport = agent_client_protocol::ByteStreams::new(client_writer, client_reader); - UntypedRole + let received = UntypedRole .builder() .connect_with(client_transport, async |cx| { cx.send_cancel_request("request-42".to_string())?; - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; - Ok(()) + Ok(next_with_timeout(&mut cancel_rx).await) }) .await .unwrap(); - assert_eq!( - *received.lock().unwrap(), - vec![RequestId::Str("request-42".into())] - ); + assert_eq!(received, RequestId::Str("request-42".into())); }) .await; } @@ -424,8 +452,7 @@ async fn sent_request_can_send_cancellation_for_its_id() { local .run_until(async { - let received = Arc::new(Mutex::new(Vec::new())); - let received_for_handler = received.clone(); + let (cancel_tx, mut cancel_rx) = mpsc::unbounded(); let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); let server_transport = @@ -441,10 +468,7 @@ async fn sent_request_can_send_cancellation_for_its_id() { .on_receive_notification( async move |notification: CancelRequestNotification, _connection: ConnectionTo| { - received_for_handler - .lock() - .unwrap() - .push(notification.request_id); + cancel_tx.unbounded_send(notification.request_id).unwrap(); Ok(()) }, agent_client_protocol::on_receive_notification!(), @@ -458,7 +482,7 @@ async fn sent_request_can_send_cancellation_for_its_id() { let client_transport = agent_client_protocol::ByteStreams::new(client_writer, client_reader); - let expected_id = UntypedRole + let (expected_id, received) = UntypedRole .builder() .connect_with(client_transport, async |cx| { let request: SentRequest = cx.send_request(SimpleRequest { @@ -466,15 +490,14 @@ async fn sent_request_can_send_cancellation_for_its_id() { }); let expected_id = request.id(); request.cancel()?; - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; - Ok(expected_id) + let received = next_with_timeout(&mut cancel_rx).await; + Ok((expected_id, received)) }) .await .unwrap(); - let received = received.lock().unwrap(); - assert_eq!(received.len(), 1); - assert_eq!(serde_json::to_value(&received[0]).unwrap(), expected_id); + assert_eq!(serde_json::to_value(received).unwrap(), expected_id); + assert_no_event(&mut cancel_rx); }) .await; } @@ -487,8 +510,7 @@ async fn dropped_sent_request_sends_cancellation_for_its_id() { local .run_until(async { - let received = Arc::new(Mutex::new(Vec::new())); - let received_for_handler = received.clone(); + let (cancel_tx, mut cancel_rx) = mpsc::unbounded(); let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); let server_transport = @@ -504,10 +526,7 @@ async fn dropped_sent_request_sends_cancellation_for_its_id() { .on_receive_notification( async move |notification: CancelRequestNotification, _connection: ConnectionTo| { - received_for_handler - .lock() - .unwrap() - .push(notification.request_id); + cancel_tx.unbounded_send(notification.request_id).unwrap(); Ok(()) }, agent_client_protocol::on_receive_notification!(), @@ -521,7 +540,7 @@ async fn dropped_sent_request_sends_cancellation_for_its_id() { let client_transport = agent_client_protocol::ByteStreams::new(client_writer, client_reader); - let expected_id = UntypedRole + let (expected_id, received) = UntypedRole .builder() .connect_with(client_transport, async |cx| { let request: SentRequest = cx.send_request(SimpleRequest { @@ -529,15 +548,14 @@ async fn dropped_sent_request_sends_cancellation_for_its_id() { }); let expected_id = request.id(); drop(request); - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; - Ok(expected_id) + let received = next_with_timeout(&mut cancel_rx).await; + Ok((expected_id, received)) }) .await .unwrap(); - let received = received.lock().unwrap(); - assert_eq!(received.len(), 1); - assert_eq!(serde_json::to_value(&received[0]).unwrap(), expected_id); + assert_eq!(serde_json::to_value(received).unwrap(), expected_id); + assert_no_event(&mut cancel_rx); }) .await; } @@ -550,8 +568,11 @@ async fn late_response_after_dropped_sent_request_does_not_close_connection() { local .run_until(async { - let received = Arc::new(Mutex::new(Vec::new())); - let received_for_handler = received.clone(); + let (cancel_tx, mut cancel_rx) = mpsc::unbounded(); + // The responder for the abandoned request, held by the server + // until the cancellation notification arrives. + let pending_responder: Arc>>> = + Arc::new(Mutex::new(None)); let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); let server_transport = @@ -559,33 +580,38 @@ async fn late_response_after_dropped_sent_request_does_not_close_connection() { let server = UntypedRole .builder() .on_receive_request( - async |request: SimpleRequest, - responder: Responder, - connection: ConnectionTo| { - if request.message == "late" { - connection.spawn(async move { - tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; - responder.respond(SimpleResponse { - result: "late response".into(), - }) - })?; - return Ok(()); + { + let pending_responder = pending_responder.clone(); + async move |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + if request.message == "late" { + *pending_responder.lock().unwrap() = Some(responder); + return Ok(()); + } + + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) } - - responder.respond(SimpleResponse { - result: format!("echo: {}", request.message), - }) }, agent_client_protocol::on_receive_request!(), ) .on_receive_notification( - async move |notification: CancelRequestNotification, - _connection: ConnectionTo| { - received_for_handler - .lock() - .unwrap() - .push(notification.request_id); - Ok(()) + { + let pending_responder = pending_responder.clone(); + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + // Ignore the cancellation and answer the abandoned + // request anyway: the client must tolerate this. + if let Some(responder) = pending_responder.lock().unwrap().take() { + responder.respond(SimpleResponse { + result: "late response".into(), + })?; + } + cancel_tx.unbounded_send(notification.request_id).unwrap(); + Ok(()) + } }, agent_client_protocol::on_receive_notification!(), ); @@ -598,7 +624,7 @@ async fn late_response_after_dropped_sent_request_does_not_close_connection() { let client_transport = agent_client_protocol::ByteStreams::new(client_writer, client_reader); - let (expected_id, response) = UntypedRole + let (expected_id, received, response) = UntypedRole .builder() .connect_with(client_transport, async |cx| { let request: SentRequest = cx.send_request(SimpleRequest { @@ -607,23 +633,25 @@ async fn late_response_after_dropped_sent_request_does_not_close_connection() { let expected_id = request.id(); drop(request); - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + let received = next_with_timeout(&mut cancel_rx).await; + // The server sent the late response before answering this + // follow-up, so a successful round trip proves the late + // response for the dropped request was routed without + // closing the connection. let response = cx .send_request(SimpleRequest { message: "after late".into(), }) .block_task() .await?; - Ok((expected_id, response)) + Ok((expected_id, received, response)) }) .await .unwrap(); assert_eq!(response.result, "echo: after late"); - let received = received.lock().unwrap(); - assert_eq!(received.len(), 1); - assert_eq!(serde_json::to_value(&received[0]).unwrap(), expected_id); + assert_eq!(serde_json::to_value(received).unwrap(), expected_id); }) .await; } @@ -636,8 +664,7 @@ async fn response_buffered_before_drop_disarms_auto_cancellation() { local .run_until(async { - let received = Arc::new(Mutex::new(Vec::new())); - let received_for_handler = received.clone(); + let (cancel_tx, mut cancel_rx) = mpsc::unbounded(); let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); let server_transport = @@ -657,10 +684,7 @@ async fn response_buffered_before_drop_disarms_auto_cancellation() { .on_receive_notification( async move |notification: CancelRequestNotification, _connection: ConnectionTo| { - received_for_handler - .lock() - .unwrap() - .push(notification.request_id); + cancel_tx.unbounded_send(notification.request_id).unwrap(); Ok(()) }, agent_client_protocol::on_receive_notification!(), @@ -681,10 +705,22 @@ async fn response_buffered_before_drop_disarms_auto_cancellation() { message: "buffered".into(), }); - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + // The server answers requests in order, so once this round + // trip completes, the response to `buffered` has already + // been routed into the unconsumed request handle above, + // disarming its auto-cancellation. + let barrier = cx + .send_request(SimpleRequest { + message: "barrier".into(), + }) + .block_task() + .await?; + assert_eq!(barrier.result, "echo: barrier"); + drop(request); - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + // Another round trip: any cancellation sent by the drop + // above would reach the server before this request. cx.send_request(SimpleRequest { message: "after buffered".into(), }) @@ -695,7 +731,7 @@ async fn response_buffered_before_drop_disarms_auto_cancellation() { .unwrap(); assert_eq!(response.result, "echo: after buffered"); - assert!(received.lock().unwrap().is_empty()); + assert_no_event(&mut cancel_rx); }) .await; } @@ -708,8 +744,7 @@ async fn completed_sent_request_does_not_send_cancellation_on_drop() { local .run_until(async { - let received = Arc::new(Mutex::new(Vec::new())); - let received_for_handler = received.clone(); + let (cancel_tx, mut cancel_rx) = mpsc::unbounded(); let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); let server_transport = @@ -729,10 +764,7 @@ async fn completed_sent_request_does_not_send_cancellation_on_drop() { .on_receive_notification( async move |notification: CancelRequestNotification, _connection: ConnectionTo| { - received_for_handler - .lock() - .unwrap() - .push(notification.request_id); + cancel_tx.unbounded_send(notification.request_id).unwrap(); Ok(()) }, agent_client_protocol::on_receive_notification!(), @@ -755,14 +787,25 @@ async fn completed_sent_request_does_not_send_cancellation_on_drop() { }) .block_task() .await?; - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + // Barrier round trip: any cancellation erroneously sent + // when the completed request handle was dropped would + // reach the server before this request. + let barrier = cx + .send_request(SimpleRequest { + message: "barrier".into(), + }) + .block_task() + .await?; + assert_eq!(barrier.result, "echo: barrier"); + Ok(response) }) .await .unwrap(); assert_eq!(response.result, "echo: complete"); - assert!(received.lock().unwrap().is_empty()); + assert_no_event(&mut cancel_rx); }) .await; } @@ -775,8 +818,7 @@ async fn forward_response_to_propagates_cancellation_to_downstream_request() { local .run_until(async { - let backend_cancellations = Arc::new(Mutex::new(Vec::new())); - let backend_cancellations_for_handler = backend_cancellations.clone(); + let (backend_cancel_tx, mut backend_cancel_rx) = mpsc::unbounded(); let (backend_for_proxy, backend_for_server) = Channel::duplex(); let (backend_connection_tx, backend_connection_rx) = @@ -806,10 +848,9 @@ async fn forward_response_to_propagates_cancellation_to_downstream_request() { .on_receive_notification( async move |notification: CancelRequestNotification, _connection: ConnectionTo| { - backend_cancellations_for_handler - .lock() - .unwrap() - .push(notification.request_id); + backend_cancel_tx + .unbounded_send(notification.request_id) + .unwrap(); Ok(()) }, agent_client_protocol::on_receive_notification!(), @@ -859,14 +900,15 @@ async fn forward_response_to_propagates_cancellation_to_downstream_request() { message: "cancel downstream".into(), }); request.cancel()?; - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + // Wait until the proxy has forwarded the cancellation all + // the way to the backend. + next_with_timeout(&mut backend_cancel_rx).await; Ok(()) }) .await .unwrap(); - let backend_cancellations = backend_cancellations.lock().unwrap(); - assert_eq!(backend_cancellations.len(), 1); + assert_no_event(&mut backend_cancel_rx); }) .await; } diff --git a/src/agent-client-protocol/tests/protocol_v2.rs b/src/agent-client-protocol/tests/protocol_v2.rs index d6692ec..01a0577 100644 --- a/src/agent-client-protocol/tests/protocol_v2.rs +++ b/src/agent-client-protocol/tests/protocol_v2.rs @@ -547,3 +547,81 @@ async fn v2_client_and_agent_negotiate_v2() -> Result<(), Error> { }) .await } + +/// A v2 agent whose `session/new` handler only responds once the peer cancels +/// the request via `$/cancel_request`. +#[cfg(feature = "unstable_cancel_request")] +fn v2_agent_with_cancellable_new_session() +-> Builder> { + Agent + .v2() + .on_receive_request( + async |initialize: v2::InitializeRequest, responder, _cx| { + responder.respond(v2::InitializeResponse::new(initialize.protocol_version)) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request( + async |_request: v2::NewSessionRequest, responder, cx| { + let cancellation = responder.cancellation(); + cx.spawn(async move { + let response = cancellation + .run_until_cancelled(std::future::pending::< + Result, + >()) + .await; + responder.respond_with_result(response) + })?; + Ok(()) + }, + agent_client_protocol::on_receive_request!(), + ) +} + +#[cfg(feature = "unstable_cancel_request")] +#[tokio::test(flavor = "current_thread")] +async fn v2_client_can_cancel_request_to_v2_agent() -> Result<(), Error> { + Client + .v2() + .connect_with(v2_agent_with_cancellable_new_session(), async |cx| { + let initialize = cx + .send_request(v2::InitializeRequest::new(ProtocolVersion::V2)) + .block_task() + .await?; + assert_eq!(initialize.protocol_version, ProtocolVersion::V2); + + let request = cx.send_request(v2::NewSessionRequest::new(cwd()?)); + request.cancel()?; + let error = request + .block_task() + .await + .expect_err("request should be cancelled"); + assert_eq!(i32::from(error.code), -32800); + Ok(()) + }) + .await +} + +#[cfg(feature = "unstable_cancel_request")] +#[tokio::test(flavor = "current_thread")] +async fn v1_client_can_cancel_request_to_v2_agent() -> Result<(), Error> { + Client + .builder() + .connect_with(v2_agent_with_cancellable_new_session(), async |cx| { + let initialize = cx + .send_request(schema::InitializeRequest::new(ProtocolVersion::V1)) + .block_task() + .await?; + assert_eq!(initialize.protocol_version, ProtocolVersion::V1); + + let request = cx.send_request(schema::NewSessionRequest::new(cwd()?)); + request.cancel()?; + let error = request + .block_task() + .await + .expect_err("request should be cancelled"); + assert_eq!(i32::from(error.code), -32800); + Ok(()) + }) + .await +} From f02ea5665099a549d0127f264f47a6478df98722 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 10 Jun 2026 12:53:27 +0200 Subject: [PATCH 10/18] fix(acp): Defer cancellation marker allocation --- src/agent-client-protocol/src/jsonrpc.rs | 96 ++++++-- .../src/jsonrpc/incoming_actor.rs | 14 +- .../tests/jsonrpc_request_cancellation.rs | 220 +++++++++++++++++- 3 files changed, 305 insertions(+), 25 deletions(-) diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index 4bd7ecd..e40946b 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -1476,10 +1476,27 @@ impl Debug for RequestCancellation { } } +/// Per-request cancellation state tracked by [`RequestCancellationRegistry`]. +/// +/// The full [`RequestCancellation`] marker (with its wakeup machinery) is only +/// allocated once a handler asks for it via [`Responder::cancellation`]; until +/// then an incoming `$/cancel_request` just flips the entry to `Cancelled`. +/// This keeps the per-request cost of the registry to a single map entry. +#[cfg(feature = "unstable_cancel_request")] +#[derive(Debug)] +enum RequestCancellationEntry { + /// The request is in flight; no marker handed out, no cancellation yet. + Armed, + /// `$/cancel_request` arrived before a marker was handed out. + Cancelled, + /// A marker was handed out via [`Responder::cancellation`]. + Marker(RequestCancellation), +} + #[cfg(feature = "unstable_cancel_request")] #[derive(Clone, Debug, Default)] struct RequestCancellationRegistry { - inner: Arc>>, + inner: Arc>>, } #[cfg(not(feature = "unstable_cancel_request"))] @@ -1491,7 +1508,6 @@ struct RequestCancellationRegistry; struct ResponderCancellation { id: serde_json::Value, registry: RequestCancellationRegistry, - cancellation: RequestCancellation, } #[cfg(not(feature = "unstable_cancel_request"))] @@ -1505,15 +1521,45 @@ impl RequestCancellationRegistry { } fn register(&self, id: serde_json::Value) -> ResponderCancellation { - let cancellation = RequestCancellation::new(); self.inner .lock() .expect("request cancellation registry mutex poisoned") - .insert(id.clone(), cancellation.clone()); + .insert(id.clone(), RequestCancellationEntry::Armed); ResponderCancellation { id, registry: self.clone(), - cancellation, + } + } + + /// Get the cancellation marker for a registered request, creating it on + /// first use. Repeated calls return markers that share the same state. + fn marker(&self, id: &serde_json::Value) -> RequestCancellation { + let mut inner = self + .inner + .lock() + .expect("request cancellation registry mutex poisoned"); + let Some(entry) = inner.get_mut(id) else { + // The entry lives as long as the responder that owns it, so this + // is only reachable if the peer reused a request ID and the + // earlier request's responder already removed the shared entry. + // Hand out a detached marker rather than panicking. + return RequestCancellation::new(); + }; + match entry { + RequestCancellationEntry::Marker(marker) => marker.clone(), + RequestCancellationEntry::Armed => { + let marker = RequestCancellation::new(); + *entry = RequestCancellationEntry::Marker(marker.clone()); + marker + } + RequestCancellationEntry::Cancelled => { + // No one can be waiting on a marker that did not exist yet, + // so firing it while holding the registry lock is fine. + let marker = RequestCancellation::new(); + marker.cancel(); + *entry = RequestCancellationEntry::Marker(marker.clone()); + marker + } } } @@ -1525,18 +1571,28 @@ impl RequestCancellationRegistry { } fn cancel(&self, request_id: &serde_json::Value) -> bool { - let cancellation = self - .inner - .lock() - .expect("request cancellation registry mutex poisoned") - .get(request_id) - .cloned(); - if let Some(cancellation) = cancellation { - cancellation.cancel(); - true - } else { - false - } + let marker = { + let mut inner = self + .inner + .lock() + .expect("request cancellation registry mutex poisoned"); + let Some(entry) = inner.get_mut(request_id) else { + return false; + }; + match entry { + RequestCancellationEntry::Marker(marker) => marker.clone(), + RequestCancellationEntry::Cancelled => return true, + RequestCancellationEntry::Armed => { + *entry = RequestCancellationEntry::Cancelled; + return true; + } + } + }; + + // Fire the marker outside the registry lock: waking waiters runs + // arbitrary waker code that must not observe the lock held. + marker.cancel(); + true } fn remove(&self, request_id: &serde_json::Value) { @@ -1574,7 +1630,7 @@ impl RequestCancellationRegistry { #[cfg(feature = "unstable_cancel_request")] impl ResponderCancellation { fn cancellation(&self) -> RequestCancellation { - self.cancellation.clone() + self.registry.marker(&self.id) } } @@ -1627,6 +1683,10 @@ fn cancellation_request_id_from_message( /// `$/cancel_request` must be able to interoperate with an SDK built without /// `unstable_cancel_request` (which simply won't act on it). /// +/// A handler that explicitly declines with `retry: true` takes precedence +/// over this fallback: the notification is queued for newly registered +/// dynamic handlers like any other retried message. +/// /// [`SuccessorMessage`]: crate::schema::SuccessorMessage fn is_protocol_level_notification(dispatch: &Dispatch) -> bool { let Dispatch::Notification(message) = dispatch else { diff --git a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs index 1a5b80d..f95d2ee 100644 --- a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs @@ -394,13 +394,12 @@ async fn dispatch_dispatch( } } - if super::is_protocol_level_notification(&dispatch) { - tracing::debug!(?method, "Ignoring unhandled protocol-level notification"); - return Ok(()); - } - // If the message was never handled, check whether the retry flag was set. // If so, enqueue it for later processing. Else, reject it. + // + // An explicit retry request takes precedence over the protocol-level + // fallback below, so that handlers may defer `$/` notifications to a + // dynamic handler that has not been registered yet. if retry_any { tracing::debug!( ?method, @@ -408,6 +407,11 @@ async fn dispatch_dispatch( ); pending_messages.push(dispatch); Ok(()) + } else if super::is_protocol_level_notification(&dispatch) { + // Unsupported protocol-level notifications are ignored rather than + // rejected; see `is_protocol_level_notification` for the rationale. + tracing::debug!(?method, "Ignoring unhandled protocol-level notification"); + Ok(()) } else { match dispatch { Dispatch::Request(..) | Dispatch::Notification(_) => { diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs index d3ebb6d..37e5a9f 100644 --- a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -15,8 +15,8 @@ use std::sync::{Arc, Mutex}; use agent_client_protocol::{ - Channel, ConnectionTo, Dispatch, Handled, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, - Responder, Role, RoleId, SentRequest, + Channel, ConnectionTo, Dispatch, HandleDispatchFrom, Handled, JsonRpcMessage, JsonRpcRequest, + JsonRpcResponse, Responder, Role, RoleId, SentRequest, role::UntypedRole, schema::{CancelRequestNotification, ProtocolLevelNotification, RequestId}, }; @@ -913,6 +913,222 @@ async fn forward_response_to_propagates_cancellation_to_downstream_request() { .await; } +#[tokio::test(flavor = "current_thread")] +async fn cancellation_marker_requested_after_cancel_is_already_cancelled() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + // The responder is parked here by the request handler *without* + // requesting a cancellation marker; the marker is only created + // after the cancellation has already been recorded. + let pending_responder: Arc>>> = + Arc::new(Mutex::new(None)); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + { + let pending_responder = pending_responder.clone(); + async move |_request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + *pending_responder.lock().unwrap() = Some(responder); + Ok(()) + } + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + { + let pending_responder = pending_responder.clone(); + async move |_cancel: CancelRequestNotification, + _connection: ConnectionTo| { + // The registry recorded the cancellation before + // this handler ran, so markers created only now + // must already report it. + let responder = pending_responder + .lock() + .unwrap() + .take() + .expect("request should have arrived before its cancellation"); + let marker = responder.cancellation(); + let second_marker = responder.cancellation(); + if marker.is_cancelled() && second_marker.is_cancelled() { + responder.respond_with_result(Err( + agent_client_protocol::Error::request_cancelled(), + )) + } else { + responder.respond(SimpleResponse { + result: "marker not cancelled".into(), + }) + } + } + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let error = UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + let request: SentRequest = cx.send_request(SimpleRequest { + message: "cancel before marker".into(), + }); + request.cancel()?; + Ok(request + .block_task() + .await + .expect_err("request should be cancelled")) + }) + .await + .unwrap(); + + assert_eq!(i32::from(error.code), -32800); + assert_eq!(error.message, "Request cancelled"); + }) + .await; +} + +/// A dynamic handler that claims `$/cancel_request` notifications and reports +/// them on a channel. +struct CancelCollector { + tx: mpsc::UnboundedSender, +} + +impl HandleDispatchFrom for CancelCollector { + async fn handle_dispatch_from( + &mut self, + message: Dispatch, + _connection: ConnectionTo, + ) -> Result, agent_client_protocol::Error> { + if let Dispatch::Notification(notification) = &message + && CancelRequestNotification::matches_method(¬ification.method) + { + let cancel = CancelRequestNotification::parse_message( + ¬ification.method, + ¬ification.params, + )?; + self.tx.unbounded_send(cancel.request_id).unwrap(); + return Ok(Handled::Yes); + } + + Ok(Handled::No { + message, + retry: false, + }) + } + + fn describe_chain(&self) -> impl std::fmt::Debug { + "CancelCollector" + } +} + +#[tokio::test(flavor = "current_thread")] +async fn retried_protocol_level_notification_reaches_later_dynamic_handler() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (collector_tx, mut collector_rx) = mpsc::unbounded(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_notification( + // Decline the notification but ask for a retry: this must + // take precedence over the "ignore unhandled `$/` + // notifications" fallback. + async |cancel: CancelRequestNotification, cx: ConnectionTo| { + Ok::<_, agent_client_protocol::Error>(Handled::No { + message: (cancel, cx), + retry: true, + }) + }, + agent_client_protocol::on_receive_notification!(), + ) + .on_receive_request( + { + let collector_tx = collector_tx.clone(); + async move |request: SimpleRequest, + responder: Responder, + connection: ConnectionTo| { + if request.message == "register" { + connection + .add_dynamic_handler(CancelCollector { + tx: collector_tx.clone(), + })? + .run_indefinitely(); + } + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + } + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let received = UntypedRole + .builder() + .connect_with(client_transport, async |cx| { + cx.send_cancel_request("req-1".to_string())?; + + // Barrier: the notification has now been declined and + // queued for retry, and no dynamic handler has seen it. + let barrier = cx + .send_request(SimpleRequest { + message: "barrier".into(), + }) + .block_task() + .await?; + assert_eq!(barrier.result, "echo: barrier"); + assert_no_event(&mut collector_rx); + + // Registering the dynamic handler replays the queued + // notification to it. + let register = cx + .send_request(SimpleRequest { + message: "register".into(), + }) + .block_task() + .await?; + assert_eq!(register.result, "echo: register"); + + Ok(next_with_timeout(&mut collector_rx).await) + }) + .await + .unwrap(); + + assert_eq!(received, RequestId::Str("req-1".into())); + assert_no_event(&mut collector_rx); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn request_handler_can_observe_cancellation_from_responder() { use tokio::task::LocalSet; From 6b79fff877bd9bf89092b6f4aa30921a8aa4cde3 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 10 Jun 2026 17:33:57 +0200 Subject: [PATCH 11/18] fix(acp): handle wrapped request cancellation correctly --- md/request-cancellation.md | 172 +++---- .../src/concepts/cancellation.rs | 5 + .../src/concepts/peers.rs | 5 +- src/agent-client-protocol/src/jsonrpc.rs | 423 ++++++++++++------ src/agent-client-protocol/src/role.rs | 4 +- .../tests/jsonrpc_request_cancellation.rs | 134 ++++++ 6 files changed, 502 insertions(+), 241 deletions(-) diff --git a/md/request-cancellation.md b/md/request-cancellation.md index 0d58f73..9bf6a0f 100644 --- a/md/request-cancellation.md +++ b/md/request-cancellation.md @@ -1,124 +1,78 @@ # Request Cancellation -The SDK exposes the ACP `$/cancel_request` notification behind the -`unstable_cancel_request` feature. The notification is protocol-level: either -side may send it to ask the peer to cancel one outstanding JSON-RPC request by -ID. +This chapter documents the `$/cancel_request` protocol-level notification and +how the SDK implements it. -Enable the feature when depending on the crate: +For API usage (cancelling a `SentRequest`, observing cancellation from a +`Responder`), see the `concepts::cancellation` chapter in the +[agent-client-protocol rustdoc](https://docs.rs/agent-client-protocol). The +SDK support is gated behind the `unstable_cancel_request` feature: ```toml agent-client-protocol = { version = "...", features = ["unstable_cancel_request"] } ``` -Cancellation is cooperative. A peer may ignore `$/cancel_request`, may finish -with normal data, or may respond to the original request with -`Error::request_cancelled()` (`-32800`). The requesting side always receives a -response to the original request; cancellation only changes _which_ response -that is. The SDK ignores unhandled `$/...` notifications (even when the -feature is disabled) so unsupported protocol-level notifications do not -produce method-not-found errors. - -## Cancelling outgoing requests - -To cancel a request sent through `ConnectionTo::send_request`, keep the -returned `SentRequest` and call `cancel` on it: - -```rust -# use agent_client_protocol::{ConnectionTo, Error, UntypedRole}; -# use agent_client_protocol_test::MyRequest; -# async fn example(cx: ConnectionTo) -> Result<(), Error> { -let request = cx.send_request(MyRequest {}); -request.cancel()?; - -// The peer still responds to the request: with normal data if it raced -// ahead, or with the standard cancellation error. -let result = request.block_task().await; -# let _ = result; -# Ok(()) -# } +## The `$/cancel_request` Notification + +Either side of a connection may send `$/cancel_request` to ask the peer to +cancel one outstanding JSON-RPC request, identified by its ID: + +```json +{ + "jsonrpc": "2.0", + "method": "$/cancel_request", + "params": { + "requestId": "70b9f1c9-c2a3-4bd2-b6b9-65a06d96b675" + } +} ``` -The `SentRequest` remembers the peer and any proxy wrapping used for the -original request, so this also works for requests sent through -`ConnectionTo::send_request_to`. +`requestId` is the JSON-RPC `id` of the request to cancel, as allocated by the +sender of that request (a string, number, or null). -Dropping a `SentRequest` before the SDK receives a response also sends -`$/cancel_request`. This covers abandoned request handles and futures. Once the -SDK routes a response to the waiting request handle, automatic cancellation is -disarmed, even if caller code has not yet consumed it with `block_task`, -`on_receiving_result`, or `forward_response_to`. +## Semantics -If you already have the JSON-RPC request ID, send the notification directly: +Cancellation is **cooperative**. After receiving `$/cancel_request`, the peer +may: -```rust -# use agent_client_protocol::{ConnectionTo, Error, UntypedRole}; -# async fn example(cx: ConnectionTo) -> Result<(), Error> { -cx.send_cancel_request("request-id".to_string())?; -# Ok(()) -# } -``` +- ignore it and respond to the request normally, +- finish early with whatever data it has, or +- respond to the original request with the standard cancellation error, + code `-32800` ("Request cancelled"). -## Handling cancellation of incoming requests - -For incoming requests, get the request-local cancellation marker from the -`Responder`. This keeps cancellation handling next to the request work it -controls: - -```rust -# use agent_client_protocol::{ConnectionTo, Error, Responder, UntypedRole}; -# use agent_client_protocol_test::{MyRequest, MyResponse}; -# async fn example(request: MyRequest, responder: Responder, cx: ConnectionTo) -> Result<(), Error> { -# async fn run_request(_request: MyRequest) -> Result { todo!() } -let cancellation = responder.cancellation(); - -cx.spawn(async move { - let response = cancellation.run_until_cancelled(run_request(request)).await; - responder.respond_with_result(response) -})?; -# Ok(()) -# } -``` +The requesting side always receives a response to the original request; +cancellation only changes _which_ response that is. A `$/cancel_request` for +an unknown or already-completed request ID is silently ignored. -`run_until_cancelled` is the simple path for handlers that should stop work and -reply with the standard cancellation error as soon as cancellation is -requested; it drops the work future when cancellation wins, so cleanup must -happen in `Drop` implementations and partial results are lost. If the handler -needs cleanup, partial results, or custom cancellation behavior, use -`cancellation.cancelled()` or `cancellation.is_cancelled()` directly inside -the request work instead. - -Cancellation markers are only updated when the connection can process the -incoming `$/cancel_request` notification. Long-running handlers should return -quickly and move work into `ConnectionTo::spawn`, `SentRequest` callbacks, or -another task. - -## Proxies - -When proxying with `SentRequest::forward_response_to`, the SDK observes the -upstream `Responder` cancellation marker and forwards cancellation to the -downstream request automatically. The downstream response (normal data or a -cancellation error) is still forwarded back upstream. - -## Low-level access - -Register `CancelRequestNotification` or `ProtocolLevelNotification` directly -only when you need low-level access to cancellation notifications, such as -custom routing or protocol tracing: - -```rust -# use agent_client_protocol::{ConnectionTo, Error, UntypedRole}; -use agent_client_protocol::schema::CancelRequestNotification; - -# fn example() { -let builder = UntypedRole.builder().on_receive_notification( - async |cancel: CancelRequestNotification, _cx: ConnectionTo| { - // Mark the matching in-flight operation cancelled. - let _request_id = cancel.request_id; - Ok(()) - }, - agent_client_protocol::on_receive_notification!(), -); -# let _ = builder; -# } -``` +## Interoperability + +Protocol-level (`$/`-prefixed) notifications are optional by design. The SDK +ignores unhandled `$/` notifications instead of rejecting them with a +method-not-found error, and does so even when the `unstable_cancel_request` +feature is disabled. A peer that sends `$/cancel_request` to a component built +without cancellation support therefore loses nothing: the request simply runs +to completion. + +## Proxy Chains + +Cancellation propagates **hop by hop** rather than end to end. Request IDs are +allocated per connection, so a `$/cancel_request` only ever refers to a +request on the connection it is sent over: + +1. The client sends `$/cancel_request` for a request it made to its direct + peer (for example, a proxy). +2. A proxy that forwarded the request downstream (the SDK does this with + `forward_response_to`) reacts by sending its own `$/cancel_request` for the + downstream request, using the downstream connection's request ID. +3. The downstream response — normal data or the cancellation error — flows + back up the chain as the response to each hop's request. + +When the notification targets a request that was wrapped in a +`_proxy/successor` envelope (see the [Protocol Reference](./protocol.md)), the +`$/cancel_request` is wrapped in the same envelope, and `requestId` refers to +the JSON-RPC `id` of the wrapped request on that connection. + +## Related Documentation + +- [Protocol Reference](./protocol.md) - The `_proxy/successor/*` envelope protocol +- [agent-client-protocol rustdoc](https://docs.rs/agent-client-protocol) - SDK API for sending, observing, and forwarding cancellations (see `concepts::cancellation`) diff --git a/src/agent-client-protocol/src/concepts/cancellation.rs b/src/agent-client-protocol/src/concepts/cancellation.rs index 1de7c0a..13c6d3a 100644 --- a/src/agent-client-protocol/src/concepts/cancellation.rs +++ b/src/agent-client-protocol/src/concepts/cancellation.rs @@ -118,6 +118,11 @@ //! # } //! ``` //! +//! Such a handler observes cancellation notifications but does not replace +//! the built-in handling: the SDK updates the [`Responder`] cancellation +//! markers for every incoming `$/cancel_request` before the handler chain +//! runs, even when a handler claims the notification. +//! //! [`block_task`]: crate::SentRequest::block_task //! [`on_receiving_result`]: crate::SentRequest::on_receiving_result //! [`forward_response_to`]: crate::SentRequest::forward_response_to diff --git a/src/agent-client-protocol/src/concepts/peers.rs b/src/agent-client-protocol/src/concepts/peers.rs index 29978ff..2bde50b 100644 --- a/src/agent-client-protocol/src/concepts/peers.rs +++ b/src/agent-client-protocol/src/concepts/peers.rs @@ -51,8 +51,9 @@ //! # Client.builder().connect_with(transport, async |cx| { //! # let req = MyRequest {}; //! // These are equivalent for Client: -//! cx.send_request(req.clone()); -//! cx.send_request_to(Agent, req); +//! let request = cx.send_request(req.clone()); +//! let same = cx.send_request_to(Agent, req); +//! # let _ = (request, same); //! # Ok(()) //! # }).await?; //! # Ok(()) diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index e40946b..c14f9a3 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -1437,7 +1437,7 @@ impl RequestCancellation { return Err(crate::Error::request_cancelled()); } - match future::select(Box::pin(future), Box::pin(self.cancelled())).await { + match future::select(pin!(future), pin!(self.cancelled())).await { Either::Left((result, _)) => result, Either::Right(((), _)) => Err(crate::Error::request_cancelled()), } @@ -1493,10 +1493,31 @@ enum RequestCancellationEntry { Marker(RequestCancellation), } +/// A registered request's cancellation state, tagged with the generation of +/// its registration. +/// +/// The generation distinguishes a registration from earlier ones that used +/// the same request ID, so that when a (protocol-violating) peer reuses the +/// ID of a request that is still in flight, the stale request's responder can +/// neither remove nor observe the cancellation state of the newer request. +#[cfg(feature = "unstable_cancel_request")] +#[derive(Debug)] +struct RequestCancellationSlot { + generation: u64, + entry: RequestCancellationEntry, +} + +#[cfg(feature = "unstable_cancel_request")] +#[derive(Debug, Default)] +struct RequestCancellationRegistryInner { + slots: HashMap, + next_generation: u64, +} + #[cfg(feature = "unstable_cancel_request")] #[derive(Clone, Debug, Default)] struct RequestCancellationRegistry { - inner: Arc>>, + inner: Arc>, } #[cfg(not(feature = "unstable_cancel_request"))] @@ -1507,6 +1528,7 @@ struct RequestCancellationRegistry; #[derive(Debug)] struct ResponderCancellation { id: serde_json::Value, + generation: u64, registry: RequestCancellationRegistry, } @@ -1521,30 +1543,60 @@ impl RequestCancellationRegistry { } fn register(&self, id: serde_json::Value) -> ResponderCancellation { - self.inner - .lock() - .expect("request cancellation registry mutex poisoned") - .insert(id.clone(), RequestCancellationEntry::Armed); + let generation = { + let mut inner = self + .inner + .lock() + .expect("request cancellation registry mutex poisoned"); + let generation = inner.next_generation; + inner.next_generation += 1; + if inner + .slots + .insert( + id.clone(), + RequestCancellationSlot { + generation, + entry: RequestCancellationEntry::Armed, + }, + ) + .is_some() + { + tracing::debug!( + ?id, + "peer reused the ID of a request that is still in flight" + ); + } + generation + }; ResponderCancellation { id, + generation, registry: self.clone(), } } /// Get the cancellation marker for a registered request, creating it on /// first use. Repeated calls return markers that share the same state. - fn marker(&self, id: &serde_json::Value) -> RequestCancellation { + fn marker(&self, id: &serde_json::Value, generation: u64) -> RequestCancellation { let mut inner = self .inner .lock() .expect("request cancellation registry mutex poisoned"); - let Some(entry) = inner.get_mut(id) else { - // The entry lives as long as the responder that owns it, so this - // is only reachable if the peer reused a request ID and the - // earlier request's responder already removed the shared entry. + let Some(slot) = inner.slots.get_mut(id) else { + // The slot lives as long as the responder that owns it, so this + // is only reachable if the peer reused this request ID and the + // newer request's responder already removed the replacement slot. // Hand out a detached marker rather than panicking. return RequestCancellation::new(); }; + if slot.generation != generation { + // The peer reused this request ID while the request was still in + // flight, and the slot now belongs to the newer request. Hand the + // stale responder a detached marker instead of cross-wiring the + // two requests' cancellation states. + return RequestCancellation::new(); + } + let entry = &mut slot.entry; match entry { RequestCancellationEntry::Marker(marker) => marker.clone(), RequestCancellationEntry::Armed => { @@ -1570,15 +1622,17 @@ impl RequestCancellationRegistry { Ok(self.cancel(&request_id)) } + /// Mark whichever request currently owns `request_id` as cancelled. fn cancel(&self, request_id: &serde_json::Value) -> bool { let marker = { let mut inner = self .inner .lock() .expect("request cancellation registry mutex poisoned"); - let Some(entry) = inner.get_mut(request_id) else { + let Some(slot) = inner.slots.get_mut(request_id) else { return false; }; + let entry = &mut slot.entry; match entry { RequestCancellationEntry::Marker(marker) => marker.clone(), RequestCancellationEntry::Cancelled => return true, @@ -1595,11 +1649,20 @@ impl RequestCancellationRegistry { true } - fn remove(&self, request_id: &serde_json::Value) { - self.inner + /// Remove the slot for `request_id`, but only if it still belongs to the + /// registration identified by `generation`. + fn remove(&self, request_id: &serde_json::Value, generation: u64) { + let mut inner = self + .inner .lock() - .expect("request cancellation registry mutex poisoned") - .remove(request_id); + .expect("request cancellation registry mutex poisoned"); + if inner + .slots + .get(request_id) + .is_some_and(|slot| slot.generation == generation) + { + inner.slots.remove(request_id); + } } } @@ -1630,14 +1693,14 @@ impl RequestCancellationRegistry { #[cfg(feature = "unstable_cancel_request")] impl ResponderCancellation { fn cancellation(&self) -> RequestCancellation { - self.registry.marker(&self.id) + self.registry.marker(&self.id, self.generation) } } #[cfg(feature = "unstable_cancel_request")] impl Drop for ResponderCancellation { fn drop(&mut self) { - self.registry.remove(&self.id); + self.registry.remove(&self.id, self.generation); } } @@ -1653,25 +1716,41 @@ fn cancellation_request_id(dispatch: &Dispatch) -> Result Result, crate::Error> { - if crate::schema::CancelRequestNotification::matches_method(&message.method) { - let notification = crate::schema::CancelRequestNotification::parse_message( - &message.method, - &message.params, - )?; - return serde_json::to_value(notification.request_id) - .map(Some) - .map_err(crate::Error::into_internal_error); + let (method, params) = peel_successor_envelopes(&message.method, &message.params); + if !crate::schema::CancelRequestNotification::matches_method(method) { + return Ok(None); } - if crate::schema::SuccessorMessage::::matches_method(&message.method) { - let successor = crate::schema::SuccessorMessage::::parse_message( - &message.method, - &message.params, - )?; - return cancellation_request_id_from_message(&successor.message); - } + let notification = crate::schema::CancelRequestNotification::parse_message(method, params)?; + serde_json::to_value(notification.request_id) + .map(Some) + .map_err(crate::Error::into_internal_error) +} - Ok(None) +/// Peel any [`SuccessorMessage`] envelopes off a notification by reference, +/// returning the innermost method and params. +/// +/// This only peeks at the envelope's `method`/`params` fields instead of +/// deserializing the envelope, for two reasons: +/// +/// - It avoids deep-cloning the params of every wrapped notification on the +/// hot dispatch path just to inspect the inner method name. +/// - It is deliberately lenient: a malformed envelope is left as-is here and +/// flows on to the handler chain, which is responsible for reporting it. +/// +/// [`SuccessorMessage`]: crate::schema::SuccessorMessage +fn peel_successor_envelopes<'message>( + mut method: &'message str, + mut params: &'message serde_json::Value, +) -> (&'message str, &'message serde_json::Value) { + while crate::schema::SuccessorMessage::::matches_method(method) { + let Some(inner_method) = params.get("method").and_then(serde_json::Value::as_str) else { + break; + }; + method = inner_method; + params = params.get("params").unwrap_or(&serde_json::Value::Null); + } + (method, params) } /// Whether the dispatch is a protocol-level (`$/`-prefixed) notification, @@ -1692,25 +1771,8 @@ fn is_protocol_level_notification(dispatch: &Dispatch) -> bool { let Dispatch::Notification(message) = dispatch else { return false; }; - is_protocol_level_notification_message(message) -} - -fn is_protocol_level_notification_message(message: &UntypedMessage) -> bool { - if message.method.starts_with("$/") { - return true; - } - - if crate::schema::SuccessorMessage::::matches_method(&message.method) { - let Ok(successor) = crate::schema::SuccessorMessage::::parse_message( - &message.method, - &message.params, - ) else { - return false; - }; - return is_protocol_level_notification_message(&successor.message); - } - - false + let (method, _params) = peel_successor_envelopes(&message.method, &message.params); + method.starts_with("$/") } /// Messages send to be serialized over the transport. @@ -2075,7 +2137,7 @@ impl ConnectionTo { let remote_style = self.counterpart.remote_style(peer); #[cfg(feature = "unstable_cancel_request")] let cancellation = - SentRequestCancellation::new(self.message_tx.clone(), &remote_style, &id); + SentRequestCancellation::new(self.message_tx.clone(), remote_style, id.clone()); match remote_style.transform_outgoing_message(request) { Ok(untyped) => { // Transform the message for the target role @@ -3256,6 +3318,10 @@ impl JsonRpcNotification for UntypedMessage {} /// additionally sends a `$/cancel_request` notification asking the peer to /// cancel the request; fire-and-forget requests should consume their handle /// (for example with [`on_receiving_result`](Self::on_receiving_result)). +#[must_use = "dropping a SentRequest discards the response (and, with the \ + `unstable_cancel_request` feature, asks the peer to cancel the \ + request); consume it with `block_task`, `on_receiving_result`, \ + or `forward_response_to`"] pub struct SentRequest { id: jsonrpcmsg::Id, method: String, @@ -3301,82 +3367,56 @@ impl SentRequestCancellationDisarm { } #[cfg(feature = "unstable_cancel_request")] -enum SentRequestCancellation { - Send { - message_tx: OutgoingMessageTx, - notification: UntypedMessage, - disarm: SentRequestCancellationDisarm, - }, - Failed { - error: String, - disarm: SentRequestCancellationDisarm, - }, +struct SentRequestCancellation { + message_tx: OutgoingMessageTx, + remote_style: crate::role::RemoteStyle, + request_id: jsonrpcmsg::Id, + disarm: SentRequestCancellationDisarm, } #[cfg(feature = "unstable_cancel_request")] impl SentRequestCancellation { fn new( message_tx: OutgoingMessageTx, - remote_style: &crate::role::RemoteStyle, - request_id: &jsonrpcmsg::Id, + remote_style: crate::role::RemoteStyle, + request_id: jsonrpcmsg::Id, ) -> Self { - let notification = jsonrpc_id_to_request_id(request_id) - .and_then(|request_id| { - remote_style.transform_outgoing_message( - crate::schema::CancelRequestNotification::new(request_id), - ) - }) - .map_err(|error| error.to_string()); - let disarm = SentRequestCancellationDisarm::new(); - - match notification { - Ok(notification) => Self::Send { - message_tx, - notification, - disarm, - }, - Err(error) => Self::Failed { error, disarm }, + Self { + message_tx, + remote_style, + request_id, + disarm: SentRequestCancellationDisarm::new(), } } fn disarm(&self) { - self.disarm_handle().disarm(); + self.disarm.disarm(); } fn disarm_handle(&self) -> SentRequestCancellationDisarm { - match self { - Self::Send { disarm, .. } | Self::Failed { disarm, .. } => disarm.clone(), - } + self.disarm.clone() } fn send(&self) -> Result<(), crate::Error> { - match self { - Self::Send { - message_tx, - notification, - disarm, - } => { - if !disarm.armed.swap(false, Ordering::AcqRel) { - return Ok(()); - } + if !self.disarm.armed.swap(false, Ordering::AcqRel) { + return Ok(()); + } - send_raw_message( - message_tx, - OutgoingMessage::Notification { - untyped: notification.clone(), - }, + // Build the notification lazily: most requests are never cancelled, + // so this avoids serializing a notification per outgoing request. + let untyped = jsonrpc_id_to_request_id(&self.request_id) + .and_then(|request_id| { + self.remote_style.transform_outgoing_message( + crate::schema::CancelRequestNotification::new(request_id), ) - } - Self::Failed { error, disarm } => { - if !disarm.armed.swap(false, Ordering::AcqRel) { - return Ok(()); - } - - Err(crate::util::internal_error(format!( + }) + .map_err(|error| { + crate::util::internal_error(format!( "failed to create cancel request notification: {error}" - ))) - } - } + )) + })?; + + send_raw_message(&self.message_tx, OutgoingMessage::Notification { untyped }) } } @@ -3392,22 +3432,11 @@ impl Drop for SentRequestCancellation { #[cfg(feature = "unstable_cancel_request")] impl Debug for SentRequestCancellation { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Send { - notification, - disarm, - .. - } => f - .debug_struct("SentRequestCancellation") - .field("notification", notification) - .field("armed", &disarm.armed.load(Ordering::Acquire)) - .finish(), - Self::Failed { error, disarm } => f - .debug_struct("SentRequestCancellation") - .field("error", error) - .field("armed", &disarm.armed.load(Ordering::Acquire)) - .finish(), - } + f.debug_struct("SentRequestCancellation") + .field("request_id", &self.request_id) + .field("remote_style", &self.remote_style) + .field("armed", &self.disarm.armed.load(Ordering::Acquire)) + .finish_non_exhaustive() } } @@ -3451,6 +3480,16 @@ impl SentRequest { /// This uses the same peer and message wrapping that were used to send the /// original request, so it is the preferred way to cancel a [`SentRequest`] /// when the request handle is still available. + /// + /// At most one `$/cancel_request` is ever sent per request: the first + /// `cancel` call sends it (and also prevents the drop-time automatic + /// cancellation described in [Drop Behavior](Self#drop-behavior)), while + /// later calls return `Ok(())` without sending anything. Likewise, once + /// the SDK has routed the response to this handle, `cancel` becomes a + /// no-op: there is nothing left to cancel. + /// + /// Errors are only reported by the call that attempts to send the + /// notification. #[cfg(feature = "unstable_cancel_request")] pub fn cancel(&self) -> Result<(), crate::Error> { self.cancellation.send() @@ -3580,8 +3619,7 @@ impl SentRequest { forward_cancellation(&downstream_cancellation); response_rx.await } else { - match future::select(Box::pin(upstream_cancellation.cancelled()), response_rx) - .await + match future::select(pin!(upstream_cancellation.cancelled()), response_rx).await { Either::Left(((), response_rx)) => { forward_cancellation(&downstream_cancellation); @@ -4190,3 +4228,132 @@ impl ConnectTo for Channel { (self, Box::pin(future::ready(Ok(())))) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn peel_successor_envelopes_returns_plain_messages_unchanged() { + let params = serde_json::json!({ "key": "value" }); + let (method, peeled) = peel_successor_envelopes("session/update", ¶ms); + assert_eq!(method, "session/update"); + assert_eq!(peeled, ¶ms); + } + + #[test] + fn peel_successor_envelopes_unwraps_nested_envelopes() { + let params = serde_json::json!({ + "method": "_proxy/successor", + "params": { + "method": "$/cancel_request", + "params": { "requestId": "req-1" } + } + }); + let (method, peeled) = peel_successor_envelopes("_proxy/successor", ¶ms); + assert_eq!(method, "$/cancel_request"); + assert_eq!(peeled, &serde_json::json!({ "requestId": "req-1" })); + } + + #[test] + fn peel_successor_envelopes_leaves_malformed_envelopes_intact() { + // No string `method` field: the envelope cannot be peeled, so the + // message is returned as-is for the handler chain to deal with. + let params = serde_json::json!({ "unexpected": true }); + let (method, peeled) = peel_successor_envelopes("_proxy/successor", ¶ms); + assert_eq!(method, "_proxy/successor"); + assert_eq!(peeled, ¶ms); + } + + #[cfg(feature = "unstable_cancel_request")] + mod cancel_request { + use super::super::*; + + fn notification(method: &str, params: serde_json::Value) -> UntypedMessage { + UntypedMessage::new(method, params).expect("well-formed JSON") + } + + #[test] + fn cancellation_request_id_is_extracted_from_wrapped_notifications() { + let message = notification( + "_proxy/successor", + serde_json::json!({ + "method": "$/cancel_request", + "params": { "requestId": "req-1" } + }), + ); + let request_id = cancellation_request_id_from_message(&message) + .expect("wrapped cancel should parse"); + assert_eq!(request_id, Some(serde_json::json!("req-1"))); + } + + #[test] + fn malformed_successor_envelope_is_not_treated_as_cancellation() { + // The envelope cannot be peeled; the message must flow on to the + // handler chain instead of erroring the dispatch. + let message = notification("_proxy/successor", serde_json::json!({ "bogus": true })); + let request_id = cancellation_request_id_from_message(&message) + .expect("malformed envelope should be left to the handler chain"); + assert_eq!(request_id, None); + } + + #[test] + fn malformed_cancel_request_params_error() { + let message = notification( + "$/cancel_request", + serde_json::json!({ "requestId": { "not": "an id" } }), + ); + cancellation_request_id_from_message(&message) + .expect_err("malformed cancel params should error"); + } + + #[test] + fn registry_marks_and_removes_requests() { + let registry = RequestCancellationRegistry::new(); + let id = serde_json::json!("req-1"); + + let responder_cancellation = registry.register(id.clone()); + let marker = responder_cancellation.cancellation(); + assert!(!marker.is_cancelled()); + + assert!(registry.cancel(&id)); + assert!(marker.is_cancelled()); + assert!(responder_cancellation.cancellation().is_cancelled()); + + drop(responder_cancellation); + assert!(!registry.cancel(&id), "slot should be removed on drop"); + } + + #[test] + fn reused_request_id_does_not_cross_wire_cancellation_state() { + let registry = RequestCancellationRegistry::new(); + let id = serde_json::json!("dup"); + + // A protocol-violating peer reuses an in-flight request ID. + let first = registry.register(id.clone()); + let first_marker = first.cancellation(); + let second = registry.register(id.clone()); + let second_marker = second.cancellation(); + + // A cancellation targets whichever request currently owns the ID. + assert!(registry.cancel(&id)); + assert!(second_marker.is_cancelled()); + assert!( + !first_marker.is_cancelled(), + "the stale request must not observe the newer request's cancellation" + ); + + // The stale responder must hand out detached markers, not the + // newer request's marker. + assert!(!first.cancellation().is_cancelled()); + + // Dropping the stale responder must not remove the newer + // request's slot. + drop(first); + assert!(registry.cancel(&id), "newer slot should still be present"); + + drop(second); + assert!(!registry.cancel(&id), "slot should be removed on drop"); + } + } +} diff --git a/src/agent-client-protocol/src/role.rs b/src/agent-client-protocol/src/role.rs index e1c28aa..fc52e4a 100644 --- a/src/agent-client-protocol/src/role.rs +++ b/src/agent-client-protocol/src/role.rs @@ -81,7 +81,7 @@ pub trait HasPeer: Role { } /// Describes how messages are transformed when sent to a remote peer. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] #[non_exhaustive] pub enum RemoteStyle { /// Pass each message through exactly as it is. @@ -96,7 +96,7 @@ pub enum RemoteStyle { impl RemoteStyle { pub(crate) fn transform_outgoing_message( - &self, + self, msg: M, ) -> Result { match self { diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs index 37e5a9f..d391459 100644 --- a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -400,6 +400,140 @@ async fn unhandled_wrapped_protocol_level_notifications_are_ignored() { .await; } +#[tokio::test(flavor = "current_thread")] +async fn malformed_successor_envelope_still_reaches_handlers() { + use tokio::io::AsyncWriteExt; + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (notification_tx, mut notification_rx) = mpsc::unbounded(); + + let (mut client_writer, server_reader) = tokio::io::duplex(4096); + let (server_writer, _client_reader) = tokio::io::duplex(4096); + + let server_transport = agent_client_protocol::ByteStreams::new( + server_writer.compat_write(), + server_reader.compat(), + ); + // A catch-all notification handler: a successor envelope whose + // params cannot be peeled (no inner `method`) must not be + // mistaken for a cancellation and short-circuited; it must flow + // through the handler chain like any other notification. + let server = UntypedRole.builder().on_receive_notification( + async move |notification: agent_client_protocol::UntypedMessage, + _connection: ConnectionTo| { + notification_tx + .unbounded_send((notification.method, notification.params)) + .unwrap(); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","method":"_proxy/successor","params":{"bogus":true}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let (method, params) = next_with_timeout(&mut notification_rx).await; + assert_eq!(method, "_proxy/successor"); + assert_eq!(params, serde_json::json!({ "bogus": true })); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn wrapped_cancel_request_cancels_wrapped_request() { + use tokio::io::{AsyncWriteExt, BufReader}; + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (mut client_writer, server_reader) = tokio::io::duplex(4096); + let (server_writer, client_reader) = tokio::io::duplex(4096); + + let server_transport = agent_client_protocol::ByteStreams::new( + server_writer.compat_write(), + server_reader.compat(), + ); + let server = WrappedHost.builder().on_receive_request_from( + WrappedSuccessor, + async |_request: SimpleRequest, + responder: Responder, + cx: ConnectionTo| { + let cancellation = responder.cancellation(); + cx.spawn(async move { + let response = cancellation + .run_until_cancelled(futures::future::pending::< + Result, + >()) + .await; + responder.respond_with_result(response) + })?; + Ok(()) + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let mut client_reader = BufReader::new(client_reader); + + // A request wrapped in a successor envelope is registered under + // its outer JSON-RPC id, so a wrapped `$/cancel_request` for that + // outer id must cancel it. + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":7,"method":"_proxy/successor","params":{"method":"simple_method","params":{"message":"wrapped"}}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","method":"_proxy/successor","params":{"method":"$/cancel_request","params":{"requestId":7}}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "jsonrpc": "2.0", + "error": { + "code": -32800, + "message": "Request cancelled" + }, + "id": 7 + }"#]] + .assert_eq(&serde_json::to_string_pretty(&response).unwrap()); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn cancel_request_notification_can_be_sent_and_handled() { use tokio::task::LocalSet; From 9f9c6162ac8faffadb2c17c0aba394f4f57b34f7 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 10 Jun 2026 18:15:02 +0200 Subject: [PATCH 12/18] fix(acp): Serialize all cancel notifications correctly --- md/request-cancellation.md | 10 ++- .../Cargo.toml | 7 ++ src/agent-client-protocol/src/jsonrpc.rs | 24 +++---- src/agent-client-protocol/src/schema/mod.rs | 30 ++++----- .../tests/jsonrpc_request_cancellation.rs | 67 ++++++++++++++++--- 5 files changed, 96 insertions(+), 42 deletions(-) diff --git a/md/request-cancellation.md b/md/request-cancellation.md index 9bf6a0f..87a5c63 100644 --- a/md/request-cancellation.md +++ b/md/request-cancellation.md @@ -42,7 +42,10 @@ may: The requesting side always receives a response to the original request; cancellation only changes _which_ response that is. A `$/cancel_request` for -an unknown or already-completed request ID is silently ignored. +an unknown or already-completed request ID is silently ignored. A +`$/cancel_request` with malformed params (for example, a `requestId` that is +not a string, number, or null) is different: like any other malformed +notification, it is reported back with an out-of-band error notification. ## Interoperability @@ -72,6 +75,11 @@ When the notification targets a request that was wrapped in a `$/cancel_request` is wrapped in the same envelope, and `requestId` refers to the JSON-RPC `id` of the wrapped request on that connection. +The conductor forwards cancellations between hops when it is built with its +`unstable_cancel_request` feature, which forwards the feature of the same name +to the SDK. Without it, the conductor ignores `$/cancel_request` as described +in [Interoperability](#interoperability). + ## Related Documentation - [Protocol Reference](./protocol.md) - The `_proxy/successor/*` envelope protocol diff --git a/src/agent-client-protocol-conductor/Cargo.toml b/src/agent-client-protocol-conductor/Cargo.toml index 6927a38..f31ec43 100644 --- a/src/agent-client-protocol-conductor/Cargo.toml +++ b/src/agent-client-protocol-conductor/Cargo.toml @@ -14,6 +14,13 @@ categories = ["development-tools"] name = "agent-client-protocol-conductor" path = "src/main.rs" +[features] +default = [] + +# Forwarded from agent-client-protocol. Enable to let the conductor forward +# `$/cancel_request` hop by hop through the proxy chain. +unstable_cancel_request = ["agent-client-protocol/unstable_cancel_request"] + [dependencies] agent-client-protocol = { workspace = true } agent-client-protocol-trace-viewer.workspace = true diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index c14f9a3..6aca98b 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -1454,13 +1454,16 @@ impl RequestCancellation { return; } - if let Some(signal_tx) = self + let signal_tx = self .state .signal_tx .lock() .expect("request cancellation signal mutex poisoned") - .take() - { + .take(); + + // Complete the oneshot outside the lock: it wakes waiters, and + // arbitrary waker code must not observe the lock held. + if let Some(signal_tx) = signal_tx { let _ = signal_tx.send(()); } } @@ -3404,17 +3407,10 @@ impl SentRequestCancellation { // Build the notification lazily: most requests are never cancelled, // so this avoids serializing a notification per outgoing request. - let untyped = jsonrpc_id_to_request_id(&self.request_id) - .and_then(|request_id| { - self.remote_style.transform_outgoing_message( - crate::schema::CancelRequestNotification::new(request_id), - ) - }) - .map_err(|error| { - crate::util::internal_error(format!( - "failed to create cancel request notification: {error}" - )) - })?; + let request_id = jsonrpc_id_to_request_id(&self.request_id)?; + let untyped = self.remote_style.transform_outgoing_message( + crate::schema::CancelRequestNotification::new(request_id), + )?; send_raw_message(&self.message_tx, OutgoingMessage::Notification { untyped }) } diff --git a/src/agent-client-protocol/src/schema/mod.rs b/src/agent-client-protocol/src/schema/mod.rs index b9131e1..321aa66 100644 --- a/src/agent-client-protocol/src/schema/mod.rs +++ b/src/agent-client-protocol/src/schema/mod.rs @@ -221,10 +221,12 @@ macro_rules! impl_jsonrpc_notification_enum { /// notification enum (`$/`-prefixed methods), shared between the v1 and v2 /// schema namespaces. /// -/// The schema enums are `#[non_exhaustive]`, so the matches need wildcard -/// arms: when the schema crate adds a protocol-level notification, list it -/// here as well. Unknown variants fail to serialize rather than producing a -/// bogus method name on the wire. +/// The incoming side (`matches_method`, `parse_message`) only recognizes the +/// methods listed in the macro invocation: when the schema crate adds a +/// protocol-level notification, list it here to parse it. The outgoing side +/// (`method`, `to_untyped_message`) instead delegates to the schema enum's +/// inherent `method()` and untagged serialization, which cover every variant, +/// so unlisted variants still serialize with the correct method name. /// /// ```ignore /// impl_jsonrpc_protocol_level_notification_enum!(ProtocolLevelNotification { @@ -242,21 +244,17 @@ macro_rules! impl_jsonrpc_protocol_level_notification_enum { } fn method(&self) -> &str { - match self { - $( Self::$variant(_) => $method, )* - _ => "_unknown", - } + // Resolves to the schema enum's *inherent* `method()` (path + // syntax prefers inherent items over trait items), which + // matches its variants exhaustively: the enum is only + // non-exhaustive downstream. + <$enum>::method(self) } fn to_untyped_message(&self) -> Result<$crate::UntypedMessage, $crate::Error> { - match self { - $( Self::$variant(notification) => { - $crate::UntypedMessage::new($method, notification) - } )* - _ => Err($crate::util::internal_error( - "protocol-level notification variant is not supported by this SDK version", - )), - } + // The schema enum is `#[serde(untagged)]`, so serializing the + // enum serializes the inner notification. + $crate::UntypedMessage::new(<$enum>::method(self), self) } fn parse_message( diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs index d391459..b6a44cf 100644 --- a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -953,6 +953,10 @@ async fn forward_response_to_propagates_cancellation_to_downstream_request() { local .run_until(async { let (backend_cancel_tx, mut backend_cancel_rx) = mpsc::unbounded(); + // The responder for the cancelled request, parked by the backend + // until the forwarded cancellation arrives. + let pending_responder: Arc>>> = + Arc::new(Mutex::new(None)); let (backend_for_proxy, backend_for_server) = Channel::duplex(); let (backend_connection_tx, backend_connection_rx) = @@ -974,18 +978,40 @@ async fn forward_response_to_propagates_cancellation_to_downstream_request() { let backend_server = UntypedRole .builder() .on_receive_request( - async |_request: SimpleRequest, - _responder: Responder, - _connection: ConnectionTo| { Ok(()) }, + { + let pending_responder = pending_responder.clone(); + async move |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + if request.message == "cancel downstream" { + *pending_responder.lock().unwrap() = Some(responder); + return Ok(()); + } + + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + } + }, agent_client_protocol::on_receive_request!(), ) .on_receive_notification( - async move |notification: CancelRequestNotification, - _connection: ConnectionTo| { - backend_cancel_tx - .unbounded_send(notification.request_id) - .unwrap(); - Ok(()) + { + let pending_responder = pending_responder.clone(); + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + // Honor the forwarded cancellation: answer the + // parked request with the cancellation error. + if let Some(responder) = pending_responder.lock().unwrap().take() { + responder.respond_with_result(Err( + agent_client_protocol::Error::request_cancelled(), + ))?; + } + backend_cancel_tx + .unbounded_send(notification.request_id) + .unwrap(); + Ok(()) + } }, agent_client_protocol::on_receive_notification!(), ); @@ -1034,9 +1060,28 @@ async fn forward_response_to_propagates_cancellation_to_downstream_request() { message: "cancel downstream".into(), }); request.cancel()?; - // Wait until the proxy has forwarded the cancellation all - // the way to the backend. + + // The backend answers the parked request only once the + // proxy has forwarded the cancellation to it, and the + // proxy forwards the backend's cancellation error back + // upstream as the response. + let error = request + .block_task() + .await + .expect_err("request should be cancelled"); + assert_eq!(i32::from(error.code), -32800); next_with_timeout(&mut backend_cancel_rx).await; + + // Barrier: this round trip traverses both hops after the + // cancellation, so a duplicate `$/cancel_request` would + // already have been recorded by the backend. + let barrier = connection + .send_request(SimpleRequest { + message: "barrier".into(), + }) + .block_task() + .await?; + assert_eq!(barrier.result, "echo: barrier"); Ok(()) }) .await From 1965e6c8677ce2c71894db801c40cdc2e426878e Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Thu, 11 Jun 2026 06:39:52 +0200 Subject: [PATCH 13/18] fix(acp): Drop proxied cancel notifications --- md/request-cancellation.md | 23 +- .../src/conductor.rs | 27 +- .../tests/request_cancellation.rs | 606 ++++++++++++++++++ .../src/concepts/cancellation.rs | 7 + src/agent-client-protocol/src/jsonrpc.rs | 32 +- src/agent-client-protocol/src/session.rs | 2 +- .../tests/jsonrpc_request_cancellation.rs | 294 ++++++++- 7 files changed, 976 insertions(+), 15 deletions(-) create mode 100644 src/agent-client-protocol-conductor/tests/request_cancellation.rs diff --git a/md/request-cancellation.md b/md/request-cancellation.md index 87a5c63..81ffd92 100644 --- a/md/request-cancellation.md +++ b/md/request-cancellation.md @@ -44,8 +44,12 @@ The requesting side always receives a response to the original request; cancellation only changes _which_ response that is. A `$/cancel_request` for an unknown or already-completed request ID is silently ignored. A `$/cancel_request` with malformed params (for example, a `requestId` that is -not a string, number, or null) is different: like any other malformed -notification, it is reported back with an out-of-band error notification. +not a string, number, or null) is different: when the receiver is built with +the `unstable_cancel_request` feature, it is reported back with an +out-of-band error notification, like any other malformed notification. A +receiver built without the feature never parses the params and ignores the +notification like any other unhandled `$/` notification (see +[Interoperability](#interoperability)). ## Interoperability @@ -70,15 +74,24 @@ request on the connection it is sent over: 3. The downstream response — normal data or the cancellation error — flows back up the chain as the response to each hop's request. +Because the notification is hop-scoped, it is never tunneled across hops: +when the feature is enabled, generic forwarding helpers +(`send_proxied_message_to` in the SDK, and the conductor's internal routing) +drop a raw `$/cancel_request` instead of forwarding a request ID that means +nothing on the next connection. The cancellation still reaches the next hop, +re-issued by `forward_response_to` with that hop's own request ID. + When the notification targets a request that was wrapped in a `_proxy/successor` envelope (see the [Protocol Reference](./protocol.md)), the `$/cancel_request` is wrapped in the same envelope, and `requestId` refers to the JSON-RPC `id` of the wrapped request on that connection. -The conductor forwards cancellations between hops when it is built with its +The conductor translates cancellations between hops when it is built with its `unstable_cancel_request` feature, which forwards the feature of the same name -to the SDK. Without it, the conductor ignores `$/cancel_request` as described -in [Interoperability](#interoperability). +to the SDK. Without it, no per-hop cancellation is issued; since request IDs +are reallocated at every hop, a `$/cancel_request` cannot match anything +beyond the hop it was sent over, and the affected request simply runs to +completion as described in [Interoperability](#interoperability). ## Related Documentation diff --git a/src/agent-client-protocol-conductor/src/conductor.rs b/src/agent-client-protocol-conductor/src/conductor.rs index a12ed7f..844c909 100644 --- a/src/agent-client-protocol-conductor/src/conductor.rs +++ b/src/agent-client-protocol-conductor/src/conductor.rs @@ -462,11 +462,28 @@ where Dispatch::Request(request, responder) => self .send_request_to_predecessor_of(client, source_component_index, request) .forward_response_to(responder), - Dispatch::Notification(notification) => self.send_notification_to_predecessor_of( - client, - source_component_index, - notification, - ), + Dispatch::Notification(notification) => { + // `$/cancel_request` is connection-scoped: its `requestId` was + // allocated on the connection the notification arrived over + // and means nothing on the predecessor's connection. The SDK + // already propagates the cancellation hop by hop through the + // `forward_response_to` calls above, so drop the raw + // notification instead of tunneling a meaningless ID. + #[cfg(feature = "unstable_cancel_request")] + if agent_client_protocol::schema::CancelRequestNotification::matches_method( + notification.method(), + ) { + tracing::debug!( + "not forwarding hop-scoped `$/cancel_request` notification to predecessor" + ); + return Ok(()); + } + self.send_notification_to_predecessor_of( + client, + source_component_index, + notification, + ) + } Dispatch::Response(result, router) => router.respond_with_result(result), } } diff --git a/src/agent-client-protocol-conductor/tests/request_cancellation.rs b/src/agent-client-protocol-conductor/tests/request_cancellation.rs new file mode 100644 index 0000000..c016edd --- /dev/null +++ b/src/agent-client-protocol-conductor/tests/request_cancellation.rs @@ -0,0 +1,606 @@ +#![cfg(feature = "unstable_cancel_request")] + +//! Integration tests for `$/cancel_request` propagation through the conductor. +//! +//! Cancellation is hop-by-hop: every hop re-issues `$/cancel_request` with +//! its own connection's request ID, and the raw notification (whose +//! `requestId` is only meaningful on the connection it arrived over) must +//! *not* be tunneled verbatim through the chain. +//! +//! These tests avoid sleeps: +//! +//! - Channels report what each endpoint observed, awaited with a timeout. +//! - "Exactly one cancellation" assertions rely on a barrier round trip +//! through the whole chain: the conductor's routing loop and each +//! connection deliver messages in order, so by the time the barrier +//! response arrives, any erroneously tunneled notification would already +//! have been observed. + +use std::time::Duration; + +use agent_client_protocol::schema::{ + CancelRequestNotification, ContentBlock, ContentChunk, InitializeRequest, InitializeResponse, + NewSessionRequest, NewSessionResponse, PermissionOption, PermissionOptionKind, PromptRequest, + PromptResponse, ProtocolVersion, RequestPermissionRequest, RequestPermissionResponse, + SessionId, SessionNotification, SessionUpdate, StopReason, ToolCallUpdate, + ToolCallUpdateFields, +}; +use agent_client_protocol::{ + Agent, ByteStreams, Client, Conductor, ConnectTo, ConnectionTo, Error, JsonRpcRequest, + JsonRpcResponse, Proxy, Responder, SentRequest, +}; +use agent_client_protocol_conductor::{ConductorImpl, ProxiesAndAgent}; +use futures::StreamExt as _; +use futures::channel::mpsc; +use serde::{Deserialize, Serialize}; +use tokio::io::duplex; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + +#[derive(Debug, Clone, Serialize, Deserialize, JsonRpcRequest)] +#[request(method = "test/simple", response = SimpleResponse)] +struct SimpleRequest { + message: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonRpcResponse)] +struct SimpleResponse { + result: String, +} + +/// Await the next item on `rx`, panicking instead of hanging if it never +/// arrives. +async fn next_with_timeout(rx: &mut mpsc::UnboundedReceiver) -> T { + tokio::time::timeout(Duration::from_secs(10), rx.next()) + .await + .expect("timed out waiting for channel event") + .expect("channel closed before expected event") +} + +/// Assert that no item is currently buffered on `rx`. +/// +/// Callers must first establish an ordering barrier (such as a round trip +/// through the whole chain) that guarantees any erroneously sent +/// notification would already have been observed. +fn assert_no_event(rx: &mut mpsc::UnboundedReceiver) { + if let Ok(event) = rx.try_recv() { + panic!("unexpected event: {event:?}"); + } +} + +/// The real intercepting proxy from the test fixtures, run in-process. +struct InProcessArrowProxy; + +impl ConnectTo for InProcessArrowProxy { + async fn connect_to(self, client: impl ConnectTo) -> Result<(), Error> { + agent_client_protocol_test::arrow_proxy::run_arrow_proxy(client).await + } +} + +fn prompt_text(request: &PromptRequest) -> String { + request + .prompt + .iter() + .filter_map(|block| match block { + ContentBlock::Text(text) => Some(text.text.as_str()), + _ => None, + }) + .collect() +} + +/// A client cancels a request it sent through the conductor (and a +/// passthrough proxy) to the agent. +/// +/// The agent must observe exactly one `$/cancel_request`, carrying the ID of +/// the request on the conductor-to-agent connection — not the client's raw +/// notification with its hop-local request ID. +#[tokio::test] +async fn client_cancellation_propagates_hop_by_hop_to_agent() -> Result<(), Error> { + let (agent_cancel_tx, mut agent_cancel_rx) = mpsc::unbounded(); + // The JSON-RPC id of the parked request, as seen by the agent. + let (parked_id_tx, mut parked_id_rx) = mpsc::unbounded(); + + let agent = Agent + .builder() + .on_receive_request( + async |initialize: InitializeRequest, responder, _cx: ConnectionTo| { + responder.respond(InitializeResponse::new(initialize.protocol_version)) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request( + async move |request: SimpleRequest, + responder: Responder, + cx: ConnectionTo| { + if request.message == "park" { + parked_id_tx.unbounded_send(responder.id()).unwrap(); + let cancellation = responder.cancellation(); + cx.spawn(async move { + let response = cancellation + .run_until_cancelled(std::future::pending::< + Result, + >()) + .await; + responder.respond_with_result(response) + })?; + return Ok(()); + } + + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |cancel: CancelRequestNotification, _cx: ConnectionTo| { + agent_cancel_tx.unbounded_send(cancel.request_id).unwrap(); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + let (editor_write, conductor_read) = duplex(8192); + let (conductor_write, editor_read) = duplex(8192); + + // One passthrough proxy in the chain, so the cancellation crosses every + // kind of hop: client→conductor, conductor→proxy, proxy→conductor + // (successor-wrapped), and conductor→agent. + let conductor_handle = tokio::spawn(async move { + ConductorImpl::new_agent( + "cancellation-conductor".to_string(), + ProxiesAndAgent::new(agent).proxy(Proxy.builder()), + ) + .run(ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) + .await + }); + + let client_request_id = tokio::time::timeout(Duration::from_secs(30), async move { + Client + .builder() + .connect_with( + ByteStreams::new(editor_write.compat_write(), editor_read.compat()), + async |cx| { + let initialize = cx + .send_request(InitializeRequest::new(ProtocolVersion::V1)) + .block_task() + .await?; + assert_eq!(initialize.protocol_version, ProtocolVersion::V1); + + let request: SentRequest = cx.send_request(SimpleRequest { + message: "park".into(), + }); + let client_request_id = request.id(); + request.cancel()?; + + // The cancellation reaches the agent hop by hop, and the + // agent's cancellation error flows back the same way. + let error = request + .block_task() + .await + .expect_err("request should be cancelled"); + assert_eq!(i32::from(error.code), -32800); + + // Barrier: this round trip traverses the whole chain + // after the cancellation, so a tunneled raw + // `$/cancel_request` would already have been recorded by + // the agent. + let barrier = cx + .send_request(SimpleRequest { + message: "barrier".into(), + }) + .block_task() + .await?; + assert_eq!(barrier.result, "echo: barrier"); + + Ok(client_request_id) + }, + ) + .await + }) + .await + .expect("test timed out") + .expect("client failed"); + + // The agent saw exactly one `$/cancel_request`, for the request ID on + // its own connection. + let parked_id = next_with_timeout(&mut parked_id_rx).await; + assert_ne!( + parked_id, client_request_id, + "each hop must re-issue the request under its own ID" + ); + let observed = next_with_timeout(&mut agent_cancel_rx).await; + assert_eq!(serde_json::to_value(observed).unwrap(), parked_id); + assert_no_event(&mut agent_cancel_rx); + + conductor_handle.abort(); + Ok(()) +} + +/// The agent cancels a request it sent through the conductor to the client +/// (the right-to-left direction). +/// +/// The client must observe exactly one `$/cancel_request`, carrying the ID +/// of the request on the client-to-conductor connection — not the agent's +/// raw notification with its hop-local request ID. +#[tokio::test] +async fn agent_cancellation_propagates_hop_by_hop_to_client() -> Result<(), Error> { + let (client_cancel_tx, mut client_cancel_rx) = mpsc::unbounded(); + // The JSON-RPC id of the parked request, as seen by the client. + let (parked_id_tx, mut parked_id_rx) = mpsc::unbounded(); + + let agent = Agent + .builder() + .on_receive_request( + async |initialize: InitializeRequest, responder, _cx: ConnectionTo| { + responder.respond(InitializeResponse::new(initialize.protocol_version)) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request( + async |request: SimpleRequest, + responder: Responder, + cx: ConnectionTo| { + if request.message == "trigger reverse cancel" { + let connection = cx.clone(); + cx.spawn(async move { + // Send a request to the client, cancel it, and report + // how it concluded as the response to the trigger. + let upstream: SentRequest = + connection.send_request(SimpleRequest { + message: "park".into(), + }); + upstream.cancel()?; + let error = upstream + .block_task() + .await + .expect_err("request to the client should be cancelled"); + responder.respond(SimpleResponse { + result: format!("client request error: {}", i32::from(error.code)), + }) + })?; + return Ok(()); + } + + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ); + + let (editor_write, conductor_read) = duplex(8192); + let (conductor_write, editor_read) = duplex(8192); + + let conductor_handle = tokio::spawn(async move { + ConductorImpl::new_agent( + "cancellation-conductor".to_string(), + ProxiesAndAgent::new(agent), + ) + .run(ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) + .await + }); + + tokio::time::timeout(Duration::from_secs(30), async move { + Client + .builder() + .on_receive_request( + async move |request: SimpleRequest, + responder: Responder, + cx: ConnectionTo| { + assert_eq!(request.message, "park"); + parked_id_tx.unbounded_send(responder.id()).unwrap(); + let cancellation = responder.cancellation(); + cx.spawn(async move { + let response = cancellation + .run_until_cancelled(std::future::pending::< + Result, + >()) + .await; + responder.respond_with_result(response) + })?; + Ok(()) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |cancel: CancelRequestNotification, _cx: ConnectionTo| { + client_cancel_tx.unbounded_send(cancel.request_id).unwrap(); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ) + .connect_with( + ByteStreams::new(editor_write.compat_write(), editor_read.compat()), + async |cx| { + let initialize = cx + .send_request(InitializeRequest::new(ProtocolVersion::V1)) + .block_task() + .await?; + assert_eq!(initialize.protocol_version, ProtocolVersion::V1); + + // The agent answers the trigger only after its request to + // the client was cancelled and answered with the standard + // cancellation error. + let response = cx + .send_request(SimpleRequest { + message: "trigger reverse cancel".into(), + }) + .block_task() + .await?; + assert_eq!(response.result, "client request error: -32800"); + + // Barrier: a tunneled raw `$/cancel_request` was queued + // in the conductor's routing loop before this request, so + // it would already have been recorded by the client. + let barrier = cx + .send_request(SimpleRequest { + message: "barrier".into(), + }) + .block_task() + .await?; + assert_eq!(barrier.result, "echo: barrier"); + + Ok(()) + }, + ) + .await + }) + .await + .expect("test timed out") + .expect("client failed"); + + // The client saw exactly one `$/cancel_request`, for the request ID on + // its own connection. + let parked_id = next_with_timeout(&mut parked_id_rx).await; + let observed = next_with_timeout(&mut client_cancel_rx).await; + assert_eq!(serde_json::to_value(observed).unwrap(), parked_id); + assert_no_event(&mut client_cancel_rx); + + conductor_handle.abort(); + Ok(()) +} + +/// The canonical real-world cancellation cascade, through a real intercepting +/// proxy (`arrow_proxy`) and a full ACP session: +/// +/// 1. The client initializes, creates a session, and sends `session/prompt`. +/// 2. The agent asks the client for permission (`session/request_permission`) +/// and waits. +/// 3. The client cancels the *prompt*; the agent reacts by cancelling its +/// outstanding *permission request*, then answers the prompt with the +/// cancellation error. +/// +/// Both cancellations must arrive re-issued with the receiving connection's +/// own request ID — exactly once each — and the chain (including the +/// transforming proxy) must keep working afterwards. +#[tokio::test] +async fn prompt_cancellation_cascades_through_real_proxy_chain() -> Result<(), Error> { + // What the agent observed: incoming `$/cancel_request`s and the id of the + // parked prompt. + let (agent_cancel_tx, mut agent_cancel_rx) = mpsc::unbounded(); + let (prompt_id_tx, mut prompt_id_rx) = mpsc::unbounded(); + // What the client observed: incoming `$/cancel_request`s, the id of the + // parked permission request, and session updates. + let (client_cancel_tx, mut client_cancel_rx) = mpsc::unbounded(); + let (permission_id_tx, mut permission_id_rx) = mpsc::unbounded(); + let (session_update_tx, mut session_update_rx) = mpsc::unbounded(); + + let agent = Agent + .builder() + .on_receive_request( + async |initialize: InitializeRequest, responder, _cx: ConnectionTo| { + responder.respond(InitializeResponse::new(initialize.protocol_version)) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request( + async |request: NewSessionRequest, responder, _cx: ConnectionTo| { + assert!(request.mcp_servers.is_empty()); + responder.respond(NewSessionResponse::new(SessionId::new("test-session"))) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request( + async move |request: PromptRequest, + responder: Responder, + cx: ConnectionTo| { + let text = prompt_text(&request); + if text != "park" { + // Echo prompts complete normally, with a session update + // the arrow proxy will transform on its way back. + cx.send_notification(SessionNotification::new( + request.session_id, + SessionUpdate::AgentMessageChunk(ContentChunk::new(text.into())), + ))?; + return responder.respond(PromptResponse::new(StopReason::EndTurn)); + } + + prompt_id_tx.unbounded_send(responder.id()).unwrap(); + let cancellation = responder.cancellation(); + let connection = cx.clone(); + cx.spawn(async move { + // Ask the client for permission through the chain. + let permission: SentRequest = connection + .send_request(RequestPermissionRequest::new( + request.session_id, + ToolCallUpdate::new("tool-1", ToolCallUpdateFields::default()), + vec![PermissionOption::new( + "allow", + "Allow", + PermissionOptionKind::AllowOnce, + )], + )); + + // The client cancels the prompt rather than answering the + // permission request. + cancellation.cancelled().await; + + // React like a real agent: withdraw the outstanding + // permission request, then report the prompt as + // cancelled. + permission.cancel()?; + let permission_error = permission + .block_task() + .await + .expect_err("permission request should be cancelled"); + + if i32::from(permission_error.code) == -32800 { + responder.respond_with_result(Err(Error::request_cancelled())) + } else { + responder.respond_with_result(Err( + agent_client_protocol::util::internal_error(format!( + "unexpected permission error: {permission_error:?}" + )), + )) + } + })?; + Ok(()) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |cancel: CancelRequestNotification, _cx: ConnectionTo| { + agent_cancel_tx.unbounded_send(cancel.request_id).unwrap(); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + let (editor_write, conductor_read) = duplex(8192); + let (conductor_write, editor_read) = duplex(8192); + + let conductor_handle = tokio::spawn(async move { + ConductorImpl::new_agent( + "cancellation-conductor".to_string(), + ProxiesAndAgent::new(agent).proxy(InProcessArrowProxy), + ) + .run(ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) + .await + }); + + let client_prompt_id = tokio::time::timeout(Duration::from_secs(30), async move { + Client + .builder() + .on_receive_request( + async move |_request: RequestPermissionRequest, + responder: Responder, + cx: ConnectionTo| { + permission_id_tx.unbounded_send(responder.id()).unwrap(); + let cancellation = responder.cancellation(); + cx.spawn(async move { + let response = cancellation + .run_until_cancelled(std::future::pending::< + Result, + >()) + .await; + responder.respond_with_result(response) + })?; + Ok(()) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: SessionNotification, _cx: ConnectionTo| { + if let SessionUpdate::AgentMessageChunk(ContentChunk { + content: ContentBlock::Text(text), + .. + }) = notification.update + { + session_update_tx.unbounded_send(text.text).unwrap(); + } + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ) + .on_receive_notification( + async move |cancel: CancelRequestNotification, _cx: ConnectionTo| { + client_cancel_tx.unbounded_send(cancel.request_id).unwrap(); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ) + .connect_with( + ByteStreams::new(editor_write.compat_write(), editor_read.compat()), + async |cx| { + let initialize = cx + .send_request(InitializeRequest::new(ProtocolVersion::V1)) + .block_task() + .await?; + assert_eq!(initialize.protocol_version, ProtocolVersion::V1); + + let session = cx + .send_request(NewSessionRequest::new( + std::env::current_dir().map_err(Error::into_internal_error)?, + )) + .block_task() + .await?; + + let prompt: SentRequest = cx.send_request(PromptRequest::new( + session.session_id.clone(), + vec!["park".into()], + )); + let client_prompt_id = prompt.id(); + prompt.cancel()?; + + let error = prompt + .block_task() + .await + .expect_err("prompt should be cancelled"); + assert_eq!(i32::from(error.code), -32800); + + // The chain still works end to end: a normal prompt + // completes, and its session update comes back through + // the arrow proxy, which prefixes `>` — proving the real + // proxy sits in the message path. + let barrier: PromptResponse = cx + .send_request(PromptRequest::new( + session.session_id.clone(), + vec!["barrier".into()], + )) + .block_task() + .await?; + assert_eq!(barrier.stop_reason, StopReason::EndTurn); + + Ok(client_prompt_id) + }, + ) + .await + }) + .await + .expect("test timed out") + .expect("client failed"); + + // The agent saw exactly one `$/cancel_request` (for the prompt), with the + // ID of the prompt on the conductor-to-agent connection. + let prompt_id = next_with_timeout(&mut prompt_id_rx).await; + assert_ne!( + prompt_id, client_prompt_id, + "each hop must re-issue the request under its own ID" + ); + let observed = next_with_timeout(&mut agent_cancel_rx).await; + assert_eq!(serde_json::to_value(observed).unwrap(), prompt_id); + assert_no_event(&mut agent_cancel_rx); + + // The client saw exactly one `$/cancel_request` (for the permission + // request), with the ID of that request on the client's own connection. + let permission_id = next_with_timeout(&mut permission_id_rx).await; + let observed = next_with_timeout(&mut client_cancel_rx).await; + assert_eq!(serde_json::to_value(observed).unwrap(), permission_id); + assert_no_event(&mut client_cancel_rx); + + // The barrier prompt's session update was transformed by the arrow proxy. + let update = next_with_timeout(&mut session_update_rx).await; + assert_eq!(update, ">barrier"); + + conductor_handle.abort(); + Ok(()) +} diff --git a/src/agent-client-protocol/src/concepts/cancellation.rs b/src/agent-client-protocol/src/concepts/cancellation.rs index 13c6d3a..07ba06a 100644 --- a/src/agent-client-protocol/src/concepts/cancellation.rs +++ b/src/agent-client-protocol/src/concepts/cancellation.rs @@ -95,6 +95,12 @@ //! downstream request automatically. The downstream response (normal data or a //! cancellation error) is still forwarded back upstream. //! +//! Because cancellation propagates per hop this way, the raw notification is +//! never tunneled across hops: [`ConnectionTo::send_proxied_message_to`] drops +//! `$/cancel_request` notifications rather than forwarding a `requestId` that +//! was allocated on a different connection and would be meaningless to the +//! next peer. +//! //! # Low-level access //! //! Register [`CancelRequestNotification`] (or [`ProtocolLevelNotification`]) @@ -132,6 +138,7 @@ //! [`RequestCancellation::is_cancelled`]: crate::RequestCancellation::is_cancelled //! [`ConnectionTo::send_request`]: crate::ConnectionTo::send_request //! [`ConnectionTo::send_request_to`]: crate::ConnectionTo::send_request_to +//! [`ConnectionTo::send_proxied_message_to`]: crate::ConnectionTo::send_proxied_message_to //! [`ConnectionTo::spawn`]: crate::ConnectionTo::spawn //! [`SentRequest`]: crate::SentRequest //! [`SentRequest::cancel`]: crate::SentRequest::cancel diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index 5022513..8043503 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -1756,6 +1756,12 @@ impl RequestCancellationRegistry { /// Get the cancellation marker for a registered request, creating it on /// first use. Repeated calls return markers that share the same state. + /// + /// Exception: when the registration is stale (a protocol-violating peer + /// reused this request ID and the slot now belongs to a newer request, or + /// was already removed by it), every call returns a fresh *detached* + /// marker. Detached markers can never fire, and detached markers from + /// repeated calls do not share state with each other. fn marker(&self, id: &serde_json::Value, generation: u64) -> RequestCancellation { let mut inner = self .inner @@ -2214,6 +2220,14 @@ impl ConnectionTo { /// /// The request context's response type matches the request's response type, /// enabling type-safe message forwarding. + /// + /// When the `unstable_cancel_request` feature is enabled, `$/cancel_request` + /// notifications are *not* forwarded: their `requestId` refers to a request + /// on the connection they arrived over and would be meaningless to `peer`. + /// Cancellation instead propagates hop by hop, because the responders + /// passed to [`forward_response_to`](SentRequest::forward_response_to) + /// observe it and re-issue the cancellation with the forwarded request's + /// own ID. pub fn send_proxied_message_to< Peer: Role, Req: JsonRpcRequest, @@ -2230,7 +2244,23 @@ impl ConnectionTo { Dispatch::Request(request, responder) => self .send_request_to(peer, request) .forward_response_to(responder), - Dispatch::Notification(notification) => self.send_notification_to(peer, notification), + Dispatch::Notification(notification) => { + // `$/cancel_request` is connection-scoped: its `requestId` was + // allocated on the connection the notification arrived over + // and means nothing to `peer`. The cancellation has already + // been recorded on this connection's responder markers, and + // `forward_response_to` re-issues it for the forwarded request + // with the correct per-hop ID, so drop the raw notification + // instead of tunneling a meaningless ID across the hop. + #[cfg(feature = "unstable_cancel_request")] + if crate::schema::CancelRequestNotification::matches_method(notification.method()) { + tracing::debug!( + "not forwarding hop-scoped `$/cancel_request` notification across proxy hop" + ); + return Ok(()); + } + self.send_notification_to(peer, notification) + } Dispatch::Response(result, router) => { // Responses are forwarded directly to their destination router.respond_with_result(result) diff --git a/src/agent-client-protocol/src/session.rs b/src/agent-client-protocol/src/session.rs index c12f1b5..424895b 100644 --- a/src/agent-client-protocol/src/session.rs +++ b/src/agent-client-protocol/src/session.rs @@ -516,7 +516,7 @@ where #[derive(Debug)] #[cfg_attr( feature = "unstable_cancel_request", - expect( + allow( clippy::large_enum_variant, reason = "Dispatch messages vastly outnumber StopReason; boxing would add a heap allocation" ) diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs index 9a09dd6..fa976b4 100644 --- a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -250,6 +250,12 @@ impl agent_client_protocol::role::HasPeer for WrappedCounterpa } } +impl agent_client_protocol::role::HasPeer for WrappedHost { + fn remote_style(&self, _peer: WrappedSuccessor) -> agent_client_protocol::role::RemoteStyle { + agent_client_protocol::role::RemoteStyle::Successor + } +} + #[tokio::test(flavor = "current_thread")] async fn unhandled_protocol_level_notifications_are_ignored() { use tokio::io::{AsyncWriteExt, BufReader}; @@ -534,6 +540,104 @@ async fn wrapped_cancel_request_cancels_wrapped_request() { .await; } +#[tokio::test(flavor = "current_thread")] +async fn cancelling_request_sent_to_successor_peer_sends_wrapped_cancel() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (wrapped_cancel_tx, mut wrapped_cancel_rx) = mpsc::unbounded(); + let (plain_cancel_tx, mut plain_cancel_rx) = mpsc::unbounded(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = WrappedHost + .builder() + .on_receive_request_from( + WrappedSuccessor, + async |_request: SimpleRequest, + responder: Responder, + cx: ConnectionTo| { + let cancellation = responder.cancellation(); + cx.spawn(async move { + let response = cancellation + .run_until_cancelled(futures::future::pending::< + Result, + >()) + .await; + responder.respond_with_result(response) + })?; + Ok(()) + }, + agent_client_protocol::on_receive_request!(), + ) + // Matches only a `$/cancel_request` wrapped in a + // `_proxy/successor` envelope: observing it here proves the + // client wrapped the outgoing cancellation the same way as + // the request it refers to. + .on_receive_notification_from( + WrappedSuccessor, + async move |cancel: CancelRequestNotification, + _cx: ConnectionTo| { + wrapped_cancel_tx.unbounded_send(cancel.request_id).unwrap(); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ) + // Matches only an *unwrapped* `$/cancel_request`; the client + // must never send one for a successor-wrapped request. + .on_receive_notification( + async move |cancel: CancelRequestNotification, + _cx: ConnectionTo| { + plain_cancel_tx.unbounded_send(cancel.request_id).unwrap(); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let (expected_id, error) = WrappedCounterpart + .builder() + .connect_with(client_transport, async |cx| { + let request: SentRequest = cx.send_request_to( + WrappedSuccessor, + SimpleRequest { + message: "wrapped cancel".into(), + }, + ); + let expected_id = request.id(); + request.cancel()?; + let error = request + .block_task() + .await + .expect_err("request should be cancelled"); + Ok((expected_id, error)) + }) + .await + .unwrap(); + + assert_eq!(i32::from(error.code), -32800); + + // The cancellation arrived wrapped, for the wrapped request's + // outer JSON-RPC id, and never in unwrapped form. + let received = next_with_timeout(&mut wrapped_cancel_rx).await; + assert_eq!(serde_json::to_value(received).unwrap(), expected_id); + assert_no_event(&mut wrapped_cancel_rx); + assert_no_event(&mut plain_cancel_rx); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn cancel_request_notification_can_be_sent_and_handled() { use tokio::task::LocalSet; @@ -594,9 +698,19 @@ async fn sent_request_can_send_cancellation_for_its_id() { let server = UntypedRole .builder() .on_receive_request( - async |_request: SimpleRequest, - _responder: Responder, - _connection: ConnectionTo| { Ok(()) }, + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + if request.message == "barrier" { + return responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }); + } + // Park other requests (by dropping the responder) so + // the cancelled request is never answered and the + // client handle stays unconsumed. + Ok(()) + }, agent_client_protocol::on_receive_request!(), ) .on_receive_notification( @@ -625,6 +739,21 @@ async fn sent_request_can_send_cancellation_for_its_id() { let expected_id = request.id(); request.cancel()?; let received = next_with_timeout(&mut cancel_rx).await; + + // Dropping the handle after an explicit cancel must not + // send a second `$/cancel_request`. + drop(request); + + // Barrier round trip: a duplicate cancel sent by the drop + // above would reach the server before this request. + let barrier = cx + .send_request(SimpleRequest { + message: "barrier".into(), + }) + .block_task() + .await?; + assert_eq!(barrier.result, "echo: barrier"); + Ok((expected_id, received)) }) .await @@ -1092,6 +1221,165 @@ async fn forward_response_to_propagates_cancellation_to_downstream_request() { .await; } +#[tokio::test(flavor = "current_thread")] +async fn send_proxied_message_does_not_tunnel_cancel_notifications() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (backend_cancel_tx, mut backend_cancel_rx) = mpsc::unbounded(); + // The downstream JSON-RPC id of the parked request, as seen by + // the backend. + let (parked_id_tx, mut parked_id_rx) = mpsc::unbounded(); + // The responder for the cancelled request, parked by the backend + // until the forwarded cancellation arrives. + let pending_responder: Arc>>> = + Arc::new(Mutex::new(None)); + + let (backend_for_proxy, backend_for_server) = Channel::duplex(); + let (backend_connection_tx, backend_connection_rx) = + futures::channel::oneshot::channel(); + + tokio::task::spawn_local(async move { + let result = UntypedRole + .builder() + .connect_with(backend_for_proxy, async |connection| { + drop(backend_connection_tx.send(connection.clone())); + std::future::pending::>().await + }) + .await; + if let Err(error) = result { + panic!("proxy-to-backend connection should stay alive: {error:?}"); + } + }); + + let backend_server = UntypedRole + .builder() + .on_receive_request( + { + let pending_responder = pending_responder.clone(); + async move |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + if request.message == "park" { + parked_id_tx.unbounded_send(responder.id()).unwrap(); + *pending_responder.lock().unwrap() = Some(responder); + return Ok(()); + } + + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + } + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + { + let pending_responder = pending_responder.clone(); + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + // Honor the cancellation: answer the parked + // request with the cancellation error. + if let Some(responder) = pending_responder.lock().unwrap().take() { + responder.respond_with_result(Err( + agent_client_protocol::Error::request_cancelled(), + ))?; + } + backend_cancel_tx + .unbounded_send(notification.request_id) + .unwrap(); + Ok(()) + } + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = backend_server.connect_to(backend_for_server).await { + panic!("backend server should stay alive: {error:?}"); + } + }); + + let backend_connection = backend_connection_rx + .await + .expect("backend connection should start"); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let proxy_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + // The proxy forwards *every* incoming dispatch with + // `send_proxied_message`. Without the hop-scoped filter, the + // client's raw `$/cancel_request` (whose request ID only means + // something on the client-to-proxy connection) would be tunneled + // to the backend verbatim, alongside the cancellation that + // `forward_response_to` re-issues with the downstream ID. + let proxy = UntypedRole.builder().on_receive_dispatch( + { + let backend_connection = backend_connection.clone(); + async move |dispatch: Dispatch, _connection: ConnectionTo| { + backend_connection.send_proxied_message(dispatch) + } + }, + agent_client_protocol::on_receive_dispatch!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = proxy.connect_to(proxy_transport).await { + panic!("proxy should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let client_request_id = UntypedRole + .builder() + .connect_with(client_transport, async |connection| { + let request: SentRequest = + connection.send_request(SimpleRequest { + message: "park".into(), + }); + let client_request_id = request.id(); + request.cancel()?; + + let error = request + .block_task() + .await + .expect_err("request should be cancelled"); + assert_eq!(i32::from(error.code), -32800); + + // Barrier: this round trip traverses both hops after the + // cancellation, so a tunneled raw `$/cancel_request` + // would already have been recorded by the backend. + let barrier = connection + .send_request(SimpleRequest { + message: "barrier".into(), + }) + .block_task() + .await?; + assert_eq!(barrier.result, "echo: barrier"); + Ok(client_request_id) + }) + .await + .unwrap(); + + // The backend saw exactly one `$/cancel_request`: the one + // re-issued for the downstream request, not the client's raw + // notification with its hop-local request ID. + let parked_id = next_with_timeout(&mut parked_id_rx).await; + assert_ne!( + parked_id, client_request_id, + "the proxy must re-issue the request under its own ID" + ); + let observed = next_with_timeout(&mut backend_cancel_rx).await; + assert_eq!(serde_json::to_value(observed).unwrap(), parked_id); + assert_no_event(&mut backend_cancel_rx); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn cancellation_marker_requested_after_cancel_is_already_cancelled() { use tokio::task::LocalSet; From ce4f6beb10aa646a7958ea6292560bce29f1bdb9 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Thu, 11 Jun 2026 07:15:11 +0200 Subject: [PATCH 14/18] feat(acp): Allow custom cancellation forwarding --- md/request-cancellation.md | 10 + .../src/conductor.rs | 18 +- .../tests/request_cancellation.rs | 233 +++++++++++++++ .../src/concepts/cancellation.rs | 42 +++ src/agent-client-protocol/src/jsonrpc.rs | 211 ++++++++++---- src/agent-client-protocol/src/role/acp.rs | 9 +- .../tests/jsonrpc_request_cancellation.rs | 268 ++++++++++++++++++ 7 files changed, 722 insertions(+), 69 deletions(-) diff --git a/md/request-cancellation.md b/md/request-cancellation.md index 81ffd92..5022294 100644 --- a/md/request-cancellation.md +++ b/md/request-cancellation.md @@ -81,6 +81,16 @@ drop a raw `$/cancel_request` instead of forwarding a request ID that means nothing on the next connection. The cancellation still reaches the next hop, re-issued by `forward_response_to` with that hop's own request ID. +Proxies that intercept methods with custom handlers stay in control: the +request's cancellation marker is their decision point, and handlers see the +raw notification before any generic forwarding fallback. A custom handler can +handle the cancellation locally, propagate it to a forwarded request +(`forward_response_to`, or `forward_cancellation_from` when the forwarding +needs custom logic), absorb it, or claim the notification and route it itself. +See the `concepts::cancellation` chapter in the +[agent-client-protocol rustdoc](https://docs.rs/agent-client-protocol) for +the full decision matrix. + When the notification targets a request that was wrapped in a `_proxy/successor` envelope (see the [Protocol Reference](./protocol.md)), the `$/cancel_request` is wrapped in the same envelope, and `requestId` refers to diff --git a/src/agent-client-protocol-conductor/src/conductor.rs b/src/agent-client-protocol-conductor/src/conductor.rs index 844c909..badb90f 100644 --- a/src/agent-client-protocol-conductor/src/conductor.rs +++ b/src/agent-client-protocol-conductor/src/conductor.rs @@ -780,12 +780,18 @@ where // // The proxy will then initialize itself and forward an `Initialize` // request to its successor. - self.proxies[target_component_index] - .send_request(InitializeProxyRequest::from(request)) - .on_receiving_result(async move |result| { - tracing::debug!(?result, "got initialize_proxy response from proxy"); - responder.respond_with_result(result) - }) + let sent = self.proxies[target_component_index] + .send_request(InitializeProxyRequest::from(request)); + // The request is rewritten, so `forward_response_to` cannot be + // used here; wire up cancellation forwarding explicitly to + // keep `initialize` cancellable like every other forwarded + // request. + #[cfg(feature = "unstable_cancel_request")] + let sent = sent.forward_cancellation_from(responder.cancellation()); + sent.on_receiving_result(async move |result| { + tracing::debug!(?result, "got initialize_proxy response from proxy"); + responder.respond_with_result(result) + }) }) .await .otherwise(async |message| { diff --git a/src/agent-client-protocol-conductor/tests/request_cancellation.rs b/src/agent-client-protocol-conductor/tests/request_cancellation.rs index c016edd..6d12ffc 100644 --- a/src/agent-client-protocol-conductor/tests/request_cancellation.rs +++ b/src/agent-client-protocol-conductor/tests/request_cancellation.rs @@ -16,6 +16,8 @@ //! response arrives, any erroneously tunneled notification would already //! have been observed. +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use agent_client_protocol::schema::{ @@ -604,3 +606,234 @@ async fn prompt_cancellation_cascades_through_real_proxy_chain() -> Result<(), E conductor_handle.abort(); Ok(()) } + +/// `session/new` is forwarded by proxies with a result hook (to register the +/// session's dynamic handler), not with `forward_response_to` — cancellation +/// must still propagate hop by hop, exactly like every other request. +#[tokio::test] +async fn session_new_cancellation_propagates_through_proxy() -> Result<(), Error> { + let (agent_cancel_tx, mut agent_cancel_rx) = mpsc::unbounded(); + let (parked_id_tx, mut parked_id_rx) = mpsc::unbounded(); + + let agent = Agent + .builder() + .on_receive_request( + async |initialize: InitializeRequest, responder, _cx: ConnectionTo| { + responder.respond(InitializeResponse::new(initialize.protocol_version)) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request( + async move |request: NewSessionRequest, + responder: Responder, + cx: ConnectionTo| { + if request.cwd.ends_with("park-session") { + parked_id_tx.unbounded_send(responder.id()).unwrap(); + let cancellation = responder.cancellation(); + cx.spawn(async move { + let response = cancellation + .run_until_cancelled(std::future::pending::< + Result, + >()) + .await; + responder.respond_with_result(response) + })?; + return Ok(()); + } + + responder.respond(NewSessionResponse::new(SessionId::new("normal-session"))) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |cancel: CancelRequestNotification, _cx: ConnectionTo| { + agent_cancel_tx.unbounded_send(cancel.request_id).unwrap(); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + let (editor_write, conductor_read) = duplex(8192); + let (conductor_write, editor_read) = duplex(8192); + + // The passthrough proxy is what exercises the proxy-side `session/new` + // forwarding hook. + let conductor_handle = tokio::spawn(async move { + ConductorImpl::new_agent( + "cancellation-conductor".to_string(), + ProxiesAndAgent::new(agent).proxy(Proxy.builder()), + ) + .run(ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) + .await + }); + + let client_request_id = tokio::time::timeout(Duration::from_secs(30), async move { + Client + .builder() + .connect_with( + ByteStreams::new(editor_write.compat_write(), editor_read.compat()), + async |cx| { + let initialize = cx + .send_request(InitializeRequest::new(ProtocolVersion::V1)) + .block_task() + .await?; + assert_eq!(initialize.protocol_version, ProtocolVersion::V1); + + let request: SentRequest = + cx.send_request(NewSessionRequest::new("/park-session")); + let client_request_id = request.id(); + request.cancel()?; + + let error = request + .block_task() + .await + .expect_err("session/new should be cancelled"); + assert_eq!(i32::from(error.code), -32800); + + // Barrier through the whole chain: a fresh session still + // works after the cancelled one. + let session = cx + .send_request(NewSessionRequest::new( + std::env::current_dir().map_err(Error::into_internal_error)?, + )) + .block_task() + .await?; + assert_eq!(session.session_id, SessionId::new("normal-session")); + + Ok(client_request_id) + }, + ) + .await + }) + .await + .expect("test timed out") + .expect("client failed"); + + // The agent saw exactly one `$/cancel_request`, for the `session/new` ID + // on its own connection. + let parked_id = next_with_timeout(&mut parked_id_rx).await; + assert_ne!( + parked_id, client_request_id, + "each hop must re-issue the request under its own ID" + ); + let observed = next_with_timeout(&mut agent_cancel_rx).await; + assert_eq!(serde_json::to_value(observed).unwrap(), parked_id); + assert_no_event(&mut agent_cancel_rx); + + conductor_handle.abort(); + Ok(()) +} + +/// `initialize` is rewritten to `_proxy/initialize` at the conductor-to-proxy +/// hop and forwarded with a result hook — cancellation must still propagate +/// hop by hop, exactly like every other request. +#[tokio::test] +async fn initialize_cancellation_propagates_through_proxy() -> Result<(), Error> { + let (agent_cancel_tx, mut agent_cancel_rx) = mpsc::unbounded(); + let (parked_id_tx, mut parked_id_rx) = mpsc::unbounded(); + let parked_first = Arc::new(AtomicBool::new(false)); + + let agent = Agent.builder().on_receive_request( + { + let parked_first = parked_first.clone(); + async move |initialize: InitializeRequest, + responder: Responder, + cx: ConnectionTo| { + if !parked_first.swap(true, Ordering::SeqCst) { + parked_id_tx.unbounded_send(responder.id()).unwrap(); + let cancellation = responder.cancellation(); + cx.spawn(async move { + let response = cancellation + .run_until_cancelled(std::future::pending::< + Result, + >()) + .await; + responder.respond_with_result(response) + })?; + return Ok(()); + } + + responder.respond(InitializeResponse::new(initialize.protocol_version)) + } + }, + agent_client_protocol::on_receive_request!(), + ); + + // The cancel observer must be registered on the same builder; do it + // separately so the request handler above can keep its captures. + let agent = agent.on_receive_notification( + async move |cancel: CancelRequestNotification, _cx: ConnectionTo| { + agent_cancel_tx.unbounded_send(cancel.request_id).unwrap(); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + let (editor_write, conductor_read) = duplex(8192); + let (conductor_write, editor_read) = duplex(8192); + + // The passthrough proxy is what exercises the conductor's + // `initialize` -> `_proxy/initialize` rewriting hop. + let conductor_handle = tokio::spawn(async move { + ConductorImpl::new_agent( + "cancellation-conductor".to_string(), + ProxiesAndAgent::new(agent).proxy(Proxy.builder()), + ) + .run(ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) + .await + }); + + let client_request_id = tokio::time::timeout(Duration::from_secs(30), async move { + Client + .builder() + .connect_with( + ByteStreams::new(editor_write.compat_write(), editor_read.compat()), + async |cx| { + let request: SentRequest = + cx.send_request(InitializeRequest::new(ProtocolVersion::V1)); + let client_request_id = request.id(); + request.cancel()?; + + let error = request + .block_task() + .await + .expect_err("initialize should be cancelled"); + assert_eq!(i32::from(error.code), -32800); + + // Barrier through the whole chain: initializing again + // still works after the cancelled attempt. + let initialize = cx + .send_request(InitializeRequest::new(ProtocolVersion::V1)) + .block_task() + .await?; + assert_eq!(initialize.protocol_version, ProtocolVersion::V1); + + Ok(client_request_id) + }, + ) + .await + }) + .await + .expect("test timed out") + .expect("client failed"); + + // The agent saw exactly one `$/cancel_request`, for the `initialize` ID + // on its own connection. + let parked_id = next_with_timeout(&mut parked_id_rx).await; + assert_ne!( + parked_id, client_request_id, + "each hop must re-issue the request under its own ID" + ); + let observed = next_with_timeout(&mut agent_cancel_rx).await; + assert_eq!(serde_json::to_value(observed).unwrap(), parked_id); + assert_no_event(&mut agent_cancel_rx); + + conductor_handle.abort(); + Ok(()) +} diff --git a/src/agent-client-protocol/src/concepts/cancellation.rs b/src/agent-client-protocol/src/concepts/cancellation.rs index 07ba06a..975a870 100644 --- a/src/agent-client-protocol/src/concepts/cancellation.rs +++ b/src/agent-client-protocol/src/concepts/cancellation.rs @@ -101,6 +101,45 @@ //! was allocated on a different connection and would be meaningless to the //! next peer. //! +//! ## Custom methods on proxies +//! +//! A proxy that intercepts a method with its own handler decides what +//! cancellation means for it. The SDK always records the cancellation on the +//! request's [`Responder`] marker before the handler chain runs; what happens +//! next is up to the handler that owns the request: +//! +//! - **Handle locally**: react to [`Responder::cancellation`] like any +//! request handler (ignore it, finish early, or respond with +//! [`Error::request_cancelled`]). +//! - **Forward and propagate**: use [`forward_response_to`], or, when the +//! forwarding needs custom logic (rewriting the request, post-processing +//! the result), register the upstream marker explicitly with +//! [`forward_cancellation_from`] before consuming the handle: +//! +//! ``` +//! # use agent_client_protocol::{ConnectionTo, Error, Responder, UntypedRole}; +//! # use agent_client_protocol_test::{MyRequest, MyResponse}; +//! # async fn example(request: MyRequest, responder: Responder, backend: ConnectionTo) -> Result<(), Error> { +//! backend +//! .send_request(request) +//! .forward_cancellation_from(responder.cancellation()) +//! .on_receiving_result(async move |result| { +//! // Custom result handling before responding upstream. +//! responder.respond_with_result(result) +//! })?; +//! # Ok(()) +//! # } +//! ``` +//! +//! - **Absorb**: consume the handle without registering the marker +//! ([`on_receiving_result`] or [`block_task`] alone); the upstream marker is +//! still set, but nothing is sent downstream and the request runs to +//! completion there. +//! - **Custom routing**: claim the `$/cancel_request` notification itself in a +//! handler (user handlers run before the generic forwarding fallbacks) and +//! translate it manually, for example with +//! [`ConnectionTo::send_cancel_request_to`]. +//! //! # Low-level access //! //! Register [`CancelRequestNotification`] (or [`ProtocolLevelNotification`]) @@ -142,6 +181,9 @@ //! [`ConnectionTo::spawn`]: crate::ConnectionTo::spawn //! [`SentRequest`]: crate::SentRequest //! [`SentRequest::cancel`]: crate::SentRequest::cancel +//! [`forward_cancellation_from`]: crate::SentRequest::forward_cancellation_from +//! [`ConnectionTo::send_cancel_request_to`]: crate::ConnectionTo::send_cancel_request_to +//! [`Responder::cancellation`]: crate::Responder::cancellation //! [`Responder`]: crate::Responder //! [`Error::request_cancelled`]: crate::Error::request_cancelled //! [`CancelRequestNotification`]: crate::schema::CancelRequestNotification diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index 8043503..e55d36a 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -3542,6 +3542,11 @@ pub struct SentRequest { to_result: Box Result + Send>, #[cfg(feature = "unstable_cancel_request")] cancellation: SentRequestCancellation, + /// Cancellation markers of other (incoming) requests whose cancellation + /// should be forwarded to this request. See + /// [`forward_cancellation_from`](Self::forward_cancellation_from). + #[cfg(feature = "unstable_cancel_request")] + cancellation_sources: Vec, } #[cfg(feature = "unstable_cancel_request")] @@ -3629,6 +3634,54 @@ impl Debug for SentRequestCancellation { } } +/// Await the response payload for an outgoing request, watching `sources` for +/// cancellation of the upstream requests it was registered with. +/// +/// When any source reports cancellation, a `$/cancel_request` is forwarded to +/// the outgoing request (at most once, shared with [`SentRequest::cancel`] and +/// drop-time auto-cancellation), and the response is *still* awaited: the peer +/// always answers, with normal data or a cancellation error. +/// +/// Watching is deliberately bounded by response arrival so that completed +/// requests do not leak waiters on markers that will never fire. +#[cfg(feature = "unstable_cancel_request")] +async fn await_response_forwarding_cancellation( + response_rx: oneshot::Receiver, + cancellation: &SentRequestCancellation, + sources: &[RequestCancellation], +) -> Result { + // Failing to forward the cancellation must not abort the wait: the + // response (normal data or a cancellation error) may still arrive and + // must still be processed. + let forward_cancellation = || { + if let Err(error) = cancellation.send() { + tracing::debug!( + ?error, + "failed to forward cancellation to downstream request" + ); + } + }; + + let response = if sources.is_empty() { + response_rx.await + } else if sources.iter().any(RequestCancellation::is_cancelled) { + forward_cancellation(); + response_rx.await + } else { + let cancelled = sources.iter().map(|source| source.state.signal_rx.clone()); + match future::select(future::select_all(cancelled), response_rx).await { + Either::Left((_, response_rx)) => { + forward_cancellation(); + response_rx.await + } + Either::Right((response, _)) => response, + } + }; + + cancellation.disarm(); + response +} + impl Debug for SentRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut debug = f.debug_struct("SentRequest"); @@ -3638,7 +3691,9 @@ impl Debug for SentRequest { .field("task_tx", &self.task_tx) .field("response_rx", &self.response_rx); #[cfg(feature = "unstable_cancel_request")] - debug.field("cancellation", &self.cancellation); + debug + .field("cancellation", &self.cancellation) + .field("cancellation_sources", &self.cancellation_sources); debug.finish_non_exhaustive() } } @@ -3659,6 +3714,8 @@ impl SentRequest { to_result: Box::new(Ok), #[cfg(feature = "unstable_cancel_request")] cancellation, + #[cfg(feature = "unstable_cancel_request")] + cancellation_sources: Vec::new(), } } } @@ -3683,6 +3740,50 @@ impl SentRequest { pub fn cancel(&self) -> Result<(), crate::Error> { self.cancellation.send() } + + /// Forward cancellation of another request to this one. + /// + /// When the request that `source` belongs to is cancelled by its peer, + /// a `$/cancel_request` for *this* request is sent to its peer, using the + /// same wrapping as the original request. The response is still awaited + /// and delivered as usual (normal data or a cancellation error), so this + /// composes with [`block_task`](Self::block_task) and + /// [`on_receiving_result`](Self::on_receiving_result). + /// + /// This is the building block for proxies that forward a request with + /// custom logic instead of [`forward_response_to`](Self::forward_response_to) + /// (which wires this up automatically from its responder). Without it, + /// custom forwarding *absorbs* cancellation: the upstream marker is still + /// set, but nothing is sent downstream. + /// + /// ``` + /// # use agent_client_protocol::{ConnectionTo, Error, Responder, UntypedRole}; + /// # use agent_client_protocol_test::{MyRequest, MyResponse}; + /// # async fn example(request: MyRequest, responder: Responder, backend: ConnectionTo) -> Result<(), Error> { + /// backend + /// .send_request(request) + /// .forward_cancellation_from(responder.cancellation()) + /// .on_receiving_result(async move |result| { + /// // Custom result handling, e.g. bookkeeping or rewriting. + /// responder.respond_with_result(result) + /// })?; + /// # Ok(()) + /// # } + /// ``` + /// + /// May be called multiple times; cancellation of any registered source + /// triggers the forwarding (at most one `$/cancel_request` is ever sent + /// per request). Sources are observed while the response is being + /// awaited — that is, once the handle is consumed with + /// [`block_task`](Self::block_task), + /// [`on_receiving_result`](Self::on_receiving_result), or + /// [`forward_response_to`](Self::forward_response_to); a source that was + /// already cancelled by then is honored immediately. + #[cfg(feature = "unstable_cancel_request")] + pub fn forward_cancellation_from(mut self, source: RequestCancellation) -> Self { + self.cancellation_sources.push(source); + self + } } impl SentRequest { @@ -3711,6 +3812,8 @@ impl SentRequest { to_result: Box::new(move |value| map_fn((self.to_result)(value)?)), #[cfg(feature = "unstable_cancel_request")] cancellation: self.cancellation, + #[cfg(feature = "unstable_cancel_request")] + cancellation_sources: self.cancellation_sources, } } @@ -3772,7 +3875,9 @@ impl SentRequest { /// - When the `unstable_cancel_request` feature is enabled and the peer /// cancels the incoming request, the cancellation is forwarded to the /// outgoing request, and the downstream response (normal data or a - /// cancellation error) is still forwarded back. + /// cancellation error) is still forwarded back. This is equivalent to + /// registering the responder's marker with + /// [`forward_cancellation_from`](Self::forward_cancellation_from). #[track_caller] pub fn forward_response_to(self, responder: Responder) -> Result<(), crate::Error> where @@ -3785,42 +3890,21 @@ impl SentRequest { #[cfg(feature = "unstable_cancel_request")] let downstream_cancellation = self.cancellation; #[cfg(feature = "unstable_cancel_request")] - let upstream_cancellation = responder.cancellation(); + let cancellation_sources = { + let mut sources = self.cancellation_sources; + sources.push(responder.cancellation()); + sources + }; let location = Location::caller(); Task::new(location, async move { #[cfg(feature = "unstable_cancel_request")] - let response = { - // Failing to forward the cancellation must not abort this - // task: the downstream response (normal data or a - // cancellation error) may still arrive and must still be - // forwarded upstream. - let forward_cancellation = |cancellation: &SentRequestCancellation| { - if let Err(error) = cancellation.send() { - tracing::debug!( - ?error, - "failed to forward cancellation to downstream request" - ); - } - }; - - let response = if upstream_cancellation.is_cancelled() { - forward_cancellation(&downstream_cancellation); - response_rx.await - } else { - match future::select(pin!(upstream_cancellation.cancelled()), response_rx).await - { - Either::Left(((), response_rx)) => { - forward_cancellation(&downstream_cancellation); - response_rx.await - } - Either::Right((response, _)) => response, - } - }; - - downstream_cancellation.disarm(); - response - }; + let response = await_response_forwarding_cancellation( + response_rx, + &downstream_cancellation, + &cancellation_sources, + ) + .await; #[cfg(not(feature = "unstable_cancel_request"))] let response = response_rx.await; @@ -3918,14 +4002,21 @@ impl SentRequest { where T: Send, { - match self.response_rx.await { + #[cfg(feature = "unstable_cancel_request")] + let response = await_response_forwarding_cancellation( + self.response_rx, + &self.cancellation, + &self.cancellation_sources, + ) + .await; + #[cfg(not(feature = "unstable_cancel_request"))] + let response = self.response_rx.await; + + match response { Ok(ResponsePayload { result: Ok(json_value), ack_tx, }) => { - #[cfg(feature = "unstable_cancel_request")] - self.cancellation.disarm(); - // Ack immediately - we're in a spawned task, so the dispatch loop // can continue while we process the value. if let Some(tx) = ack_tx { @@ -3940,23 +4031,15 @@ impl SentRequest { result: Err(err), ack_tx, }) => { - #[cfg(feature = "unstable_cancel_request")] - self.cancellation.disarm(); - if let Some(tx) = ack_tx { let _ = tx.send(()); } Err(err) } - Err(err) => { - #[cfg(feature = "unstable_cancel_request")] - self.cancellation.disarm(); - - Err(crate::util::internal_error(format!( - "response to `{}` never received: {}", - self.method, err - ))) - } + Err(err) => Err(crate::util::internal_error(format!( + "response to `{}` never received: {}", + self.method, err + ))), } } @@ -4108,14 +4191,23 @@ impl SentRequest { let to_result = self.to_result; #[cfg(feature = "unstable_cancel_request")] let cancellation = self.cancellation; + #[cfg(feature = "unstable_cancel_request")] + let cancellation_sources = self.cancellation_sources; let location = Location::caller(); Task::new(location, async move { - match response_rx.await { - Ok(ResponsePayload { result, ack_tx }) => { - #[cfg(feature = "unstable_cancel_request")] - cancellation.disarm(); + #[cfg(feature = "unstable_cancel_request")] + let response = await_response_forwarding_cancellation( + response_rx, + &cancellation, + &cancellation_sources, + ) + .await; + #[cfg(not(feature = "unstable_cancel_request"))] + let response = response_rx.await; + match response { + Ok(ResponsePayload { result, ack_tx }) => { // Convert the result using to_result for Ok values let typed_result = match result { Ok(json_value) => to_result(json_value), @@ -4133,14 +4225,9 @@ impl SentRequest { outcome } - Err(err) => { - #[cfg(feature = "unstable_cancel_request")] - cancellation.disarm(); - - Err(crate::util::internal_error(format!( - "response to `{method}` never received: {err}" - ))) - } + Err(err) => Err(crate::util::internal_error(format!( + "response to `{method}` never received: {err}" + ))), } }) .spawn(&task_tx) diff --git a/src/agent-client-protocol/src/role/acp.rs b/src/agent-client-protocol/src/role/acp.rs index b103890..5917655 100644 --- a/src/agent-client-protocol/src/role/acp.rs +++ b/src/agent-client-protocol/src/role/acp.rs @@ -239,7 +239,14 @@ impl Role for Conductor { // New session coming from the client -- proxy to the agent // and add a dynamic handler for that session-id. .if_request_from(Client, async |request: NewSessionRequest, responder| { - cx.send_request_to(Agent, request).on_receiving_result({ + let sent = cx.send_request_to(Agent, request); + // The dynamic-handler hook below means we cannot use + // `forward_response_to`, so wire up cancellation forwarding + // explicitly to keep `session/new` cancellable like every + // other proxied request. + #[cfg(feature = "unstable_cancel_request")] + let sent = sent.forward_cancellation_from(responder.cancellation()); + sent.on_receiving_result({ let cx = cx.clone(); async move |result| { if let Ok(NewSessionResponse { session_id, .. }) = &result { diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs index fa976b4..f922883 100644 --- a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -1380,6 +1380,274 @@ async fn send_proxied_message_does_not_tunnel_cancel_notifications() { .await; } +/// Spawn a backend whose `park` requests wait until the cancel observer +/// releases them, reporting parked request ids and observed cancellations. +/// +/// Returns the proxy-side connection to the backend. +async fn spawn_parking_backend( + honor_cancellations: bool, + parked_id_tx: mpsc::UnboundedSender, + backend_cancel_tx: mpsc::UnboundedSender, +) -> ConnectionTo { + let pending_responder: Arc>>> = + Arc::new(Mutex::new(None)); + + let (backend_for_proxy, backend_for_server) = Channel::duplex(); + let (backend_connection_tx, backend_connection_rx) = futures::channel::oneshot::channel(); + + tokio::task::spawn_local(async move { + let result = UntypedRole + .builder() + .connect_with(backend_for_proxy, async |connection| { + drop(backend_connection_tx.send(connection.clone())); + std::future::pending::>().await + }) + .await; + if let Err(error) = result { + panic!("proxy-to-backend connection should stay alive: {error:?}"); + } + }); + + let backend_server = UntypedRole + .builder() + .on_receive_request( + { + let pending_responder = pending_responder.clone(); + async move |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + match request.message.as_str() { + "park" => { + parked_id_tx.unbounded_send(responder.id()).unwrap(); + *pending_responder.lock().unwrap() = Some(responder); + Ok(()) + } + "release" => { + if let Some(parked) = pending_responder.lock().unwrap().take() { + parked.respond(SimpleResponse { + result: "released".into(), + })?; + } + responder.respond(SimpleResponse { + result: "echo: release".into(), + }) + } + other => responder.respond(SimpleResponse { + result: format!("echo: {other}"), + }), + } + } + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + { + let pending_responder = pending_responder.clone(); + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + if honor_cancellations + && let Some(responder) = pending_responder.lock().unwrap().take() + { + responder.respond_with_result(Err( + agent_client_protocol::Error::request_cancelled(), + ))?; + } + backend_cancel_tx + .unbounded_send(notification.request_id) + .unwrap(); + Ok(()) + } + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = backend_server.connect_to(backend_for_server).await { + panic!("backend server should stay alive: {error:?}"); + } + }); + + backend_connection_rx + .await + .expect("backend connection should start") +} + +#[tokio::test(flavor = "current_thread")] +async fn custom_forwarding_propagates_cancellation_when_opted_in() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (backend_cancel_tx, mut backend_cancel_rx) = mpsc::unbounded(); + let (parked_id_tx, mut parked_id_rx) = mpsc::unbounded(); + + let backend_connection = + spawn_parking_backend(true, parked_id_tx, backend_cancel_tx).await; + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let proxy_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + // A proxy with a *custom* method handler: it forwards with + // `on_receiving_result` (so it could post-process the result) and + // opts into cancellation propagation explicitly. + let proxy = UntypedRole.builder().on_receive_request( + { + let backend_connection = backend_connection.clone(); + async move |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + backend_connection + .send_request(request) + .forward_cancellation_from(responder.cancellation()) + .on_receiving_result(async move |result| { + responder.respond_with_result(result) + }) + } + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = proxy.connect_to(proxy_transport).await { + panic!("proxy should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + let client_request_id = UntypedRole + .builder() + .connect_with(client_transport, async |connection| { + let request: SentRequest = + connection.send_request(SimpleRequest { + message: "park".into(), + }); + let client_request_id = request.id(); + request.cancel()?; + + let error = request + .block_task() + .await + .expect_err("request should be cancelled"); + assert_eq!(i32::from(error.code), -32800); + + let barrier = connection + .send_request(SimpleRequest { + message: "barrier".into(), + }) + .block_task() + .await?; + assert_eq!(barrier.result, "echo: barrier"); + Ok(client_request_id) + }) + .await + .unwrap(); + + // Exactly one cancellation reached the backend, re-issued under + // the proxy's downstream request ID. + let parked_id = next_with_timeout(&mut parked_id_rx).await; + assert_ne!(parked_id, client_request_id); + let observed = next_with_timeout(&mut backend_cancel_rx).await; + assert_eq!(serde_json::to_value(observed).unwrap(), parked_id); + assert_no_event(&mut backend_cancel_rx); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn custom_forwarding_absorbs_cancellation_by_default() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (backend_cancel_tx, mut backend_cancel_rx) = mpsc::unbounded(); + let (parked_id_tx, mut parked_id_rx) = mpsc::unbounded(); + + let backend_connection = + spawn_parking_backend(false, parked_id_tx, backend_cancel_tx).await; + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let proxy_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + // The same custom forwarding *without* opting into propagation: + // the implementor decided cancellation stops at this hop. + let proxy = UntypedRole.builder().on_receive_request( + { + let backend_connection = backend_connection.clone(); + async move |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + backend_connection + .send_request(request) + .on_receiving_result(async move |result| { + responder.respond_with_result(result) + }) + } + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = proxy.connect_to(proxy_transport).await { + panic!("proxy should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + UntypedRole + .builder() + .connect_with(client_transport, async |connection| { + let request: SentRequest = + connection.send_request(SimpleRequest { + message: "park".into(), + }); + request.cancel()?; + + // Barrier: the cancellation has now been processed by the + // proxy (and would have been processed by the backend if + // it had been forwarded). + let barrier = connection + .send_request(SimpleRequest { + message: "barrier".into(), + }) + .block_task() + .await?; + assert_eq!(barrier.result, "echo: barrier"); + assert_no_event(&mut backend_cancel_rx); + + // Release the parked request: the cancelled request still + // completes with normal data, because the proxy absorbed + // the cancellation. + let release = connection + .send_request(SimpleRequest { + message: "release".into(), + }) + .block_task() + .await?; + assert_eq!(release.result, "echo: release"); + + let response = request + .block_task() + .await + .expect("absorbed cancellation must not fail the request"); + assert_eq!(response.result, "released"); + Ok(()) + }) + .await + .unwrap(); + + // The backend never saw any `$/cancel_request`. + let _parked_id = next_with_timeout(&mut parked_id_rx).await; + assert_no_event(&mut backend_cancel_rx); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn cancellation_marker_requested_after_cancel_is_already_cancelled() { use tokio::task::LocalSet; From cee968622173bb1dafd3cf124ea70b75ee7ecd7f Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Thu, 11 Jun 2026 08:16:15 +0200 Subject: [PATCH 15/18] fix(acp): Drop wrapped cancel notifications --- .../src/conductor.rs | 4 +- src/agent-client-protocol/src/concepts/mod.rs | 1 - src/agent-client-protocol/src/jsonrpc.rs | 108 ++++- src/agent-client-protocol/src/lib.rs | 4 +- .../tests/jsonrpc_request_cancellation.rs | 247 +++++----- .../tests/jsonrpc_unhandled_messages.rs | 455 ++++++++++++++++++ 6 files changed, 669 insertions(+), 150 deletions(-) create mode 100644 src/agent-client-protocol/tests/jsonrpc_unhandled_messages.rs diff --git a/src/agent-client-protocol-conductor/src/conductor.rs b/src/agent-client-protocol-conductor/src/conductor.rs index badb90f..eb7c0cf 100644 --- a/src/agent-client-protocol-conductor/src/conductor.rs +++ b/src/agent-client-protocol-conductor/src/conductor.rs @@ -470,9 +470,7 @@ where // `forward_response_to` calls above, so drop the raw // notification instead of tunneling a meaningless ID. #[cfg(feature = "unstable_cancel_request")] - if agent_client_protocol::schema::CancelRequestNotification::matches_method( - notification.method(), - ) { + if agent_client_protocol::is_cancel_request_notification(¬ification) { tracing::debug!( "not forwarding hop-scoped `$/cancel_request` notification to predecessor" ); diff --git a/src/agent-client-protocol/src/concepts/mod.rs b/src/agent-client-protocol/src/concepts/mod.rs index 1aee662..d752c01 100644 --- a/src/agent-client-protocol/src/concepts/mod.rs +++ b/src/agent-client-protocol/src/concepts/mod.rs @@ -36,7 +36,6 @@ pub mod acp_basics; pub mod callbacks; #[cfg(feature = "unstable_cancel_request")] -#[cfg_attr(docsrs, doc(cfg(feature = "unstable_cancel_request")))] pub mod cancellation; pub mod connections; pub mod error_handling; diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index e55d36a..bdb7f56 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -1689,7 +1689,7 @@ struct RequestCancellationSlot { #[cfg(feature = "unstable_cancel_request")] #[derive(Debug, Default)] struct RequestCancellationRegistryInner { - slots: HashMap, + slots: HashMap, next_generation: u64, } @@ -1706,7 +1706,7 @@ struct RequestCancellationRegistry; #[cfg(feature = "unstable_cancel_request")] #[derive(Debug)] struct ResponderCancellation { - id: serde_json::Value, + id: RequestId, generation: u64, registry: RequestCancellationRegistry, } @@ -1721,7 +1721,7 @@ impl RequestCancellationRegistry { Self::default() } - fn register(&self, id: serde_json::Value) -> ResponderCancellation { + fn register(&self, id: &RequestId) -> ResponderCancellation { let generation = { let mut inner = self .inner @@ -1748,7 +1748,7 @@ impl RequestCancellationRegistry { generation }; ResponderCancellation { - id, + id: id.clone(), generation, registry: self.clone(), } @@ -1762,7 +1762,7 @@ impl RequestCancellationRegistry { /// was already removed by it), every call returns a fresh *detached* /// marker. Detached markers can never fire, and detached markers from /// repeated calls do not share state with each other. - fn marker(&self, id: &serde_json::Value, generation: u64) -> RequestCancellation { + fn marker(&self, id: &RequestId, generation: u64) -> RequestCancellation { let mut inner = self .inner .lock() @@ -1808,7 +1808,7 @@ impl RequestCancellationRegistry { } /// Mark whichever request currently owns `request_id` as cancelled. - fn cancel(&self, request_id: &serde_json::Value) -> bool { + fn cancel(&self, request_id: &RequestId) -> bool { let marker = { let mut inner = self .inner @@ -1836,7 +1836,7 @@ impl RequestCancellationRegistry { /// Remove the slot for `request_id`, but only if it still belongs to the /// registration identified by `generation`. - fn remove(&self, request_id: &serde_json::Value, generation: u64) { + fn remove(&self, request_id: &RequestId, generation: u64) { let mut inner = self .inner .lock() @@ -1861,7 +1861,7 @@ impl RequestCancellationRegistry { clippy::unused_self, reason = "feature-disabled stub mirrors the real registry API" )] - fn register(&self, _id: serde_json::Value) -> ResponderCancellation { + fn register(&self, _id: &RequestId) -> ResponderCancellation { ResponderCancellation } @@ -1890,7 +1890,7 @@ impl Drop for ResponderCancellation { } #[cfg(feature = "unstable_cancel_request")] -fn cancellation_request_id(dispatch: &Dispatch) -> Result, crate::Error> { +fn cancellation_request_id(dispatch: &Dispatch) -> Result, crate::Error> { let Dispatch::Notification(message) = dispatch else { return Ok(None); }; @@ -1900,16 +1900,14 @@ fn cancellation_request_id(dispatch: &Dispatch) -> Result Result, crate::Error> { +) -> Result, crate::Error> { let (method, params) = peel_successor_envelopes(&message.method, &message.params); if !crate::schema::CancelRequestNotification::matches_method(method) { return Ok(None); } let notification = crate::schema::CancelRequestNotification::parse_message(method, params)?; - serde_json::to_value(notification.request_id) - .map(Some) - .map_err(crate::Error::into_internal_error) + Ok(Some(notification.request_id)) } /// Peel any [`SuccessorMessage`] envelopes off a notification by reference, @@ -1938,6 +1936,46 @@ fn peel_successor_envelopes<'message>( (method, params) } +/// Whether a notification is a `$/cancel_request`, even when it is still +/// wrapped in `_proxy/successor` envelopes. +/// +/// `$/cancel_request` is connection-scoped: its `requestId` was allocated on +/// the connection the notification arrived over and means nothing on any +/// other connection. Generic forwarding code (such as +/// [`ConnectionTo::send_proxied_message_to`]) uses this check to drop the raw +/// notification instead of tunneling it across a hop; the cancellation still +/// propagates because [`forward_response_to`](SentRequest::forward_response_to) +/// re-issues it with the forwarded request's own ID. +/// +/// Checking a notification whose method is not the successor envelope is a +/// plain method-name comparison. Only successor-wrapped notifications pay for +/// a serialization to peel the envelope. +#[cfg(feature = "unstable_cancel_request")] +#[must_use] +pub fn is_cancel_request_notification(notification: &N) -> bool { + let method = notification.method(); + if crate::schema::CancelRequestNotification::matches_method(method) { + return true; + } + if !crate::schema::SuccessorMessage::::matches_method(method) { + return false; + } + + match notification.to_untyped_message() { + Ok(untyped) => { + let (method, _params) = peel_successor_envelopes(&untyped.method, &untyped.params); + crate::schema::CancelRequestNotification::matches_method(method) + } + Err(error) => { + tracing::debug!( + ?error, + "failed to inspect successor-wrapped notification for cancellation" + ); + false + } + } +} + /// Whether the dispatch is a protocol-level (`$/`-prefixed) notification, /// possibly wrapped in a [`SuccessorMessage`] envelope. /// @@ -2253,7 +2291,7 @@ impl ConnectionTo { // with the correct per-hop ID, so drop the raw notification // instead of tunneling a meaningless ID across the hop. #[cfg(feature = "unstable_cancel_request")] - if crate::schema::CancelRequestNotification::matches_method(notification.method()) { + if is_cancel_request_notification(¬ification) { tracing::debug!( "not forwarding hop-scoped `$/cancel_request` notification across proxy hop" ); @@ -2646,7 +2684,7 @@ impl Responder { ) -> Self { let id_clone = id.clone(); let method_clone = method.clone(); - let cancellation = cancellation_registry.register(crate::util::id_to_json(&id)); + let cancellation = cancellation_registry.register(&id); Self { method, id, @@ -4560,7 +4598,7 @@ mod tests { ); let request_id = cancellation_request_id_from_message(&message) .expect("wrapped cancel should parse"); - assert_eq!(request_id, Some(serde_json::json!("req-1"))); + assert_eq!(request_id, Some(RequestId::Str("req-1".into()))); } #[test] @@ -4573,6 +4611,34 @@ mod tests { assert_eq!(request_id, None); } + #[test] + fn cancel_request_notifications_are_detected_even_when_wrapped() { + let plain = notification("$/cancel_request", serde_json::json!({ "requestId": 1 })); + assert!(is_cancel_request_notification(&plain)); + + let wrapped = notification( + "_proxy/successor", + serde_json::json!({ + "method": "$/cancel_request", + "params": { "requestId": 1 } + }), + ); + assert!(is_cancel_request_notification(&wrapped)); + + let other_wrapped = notification( + "_proxy/successor", + serde_json::json!({ + "method": "session/update", + "params": {} + }), + ); + assert!(!is_cancel_request_notification(&other_wrapped)); + + let malformed_envelope = + notification("_proxy/successor", serde_json::json!({ "bogus": true })); + assert!(!is_cancel_request_notification(&malformed_envelope)); + } + #[test] fn malformed_cancel_request_params_error() { let message = notification( @@ -4586,9 +4652,9 @@ mod tests { #[test] fn registry_marks_and_removes_requests() { let registry = RequestCancellationRegistry::new(); - let id = serde_json::json!("req-1"); + let id = RequestId::Str("req-1".into()); - let responder_cancellation = registry.register(id.clone()); + let responder_cancellation = registry.register(&id); let marker = responder_cancellation.cancellation(); assert!(!marker.is_cancelled()); @@ -4603,12 +4669,12 @@ mod tests { #[test] fn reused_request_id_does_not_cross_wire_cancellation_state() { let registry = RequestCancellationRegistry::new(); - let id = serde_json::json!("dup"); + let id = RequestId::Str("dup".into()); // A protocol-violating peer reuses an in-flight request ID. - let first = registry.register(id.clone()); + let first = registry.register(&id); let first_marker = first.cancellation(); - let second = registry.register(id.clone()); + let second = registry.register(&id); let second_marker = second.cancellation(); // A cancellation targets whichever request currently owns the ID. diff --git a/src/agent-client-protocol/src/lib.rs b/src/agent-client-protocol/src/lib.rs index 3c36d12..2edd3de 100644 --- a/src/agent-client-protocol/src/lib.rs +++ b/src/agent-client-protocol/src/lib.rs @@ -94,8 +94,6 @@ pub mod util; pub use capabilities::*; -#[cfg(feature = "unstable_cancel_request")] -pub use jsonrpc::RequestCancellation; pub use jsonrpc::{ Builder, ByteStreams, Channel, ConnectionTo, Dispatch, HandleDispatchFrom, Handled, IntoHandled, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Lines, @@ -103,6 +101,8 @@ pub use jsonrpc::{ UntypedMessage, run::{ChainRun, NullRun, RunWithConnectionTo}, }; +#[cfg(feature = "unstable_cancel_request")] +pub use jsonrpc::{RequestCancellation, is_cancel_request_notification}; pub use role::{ Role, RoleId, UntypedRole, diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs index f922883..08d135f 100644 --- a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -256,74 +256,10 @@ impl agent_client_protocol::role::HasPeer for WrappedHost { } } -#[tokio::test(flavor = "current_thread")] -async fn unhandled_protocol_level_notifications_are_ignored() { - use tokio::io::{AsyncWriteExt, BufReader}; - use tokio::task::LocalSet; - - let local = LocalSet::new(); - - local - .run_until(async { - let (mut client_writer, server_reader) = tokio::io::duplex(4096); - let (server_writer, client_reader) = tokio::io::duplex(4096); - - let server_transport = agent_client_protocol::ByteStreams::new( - server_writer.compat_write(), - server_reader.compat(), - ); - let server = UntypedRole.builder().on_receive_request( - async |request: SimpleRequest, - responder: Responder, - _connection: ConnectionTo| { - responder.respond(SimpleResponse { - result: format!("echo: {}", request.message), - }) - }, - agent_client_protocol::on_receive_request!(), - ); - - tokio::task::spawn_local(async move { - if let Err(error) = server.connect_to(server_transport).await { - panic!("server should stay alive: {error:?}"); - } - }); - - let mut client_reader = BufReader::new(client_reader); - - client_writer - .write_all( - br#"{"jsonrpc":"2.0","method":"$/cancel_request","params":{"requestId":"req-1"}} -"#, - ) - .await - .unwrap(); - client_writer.flush().await.unwrap(); - - // The server processes messages in order: a response to this - // request proves the unknown `$/` notification before it was - // ignored without erroring or closing the connection. - client_writer - .write_all( - br#"{"jsonrpc":"2.0","id":2,"method":"simple_method","params":{"message":"after cancel"}} -"#, - ) - .await - .unwrap(); - client_writer.flush().await.unwrap(); - - let response = read_jsonrpc_response_line(&mut client_reader).await; - expect![[r#" - { - "jsonrpc": "2.0", - "id": 2, - "result": { - "result": "echo: after cancel" - } - }"#]] - .assert_eq(&serde_json::to_string_pretty(&response).unwrap()); - }) - .await; +impl agent_client_protocol::role::HasPeer for WrappedHost { + fn remote_style(&self, _peer: WrappedHost) -> agent_client_protocol::role::RemoteStyle { + agent_client_protocol::role::RemoteStyle::Counterpart + } } #[tokio::test(flavor = "current_thread")] @@ -406,61 +342,6 @@ async fn unhandled_wrapped_protocol_level_notifications_are_ignored() { .await; } -#[tokio::test(flavor = "current_thread")] -async fn malformed_successor_envelope_still_reaches_handlers() { - use tokio::io::AsyncWriteExt; - use tokio::task::LocalSet; - - let local = LocalSet::new(); - - local - .run_until(async { - let (notification_tx, mut notification_rx) = mpsc::unbounded(); - - let (mut client_writer, server_reader) = tokio::io::duplex(4096); - let (server_writer, _client_reader) = tokio::io::duplex(4096); - - let server_transport = agent_client_protocol::ByteStreams::new( - server_writer.compat_write(), - server_reader.compat(), - ); - // A catch-all notification handler: a successor envelope whose - // params cannot be peeled (no inner `method`) must not be - // mistaken for a cancellation and short-circuited; it must flow - // through the handler chain like any other notification. - let server = UntypedRole.builder().on_receive_notification( - async move |notification: agent_client_protocol::UntypedMessage, - _connection: ConnectionTo| { - notification_tx - .unbounded_send((notification.method, notification.params)) - .unwrap(); - Ok(()) - }, - agent_client_protocol::on_receive_notification!(), - ); - - tokio::task::spawn_local(async move { - if let Err(error) = server.connect_to(server_transport).await { - panic!("server should stay alive: {error:?}"); - } - }); - - client_writer - .write_all( - br#"{"jsonrpc":"2.0","method":"_proxy/successor","params":{"bogus":true}} -"#, - ) - .await - .unwrap(); - client_writer.flush().await.unwrap(); - - let (method, params) = next_with_timeout(&mut notification_rx).await; - assert_eq!(method, "_proxy/successor"); - assert_eq!(params, serde_json::json!({ "bogus": true })); - }) - .await; -} - #[tokio::test(flavor = "current_thread")] async fn wrapped_cancel_request_cancels_wrapped_request() { use tokio::io::{AsyncWriteExt, BufReader}; @@ -1380,6 +1261,126 @@ async fn send_proxied_message_does_not_tunnel_cancel_notifications() { .await; } +/// A proxy that forwards raw dispatches with `send_proxied_message` can see a +/// `$/cancel_request` that is still wrapped in a `_proxy/successor` envelope: +/// raw dispatch handlers run before any peer-specific unwrapping. The +/// hop-scoped filter must peel the envelope and drop the notification rather +/// than tunnel it to the next peer. +#[tokio::test(flavor = "current_thread")] +async fn send_proxied_message_does_not_tunnel_wrapped_cancel_notifications() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + // Every notification method the backend observes. + let (backend_notification_tx, mut backend_notification_rx) = mpsc::unbounded(); + + let (backend_for_proxy, backend_for_server) = Channel::duplex(); + let (backend_connection_tx, backend_connection_rx) = + futures::channel::oneshot::channel(); + + tokio::task::spawn_local(async move { + let result = UntypedRole + .builder() + .connect_with(backend_for_proxy, async |connection| { + drop(backend_connection_tx.send(connection.clone())); + std::future::pending::>().await + }) + .await; + if let Err(error) = result { + panic!("proxy-to-backend connection should stay alive: {error:?}"); + } + }); + + let backend_server = UntypedRole + .builder() + .on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: agent_client_protocol::UntypedMessage, + _connection: ConnectionTo| { + backend_notification_tx + .unbounded_send(notification.method) + .unwrap(); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = backend_server.connect_to(backend_for_server).await { + panic!("backend server should stay alive: {error:?}"); + } + }); + + let backend_connection = backend_connection_rx + .await + .expect("backend connection should start"); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let proxy_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + // The raw dispatch handler receives successor-addressed messages + // still wrapped in their envelope and forwards them verbatim. + let proxy = WrappedHost.builder().on_receive_dispatch( + { + let backend_connection = backend_connection.clone(); + async move |dispatch: Dispatch, + _connection: ConnectionTo| { + backend_connection.send_proxied_message(dispatch) + } + }, + agent_client_protocol::on_receive_dispatch!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = proxy.connect_to(proxy_transport).await { + panic!("proxy should stay alive: {error:?}"); + } + }); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + WrappedCounterpart + .builder() + .connect_with(client_transport, async |cx| { + // A successor-wrapped `$/cancel_request`, exactly as + // produced when cancelling a request sent to a successor + // peer. + cx.send_cancel_request_to(WrappedSuccessor, "req-1".to_string())?; + + // Barrier: both hops have processed the notification by + // the time this completes, so a tunneled wrapped cancel + // would already have been recorded by the backend. + let barrier = cx + .send_request(SimpleRequest { + message: "barrier".into(), + }) + .block_task() + .await?; + assert_eq!(barrier.result, "echo: barrier"); + Ok(()) + }) + .await + .unwrap(); + + // The backend saw no notification at all: the wrapped cancel was + // dropped at the proxy hop instead of being tunneled. + assert_no_event(&mut backend_notification_rx); + }) + .await; +} + /// Spawn a backend whose `park` requests wait until the cancel observer /// releases them, reporting parked request ids and observed cancellations. /// diff --git a/src/agent-client-protocol/tests/jsonrpc_unhandled_messages.rs b/src/agent-client-protocol/tests/jsonrpc_unhandled_messages.rs new file mode 100644 index 0000000..8b0a73c --- /dev/null +++ b/src/agent-client-protocol/tests/jsonrpc_unhandled_messages.rs @@ -0,0 +1,455 @@ +//! Tests for messages that nobody is waiting for or that no handler claims. +//! +//! Everything in this file holds **regardless of feature flags** — these are +//! baseline guarantees of the dispatch loop: +//! +//! - Unhandled protocol-level (`$/`-prefixed) notifications are ignored +//! instead of rejected, so peers that use optional protocol-level +//! extensions (such as `$/cancel_request`) interoperate with builds that do +//! not support them. +//! - A `_proxy/successor` envelope that cannot be peeled still reaches the +//! handler chain unchanged. +//! - A response routed to a request handle that was already dropped is +//! discarded without disturbing the connection. +//! - `forward_response_to` answers the incoming request with an error when +//! the pending response is dropped without ever being delivered, instead of +//! leaving the peer waiting forever. +//! +//! Like the other JSON-RPC tests, these avoid sleeps: messages are delivered +//! in order and each side's dispatch loop processes them sequentially, so a +//! request/response round trip acts as a barrier. + +use std::sync::{Arc, Mutex}; + +use agent_client_protocol::{ + Channel, ConnectionTo, Dispatch, Handled, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, + Responder, SentRequest, role::UntypedRole, +}; +use expect_test::expect; +use futures::StreamExt as _; +use futures::channel::mpsc; +use serde::{Deserialize, Serialize}; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + +/// Await the next item on `rx`, panicking instead of hanging if it never +/// arrives. +async fn next_with_timeout(rx: &mut mpsc::UnboundedReceiver) -> T { + tokio::time::timeout(tokio::time::Duration::from_secs(10), rx.next()) + .await + .expect("timed out waiting for channel event") + .expect("channel closed before expected event") +} + +async fn read_jsonrpc_response_line( + reader: &mut tokio::io::BufReader, +) -> serde_json::Value { + use tokio::io::AsyncBufReadExt as _; + + let mut line = String::new(); + match tokio::time::timeout( + tokio::time::Duration::from_secs(10), + reader.read_line(&mut line), + ) + .await + { + Ok(Ok(0)) | Err(_) => panic!("timed out waiting for JSON-RPC response"), + Ok(Ok(_)) => serde_json::from_str(line.trim()).expect("response should be valid JSON"), + Ok(Err(error)) => panic!("failed to read JSON-RPC response line: {error}"), + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleRequest { + message: String, +} + +impl JsonRpcMessage for SimpleRequest { + fn matches_method(method: &str) -> bool { + method == "simple_method" + } + + fn method(&self) -> &'static str { + "simple_method" + } + + fn to_untyped_message( + &self, + ) -> Result { + agent_client_protocol::UntypedMessage::new(self.method(), self) + } + + fn parse_message( + method: &str, + params: &impl Serialize, + ) -> Result { + if !Self::matches_method(method) { + return Err(agent_client_protocol::Error::method_not_found()); + } + agent_client_protocol::util::json_cast_params(params) + } +} + +impl JsonRpcRequest for SimpleRequest { + type Response = SimpleResponse; +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SimpleResponse { + result: String, +} + +impl JsonRpcResponse for SimpleResponse { + fn into_json(self, _method: &str) -> Result { + serde_json::to_value(self).map_err(agent_client_protocol::Error::into_internal_error) + } + + fn from_value( + _method: &str, + value: serde_json::Value, + ) -> Result { + agent_client_protocol::util::json_cast(&value) + } +} + +#[tokio::test(flavor = "current_thread")] +async fn unhandled_protocol_level_notifications_are_ignored() { + use tokio::io::{AsyncWriteExt, BufReader}; + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (mut client_writer, server_reader) = tokio::io::duplex(4096); + let (server_writer, client_reader) = tokio::io::duplex(4096); + + let server_transport = agent_client_protocol::ByteStreams::new( + server_writer.compat_write(), + server_reader.compat(), + ); + let server = UntypedRole.builder().on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let mut client_reader = BufReader::new(client_reader); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","method":"$/cancel_request","params":{"requestId":"req-1"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + // The server processes messages in order: a response to this + // request proves the unknown `$/` notification before it was + // ignored without erroring or closing the connection. + client_writer + .write_all( + br#"{"jsonrpc":"2.0","id":2,"method":"simple_method","params":{"message":"after cancel"}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let response = read_jsonrpc_response_line(&mut client_reader).await; + expect![[r#" + { + "jsonrpc": "2.0", + "id": 2, + "result": { + "result": "echo: after cancel" + } + }"#]] + .assert_eq(&serde_json::to_string_pretty(&response).unwrap()); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn malformed_successor_envelope_still_reaches_handlers() { + use tokio::io::AsyncWriteExt; + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (notification_tx, mut notification_rx) = mpsc::unbounded(); + + let (mut client_writer, server_reader) = tokio::io::duplex(4096); + let (server_writer, _client_reader) = tokio::io::duplex(4096); + + let server_transport = agent_client_protocol::ByteStreams::new( + server_writer.compat_write(), + server_reader.compat(), + ); + // A catch-all notification handler: a successor envelope whose + // params cannot be peeled (no inner `method`) must not be + // mistaken for a protocol-level notification and ignored; it must + // flow through the handler chain like any other notification. + let server = UntypedRole.builder().on_receive_notification( + async move |notification: agent_client_protocol::UntypedMessage, + _connection: ConnectionTo| { + notification_tx + .unbounded_send((notification.method, notification.params)) + .unwrap(); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + client_writer + .write_all( + br#"{"jsonrpc":"2.0","method":"_proxy/successor","params":{"bogus":true}} +"#, + ) + .await + .unwrap(); + client_writer.flush().await.unwrap(); + + let (method, params) = next_with_timeout(&mut notification_rx).await; + assert_eq!(method, "_proxy/successor"); + assert_eq!(params, serde_json::json!({ "bogus": true })); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn late_response_to_dropped_request_is_discarded_without_closing_connection() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + // The responder for the abandoned request, parked by the server + // until the client asks for its release. + let pending_responder: Arc>>> = + Arc::new(Mutex::new(None)); + + let (client_end, server_end) = Channel::duplex(); + + let server = UntypedRole.builder().on_receive_request( + { + let pending_responder = pending_responder.clone(); + async move |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + match request.message.as_str() { + "late" => { + *pending_responder.lock().unwrap() = Some(responder); + Ok(()) + } + "release" => { + // Answer the abandoned request first, then the + // release request: the late response is routed + // by the client before the release response. + if let Some(parked) = pending_responder.lock().unwrap().take() { + parked.respond(SimpleResponse { + result: "late response".into(), + })?; + } + responder.respond(SimpleResponse { + result: "echo: release".into(), + }) + } + other => responder.respond(SimpleResponse { + result: format!("echo: {other}"), + }), + } + } + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_end).await { + panic!("server should stay alive: {error:?}"); + } + }); + + UntypedRole + .builder() + .connect_with(client_end, async |cx| { + let request: SentRequest = cx.send_request(SimpleRequest { + message: "late".into(), + }); + drop(request); + + // By the time this round trip completes, the late response + // has already been routed into the dropped handle above. + let release = cx + .send_request(SimpleRequest { + message: "release".into(), + }) + .block_task() + .await?; + assert_eq!(release.result, "echo: release"); + + // The connection survived the unroutable response. + let after = cx + .send_request(SimpleRequest { + message: "after".into(), + }) + .block_task() + .await?; + assert_eq!(after.result, "echo: after"); + + Ok(()) + }) + .await + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn forward_response_to_answers_upstream_when_response_is_never_delivered() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (backend_for_proxy, backend_for_server) = Channel::duplex(); + let (backend_connection_tx, backend_connection_rx) = + futures::channel::oneshot::channel(); + + // The proxy's connection to the backend swallows every response: + // the `ResponseRouter` (and with it the pending response sender) + // is dropped without ever delivering the response, as also + // happens when a downstream connection closes mid-request. + tokio::task::spawn_local(async move { + let result = UntypedRole + .builder() + .on_receive_dispatch( + async |dispatch: Dispatch, _connection: ConnectionTo| { + if matches!(dispatch, Dispatch::Response(..)) { + return Ok(Handled::Yes); + } + Ok(Handled::No { + message: dispatch, + retry: false, + }) + }, + agent_client_protocol::on_receive_dispatch!(), + ) + .connect_with(backend_for_proxy, async |connection| { + drop(backend_connection_tx.send(connection.clone())); + std::future::pending::>().await + }) + .await; + if let Err(error) = result { + panic!("proxy-to-backend connection should stay alive: {error:?}"); + } + }); + + // The backend itself answers promptly; its response is then + // swallowed on the proxy side. + let backend_server = UntypedRole.builder().on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = backend_server.connect_to(backend_for_server).await { + panic!("backend server should stay alive: {error:?}"); + } + }); + + let backend_connection = backend_connection_rx + .await + .expect("backend connection should start"); + + let (client_end, proxy_end) = Channel::duplex(); + let proxy = UntypedRole.builder().on_receive_request( + { + let backend_connection = backend_connection.clone(); + async move |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + if request.message == "forward" { + return backend_connection + .send_request(request) + .forward_response_to(responder); + } + responder.respond(SimpleResponse { + result: format!("local: {}", request.message), + }) + } + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = proxy.connect_to(proxy_end).await { + panic!("proxy should stay alive: {error:?}"); + } + }); + + UntypedRole + .builder() + .connect_with(client_end, async |cx| { + // The forwarded request must not be left unanswered when + // its response is dropped downstream. + let error = cx + .send_request(SimpleRequest { + message: "forward".into(), + }) + .block_task() + .await + .expect_err("the response was dropped downstream"); + let detail = error + .data + .as_ref() + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + assert!( + detail.contains("never received"), + "unexpected error: {error:?}" + ); + + // The proxy and its connection to the client still work. + let local = cx + .send_request(SimpleRequest { + message: "ping".into(), + }) + .block_task() + .await?; + assert_eq!(local.result, "local: ping"); + + Ok(()) + }) + .await + .unwrap(); + }) + .await; +} From 637da9616b8b56464773b6c20af8e19d2682cc1d Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Thu, 11 Jun 2026 12:33:36 +0200 Subject: [PATCH 16/18] fix(acp): Disarm cancellation for claimed responses --- .../src/concepts/cancellation.rs | 8 +- src/agent-client-protocol/src/jsonrpc.rs | 177 ++++++++++-------- .../tests/jsonrpc_request_cancellation.rs | 109 +++++++++++ 3 files changed, 218 insertions(+), 76 deletions(-) diff --git a/src/agent-client-protocol/src/concepts/cancellation.rs b/src/agent-client-protocol/src/concepts/cancellation.rs index 975a870..6013f80 100644 --- a/src/agent-client-protocol/src/concepts/cancellation.rs +++ b/src/agent-client-protocol/src/concepts/cancellation.rs @@ -39,9 +39,11 @@ //! //! Dropping a [`SentRequest`] before the SDK receives a response also sends //! `$/cancel_request`. This covers abandoned request handles and futures. Once -//! the SDK routes a response to the waiting request handle, automatic -//! cancellation is disarmed, even if caller code has not yet consumed it with -//! [`block_task`], [`on_receiving_result`], or [`forward_response_to`]. +//! the SDK routes a response for the request, automatic cancellation is +//! disarmed: the peer has already answered, even if caller code has not yet +//! consumed the handle with [`block_task`], [`on_receiving_result`], or +//! [`forward_response_to`], and even if a dispatch handler claimed the +//! response. //! //! If you already have the JSON-RPC request ID, send the notification //! directly: diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index bdb7f56..6a5a68c 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -2811,6 +2811,16 @@ impl Responder { /// /// Both are fundamentally "sinks" that push the message through a `send_fn`, but they /// represent different points in the message lifecycle and carry different metadata. +/// +/// # Drop Behavior +/// +/// Dropping a `ResponseRouter` without responding (for example, from a +/// dispatch handler that claims a [`Dispatch::Response`]) discards the +/// response: the local awaiter observes the response as never received. The +/// request still counts as settled — when the `unstable_cancel_request` +/// feature is enabled, routing a response this far disarms the originating +/// [`SentRequest`]'s drop-time auto-cancellation even if the router is never +/// invoked, since the peer has already answered. #[must_use] pub struct ResponseRouter { /// The method of the original request. @@ -2853,11 +2863,21 @@ impl ResponseRouter { ) -> Self { let response_method = method.clone(); let response_id = id.clone(); + // A response for the request reached this router, so the request is + // settled from the peer's perspective and a `$/cancel_request` could + // only ever be redundant. The guard disarms the drop-time + // auto-cancellation when the router responds *or* when it is dropped + // without ever being invoked (a dispatch handler claimed the + // response). + #[cfg(feature = "unstable_cancel_request")] + let cancellation_disarm = DisarmOnDrop(cancellation_disarm); Self { method, id, role_id, send_fn: Box::new(move |response: Result| { + #[cfg(feature = "unstable_cancel_request")] + let _cancellation_disarm: DisarmOnDrop = cancellation_disarm; if sender .send(ResponsePayload { result: response, @@ -2870,9 +2890,6 @@ impl ResponseRouter { id = ?response_id, "dropped response because local receiver was gone" ); - } else { - #[cfg(feature = "unstable_cancel_request")] - cancellation_disarm.disarm(); } Ok(()) }), @@ -3606,6 +3623,24 @@ impl SentRequestCancellationDisarm { } } +/// Disarms a [`SentRequest`]'s drop-time auto-cancellation when dropped. +/// +/// A [`ResponseRouter`]'s send function holds this guard so that *routing* a +/// response settles the request: whether the response is delivered to the +/// local awaiter, discarded because the awaiter is gone, or claimed by a +/// dispatch handler that drops the router without invoking it, a later +/// `$/cancel_request` for the request could only ever be redundant. +#[cfg(feature = "unstable_cancel_request")] +#[derive(Debug)] +struct DisarmOnDrop(SentRequestCancellationDisarm); + +#[cfg(feature = "unstable_cancel_request")] +impl Drop for DisarmOnDrop { + fn drop(&mut self) { + self.0.disarm(); + } +} + #[cfg(feature = "unstable_cancel_request")] struct SentRequestCancellation { message_tx: OutgoingMessageTx, @@ -3920,56 +3955,89 @@ impl SentRequest { pub fn forward_response_to(self, responder: Responder) -> Result<(), crate::Error> where T: Send, + { + #[cfg(feature = "unstable_cancel_request")] + let this = self.forward_cancellation_from(responder.cancellation()); + #[cfg(not(feature = "unstable_cancel_request"))] + let this = self; + + this.consume_with(async move |response| { + // A response that was never delivered (outer `Err`, e.g. the + // downstream connection closed) is forwarded as an error: the + // incoming request must not be left unanswered. + responder.respond_with_result(response.unwrap_or_else(Err)) + }) + } + + /// Spawn the response-consumption task shared by + /// [`on_receiving_result`](Self::on_receiving_result) and + /// [`forward_response_to`](Self::forward_response_to). + /// + /// The task awaits the response (forwarding cancellation from registered + /// sources while waiting, when the `unstable_cancel_request` feature is + /// enabled), converts the payload, and invokes `handle` with the typed + /// result (`Ok(Result)`). The dispatch loop's ack, if any, is sent + /// after `handle` completes. + /// + /// If the pending response is dropped without ever being delivered (for + /// example, the connection closed), `handle` receives the outer `Err` + /// describing the loss; there is no ack in that case. + #[track_caller] + fn consume_with( + self, + handle: impl FnOnce(Result, crate::Error>) -> F + 'static + Send, + ) -> Result<(), crate::Error> + where + F: Future> + 'static + Send, + T: Send, { let task_tx = self.task_tx.clone(); let method = self.method; let response_rx = self.response_rx; let to_result = self.to_result; #[cfg(feature = "unstable_cancel_request")] - let downstream_cancellation = self.cancellation; + let cancellation = self.cancellation; #[cfg(feature = "unstable_cancel_request")] - let cancellation_sources = { - let mut sources = self.cancellation_sources; - sources.push(responder.cancellation()); - sources - }; + let cancellation_sources = self.cancellation_sources; let location = Location::caller(); Task::new(location, async move { #[cfg(feature = "unstable_cancel_request")] let response = await_response_forwarding_cancellation( response_rx, - &downstream_cancellation, + &cancellation, &cancellation_sources, ) .await; #[cfg(not(feature = "unstable_cancel_request"))] let response = response_rx.await; - let ResponsePayload { result, ack_tx } = match response { - Ok(payload) => payload, - Err(err) => { - // The pending response was dropped (e.g. the downstream - // connection closed). Answer the incoming request instead - // of leaving the peer waiting forever. - return responder.respond_with_result(Err(crate::util::internal_error( - format!("response to `{method}` never received: {err}"), - ))); - } - }; + match response { + Ok(ResponsePayload { result, ack_tx }) => { + // Convert the result using to_result for Ok values + let typed_result = match result { + Ok(json_value) => to_result(json_value), + Err(err) => Err(err), + }; - let typed_result = match result { - Ok(json_value) => to_result(json_value), - Err(err) => Err(err), - }; + let outcome = handle(Ok(typed_result)).await; - let outcome = responder.respond_with_result(typed_result); + // Ack AFTER the handler completes - this is the key + // difference from block_task. The dispatch loop waits for + // this ack. + if let Some(tx) = ack_tx { + let _ = tx.send(()); + } - if let Some(tx) = ack_tx { - let _ = tx.send(()); + outcome + } + Err(err) => { + handle(Err(crate::util::internal_error(format!( + "response to `{method}` never received: {err}" + )))) + .await + } } - - outcome }) .spawn(&task_tx) } @@ -4223,52 +4291,15 @@ impl SentRequest { F: Future> + 'static + Send, T: Send, { - let task_tx = self.task_tx.clone(); - let method = self.method; - let response_rx = self.response_rx; - let to_result = self.to_result; - #[cfg(feature = "unstable_cancel_request")] - let cancellation = self.cancellation; - #[cfg(feature = "unstable_cancel_request")] - let cancellation_sources = self.cancellation_sources; - let location = Location::caller(); - - Task::new(location, async move { - #[cfg(feature = "unstable_cancel_request")] - let response = await_response_forwarding_cancellation( - response_rx, - &cancellation, - &cancellation_sources, - ) - .await; - #[cfg(not(feature = "unstable_cancel_request"))] - let response = response_rx.await; - + self.consume_with(async move |response| { match response { - Ok(ResponsePayload { result, ack_tx }) => { - // Convert the result using to_result for Ok values - let typed_result = match result { - Ok(json_value) => to_result(json_value), - Err(err) => Err(err), - }; - - // Run the user's callback - let outcome = task(typed_result).await; - - // Ack AFTER the callback completes - this is the key difference - // from block_task. The dispatch loop waits for this ack. - if let Some(tx) = ack_tx { - let _ = tx.send(()); - } - - outcome - } - Err(err) => Err(crate::util::internal_error(format!( - "response to `{method}` never received: {err}" - ))), + // Run the user's callback on the peer's result. + Ok(result) => task(result).await, + // A response that was never delivered fails the consuming + // task instead of invoking the callback. + Err(err) => Err(err), } }) - .spawn(&task_tx) } } diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs index 08d135f..4190571 100644 --- a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -880,6 +880,115 @@ async fn response_buffered_before_drop_disarms_auto_cancellation() { .await; } +/// A dispatch handler may claim a `Dispatch::Response` and drop the router +/// without invoking it. Routing the response settles the request all the +/// same, so dropping the (never-delivered-to) request handle afterwards must +/// not ask the peer to cancel a request it has already answered. +#[tokio::test(flavor = "current_thread")] +async fn response_claimed_by_dispatch_handler_disarms_auto_cancellation() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (cancel_tx, mut cancel_rx) = mpsc::unbounded(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + cancel_tx.unbounded_send(notification.request_id).unwrap(); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + // The JSON-RPC id whose response the dispatch handler below + // claims (and discards) without ever invoking the router. + let claimed_id: Arc>> = Arc::new(Mutex::new(None)); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + UntypedRole + .builder() + .on_receive_dispatch( + { + let claimed_id = claimed_id.clone(); + async move |dispatch: Dispatch, _connection: ConnectionTo| { + if let Dispatch::Response(_, router) = &dispatch + && claimed_id.lock().unwrap().as_ref() == Some(&router.id()) + { + // Claim the response; the router is dropped + // without responding. + return Ok(Handled::Yes); + } + Ok(Handled::No { + message: dispatch, + retry: false, + }) + } + }, + agent_client_protocol::on_receive_dispatch!(), + ) + .connect_with(client_transport, async |cx| { + let request: SentRequest = cx.send_request(SimpleRequest { + message: "claimed".into(), + }); + *claimed_id.lock().unwrap() = Some(request.id()); + + // The server answers requests in order, so once this + // round trip completes, the response to `claimed` has + // been routed and discarded by the dispatch handler. + let barrier = cx + .send_request(SimpleRequest { + message: "barrier".into(), + }) + .block_task() + .await?; + assert_eq!(barrier.result, "echo: barrier"); + + drop(request); + + // Another round trip: any cancellation sent by the drop + // above would reach the server before this request. + let after = cx + .send_request(SimpleRequest { + message: "after claimed".into(), + }) + .block_task() + .await?; + assert_eq!(after.result, "echo: after claimed"); + Ok(()) + }) + .await + .unwrap(); + + assert_no_event(&mut cancel_rx); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn completed_sent_request_does_not_send_cancellation_on_drop() { use tokio::task::LocalSet; From d96af2bf272114f6f0dc01257d9f538c93754e73 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Sat, 13 Jun 2026 14:26:16 +0200 Subject: [PATCH 17/18] fix(acp): disarm cancellation when routing responses --- src/agent-client-protocol/src/jsonrpc.rs | 28 +---- .../tests/jsonrpc_request_cancellation.rs | 113 ++++++++++++++++++ 2 files changed, 116 insertions(+), 25 deletions(-) diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index 6a5a68c..e6cb348 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -2865,19 +2865,15 @@ impl ResponseRouter { let response_id = id.clone(); // A response for the request reached this router, so the request is // settled from the peer's perspective and a `$/cancel_request` could - // only ever be redundant. The guard disarms the drop-time - // auto-cancellation when the router responds *or* when it is dropped - // without ever being invoked (a dispatch handler claimed the - // response). + // only ever be redundant. Disarm immediately so handlers may retain + // the router without leaving auto-cancellation armed. #[cfg(feature = "unstable_cancel_request")] - let cancellation_disarm = DisarmOnDrop(cancellation_disarm); + cancellation_disarm.disarm(); Self { method, id, role_id, send_fn: Box::new(move |response: Result| { - #[cfg(feature = "unstable_cancel_request")] - let _cancellation_disarm: DisarmOnDrop = cancellation_disarm; if sender .send(ResponsePayload { result: response, @@ -3623,24 +3619,6 @@ impl SentRequestCancellationDisarm { } } -/// Disarms a [`SentRequest`]'s drop-time auto-cancellation when dropped. -/// -/// A [`ResponseRouter`]'s send function holds this guard so that *routing* a -/// response settles the request: whether the response is delivered to the -/// local awaiter, discarded because the awaiter is gone, or claimed by a -/// dispatch handler that drops the router without invoking it, a later -/// `$/cancel_request` for the request could only ever be redundant. -#[cfg(feature = "unstable_cancel_request")] -#[derive(Debug)] -struct DisarmOnDrop(SentRequestCancellationDisarm); - -#[cfg(feature = "unstable_cancel_request")] -impl Drop for DisarmOnDrop { - fn drop(&mut self) { - self.0.disarm(); - } -} - #[cfg(feature = "unstable_cancel_request")] struct SentRequestCancellation { message_tx: OutgoingMessageTx, diff --git a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs index 4190571..49f75f0 100644 --- a/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs +++ b/src/agent-client-protocol/tests/jsonrpc_request_cancellation.rs @@ -989,6 +989,119 @@ async fn response_claimed_by_dispatch_handler_disarms_auto_cancellation() { .await; } +/// A dispatch handler may keep the `ResponseRouter` alive after the peer has +/// answered. The original `SentRequest` is settled as soon as the response is +/// routed into the handler, so dropping it must not ask the peer to cancel. +#[tokio::test(flavor = "current_thread")] +async fn response_retained_by_dispatch_handler_disarms_auto_cancellation() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (cancel_tx, mut cancel_rx) = mpsc::unbounded(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let server_transport = + agent_client_protocol::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async |request: SimpleRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: format!("echo: {}", request.message), + }) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |notification: CancelRequestNotification, + _connection: ConnectionTo| { + cancel_tx.unbounded_send(notification.request_id).unwrap(); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + tokio::task::spawn_local(async move { + if let Err(error) = server.connect_to(server_transport).await { + panic!("server should stay alive: {error:?}"); + } + }); + + let claimed_id: Arc>> = Arc::new(Mutex::new(None)); + let retained_response: Arc>> = Arc::new(Mutex::new(None)); + + let client_transport = + agent_client_protocol::ByteStreams::new(client_writer, client_reader); + UntypedRole + .builder() + .on_receive_dispatch( + { + let claimed_id = claimed_id.clone(); + let retained_response = retained_response.clone(); + async move |dispatch: Dispatch, _connection: ConnectionTo| { + let should_claim = match &dispatch { + Dispatch::Response(_, router) => { + claimed_id.lock().unwrap().as_ref() == Some(&router.id()) + } + Dispatch::Request(_, _) | Dispatch::Notification(_) => false, + }; + + if should_claim { + *retained_response.lock().unwrap() = Some(dispatch); + return Ok(Handled::Yes); + } + + Ok(Handled::No { + message: dispatch, + retry: false, + }) + } + }, + agent_client_protocol::on_receive_dispatch!(), + ) + .connect_with(client_transport, async |cx| { + let request: SentRequest = cx.send_request(SimpleRequest { + message: "retained".into(), + }); + *claimed_id.lock().unwrap() = Some(request.id()); + + // This proves the earlier response was routed into the + // handler and is still retained rather than dropped. + let barrier = cx + .send_request(SimpleRequest { + message: "barrier".into(), + }) + .block_task() + .await?; + assert_eq!(barrier.result, "echo: barrier"); + assert!(retained_response.lock().unwrap().is_some()); + + drop(request); + + // Any auto-cancel from dropping the request would be + // delivered before this follow-up request. + let after = cx + .send_request(SimpleRequest { + message: "after retained".into(), + }) + .block_task() + .await?; + assert_eq!(after.result, "echo: after retained"); + Ok(()) + }) + .await + .unwrap(); + + assert_no_event(&mut cancel_rx); + }) + .await; +} + #[tokio::test(flavor = "current_thread")] async fn completed_sent_request_does_not_send_cancellation_on_drop() { use tokio::task::LocalSet; From 74c89f0d100ad12ce273de3ff5e8cc407c32dae6 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Sat, 13 Jun 2026 14:56:12 +0200 Subject: [PATCH 18/18] fix(acp): Propagate proxy session cancellation --- .../tests/request_cancellation.rs | 124 ++++++++++++++++++ src/agent-client-protocol/src/session.rs | 42 +++--- 2 files changed, 145 insertions(+), 21 deletions(-) diff --git a/src/agent-client-protocol-conductor/tests/request_cancellation.rs b/src/agent-client-protocol-conductor/tests/request_cancellation.rs index 6d12ffc..48cdb6b 100644 --- a/src/agent-client-protocol-conductor/tests/request_cancellation.rs +++ b/src/agent-client-protocol-conductor/tests/request_cancellation.rs @@ -727,6 +727,130 @@ async fn session_new_cancellation_propagates_through_proxy() -> Result<(), Error Ok(()) } +/// The SDK's documented proxy session helper also forwards `session/new` with +/// a result hook, so it must opt into cancellation propagation explicitly. +#[tokio::test] +async fn proxy_session_helper_cancellation_propagates_to_agent() -> Result<(), Error> { + let (agent_cancel_tx, mut agent_cancel_rx) = mpsc::unbounded(); + let (parked_id_tx, mut parked_id_rx) = mpsc::unbounded(); + + let agent = Agent + .builder() + .on_receive_request( + async |initialize: InitializeRequest, responder, _cx: ConnectionTo| { + responder.respond(InitializeResponse::new(initialize.protocol_version)) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request( + async move |request: NewSessionRequest, + responder: Responder, + cx: ConnectionTo| { + if request.cwd.ends_with("park-session") { + parked_id_tx.unbounded_send(responder.id()).unwrap(); + let cancellation = responder.cancellation(); + cx.spawn(async move { + let response = cancellation + .run_until_cancelled(std::future::pending::< + Result, + >()) + .await; + responder.respond_with_result(response) + })?; + return Ok(()); + } + + responder.respond(NewSessionResponse::new(SessionId::new("normal-session"))) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_notification( + async move |cancel: CancelRequestNotification, _cx: ConnectionTo| { + agent_cancel_tx.unbounded_send(cancel.request_id).unwrap(); + Ok(()) + }, + agent_client_protocol::on_receive_notification!(), + ); + + let proxy = Proxy.builder().on_receive_request_from( + Client, + async |request: NewSessionRequest, + responder: Responder, + cx: ConnectionTo| { + cx.build_session_from(request) + .on_proxy_session_start(responder, async |_session_id| Ok::<(), Error>(())) + }, + agent_client_protocol::on_receive_request!(), + ); + + let (editor_write, conductor_read) = duplex(8192); + let (conductor_write, editor_read) = duplex(8192); + + let conductor_handle = tokio::spawn(async move { + ConductorImpl::new_agent( + "helper-cancellation-conductor".to_string(), + ProxiesAndAgent::new(agent).proxy(proxy), + ) + .run(ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) + .await + }); + + let client_request_id = tokio::time::timeout(Duration::from_secs(30), async move { + Client + .builder() + .connect_with( + ByteStreams::new(editor_write.compat_write(), editor_read.compat()), + async |cx| { + let initialize = cx + .send_request(InitializeRequest::new(ProtocolVersion::V1)) + .block_task() + .await?; + assert_eq!(initialize.protocol_version, ProtocolVersion::V1); + + let request: SentRequest = + cx.send_request(NewSessionRequest::new("/park-session")); + let client_request_id = request.id(); + request.cancel()?; + + let error = request + .block_task() + .await + .expect_err("session/new should be cancelled"); + assert_eq!(i32::from(error.code), -32800); + + let session = cx + .send_request(NewSessionRequest::new( + std::env::current_dir().map_err(Error::into_internal_error)?, + )) + .block_task() + .await?; + assert_eq!(session.session_id, SessionId::new("normal-session")); + + Ok(client_request_id) + }, + ) + .await + }) + .await + .expect("test timed out") + .expect("client failed"); + + let parked_id = next_with_timeout(&mut parked_id_rx).await; + assert_ne!( + parked_id, client_request_id, + "each hop must re-issue the request under its own ID" + ); + let observed = next_with_timeout(&mut agent_cancel_rx).await; + assert_eq!(serde_json::to_value(observed).unwrap(), parked_id); + assert_no_event(&mut agent_cancel_rx); + + conductor_handle.abort(); + Ok(()) +} + /// `initialize` is rewritten to `_proxy/initialize` at the conductor-to-proxy /// hop and forwarded with a result hook — cancellation must still propagate /// hop by hop, exactly like every other request. diff --git a/src/agent-client-protocol/src/session.rs b/src/agent-client-protocol/src/session.rs index 424895b..c8389b1 100644 --- a/src/agent-client-protocol/src/session.rs +++ b/src/agent-client-protocol/src/session.rs @@ -307,27 +307,27 @@ where .into_iter() .for_each(super::jsonrpc::DynamicHandlerRegistration::run_indefinitely); - // Send the "new session" request to the agent - connection - .send_request_to(Agent, request) - .on_receiving_result({ - let connection = connection.clone(); - async move |result| { - let response = result?; - - // Extract the session-id from the response and forward - // the response back to the client - let session_id = response.session_id.clone(); - responder.respond(response)?; - - // Install a dynamic handler to proxy messages from this session - connection - .add_dynamic_handler(ProxySessionMessages::new(session_id.clone()))? - .run_indefinitely(); - - op(session_id).await - } - }) + // Send the "new session" request to the agent. + let sent = connection.send_request_to(Agent, request); + #[cfg(feature = "unstable_cancel_request")] + let sent = sent.forward_cancellation_from(responder.cancellation()); + + sent.on_receiving_ok_result(responder, { + let connection = connection.clone(); + async move |response, responder| { + // Extract the session-id from the response and forward + // the response back to the client + let session_id = response.session_id.clone(); + responder.respond(response)?; + + // Install a dynamic handler to proxy messages from this session + connection + .add_dynamic_handler(ProxySessionMessages::new(session_id.clone()))? + .run_indefinitely(); + + op(session_id).await + } + }) } }