Skip to content

Commit 4cc8462

Browse files
committed
[REFACTOR][OrcJIT] Add SlabPoolMemoryManager — growable per-session pool
Stage B of the slab-pool refactor. `ArenaJITLinkMemoryManager` is replaced by `SlabPoolMemoryManager`, which holds a vector of `Slab`s per session and grows on demand instead of pre-reserving one giant arena. Behavior changes: - **Default capacity**: per-slab `slab_size` drops from 1 GB to 64 MB. Typical ML JIT graphs are well under 10 MB; the first slab now reserves 64 MB instead of 1 GB, dramatically reducing VA footprint for small workloads on memory-constrained hosts. - **Growth on demand**: when no existing slab can fit a graph, a fresh slab is mmap'd at `slab_size` bytes and appended to the pool. Sessions that previously failed with "pool exhausted" for cumulative allocations beyond 1 GB now succeed transparently. - **Oversize path**: a single graph whose footprint exceeds `slab_size - 2 * kCommitGranularity` gets its own dedicated slab sized to fit it, rounded to the commit granularity. One graph per oversize slab; the slab becomes available to other allocations after the graph is freed but usually isn't reused (sized tightly). New infrastructure: - `SlabPoolExhaustedError` — retriable error class. Emitted by `Slab::bumpAllocate` when the requested region exceeds the pool limit; caught by `SlabPoolMemoryManager::allocate` to fall through to the next slab or mmap a new one. Other errors (mmap, mprotect, JITLink) keep their existing types and are propagated to the caller without retry. - `Slab::computeGraphFootprint(G, page_size)` — static helper that pre-computes per-pool byte totals for a graph. The pool manager uses it to make the normal-vs-oversize decision without first attempting a normal allocation. - `classifyOverflowSections(G)` — file-scope helper extracted from `Slab::allocate`; also used by `computeGraphFootprint`. Keeps the two entry points consistent on which sections go to the overflow (separate-mmap) path. Concurrency: `pool_mu_` guards only the `slabs_` vector. It is dropped before calling into `Slab::allocate` or the caller's `OnAllocated` callback, because LLJIT materialization frequently invokes nested lookups from inside those callbacks — a coarse lock here deadlocks. Existing slab pointers are stable across concurrent grows (Stage B never removes slabs), so a snapshot taken under the lock is safe to iterate afterwards. Initial-slab retry: the session constructor still halves capacity on mmap failure, now down to `kMinSlabSize = 8 MB`. Subsequent grows use exactly `slab_size_` with no retry — mmap failures propagate. Tests: - `test_arena.py`: `_ARENA_SIZE` bumped from 16 MB → 256 MB so the co-location tests continue to exercise single-slab invariants. The existing overflow-section test's contiguous-region assertion still passes because 256 MB is enough for one slab. - `test_basic.py`: 3 new parametrized tests (×2 C/C++ variants = 6 cases) under the "Slab-pool growth" section — `test_pool_grows_under_small_slab` (16 libs, 8 MB slab, pool must grow), `test_small_slab_recycles_after_drop` (32-iter load/drop exercises free-list within a slab), `test_pool_survives_mixed_load_drop_create` (interleaved paths). Verification: `pytest addons/tvm_ffi_orcjit/tests` — 70 passed, 3 skipped (was 64 + 3). Scope: Stage C (warm-slab eviction + real munmap of drained slabs) remains a planned follow-up. Drained slabs in Stage B stay mapped until the session is destroyed, same as today's single arena.
1 parent e9992ca commit 4cc8462

9 files changed

Lines changed: 443 additions & 143 deletions

File tree

addons/tvm_ffi_orcjit/python/tvm_ffi_orcjit/session.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,19 @@ def __init__(self, orc_rt_path: str | None = None, slab_size: int = 0) -> None:
6666
Args:
6767
orc_rt_path: Optional path to the liborc_rt library. If not provided,
6868
it will be automatically discovered using clang.
69-
slab_size: Slab capacity in bytes for the JIT memory manager.
69+
slab_size: Per-slab capacity in bytes for the JIT memory manager.
7070
Linux only — ignored on macOS and Windows, where the
7171
slab allocator is compiled out.
72-
0 = arch default (1 GB; falls back to smaller sizes
73-
down to 256 MB under RLIMIT_AS / container limits),
74-
>0 = custom size, <0 = disable slab allocator (LLJIT
75-
uses its default scattered-mmap allocator).
76-
77-
Stage A of the slab-pool refactor: one Slab per session,
78-
so this is also the total arena capacity. Later stages
79-
will grow session memory by adding slabs of this size
80-
on demand.
72+
0 = arch default (64 MB; initial slab halves on mmap
73+
failure down to 8 MB under RLIMIT_AS / container
74+
limits), >0 = custom size, <0 = disable slab allocator
75+
(LLJIT uses its default scattered-mmap allocator).
76+
77+
The session holds a growable pool of slabs: a fresh
78+
slab is mmap'd on demand when no existing one can fit
79+
a graph, and graphs larger than slab_size go to a
80+
dedicated oversize slab sized to fit. Drained slabs
81+
stay mapped until the session is destroyed.
8182
8283
"""
8384
if orc_rt_path is None:

addons/tvm_ffi_orcjit/src/ffi/orcjit_memory_manager.cc

Lines changed: 129 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,52 +19,162 @@
1919

2020
/*!
2121
* \file orcjit_memory_manager.cc
22-
* \brief Thin wrapper delegating JITLinkMemoryManager ops to a Slab.
22+
* \brief Growable per-session pool of `Slab`s.
2323
*/
2424
#include "orcjit_memory_manager.h"
2525

2626
#ifdef __linux__
2727

28+
#include <llvm/Support/Alignment.h>
2829
#include <llvm/Support/Error.h>
2930
#include <llvm/Support/FormatVariadic.h>
3031

3132
#include <algorithm>
33+
#include <optional>
34+
#include <utility>
3235

3336
namespace tvm {
3437
namespace ffi {
3538
namespace orcjit {
3639

3740
using llvm::Error;
41+
using llvm::Expected;
3842

39-
ArenaJITLinkMemoryManager::ArenaJITLinkMemoryManager(std::size_t page_size,
40-
std::size_t slab_capacity) {
41-
// Try requested capacity, halve on failure down to a minimum floor.
42-
// The floor is the smaller of kMinSlabCapacity and the requested size,
43-
// so explicit small slabs (e.g. 16 MB for tests) are honoured.
44-
// mmap(PROT_NONE | MAP_NORESERVE) can still fail under RLIMIT_AS or
45-
// extreme VA fragmentation.
46-
std::size_t floor = std::min(slab_capacity, kMinSlabCapacity);
47-
std::size_t cap = slab_capacity;
43+
SlabPoolMemoryManager::SlabPoolMemoryManager(std::size_t page_size, std::size_t slab_size)
44+
: page_size_(page_size), slab_size_(slab_size) {
45+
// Reserve the initial slab. Halving retry only applies here: if the
46+
// very first mmap fails (RLIMIT_AS, container limits), we halve the
47+
// requested size down to kMinSlabSize before giving up. Subsequent
48+
// slabs added during allocate() use exactly slab_size_ and propagate
49+
// errors on mmap failure.
50+
std::size_t floor = std::min(slab_size_, kMinSlabSize);
51+
std::size_t cap = slab_size_;
4852
while (cap >= floor) {
49-
auto slab = std::make_unique<Slab>(page_size, cap);
53+
auto slab = std::make_unique<Slab>(page_size_, cap);
5054
if (slab->isValid()) {
51-
slab_ = std::move(slab);
55+
// Pin the actual initial-slab size to whatever we succeeded with.
56+
// If RLIMIT_AS forced us to 8 MB, we keep 8 MB as the working slab
57+
// size; growing later at 64 MB would just fail again.
58+
slab_size_ = cap;
59+
slabs_.push_back(std::move(slab));
5260
return;
5361
}
5462
cap /= 2;
5563
}
56-
llvm::report_fatal_error("ArenaJITLinkMemoryManager: failed to reserve at least " +
64+
llvm::report_fatal_error("SlabPoolMemoryManager: failed to reserve at least " +
5765
llvm::Twine(floor / (1024 * 1024)) + " MB of virtual address space");
5866
}
5967

60-
void ArenaJITLinkMemoryManager::allocate(const llvm::jitlink::JITLinkDylib* /*JD*/,
61-
llvm::jitlink::LinkGraph& G,
62-
OnAllocatedFunction OnAllocated) {
63-
slab_->allocate(G, std::move(OnAllocated));
68+
std::unique_ptr<Slab> SlabPoolMemoryManager::createSlab(std::size_t capacity) {
69+
auto slab = std::make_unique<Slab>(page_size_, capacity);
70+
if (!slab->isValid()) return nullptr;
71+
return slab;
6472
}
6573

66-
void ArenaJITLinkMemoryManager::deallocate(std::vector<FinalizedAlloc> Allocs,
67-
OnDeallocatedFunction OnDeallocated) {
74+
void SlabPoolMemoryManager::allocate(const llvm::jitlink::JITLinkDylib* /*JD*/,
75+
llvm::jitlink::LinkGraph& G,
76+
OnAllocatedFunction OnAllocated) {
77+
// Step 1: pre-compute footprint to decide normal vs oversize.
78+
auto footprint = Slab::computeGraphFootprint(G, page_size_);
79+
std::size_t total = footprint.total();
80+
81+
// Step 2: conservative usable-per-slab estimate. The dual-pool
82+
// midpoint split means a graph cannot use the entire slab — one
83+
// pool's cursor is bounded at midpoint, the other at
84+
// exec_bump_limit. 2 MB of slack covers midpoint alignment. A
85+
// false-positive oversize just costs one extra mmap sized to fit,
86+
// never a crash.
87+
std::size_t usable = slab_size_ > 2 * Slab::kCommitGranularity
88+
? slab_size_ - 2 * Slab::kCommitGranularity
89+
: slab_size_ / 2;
90+
91+
// `pool_mu_` only protects the slabs_ vector itself. We never hold
92+
// it across a call into Slab::allocate or a user callback: the
93+
// LLJIT linker will frequently invoke nested lookups (which trigger
94+
// recursive allocate() calls via materialization) from inside
95+
// OnAllocated, and a coarse lock here would deadlock. Slabs we've
96+
// seen in a snapshot are guaranteed to outlive this call because
97+
// Stage B never removes slabs from the pool.
98+
using AllocResult = Expected<std::unique_ptr<InFlightAlloc>>;
99+
100+
// Step 3: oversize path — one graph per dedicated slab.
101+
if (total > usable) {
102+
std::size_t needed = total + 2 * Slab::kCommitGranularity;
103+
std::size_t cap = llvm::alignTo(needed, Slab::kCommitGranularity);
104+
if (cap < slab_size_) cap = slab_size_;
105+
auto slab = createSlab(cap);
106+
if (!slab) {
107+
OnAllocated(llvm::make_error<llvm::StringError>(
108+
"SlabPoolMemoryManager: mmap failed for oversize slab of " +
109+
llvm::formatv("{0:x}", cap).str() + " bytes",
110+
llvm::inconvertibleErrorCode()));
111+
return;
112+
}
113+
Slab* raw = slab.get();
114+
{
115+
std::lock_guard<std::mutex> lock(pool_mu_);
116+
slabs_.push_back(std::move(slab));
117+
}
118+
raw->allocate(G, std::move(OnAllocated));
119+
return;
120+
}
121+
122+
// Step 4: first-fit over existing slabs. Take a snapshot of raw
123+
// pointers under the lock, then iterate without holding it.
124+
// Slab::allocate is synchronous (invokes its callback inline on
125+
// every code path), so we observe the result via a captured
126+
// std::optional that the callback fills before the call returns.
127+
std::vector<Slab*> snapshot;
128+
{
129+
std::lock_guard<std::mutex> lock(pool_mu_);
130+
snapshot.reserve(slabs_.size());
131+
for (auto& s : slabs_) snapshot.push_back(s.get());
132+
}
133+
for (Slab* slab : snapshot) {
134+
std::optional<AllocResult> observed;
135+
slab->allocate(G, [&](AllocResult R) { observed.emplace(std::move(R)); });
136+
AllocResult result = std::move(*observed);
137+
if (result) {
138+
OnAllocated(std::move(result));
139+
return;
140+
}
141+
Error E = result.takeError();
142+
if (E.isA<SlabPoolExhaustedError>()) {
143+
// Retriable: this graph didn't fit in this slab. Try next.
144+
llvm::consumeError(std::move(E));
145+
continue;
146+
}
147+
// Terminal error (mmap, mprotect, JITLink, BasicLayout).
148+
OnAllocated(std::move(E));
149+
return;
150+
}
151+
152+
// Step 5: no existing slab fits. Mmap a new normal-size slab.
153+
// Another thread may have added slabs meanwhile; we don't re-scan
154+
// (would duplicate work under contention). Concurrent creates
155+
// would at worst make the pool grow faster than strictly necessary
156+
// — never incorrect.
157+
auto slab = createSlab(slab_size_);
158+
if (!slab) {
159+
OnAllocated(llvm::make_error<llvm::StringError>(
160+
"SlabPoolMemoryManager: mmap failed for new slab of " +
161+
llvm::formatv("{0:x}", slab_size_).str() + " bytes",
162+
llvm::inconvertibleErrorCode()));
163+
return;
164+
}
165+
Slab* raw = slab.get();
166+
{
167+
std::lock_guard<std::mutex> lock(pool_mu_);
168+
slabs_.push_back(std::move(slab));
169+
}
170+
// A fresh slab must fit any graph we've already decided is in-range
171+
// (step 2 + step 3 gate). If somehow it doesn't, the error is
172+
// propagated through — not retried.
173+
raw->allocate(G, std::move(OnAllocated));
174+
}
175+
176+
void SlabPoolMemoryManager::deallocate(std::vector<FinalizedAlloc> Allocs,
177+
OnDeallocatedFunction OnDeallocated) {
68178
Error DeallocErr = Error::success();
69179
for (auto& Alloc : Allocs) {
70180
auto* FA = Alloc.release().toPtr<FinalizedAllocInfo*>();

addons/tvm_ffi_orcjit/src/ffi/orcjit_memory_manager.h

Lines changed: 60 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -19,44 +19,28 @@
1919

2020
/*!
2121
* \file orcjit_memory_manager.h
22-
* \brief JITLinkMemoryManager backed by one fixed-size Slab (Stage A).
22+
* \brief Per-session growable slab pool (Stage B).
2323
*
24-
* Holds a single `Slab` (see orcjit_slab.h) and delegates all
25-
* `JITLinkMemoryManager` operations to it. The capacity-negotiation
26-
* retry loop (halve-on-mmap-failure) lives here because it is a
27-
* caller/kernel negotiation, not something a Slab itself should do.
24+
* `SlabPoolMemoryManager` implements `JITLinkMemoryManager` on top of a
25+
* per-session `std::vector<std::unique_ptr<Slab>>`. On each `allocate`
26+
* it picks the first `Slab` that can fit the graph; if none do, it
27+
* `mmap`s a new fixed-size (`slab_size`) slab and appends it. Graphs
28+
* larger than a single normal slab go to the oversize path — a
29+
* dedicated `Slab` sized to fit that one graph.
2830
*
29-
* In later stages this class is superseded by a `SlabPoolMemoryManager`
30-
* that owns multiple Slabs per session and grows on demand. The public
31-
* API surface (one memory manager per LLJIT, constructed with a
32-
* page-size and a slab-capacity) stays the same.
31+
* ## Lifecycle (Stage B)
3332
*
34-
* ## GOTPCRELX relaxation workaround
35-
*
36-
* The arena triggers a latent bug in LLVM JITLink's
37-
* `optimizeGOTAndStubAccesses()` (x86_64.cpp). That pass relaxes
38-
* `call *foo@GOTPCREL(%rip)` (ff 15) → `addr32 call foo` (67 e8) and
39-
* sets the edge kind to `Pointer32` (absolute 32-bit address). However
40-
* the `call rel32` instruction is always **PC-relative** — the `67`
41-
* prefix is just padding — so the fixup should be PC-relative too
42-
* (matching the static linker's `R_X86_64_PC32`).
33+
* Once a slab is added to the pool it stays mapped until the pool
34+
* (and its enclosing session) is destroyed. Individual graphs are
35+
* deallocated via `FA->owner->deallocateOne(...)`, returning bytes to
36+
* the slab's free list, but the slab's VA reservation is not reclaimed.
37+
* Stage C will add warm-slab eviction that munmaps drained slabs.
4338
*
44-
* The bug is latent because the relaxation only fires when the target
45-
* address fits in 32 bits (`isUInt<32>`). On PIE executables every
46-
* resolved symbol is at a high address, so the guard is never true and
47-
* the relaxation never runs. On **non-PIE** executables the PLT
48-
* entries for libc functions (malloc, free, …) live near 0x400000, the
49-
* guard passes, and the wrong fixup produces a garbage displacement →
50-
* SIGSEGV during ORC-runtime teardown.
39+
* ## GOTPCRELX relaxation workaround
5140
*
52-
* `GOTPCRELXFixPlugin` in llvm_patches/gotpcrelx_fix.cc works around
53-
* this: a PreFixupPass that runs *after* `optimizeGOTAndStubAccesses`
54-
* detects `Pointer32` edges on `67 e8` / `e9` instructions and either
55-
* (a) converts to `BranchPCRel32` when the PC-relative displacement
56-
* fits in int32, or
57-
* (b) reverts the relaxation entirely — restores the `ff 15` /
58-
* `ff 25` opcode bytes and retargets the edge to the GOT entry
59-
* with `PCRel32` + addend 0.
41+
* Unchanged from Stage A — see `llvm_patches/gotpcrelx_fix.cc`. The
42+
* plugin is added per-session to the `ObjectLinkingLayer` alongside
43+
* this memory manager and is orthogonal to pool growth.
6044
*/
6145
#ifndef TVM_FFI_ORCJIT_ORCJIT_MEMORY_MANAGER_H_
6246
#define TVM_FFI_ORCJIT_ORCJIT_MEMORY_MANAGER_H_
@@ -65,48 +49,67 @@
6549

6650
#include <cstddef>
6751
#include <memory>
52+
#include <mutex>
53+
#include <vector>
6854

6955
#include "orcjit_slab.h"
7056

7157
namespace tvm {
7258
namespace ffi {
7359
namespace orcjit {
7460

75-
/*! \brief JITLink memory manager backed by a single `Slab`.
76-
*
77-
* Reserves `slab_capacity` bytes of contiguous VA at construction time,
78-
* halving the request down to `kMinSlabCapacity` if the initial `mmap`
79-
* fails (RLIMIT_AS, container limits).
61+
/*!
62+
* \brief `JITLinkMemoryManager` backed by a growable pool of `Slab`s.
8063
*
81-
* `allocate` and `deallocate` forward to the underlying `Slab`.
64+
* The constructor reserves one initial slab (halving its capacity down
65+
* to `kMinSlabSize` if `mmap` fails under RLIMIT_AS). Subsequent
66+
* slabs are added on demand by `allocate()` and reserved at exactly
67+
* `slab_size_` bytes each — no retry, no halving, errors propagate.
8268
*/
83-
class ArenaJITLinkMemoryManager : public llvm::jitlink::JITLinkMemoryManager {
69+
class SlabPoolMemoryManager : public llvm::jitlink::JITLinkMemoryManager {
8470
public:
85-
// Default slab capacity: 1 GB on both architectures. Well within the
86-
// PC-relative relocation limit (x86_64 ±2 GB, AArch64 ±4 GB) so
87-
// cross-section fixups always fit; large enough to cover typical ML
88-
// JIT workloads without oversubscribing virtual address space on
89-
// memory-constrained hosts (containers, CI runners).
90-
static constexpr std::size_t kDefaultSlabCapacity_x86_64 = std::size_t{1} << 30; // 1 GB
91-
static constexpr std::size_t kDefaultSlabCapacity_AArch64 = std::size_t{1} << 30; // 1 GB
92-
static constexpr std::size_t kMinSlabCapacity = std::size_t{256} << 20; // 256 MB floor
93-
94-
explicit ArenaJITLinkMemoryManager(std::size_t page_size, std::size_t slab_capacity);
95-
~ArenaJITLinkMemoryManager() override = default;
96-
97-
ArenaJITLinkMemoryManager(const ArenaJITLinkMemoryManager&) = delete;
98-
ArenaJITLinkMemoryManager& operator=(const ArenaJITLinkMemoryManager&) = delete;
99-
ArenaJITLinkMemoryManager(ArenaJITLinkMemoryManager&&) = delete;
100-
ArenaJITLinkMemoryManager& operator=(ArenaJITLinkMemoryManager&&) = delete;
71+
// Default per-slab capacity. 64 MB is above the p99 size of typical
72+
// ML JIT graphs (single-kernel bindings, fused kernels), below the
73+
// PC-relative relocation limit, and a multiple of the 2 MB THP
74+
// granule. Small enough that a pinned slab only wastes 64 MB of RSS.
75+
static constexpr std::size_t kDefaultSlabSize = std::size_t{64} << 20; // 64 MB
76+
77+
// Lower bound on initial-slab reservation. If the first `mmap`
78+
// fails and halving drops below this, the constructor aborts.
79+
// 8 MB is enough for a minimal JITDylib setup under very tight
80+
// RLIMIT_AS.
81+
static constexpr std::size_t kMinSlabSize = std::size_t{8} << 20; // 8 MB
82+
83+
explicit SlabPoolMemoryManager(std::size_t page_size, std::size_t slab_size);
84+
~SlabPoolMemoryManager() override = default;
85+
86+
SlabPoolMemoryManager(const SlabPoolMemoryManager&) = delete;
87+
SlabPoolMemoryManager& operator=(const SlabPoolMemoryManager&) = delete;
88+
SlabPoolMemoryManager(SlabPoolMemoryManager&&) = delete;
89+
SlabPoolMemoryManager& operator=(SlabPoolMemoryManager&&) = delete;
10190

10291
void allocate(const llvm::jitlink::JITLinkDylib* JD, llvm::jitlink::LinkGraph& G,
10392
OnAllocatedFunction OnAllocated) override;
10493

10594
void deallocate(std::vector<FinalizedAlloc> Allocs,
10695
OnDeallocatedFunction OnDeallocated) override;
10796

97+
/*! \brief Number of slabs currently held (test introspection). */
98+
std::size_t numSlabs() const {
99+
std::lock_guard<std::mutex> lock(pool_mu_);
100+
return slabs_.size();
101+
}
102+
108103
private:
109-
std::unique_ptr<Slab> slab_;
104+
/*! \brief Reserve a fresh slab at exactly \p capacity bytes. Returns
105+
* nullptr on mmap failure (caller reports the error). */
106+
std::unique_ptr<Slab> createSlab(std::size_t capacity);
107+
108+
std::size_t page_size_;
109+
std::size_t slab_size_;
110+
111+
mutable std::mutex pool_mu_;
112+
std::vector<std::unique_ptr<Slab>> slabs_;
110113
};
111114

112115
} // namespace orcjit

0 commit comments

Comments
 (0)