[MLX] Support multiple KV cache sessions, with shared constant data#20408
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20408
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 15 Pending, 1 Unrelated FailureAs of commit b05bdae with merge base 39c0df6 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@claude review this code |
|
Claude finished @metascroy's task in 4m 21s —— View job Code ReviewI reviewed the diff against Correctness / concurrency1. Use-after-free if a session is destroyed while it (or the engine) is executing — In the current engine this is contract-safe (a session executes only under 2. Stale Minor / defensive3. HandleInfo& info = ctx.handles[handle];If 4. Things that look right
Nits
Overall: solid, well-tested, and consistent with the existing CUDA design. The only item I'd ask to address before merge is documenting (or guarding) the destroy-vs-execute contract from #1; the rest are minor. |
| // | ||
| // Unlike the CUDA backend, the MLX runtime owns mutable buffers directly in a | ||
| // swappable container (ExecutionState::mutable_buffers is a MutableBufferData*), | ||
| // so per-session isolation is a pointer swap to a freshly zero-allocated |
There was a problem hiding this comment.
Nitpick: "freshly zero-allocated" -- this is an implementation detail. Make it more generic so that you're not bound to this contract in the fuure.
| return false; | ||
| } | ||
| std::lock_guard<std::mutex> g(registry_mutex()); | ||
| return contexts().count(ctx) != 0; |
There was a problem hiding this comment.
This only checks whether the context id exists in the global map.
auto it = contexts().find(ctx);
return it != contexts().end() && !it->second.handles.empty();
checks two things:
- The context exists.
- At least one backend handle was registered into that context.
| ::executorch::runtime::Result<int> create_session() const { | ||
| return detail::mutable_state_create_session(ctx_); | ||
| } |
There was a problem hiding this comment.
// Creates an isolated mutable-buffer session for this context.
// Fails if no loaded MLX backend handle has been associated with the context.
ET_NODISCARD ::executorch::runtime::Result<int> create_session() const {
return detail::mutable_state_create_session(ctx_);
}
| } | ||
| } | ||
| #elif defined(EXECUTORCH_BUILD_MLX) | ||
| // MLX owns mutable buffers directly and clones them per session; no FQN |
There was a problem hiding this comment.
I don't think we're cloning session.
// MLX owns mutable buffers directly and selects per-session storage at execute time
See cuda_mutable_state.{h,cpp} |
This PR needs a
|
| const bool active_for_this_ctx = | ||
| tl_active_token != kNoMutableSession && tl_active_ctx == hit->second; |
There was a problem hiding this comment.
If a session token is active for a different context, falling back to default buffers would silently execute the wrong model/session state
const bool has_active_session = tl_active_token != kNoMutableSession;
const bool active_for_this_ctx = has_active_session && tl_active_ctx == hit->second;
if (has_active_session && !active_for_this_ctx) {
ET_LOG(
Error,
"mutable_state_rebind_for_execute: active context mismatch");
return Error::Internal;
}
|
@claude review this code |
|
Claude finished @metascroy's task in 3m 17s —— View job Code ReviewI re-reviewed the current state of the branch against the prior review threads. Good news first: most of the earlier feedback from @mergennachin's first pass has been folded in correctly. The one item still outstanding is the context-mismatch case from the second review round (2026-06-23). Addressed since the last review ✅
Still open — the latest review commentContext mismatch silently falls back to default buffers — @mergennachin's most recent comment asked to return an error when a session token is active for a different context. The current code doesn't do that — it collapses two distinct situations into the single const bool active_for_this_ctx =
tl_active_token != kNoMutableSession && tl_active_ctx == hit->second;
if (!active_for_this_ctx) {
if (!ctx.sessions.empty()) {
return Error::InvalidState; // sessions exist, none active → refused
}
state.mutable_buffers = info.default_buffers; // <-- reached even when a
return Error::Ok; // session is active for
} // a *different* contextIf const bool has_active_session = tl_active_token != kNoMutableSession;
const bool active_for_this_ctx = has_active_session && tl_active_ctx == hit->second;
if (has_active_session && !active_for_this_ctx) {
ET_LOG(Error, "mutable_state_rebind_for_execute: active context mismatch");
return Error::Internal;
}This matches the reviewer's suggested patch and is consistent with the existing "active session set but handle has no context" guard at Minor / non-blocking
Looks right
Overall this is in good shape and the design mirrors the CUDA path cleanly. The context-mismatch fix is the one change I'd land before merge; the two minor items are optional hardening. |
MLX backend already has mutable state in a separate execution context from its constant data. This PR exposes a way to configure that for external callers, and uses this to support serve.py on MLX like CUDA backend.