Skip to content

[MLX] Support multiple KV cache sessions, with shared constant data#20408

Merged
metascroy merged 10 commits into
mainfrom
qwen3-session-mlx
Jun 23, 2026
Merged

[MLX] Support multiple KV cache sessions, with shared constant data#20408
metascroy merged 10 commits into
mainfrom
qwen3-session-mlx

Conversation

@metascroy

Copy link
Copy Markdown
Contributor

MLX backend already has mutable state in a separate execution context from its constant data. This PR exposes a way to configure that for external callers, and uses this to support serve.py on MLX like CUDA backend.

@pytorch-bot

pytorch-bot Bot commented Jun 18, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20408

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 15 Pending, 1 Unrelated Failure

As of commit b05bdae with merge base 39c0df6 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 18, 2026
@metascroy

Copy link
Copy Markdown
Contributor Author

@claude review this code

@claude

claude Bot commented Jun 18, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 4m 21s —— View job


Code Review

I reviewed the diff against origin/qwen-mlx-cpp. This is a clean, well-scoped PR that mirrors the CUDA MutableStateContextOwner API for MLX. The core insight — MLX already owns mutable buffers in a swappable MutableBufferData*, so per-session isolation is just a pointer swap with shared constants — is sound and the header docs explain it well. Below are findings, ordered by importance.

Correctness / concurrency

1. Use-after-free if a session is destroyed while it (or the engine) is executing — mlx_mutable_state.cpp:262
rebind_for_execute returns state.mutable_buffers = &bit->second (a pointer into Context::sessions), then MLXBackend::execute releases mlx_global_mutex and runs/writes outputs while the GPU is async. mutable_state_destroy_session / forget_handle only take registry_mutex, not the engine's exec_mutex_ or mlx_global_mutex, so a concurrent ~Qwen35MoESession on another thread could erase() the very MutableBufferData the in-flight execute still references → UAF.

In the current engine this is contract-safe (a session executes only under exec_mutex_ and is destroyed only when its LLMSession is dropped), but nothing enforces "don't destroy a session that may be executing." Since this is now a public-ish API surface, it's worth either (a) documenting that destroy must not race execute on the same context, or (b) noting that unordered_map node stability protects other sessions but not the active one. The header comment at mlx_mutable_state.h:131 mentions serializing execution but not the destroy/execute interaction.

2. Stale state.mutable_buffers after session destroy — related
After execute returns, h->state.mutable_buffers still points into the session's MutableBufferData. If that session is then destroyed, the pointer dangles until the next rebind_for_execute overwrites it. Safe today (never dereferenced before the next rebind), but fragile — a one-line note in ExecutionState or here would help future readers.

Minor / defensive

3. ctx.handles[handle] uses operator[]mlx_mutable_state.cpp:213

HandleInfo& info = ctx.handles[handle];

If handle were ever in handle_ctx() but absent from ctx.handles, this silently inserts a default HandleInfo{nullptr, nullptr}, and the later load_mutable_buffers(*info.program, ...) dereferences null. Invariants currently guarantee both maps stay in sync, so this can't fire — but .at(handle) (or a find + guard) would express intent and fail loudly if the invariant ever breaks. Fix this →

4. load_mutable_buffers runs while holding both mlx_global_mutex and registry_mutexmlx_mutable_state.cpp:249
First-execute allocation of a fresh KV cache happens inside registry_mutex (and the caller holds mlx_global_mutex). Functionally fine — execution is serialized anyway, and I confirmed the lock order is consistently mlx_global → registry across init/execute/destroy, so no inversion. Just flagging that the registry mutex is held across a potentially large GPU allocation; not a blocker.

Things that look right

  • Lock-order consistency (mlx_global → registry) — no deadlock risk.
  • unordered_map node-pointer stability across rehash is correctly relied upon for the returned buffer pointer (comment at :260).
  • Thread-local tl_active_ctx/tl_active_token set and read on the same thread that runs execute (under exec_mutex_), so the active-session selection is consistent; ActiveSessionScope always restores on scope exit.
  • note_handle/rebind are genuine no-ops for single-session loads (tl_loading_ctx == kInvalidMutableContext), so existing single-session MLX execution is unaffected.
  • The "refuse execute when sessions exist but none active" guard (:221) is a good safety net against silently running on shared default buffers.
  • bytes_per_session is computed from the default buffers, which match the per-session allocations — accurate capacity estimate.
  • Test (mlx_mutable_state_test.cpp) directly exercises the no-bleed guarantee (write 7.0 to A, B sees 0, A still 7) plus the no-active-session refusal — good coverage without needing a model. Wired into CI in mlx.yml. 👍

Nits

  • mlx_mutable_state.h:131 — "with one active session selected" comment block is good; consider also noting destroy/execute must not race (see Add support for quantized LeakyReLU #1).
  • README MLX serving section is clear and matches the CUDA section's structure.

Overall: solid, well-tested, and consistent with the existing CUDA design. The only item I'd ask to address before merge is documenting (or guarding) the destroy-vs-execute contract from #1; the rest are minor.
· branch qwen3-session-mlx

Base automatically changed from qwen-mlx-cpp to main June 22, 2026 17:34
@mergennachin mergennachin self-requested a review June 22, 2026 19:55
//
// Unlike the CUDA backend, the MLX runtime owns mutable buffers directly in a
// swappable container (ExecutionState::mutable_buffers is a MutableBufferData*),
// so per-session isolation is a pointer swap to a freshly zero-allocated

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: "freshly zero-allocated" -- this is an implementation detail. Make it more generic so that you're not bound to this contract in the fuure.

Comment thread backends/mlx/runtime/mlx_mutable_state.h
return false;
}
std::lock_guard<std::mutex> g(registry_mutex());
return contexts().count(ctx) != 0;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only checks whether the context id exists in the global map.

auto it = contexts().find(ctx);
return it != contexts().end() && !it->second.handles.empty();

checks two things:

  • The context exists.
  • At least one backend handle was registered into that context.

Comment on lines +151 to +153
::executorch::runtime::Result<int> create_session() const {
return detail::mutable_state_create_session(ctx_);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    // Creates an isolated mutable-buffer session for this context.
    // Fails if no loaded MLX backend handle has been associated with the context.
    ET_NODISCARD ::executorch::runtime::Result<int> create_session() const {
      return detail::mutable_state_create_session(ctx_);
    }

Comment thread backends/mlx/runtime/mlx_mutable_state.cpp
Comment thread backends/mlx/runtime/mlx_mutable_state.cpp
}
}
#elif defined(EXECUTORCH_BUILD_MLX)
// MLX owns mutable buffers directly and clones them per session; no FQN

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we're cloning session.

// MLX owns mutable buffers directly and selects per-session storage at execute time

Comment thread examples/models/qwen3_5_moe/qwen35_moe_engine.h
@mergennachin

Copy link
Copy Markdown
Contributor
  • Similar to CUDA, add a build_error state to Context and use it to make failures sticky: nested load scopes or invalid setup should make available(), validate_coverage(), and create_session() fail consistently afterward.
  • Also similar to CUDA, add a nested load-scope guard in mutable_state_begin_load(): if tl_loading_ctx is already set, mark the active context and the new context invalid instead of silently overwriting the thread-local association.

See cuda_mutable_state.{h,cpp}

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@metascroy metascroy marked this pull request as ready for review June 23, 2026 00:35
Comment on lines +262 to +263
const bool active_for_this_ctx =
tl_active_token != kNoMutableSession && tl_active_ctx == hit->second;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a session token is active for a different context, falling back to default buffers would silently execute the wrong model/session state

  const bool has_active_session = tl_active_token != kNoMutableSession;
  const bool active_for_this_ctx = has_active_session && tl_active_ctx == hit->second;

  if (has_active_session && !active_for_this_ctx) {
    ET_LOG(
        Error,
        "mutable_state_rebind_for_execute: active context mismatch");
    return Error::Internal;
  }

@metascroy

Copy link
Copy Markdown
Contributor Author

@claude review this code

@claude

claude Bot commented Jun 23, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 3m 17s —— View job


Code Review

I re-reviewed the current state of the branch against the prior review threads. Good news first: most of the earlier feedback from @mergennachin's first pass has been folded in correctly. The one item still outstanding is the context-mismatch case from the second review round (2026-06-23).

Addressed since the last review ✅

  • Sticky build_error — added to Context and consistently checked in available(), validate_coverage(), create_session(), and rebind_for_execute() (mlx_mutable_state.cpp:44,128,159,174,257).
  • Nested load-scope guardmutable_state_begin_load now marks both the active and nested contexts invalid instead of silently overwriting tl_loading_ctx (:98-114).
  • create_session requires a registered handle — returns Error::NotSupported when handles.empty() (:177-181); session map insert moved to the Context& local.
  • available() checks both context existence and a registered handle (:128-129), with the requested doc block on the header (mlx_mutable_state.h:139-140).
  • rebind: active session set but handle has no contextError::Internal (:242-248).
  • Engine commentsqwen35_moe_engine.cpp:554 no longer says "clone"; header top comment updated.

Still open — the latest review comment

Context mismatch silently falls back to default buffers — mlx_mutable_state.cpp:262-277

@mergennachin's most recent comment asked to return an error when a session token is active for a different context. The current code doesn't do that — it collapses two distinct situations into the single !active_for_this_ctx branch:

const bool active_for_this_ctx =
    tl_active_token != kNoMutableSession && tl_active_ctx == hit->second;

if (!active_for_this_ctx) {
  if (!ctx.sessions.empty()) {
    return Error::InvalidState;        // sessions exist, none active → refused
  }
  state.mutable_buffers = info.default_buffers;  // <-- reached even when a
  return Error::Ok;                              //     session is active for
}                                                //     a *different* context

If tl_active_token != kNoMutableSession but tl_active_ctx != hit->second and this context happens to have no sessions, execution silently proceeds on default_buffers — i.e. it runs the wrong model/session's state instead of failing loudly. Splitting the "no active session" case from the "active for a different context" case fixes it:

const bool has_active_session = tl_active_token != kNoMutableSession;
const bool active_for_this_ctx = has_active_session && tl_active_ctx == hit->second;

if (has_active_session && !active_for_this_ctx) {
  ET_LOG(Error, "mutable_state_rebind_for_execute: active context mismatch");
  return Error::Internal;
}

This matches the reviewer's suggested patch and is consistent with the existing "active session set but handle has no context" guard at :242. Fix this →

Minor / non-blocking

  • operator[] on ctx.handles:260. HandleInfo& info = ctx.handles[handle]; still uses operator[]. The invariant (handle present in handle_ctx() ⇒ present in ctx.handles) holds, so this can't fire today, but .at(handle) would fail loudly if that invariant ever breaks rather than inserting a {nullptr,nullptr} and dereferencing null in load_mutable_buffers(*info.program, ...). Same point raised in the first review; still cheap to harden.
  • Destroy-vs-execute contract. mutable_state_destroy_session/forget_handle take only registry_mutex, while rebind_for_execute hands back a pointer into Context::sessions that execute dereferences after releasing the lock. Safe under the current engine contract (a session is destroyed only when its LLMSession drops, never mid-execute, and unordered_map node stability protects other sessions). Worth one line in the mlx_mutable_state.h:130-131 comment block noting destroy must not race execute on the same active session.

Looks right

  • Lock order stays mlx_global → registry across init/execute/destroy — no inversion.
  • note_handle/rebind are genuine no-ops for single-session loads (tl_loading_ctx == kInvalidMutableContext), so existing MLX execution is unaffected.
  • Test (mlx_mutable_state_test.cpp) exercises the no-bleed guarantee and the no-active-session refusal directly on the manager — good coverage without a model.

Overall this is in good shape and the design mirrors the CUDA path cleanly. The context-mismatch fix is the one change I'd land before merge; the two minor items are optional hardening.
· branch qwen3-session-mlx

@metascroy metascroy merged commit 58447b2 into main Jun 23, 2026
243 of 245 checks passed
@metascroy metascroy deleted the qwen3-session-mlx branch June 23, 2026 19:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants