refactor: redraw ComputeBackend as an inference-session engine and route the CLI through it#450
Merged
Merged
Conversation
Redraw the compute-backend seam from PR #446 into the inference-session contract ADR 0004 settled on. PR #446 drew the boundary at model load and returned the concrete MLX LoadedModel; that altitude cannot host a graph-compiler backend (FuriosaAI, Tenstorrent, OpenXLA) that never produces an MlxArray. This adds the session layer and moves the MLX CLI path behind it, byte-identical by construction, while leaving the server batched path exactly as it was. Layered contract: the core single-sequence InferenceSession (mlxcel-core/src/session.rs) is an object-safe, engine-neutral trait that exposes capability advertisement plus the conceptual token-level primitives prefill / decode_step, the shape a future non-MLX backend (issue #449, a separate default-off crate) fills in. The MLX implementation MlxInferenceSession wraps the existing CxxGenerator and delegates every generation method (generate, generate_streaming, generate_streaming_with_embeddings, generate_with_stats, generate_with_stats_and_embeddings, evaluate_loglikelihoods) verbatim, so the exact same decode loop, KV optimizations, and sampling run. The decode loop and CxxGenerator internals are unchanged, so CLI output stays byte-identical. On MLX the prefill / decode_step primitives return a reserved-contract message because the CLI drives the fused generate_* entry points; they are the contract for the compiler-family backend, not the MLX fast path. ComputeBackend gains create_session (returns a Session) and supports_batched_serving. The MLX backend builds Session::Mlx(MlxInferenceSession) with the same KV mode and token bias the CLI used before; the cfg-gated experimental scaffold returns the same not_implemented error from create_session and reports no batched serving. The retained load_model -> (LoadedModel, MlxcelTokenizer) entry is untouched, so src/server/model_worker.rs and the BatchScheduler keep owning LoadedModel directly. Batched serving is advertised as an MLX backend capability the single-sequence session does not cover yet (the deferred KV / batching abstraction per ADR 0004). Dispatch folds away under default features: Session is a single-variant enum (Session::Mlx), select_backend folds to the single Backend::Mlx variant, and every Session and Backend method is a single-arm match marked inline, so backend.create_session(...).generate(...) lowers to a direct call into the wrapped CxxGenerator with no runtime indirection added on the hot path. The per-token forward stays inside the session method and KVCache is never type-erased. The experimental-backend module and enum variant stay cfg-gated off, so shipping Apple-Silicon and CUDA builds compile no extra backend code. CLI call sites rerouted: src/commands/generate.rs (generate_standard and generate_with_embeddings now obtain a Session from the backend instead of constructing CxxGenerator directly) and src/commands/chat.rs (run_chat builds one session for the REPL; stream_turn and the /clear reset drive the session). Same KV mode, same sampling config, same arguments. The offline draft-model load already routes through select_backend().load_model. Tests: mlxcel-core session_tests assert capability advertisement, token-bias wiring, the object-safe trait bound, and that the MLX step primitives report they are the reserved compiler-backend contract; backend tests assert the MLX backend resolves a session, advertises batched serving, threads the token bias through, and (under the experimental-backend feature) that the scaffold's session creation errors.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Redraw the compute-backend seam from PR #446 into the inference-session contract ADR 0004 settled on, and move the MLX CLI path behind it as a behavior-preserving refactor. PR #446 drew the boundary at model load and returned the concrete MLX
LoadedModel; that altitude cannot host a graph-compiler backend (FuriosaAI, Tenstorrent, OpenXLA) that never produces anMlxArray. This adds the session layer for the CLI path while leaving the server batched path exactly as it was. MLX is the only real backend here; the result is byte-identical on the CLI.Layered contract
Core single-sequence session (the new contract):
mlxcel-core/src/session.rsdefinesInferenceSession, an object-safe, engine-neutral trait that exposes capability advertisement (capabilities) plus the conceptual token-level primitivesprefill/decode_step. That is the shape a future non-MLX backend (issue #449, a separate default-off crate) implements; this is the extension point left for it. The MLX implementationMlxInferenceSessionwraps the existingCxxGenerator.Extended layer (MLX-only, intentionally untouched): the server
BatchSchedulerkeeps doing cross-sequence batched forward and ownsLoadedModeldirectly through the retainedComputeBackend::load_model(path) -> (LoadedModel, MlxcelTokenizer)entry.src/server/model_worker.rsand the scheduler are not changed. Batched serving is advertised as an MLX backend capability that the single-sequence session does not cover yet (the deferred KV / batching abstraction per ADR 0004).What the MLX session wraps
MlxInferenceSessionis a thin wrapper aroundCxxGenerator. Every generation method delegates verbatim to the matchingCxxGeneratormethod insrc/lib/mlxcel-core/src/generate.rs:generate,generate_streaming,generate_streaming_with_embeddings,generate_with_stats,generate_with_stats_and_embeddings, andevaluate_loglikelihoods. The decode loop is not rewritten andCxxGeneratorinternals and theLanguageModeltrait are unchanged, so the exact same code runs. On MLX theprefill/decode_stepprimitives return a reserved-contract message, because the CLI drives the fusedgenerate_*entry points and the granular stepping primitives are the contract for the compiler-family backend, not the MLX fast path.How byte-identical is guaranteed
By construction: the same
CxxGeneratorcode runs whether a caller reaches it directly or throughSession->MlxInferenceSession. The session carries the same KV mode and the same pre-resolved token bias the CLI built before (empty map preserves the bit-exact baseline viacompose_sampling), and the same sampling config and arguments flow through unchanged.How the dispatch folds under default features
Sessionis a single-variant enum (Session::Mlx).select_backend()folds to the singleBackend::Mlxvariant (PR #446), and everySessionandBackendmethod is a single-armmatchmarked#[inline]. After inlining,backend.create_session(...).generate(...)lowers to a direct call into the wrappedCxxGeneratorwith no runtime indirection added on the generation hot path. The per-token forward stays inside the session method andKVCacheis never type-erased. Theexperimental-backendmodule and enum variant staycfg-gated off, so shipping Apple-Silicon and CUDA builds compile no extra backend code.What changed
src/lib/mlxcel-core/src/session.rs(new):SessionCapabilities, the object-safeInferenceSessiontrait (capabilities/prefill/decode_step), andMlxInferenceSessionwrappingCxxGeneratorwith verbatim-delegating generation methods.src/lib/mlxcel-core/src/session_tests.rs(new) and thepub mod session+ re-export insrc/lib/mlxcel-core/src/lib.rs.src/backend/session.rs(new): the root-crateSessionenum (singleMlxvariant) with inline single-arm dispatch over the generation methods.src/backend/mod.rs:ComputeBackendgainscreate_sessionandsupports_batched_serving; theBackendenum dispatches both; module/trait docs updated to the two-layer contract.load_modeland friends are unchanged.src/backend/mlx.rs:MlxBackend::create_sessionbuildsSession::Mlx(MlxInferenceSession::new_with_kv_mode(...).with_token_bias(...));supports_batched_servingreturns true.src/backend/experimental.rs: cfg-gated session stub returns the samenot_implemented()error;supports_batched_servingreturns false; forward-looking note updated to point at the session contract.src/backend/tests.rs: MLX backend resolves a session and advertises batched serving, threads the token bias through, and (underexperimental-backend) the scaffold's session creation errors.src/commands/generate.rs:generate_standardandgenerate_with_embeddingsobtain aSessionfrom the backend instead of constructingCxxGenerator;generate_standardnow returnsResult. The offline draft-model load already routes throughselect_backend().load_model.src/commands/chat.rs:run_chatbuilds one session for the REPL;stream_turnand the/clearreset drive the session.src/lib.rs: re-exportSession,InferenceSession,MlxInferenceSession,SessionCapabilities.Test plan
cargo check --lib --features metal,acceleratecargo check --bins --features metal,acceleratecargo clippy --lib --tests --features metal,accelerate -- -D warningscargo clippy --lib --tests --features metal,accelerate,experimental-backend -- -D warningscargo test --lib --features metal,accelerate backend::(5 passed)cargo test --lib --features metal,accelerate -p mlxcel-core session::(5 passed)cargo fmt --checkllama-3.2-1b-4bit.Closes #448