Skip to content

refactor: redraw ComputeBackend as an inference-session engine and route the CLI through it#450

Merged
inureyes merged 1 commit into
mainfrom
refactor/448-inference-session-seam
Jun 26, 2026
Merged

refactor: redraw ComputeBackend as an inference-session engine and route the CLI through it#450
inureyes merged 1 commit into
mainfrom
refactor/448-inference-session-seam

Conversation

@inureyes

Copy link
Copy Markdown
Member

Summary

Redraw the compute-backend seam from PR #446 into the inference-session contract ADR 0004 settled on, and move the MLX CLI path behind it as a behavior-preserving refactor. PR #446 drew the boundary at model load and returned the concrete MLX LoadedModel; that altitude cannot host a graph-compiler backend (FuriosaAI, Tenstorrent, OpenXLA) that never produces an MlxArray. This adds the session layer for the CLI path while leaving the server batched path exactly as it was. MLX is the only real backend here; the result is byte-identical on the CLI.

Layered contract

Core single-sequence session (the new contract): mlxcel-core/src/session.rs defines InferenceSession, an object-safe, engine-neutral trait that exposes capability advertisement (capabilities) plus the conceptual token-level primitives prefill / decode_step. That is the shape a future non-MLX backend (issue #449, a separate default-off crate) implements; this is the extension point left for it. The MLX implementation MlxInferenceSession wraps the existing CxxGenerator.

Extended layer (MLX-only, intentionally untouched): the server BatchScheduler keeps doing cross-sequence batched forward and owns LoadedModel directly through the retained ComputeBackend::load_model(path) -> (LoadedModel, MlxcelTokenizer) entry. src/server/model_worker.rs and the scheduler are not changed. Batched serving is advertised as an MLX backend capability that the single-sequence session does not cover yet (the deferred KV / batching abstraction per ADR 0004).

What the MLX session wraps

MlxInferenceSession is a thin wrapper around CxxGenerator. Every generation method delegates verbatim to the matching CxxGenerator method in src/lib/mlxcel-core/src/generate.rs: generate, generate_streaming, generate_streaming_with_embeddings, generate_with_stats, generate_with_stats_and_embeddings, and evaluate_loglikelihoods. The decode loop is not rewritten and CxxGenerator internals and the LanguageModel trait are unchanged, so the exact same code runs. On MLX the prefill / decode_step primitives return a reserved-contract message, because the CLI drives the fused generate_* entry points and the granular stepping primitives are the contract for the compiler-family backend, not the MLX fast path.

How byte-identical is guaranteed

By construction: the same CxxGenerator code runs whether a caller reaches it directly or through Session -> MlxInferenceSession. The session carries the same KV mode and the same pre-resolved token bias the CLI built before (empty map preserves the bit-exact baseline via compose_sampling), and the same sampling config and arguments flow through unchanged.

How the dispatch folds under default features

Session is a single-variant enum (Session::Mlx). select_backend() folds to the single Backend::Mlx variant (PR #446), and every Session and Backend method is a single-arm match marked #[inline]. After inlining, backend.create_session(...).generate(...) lowers to a direct call into the wrapped CxxGenerator with no runtime indirection added on the generation hot path. The per-token forward stays inside the session method and KVCache is never type-erased. The experimental-backend module and enum variant stay cfg-gated off, so shipping Apple-Silicon and CUDA builds compile no extra backend code.

What changed

  • src/lib/mlxcel-core/src/session.rs (new): SessionCapabilities, the object-safe InferenceSession trait (capabilities / prefill / decode_step), and MlxInferenceSession wrapping CxxGenerator with verbatim-delegating generation methods.
  • src/lib/mlxcel-core/src/session_tests.rs (new) and the pub mod session + re-export in src/lib/mlxcel-core/src/lib.rs.
  • src/backend/session.rs (new): the root-crate Session enum (single Mlx variant) with inline single-arm dispatch over the generation methods.
  • src/backend/mod.rs: ComputeBackend gains create_session and supports_batched_serving; the Backend enum dispatches both; module/trait docs updated to the two-layer contract. load_model and friends are unchanged.
  • src/backend/mlx.rs: MlxBackend::create_session builds Session::Mlx(MlxInferenceSession::new_with_kv_mode(...).with_token_bias(...)); supports_batched_serving returns true.
  • src/backend/experimental.rs: cfg-gated session stub returns the same not_implemented() error; supports_batched_serving returns false; forward-looking note updated to point at the session contract.
  • src/backend/tests.rs: MLX backend resolves a session and advertises batched serving, threads the token bias through, and (under experimental-backend) the scaffold's session creation errors.
  • src/commands/generate.rs: generate_standard and generate_with_embeddings obtain a Session from the backend instead of constructing CxxGenerator; generate_standard now returns Result. The offline draft-model load already routes through select_backend().load_model.
  • src/commands/chat.rs: run_chat builds one session for the REPL; stream_turn and the /clear reset drive the session.
  • src/lib.rs: re-export Session, InferenceSession, MlxInferenceSession, SessionCapabilities.

Test plan

  • cargo check --lib --features metal,accelerate
  • cargo check --bins --features metal,accelerate
  • cargo clippy --lib --tests --features metal,accelerate -- -D warnings
  • cargo clippy --lib --tests --features metal,accelerate,experimental-backend -- -D warnings
  • cargo test --lib --features metal,accelerate backend:: (5 passed)
  • cargo test --lib --features metal,accelerate -p mlxcel-core session:: (5 passed)
  • cargo fmt --check
  • Orchestrator owns the release build and the temp-0 greedy byte-identical parity gate on llama-3.2-1b-4bit.

Closes #448

Redraw the compute-backend seam from PR #446 into the inference-session contract ADR 0004 settled on. PR #446 drew the boundary at model load and returned the concrete MLX LoadedModel; that altitude cannot host a graph-compiler backend (FuriosaAI, Tenstorrent, OpenXLA) that never produces an MlxArray. This adds the session layer and moves the MLX CLI path behind it, byte-identical by construction, while leaving the server batched path exactly as it was.

Layered contract: the core single-sequence InferenceSession (mlxcel-core/src/session.rs) is an object-safe, engine-neutral trait that exposes capability advertisement plus the conceptual token-level primitives prefill / decode_step, the shape a future non-MLX backend (issue #449, a separate default-off crate) fills in. The MLX implementation MlxInferenceSession wraps the existing CxxGenerator and delegates every generation method (generate, generate_streaming, generate_streaming_with_embeddings, generate_with_stats, generate_with_stats_and_embeddings, evaluate_loglikelihoods) verbatim, so the exact same decode loop, KV optimizations, and sampling run. The decode loop and CxxGenerator internals are unchanged, so CLI output stays byte-identical. On MLX the prefill / decode_step primitives return a reserved-contract message because the CLI drives the fused generate_* entry points; they are the contract for the compiler-family backend, not the MLX fast path.

ComputeBackend gains create_session (returns a Session) and supports_batched_serving. The MLX backend builds Session::Mlx(MlxInferenceSession) with the same KV mode and token bias the CLI used before; the cfg-gated experimental scaffold returns the same not_implemented error from create_session and reports no batched serving. The retained load_model -> (LoadedModel, MlxcelTokenizer) entry is untouched, so src/server/model_worker.rs and the BatchScheduler keep owning LoadedModel directly. Batched serving is advertised as an MLX backend capability the single-sequence session does not cover yet (the deferred KV / batching abstraction per ADR 0004).

Dispatch folds away under default features: Session is a single-variant enum (Session::Mlx), select_backend folds to the single Backend::Mlx variant, and every Session and Backend method is a single-arm match marked inline, so backend.create_session(...).generate(...) lowers to a direct call into the wrapped CxxGenerator with no runtime indirection added on the hot path. The per-token forward stays inside the session method and KVCache is never type-erased. The experimental-backend module and enum variant stay cfg-gated off, so shipping Apple-Silicon and CUDA builds compile no extra backend code.

CLI call sites rerouted: src/commands/generate.rs (generate_standard and generate_with_embeddings now obtain a Session from the backend instead of constructing CxxGenerator directly) and src/commands/chat.rs (run_chat builds one session for the REPL; stream_turn and the /clear reset drive the session). Same KV mode, same sampling config, same arguments. The offline draft-model load already routes through select_backend().load_model.

Tests: mlxcel-core session_tests assert capability advertisement, token-bias wiring, the object-safe trait bound, and that the MLX step primitives report they are the reserved compiler-backend contract; backend tests assert the MLX backend resolves a session, advertises batched serving, threads the token bias through, and (under the experimental-backend feature) that the scaffold's session creation errors.
@inureyes inureyes added type:refactor Code restructuring without changing functionality priority:medium Medium priority area:architecture Architecture and code structure changes status:review Under review status:done Completed and removed status:review Under review labels Jun 26, 2026
@inureyes inureyes merged commit 552dcdd into main Jun 26, 2026
5 checks passed
@inureyes inureyes deleted the refactor/448-inference-session-seam branch June 26, 2026 21:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:architecture Architecture and code structure changes priority:medium Medium priority status:done Completed type:refactor Code restructuring without changing functionality

Projects

None yet

Development

Successfully merging this pull request may close these issues.

refactor: redraw ComputeBackend as an inference-session engine contract and move the MLX path behind it (byte-identical)

1 participant