Skip to content

Commit 04ef2aa

Browse files
committed
fix: free stale embedding engines when switching provider classes
Preserve the transactional behavior added to memory_set_model(), but release the previous engine after a successful switch even when the new model uses a different provider class. Without this, switching between custom, local, and remote providers can leave the old engine attached to the context until close or until the same provider class is selected again. For local models this can keep large model / VRAM allocations resident unnecessarily. Also add a regression test covering cross-class provider switches to ensure the previous engine is freed immediately and not double-freed on database close.
1 parent f47d2f9 commit 04ef2aa

2 files changed

Lines changed: 83 additions & 3 deletions

File tree

src/sqlite-memory.c

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,17 +1119,35 @@ static void dbmem_set_model (sqlite3_context *context, int argc, sqlite3_value *
11191119
if (old_provider) dbmemory_free(old_provider);
11201120
if (old_model) dbmemory_free(old_model);
11211121
#ifndef DBMEM_OMIT_LOCAL_ENGINE
1122-
if (!is_custom_provider && is_local_provider && old_l_engine && old_l_engine != new_l_engine) {
1122+
if (!is_custom_provider && is_local_provider) {
1123+
if (old_l_engine && old_l_engine != new_l_engine) {
1124+
dbmem_local_engine_free(old_l_engine);
1125+
}
1126+
} else if (old_l_engine) {
1127+
// switching away from local provider: release the previous engine
11231128
dbmem_local_engine_free(old_l_engine);
1129+
ctx->l_engine = NULL;
11241130
}
11251131
#endif
11261132
#ifndef DBMEM_OMIT_REMOTE_ENGINE
1127-
if (!is_custom_provider && !is_local_provider && old_r_engine && old_r_engine != new_r_engine) {
1133+
if (!is_custom_provider && !is_local_provider) {
1134+
if (old_r_engine && old_r_engine != new_r_engine) {
1135+
dbmem_remote_engine_free(old_r_engine);
1136+
}
1137+
} else if (old_r_engine) {
1138+
// switching away from remote provider: release the previous engine
11281139
dbmem_remote_engine_free(old_r_engine);
1140+
ctx->r_engine = NULL;
11291141
}
11301142
#endif
1131-
if (is_custom_provider && old_custom_engine && old_custom_engine != new_custom_engine && ctx->custom_provider.free) {
1143+
if (is_custom_provider) {
1144+
if (old_custom_engine && old_custom_engine != new_custom_engine && ctx->custom_provider.free) {
1145+
ctx->custom_provider.free(old_custom_engine, ctx->custom_provider.xdata);
1146+
}
1147+
} else if (old_custom_engine && ctx->custom_provider.free) {
1148+
// switching away from custom provider: release the previous engine
11321149
ctx->custom_provider.free(old_custom_engine, ctx->custom_provider.xdata);
1150+
ctx->custom_engine = NULL;
11331151
}
11341152

11351153
sqlite3_result_int(context, 1);

test/unittest.c

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2579,6 +2579,67 @@ TEST(sqlite_set_model_failed_reindex_preserves_existing_rows) {
25792579
sqlite3_close(db);
25802580
}
25812581

2582+
// Regression: when memory_set_model() switches from a custom provider to a
2583+
// remote provider (a different provider class), the previous custom engine
2584+
// must be released immediately — not leaked until the database is closed.
2585+
typedef struct {
2586+
int free_count;
2587+
} tracking_free_state_t;
2588+
2589+
static void *tracking_init(const char *model, const char *api_key, void *xdata, char err_msg[1024]) {
2590+
UNUSED_PARAM(model);
2591+
UNUSED_PARAM(api_key);
2592+
UNUSED_PARAM(xdata);
2593+
UNUSED_PARAM(err_msg);
2594+
// any non-NULL pointer is fine; the test only cares about the free callback
2595+
return calloc(1, 1);
2596+
}
2597+
2598+
static int tracking_compute(void *engine, const char *text, int text_len, void *xdata, dbmem_embedding_result_t *result) {
2599+
UNUSED_PARAM(engine);
2600+
UNUSED_PARAM(text);
2601+
UNUSED_PARAM(text_len);
2602+
UNUSED_PARAM(xdata);
2603+
UNUSED_PARAM(result);
2604+
return -1;
2605+
}
2606+
2607+
static void tracking_free(void *engine, void *xdata) {
2608+
tracking_free_state_t *s = (tracking_free_state_t *)xdata;
2609+
if (s) s->free_count++;
2610+
free(engine);
2611+
}
2612+
2613+
TEST(sqlite_set_model_releases_previous_engine_on_class_switch) {
2614+
sqlite3 *db = open_test_db();
2615+
ASSERT(db != NULL);
2616+
2617+
// remote engine init requires an api key to succeed
2618+
sqlite3_int64 result = 0;
2619+
int rc = exec_get_int(db, "SELECT memory_set_apikey('test-key');", &result);
2620+
ASSERT_EQ(rc, SQLITE_OK);
2621+
2622+
tracking_free_state_t state = {0};
2623+
dbmem_provider_t prov = { .init = tracking_init, .compute = tracking_compute, .free = tracking_free, .xdata = &state };
2624+
rc = sqlite3_memory_register_provider(db, "tracker", &prov);
2625+
ASSERT_EQ(rc, SQLITE_OK);
2626+
2627+
// activate the custom provider — ctx->custom_engine is now non-NULL
2628+
rc = exec_get_int(db, "SELECT memory_set_model('tracker', 'm1');", &result);
2629+
ASSERT_EQ(rc, SQLITE_OK);
2630+
ASSERT_EQ(state.free_count, 0);
2631+
2632+
// switch to a provider from a different class (remote). The previous
2633+
// custom engine must be released during this call, not kept alive on ctx.
2634+
rc = exec_get_int(db, "SELECT memory_set_model('openai', 'text-embedding-3-small');", &result);
2635+
ASSERT_EQ(rc, SQLITE_OK);
2636+
ASSERT_EQ(state.free_count, 1);
2637+
2638+
// closing the db must not double-free the already-released custom engine
2639+
sqlite3_close(db);
2640+
ASSERT_EQ(state.free_count, 1);
2641+
}
2642+
25822643
#endif // TEST_SQLITE_EXTENSION
25832644

25842645
// ============================================================================
@@ -2718,6 +2779,7 @@ int main(int argc, char *argv[]) {
27182779
RUN_TEST(sqlite_custom_provider_init_error);
27192780
RUN_TEST(sqlite_custom_provider_apikey_passed);
27202781
RUN_TEST(sqlite_set_model_failed_reindex_preserves_existing_rows);
2782+
RUN_TEST(sqlite_set_model_releases_previous_engine_on_class_switch);
27212783
#endif
27222784

27232785
printf("\n=== Results ===\n");

0 commit comments

Comments
 (0)