diff --git a/native/src/lib.rs b/native/src/lib.rs index 18e044a..0721f4f 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -328,7 +328,7 @@ impl InvertedIndex { #[napi] pub struct Database { - conn: std::sync::Mutex, + conn: std::sync::Mutex>, } #[napi(object)] @@ -374,41 +374,63 @@ impl Database { let conn = db::init_db(std::path::Path::new(&db_path)) .map_err(|e| Error::from_reason(e.to_string()))?; Ok(Self { - conn: std::sync::Mutex::new(conn), + conn: std::sync::Mutex::new(Some(conn)), }) } + fn closed_error() -> Error { + Error::from_reason("Database is closed") + } + + fn lock_conn(&self) -> Result>> { + self.conn + .lock() + .map_err(|e| Error::from_reason(e.to_string())) + } + + fn with_conn(&self, f: F) -> Result + where + F: FnOnce(&rusqlite::Connection) -> Result, + { + let conn = self.lock_conn()?; + let conn = conn.as_ref().ok_or_else(Self::closed_error)?; + f(conn) + } + + fn with_conn_mut(&self, f: F) -> Result + where + F: FnOnce(&mut rusqlite::Connection) -> Result, + { + let mut conn = self.lock_conn()?; + let conn = conn.as_mut().ok_or_else(Self::closed_error)?; + f(conn) + } + #[napi] pub fn close(&self) -> Result<()> { - let mut conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let replacement = rusqlite::Connection::open_in_memory() - .map_err(|e| Error::from_reason(e.to_string()))?; - let old = std::mem::replace(&mut *conn, replacement); + // Best-effort, idempotent shutdown: once the connection is taken, all + // future calls fail fast with `Database is closed` and repeated close() + // calls are harmless. + let mut conn = self.lock_conn()?; + let old = conn.take(); drop(old); Ok(()) } #[napi] pub fn embedding_exists(&self, content_hash: String) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::embedding_exists(&conn, &content_hash).map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::embedding_exists(conn, &content_hash).map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn get_embedding(&self, content_hash: String) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let result = db::get_embedding(&conn, &content_hash) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(result.map(Buffer::from)) + self.with_conn(|conn| { + let result = db::get_embedding(conn, &content_hash) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(result.map(Buffer::from)) + }) } #[napi] @@ -419,75 +441,44 @@ impl Database { chunk_text: String, model: String, ) -> Result<()> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::upsert_embedding(&conn, &content_hash, &embedding, &chunk_text, &model) - .map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::upsert_embedding(conn, &content_hash, &embedding, &chunk_text, &model) + .map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn get_missing_embeddings(&self, content_hashes: Vec) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::get_missing_embeddings(&conn, &content_hashes) - .map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::get_missing_embeddings(conn, &content_hashes) + .map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn upsert_chunk(&self, chunk: ChunkData) -> Result<()> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::upsert_chunk( - &conn, - &chunk.chunk_id, - &chunk.content_hash, - &chunk.file_path, - chunk.start_line, - chunk.end_line, - chunk.node_type.as_deref(), - chunk.name.as_deref(), - &chunk.language, - ) - .map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::upsert_chunk( + conn, + &chunk.chunk_id, + &chunk.content_hash, + &chunk.file_path, + chunk.start_line, + chunk.end_line, + chunk.node_type.as_deref(), + chunk.name.as_deref(), + &chunk.language, + ) + .map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn get_chunk(&self, chunk_id: String) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let result = - db::get_chunk(&conn, &chunk_id).map_err(|e| Error::from_reason(e.to_string()))?; - Ok(result.map(|row| ChunkData { - chunk_id: row.chunk_id, - content_hash: row.content_hash, - file_path: row.file_path, - start_line: row.start_line, - end_line: row.end_line, - node_type: row.node_type, - name: row.name, - language: row.language, - })) - } - - #[napi] - pub fn get_chunks_by_file(&self, file_path: String) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let rows = db::get_chunks_by_file(&conn, &file_path) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(rows - .into_iter() - .map(|row| ChunkData { + self.with_conn(|conn| { + let result = + db::get_chunk(conn, &chunk_id).map_err(|e| Error::from_reason(e.to_string()))?; + Ok(result.map(|row| ChunkData { chunk_id: row.chunk_id, content_hash: row.content_hash, file_path: row.file_path, @@ -496,94 +487,101 @@ impl Database { node_type: row.node_type, name: row.name, language: row.language, - }) - .collect()) + })) + }) + } + + #[napi] + pub fn get_chunks_by_file(&self, file_path: String) -> Result> { + self.with_conn(|conn| { + let rows = db::get_chunks_by_file(conn, &file_path) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(rows + .into_iter() + .map(|row| ChunkData { + chunk_id: row.chunk_id, + content_hash: row.content_hash, + file_path: row.file_path, + start_line: row.start_line, + end_line: row.end_line, + node_type: row.node_type, + name: row.name, + language: row.language, + }) + .collect()) + }) } #[napi] pub fn get_chunks_by_name(&self, name: String) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let rows = - db::get_chunks_by_name(&conn, &name).map_err(|e| Error::from_reason(e.to_string()))?; - Ok(rows - .into_iter() - .map(|row| ChunkData { - chunk_id: row.chunk_id, - content_hash: row.content_hash, - file_path: row.file_path, - start_line: row.start_line, - end_line: row.end_line, - node_type: row.node_type, - name: row.name, - language: row.language, - }) - .collect()) + self.with_conn(|conn| { + let rows = db::get_chunks_by_name(conn, &name) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(rows + .into_iter() + .map(|row| ChunkData { + chunk_id: row.chunk_id, + content_hash: row.content_hash, + file_path: row.file_path, + start_line: row.start_line, + end_line: row.end_line, + node_type: row.node_type, + name: row.name, + language: row.language, + }) + .collect()) + }) } #[napi] pub fn get_chunks_by_name_ci(&self, name: String) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let rows = db::get_chunks_by_name_ci(&conn, &name) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(rows - .into_iter() - .map(|row| ChunkData { - chunk_id: row.chunk_id, - content_hash: row.content_hash, - file_path: row.file_path, - start_line: row.start_line, - end_line: row.end_line, - node_type: row.node_type, - name: row.name, - language: row.language, - }) - .collect()) + self.with_conn(|conn| { + let rows = db::get_chunks_by_name_ci(conn, &name) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(rows + .into_iter() + .map(|row| ChunkData { + chunk_id: row.chunk_id, + content_hash: row.content_hash, + file_path: row.file_path, + start_line: row.start_line, + end_line: row.end_line, + node_type: row.node_type, + name: row.name, + language: row.language, + }) + .collect()) + }) } #[napi] pub fn delete_chunks_by_file(&self, file_path: String) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = db::delete_chunks_by_file(&conn, &file_path) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = db::delete_chunks_by_file(conn, &file_path) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } #[napi] pub fn delete_chunks_by_ids(&self, chunk_ids: Vec) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = db::delete_chunks_by_ids(&conn, &chunk_ids) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = db::delete_chunks_by_ids(conn, &chunk_ids) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } #[napi] pub fn add_chunks_to_branch(&self, branch: String, chunk_ids: Vec) -> Result<()> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::add_chunks_to_branch(&conn, &branch, &chunk_ids) - .map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::add_chunks_to_branch(conn, &branch, &chunk_ids) + .map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn upsert_embeddings_batch(&self, items: Vec) -> Result<()> { - let mut conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; let batch: Vec<(String, Vec, String, String)> = items .into_iter() .map(|item| { @@ -595,16 +593,13 @@ impl Database { ) }) .collect(); - db::upsert_embeddings_batch(&mut conn, &batch) - .map_err(|e| Error::from_reason(e.to_string())) + self.with_conn_mut(|conn| { + db::upsert_embeddings_batch(conn, &batch).map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn upsert_chunks_batch(&self, chunks: Vec) -> Result<()> { - let mut conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; let batch: Vec = chunks .into_iter() .map(|c| db::ChunkRow { @@ -618,39 +613,35 @@ impl Database { language: c.language, }) .collect(); - db::upsert_chunks_batch(&mut conn, &batch).map_err(|e| Error::from_reason(e.to_string())) + self.with_conn_mut(|conn| { + db::upsert_chunks_batch(conn, &batch).map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn add_chunks_to_branch_batch(&self, branch: String, chunk_ids: Vec) -> Result<()> { - let mut conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::add_chunks_to_branch_batch(&mut conn, &branch, &chunk_ids) - .map_err(|e| Error::from_reason(e.to_string())) + self.with_conn_mut(|conn| { + db::add_chunks_to_branch_batch(conn, &branch, &chunk_ids) + .map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn clear_branch(&self, branch: String) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = - db::clear_branch(&conn, &branch).map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = + db::clear_branch(conn, &branch).map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } #[napi] pub fn delete_branch_chunks_by_chunk_ids(&self, chunk_ids: Vec) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = db::delete_branch_chunks_by_chunk_ids(&conn, &chunk_ids) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = db::delete_branch_chunks_by_chunk_ids(conn, &chunk_ids) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } #[napi] @@ -659,149 +650,122 @@ impl Database { branch: String, chunk_ids: Vec, ) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = db::delete_branch_chunks_for_branch(&conn, &branch, &chunk_ids) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = db::delete_branch_chunks_for_branch(conn, &branch, &chunk_ids) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } #[napi] pub fn get_branch_chunk_ids(&self, branch: String) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::get_branch_chunk_ids(&conn, &branch).map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::get_branch_chunk_ids(conn, &branch).map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn get_branch_delta(&self, branch: String, base_branch: String) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let delta = db::get_branch_delta(&conn, &branch, &base_branch) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(BranchDelta { - added: delta.added, - removed: delta.removed, + self.with_conn(|conn| { + let delta = db::get_branch_delta(conn, &branch, &base_branch) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(BranchDelta { + added: delta.added, + removed: delta.removed, + }) }) } #[napi] pub fn get_referenced_chunk_ids(&self, chunk_ids: Vec) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::get_referenced_chunk_ids(&conn, &chunk_ids) - .map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::get_referenced_chunk_ids(conn, &chunk_ids) + .map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn chunk_exists_on_branch(&self, branch: String, chunk_id: String) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::chunk_exists_on_branch(&conn, &branch, &chunk_id) - .map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::chunk_exists_on_branch(conn, &branch, &chunk_id) + .map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn get_all_branches(&self) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::get_all_branches(&conn).map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::get_all_branches(conn).map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn get_metadata(&self, key: String) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::get_metadata(&conn, &key).map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::get_metadata(conn, &key).map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn set_metadata(&self, key: String, value: String) -> Result<()> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::set_metadata(&conn, &key, &value).map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::set_metadata(conn, &key, &value).map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn delete_metadata(&self, key: String) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::delete_metadata(&conn, &key).map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::delete_metadata(conn, &key).map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn clear_all_indexed_data(&self) -> Result<()> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::clear_all_indexed_data(&conn).map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::clear_all_indexed_data(conn).map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn clear_call_edge_targets_for_symbols(&self, symbol_ids: Vec) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = db::clear_call_edge_targets_for_symbols(&conn, &symbol_ids) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = db::clear_call_edge_targets_for_symbols(conn, &symbol_ids) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } #[napi] pub fn gc_orphan_embeddings(&self) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = - db::gc_orphan_embeddings(&conn).map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = + db::gc_orphan_embeddings(conn).map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } #[napi] pub fn gc_orphan_chunks(&self) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = db::gc_orphan_chunks(&conn).map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = + db::gc_orphan_chunks(conn).map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } #[napi] pub fn get_stats(&self) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let stats = db::get_stats(&conn).map_err(|e| Error::from_reason(e.to_string()))?; - Ok(DatabaseStats { - embedding_count: stats.embedding_count as u32, - chunk_count: stats.chunk_count as u32, - branch_chunk_count: stats.branch_chunk_count as u32, - branch_count: stats.branch_count as u32, - symbol_count: stats.symbol_count as u32, - call_edge_count: stats.call_edge_count as u32, + self.with_conn(|conn| { + let stats = db::get_stats(conn).map_err(|e| Error::from_reason(e.to_string()))?; + Ok(DatabaseStats { + embedding_count: stats.embedding_count as u32, + chunk_count: stats.chunk_count as u32, + branch_chunk_count: stats.branch_chunk_count as u32, + branch_count: stats.branch_count as u32, + symbol_count: stats.symbol_count as u32, + call_edge_count: stats.call_edge_count as u32, + }) }) } @@ -809,10 +773,6 @@ impl Database { #[napi] pub fn upsert_symbol(&self, symbol: SymbolData) -> Result<()> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; let row = db::SymbolRow { id: symbol.id, file_path: symbol.file_path, @@ -824,15 +784,13 @@ impl Database { end_col: symbol.end_col, language: symbol.language, }; - db::upsert_symbol(&conn, &row).map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::upsert_symbol(conn, &row).map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn upsert_symbols_batch(&self, symbols: Vec) -> Result<()> { - let mut conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; let rows: Vec = symbols .into_iter() .map(|s| db::SymbolRow { @@ -847,31 +805,31 @@ impl Database { language: s.language, }) .collect(); - db::upsert_symbols_batch(&mut conn, &rows).map_err(|e| Error::from_reason(e.to_string())) + self.with_conn_mut(|conn| { + db::upsert_symbols_batch(conn, &rows).map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn get_symbols_by_file(&self, file_path: String) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let rows = db::get_symbols_by_file(&conn, &file_path) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(rows - .into_iter() - .map(|r| SymbolData { - id: r.id, - file_path: r.file_path, - name: r.name, - kind: r.kind, - start_line: r.start_line, - start_col: r.start_col, - end_line: r.end_line, - end_col: r.end_col, - language: r.language, - }) - .collect()) + self.with_conn(|conn| { + let rows = db::get_symbols_by_file(conn, &file_path) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(rows + .into_iter() + .map(|r| SymbolData { + id: r.id, + file_path: r.file_path, + name: r.name, + kind: r.kind, + start_line: r.start_line, + start_col: r.start_col, + end_line: r.end_line, + end_col: r.end_col, + language: r.language, + }) + .collect()) + }) } #[napi] @@ -880,36 +838,10 @@ impl Database { name: String, file_path: String, ) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let row = db::get_symbol_by_name(&conn, &name, &file_path) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(row.map(|r| SymbolData { - id: r.id, - file_path: r.file_path, - name: r.name, - kind: r.kind, - start_line: r.start_line, - start_col: r.start_col, - end_line: r.end_line, - end_col: r.end_col, - language: r.language, - })) - } - - #[napi] - pub fn get_symbols_by_name(&self, name: String) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let rows = - db::get_symbols_by_name(&conn, &name).map_err(|e| Error::from_reason(e.to_string()))?; - Ok(rows - .into_iter() - .map(|r| SymbolData { + self.with_conn(|conn| { + let row = db::get_symbol_by_name(conn, &name, &file_path) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(row.map(|r| SymbolData { id: r.id, file_path: r.file_path, name: r.name, @@ -919,53 +851,67 @@ impl Database { end_line: r.end_line, end_col: r.end_col, language: r.language, - }) - .collect()) + })) + }) + } + + #[napi] + pub fn get_symbols_by_name(&self, name: String) -> Result> { + self.with_conn(|conn| { + let rows = db::get_symbols_by_name(conn, &name) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(rows + .into_iter() + .map(|r| SymbolData { + id: r.id, + file_path: r.file_path, + name: r.name, + kind: r.kind, + start_line: r.start_line, + start_col: r.start_col, + end_line: r.end_line, + end_col: r.end_col, + language: r.language, + }) + .collect()) + }) } #[napi] pub fn get_symbols_by_name_ci(&self, name: String) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let rows = db::get_symbols_by_name_ci(&conn, &name) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(rows - .into_iter() - .map(|r| SymbolData { - id: r.id, - file_path: r.file_path, - name: r.name, - kind: r.kind, - start_line: r.start_line, - start_col: r.start_col, - end_line: r.end_line, - end_col: r.end_col, - language: r.language, - }) - .collect()) + self.with_conn(|conn| { + let rows = db::get_symbols_by_name_ci(conn, &name) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(rows + .into_iter() + .map(|r| SymbolData { + id: r.id, + file_path: r.file_path, + name: r.name, + kind: r.kind, + start_line: r.start_line, + start_col: r.start_col, + end_line: r.end_line, + end_col: r.end_col, + language: r.language, + }) + .collect()) + }) } #[napi] pub fn delete_symbols_by_file(&self, file_path: String) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = db::delete_symbols_by_file(&conn, &file_path) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = db::delete_symbols_by_file(conn, &file_path) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } // ── Call Edge methods ──────────────────────────────────────────── #[napi] pub fn upsert_call_edge(&self, edge: CallEdgeData) -> Result<()> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; let row = db::CallEdgeRow { id: edge.id, from_symbol_id: edge.from_symbol_id, @@ -976,15 +922,13 @@ impl Database { col: edge.col, is_resolved: edge.is_resolved, }; - db::upsert_call_edge(&conn, &row).map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::upsert_call_edge(conn, &row).map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn upsert_call_edges_batch(&self, edges: Vec) -> Result<()> { - let mut conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; let rows: Vec = edges .into_iter() .map(|e| db::CallEdgeRow { @@ -998,57 +942,55 @@ impl Database { is_resolved: e.is_resolved, }) .collect(); - db::upsert_call_edges_batch(&mut conn, &rows).map_err(|e| Error::from_reason(e.to_string())) + self.with_conn_mut(|conn| { + db::upsert_call_edges_batch(conn, &rows).map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn get_callers(&self, symbol_name: String, branch: String) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let rows = db::get_callers(&conn, &symbol_name, &branch) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(rows - .into_iter() - .map(|r| CallEdgeData { - id: r.id, - from_symbol_id: r.from_symbol_id, - from_symbol_name: None, - from_symbol_file_path: None, - target_name: r.target_name, - to_symbol_id: r.to_symbol_id, - call_type: r.call_type, - line: r.line, - col: r.col, - is_resolved: r.is_resolved, - }) - .collect()) + self.with_conn(|conn| { + let rows = db::get_callers(conn, &symbol_name, &branch) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(rows + .into_iter() + .map(|r| CallEdgeData { + id: r.id, + from_symbol_id: r.from_symbol_id, + from_symbol_name: None, + from_symbol_file_path: None, + target_name: r.target_name, + to_symbol_id: r.to_symbol_id, + call_type: r.call_type, + line: r.line, + col: r.col, + is_resolved: r.is_resolved, + }) + .collect()) + }) } #[napi] pub fn get_callees(&self, symbol_id: String, branch: String) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let rows = db::get_callees(&conn, &symbol_id, &branch) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(rows - .into_iter() - .map(|r| CallEdgeData { - id: r.id, - from_symbol_id: r.from_symbol_id, - from_symbol_name: None, - from_symbol_file_path: None, - target_name: r.target_name, - to_symbol_id: r.to_symbol_id, - call_type: r.call_type, - line: r.line, - col: r.col, - is_resolved: r.is_resolved, - }) - .collect()) + self.with_conn(|conn| { + let rows = db::get_callees(conn, &symbol_id, &branch) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(rows + .into_iter() + .map(|r| CallEdgeData { + id: r.id, + from_symbol_id: r.from_symbol_id, + from_symbol_name: None, + from_symbol_file_path: None, + target_name: r.target_name, + to_symbol_id: r.to_symbol_id, + call_type: r.call_type, + line: r.line, + col: r.col, + is_resolved: r.is_resolved, + }) + .collect()) + }) } #[napi] @@ -1057,60 +999,52 @@ impl Database { symbol_name: String, branch: String, ) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let rows = db::get_callers_with_context(&conn, &symbol_name, &branch) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(rows - .into_iter() - .map(|r| CallEdgeData { - id: r.id, - from_symbol_id: r.from_symbol_id, - from_symbol_name: Some(r.from_symbol_name), - from_symbol_file_path: Some(r.from_symbol_file_path), - target_name: r.target_name, - to_symbol_id: r.to_symbol_id, - call_type: r.call_type, - line: r.line, - col: r.col, - is_resolved: r.is_resolved, - }) - .collect()) + self.with_conn(|conn| { + let rows = db::get_callers_with_context(conn, &symbol_name, &branch) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(rows + .into_iter() + .map(|r| CallEdgeData { + id: r.id, + from_symbol_id: r.from_symbol_id, + from_symbol_name: Some(r.from_symbol_name), + from_symbol_file_path: Some(r.from_symbol_file_path), + target_name: r.target_name, + to_symbol_id: r.to_symbol_id, + call_type: r.call_type, + line: r.line, + col: r.col, + is_resolved: r.is_resolved, + }) + .collect()) + }) } #[napi] pub fn delete_call_edges_by_file(&self, file_path: String) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = db::delete_call_edges_by_file(&conn, &file_path) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = db::delete_call_edges_by_file(conn, &file_path) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } #[napi] pub fn resolve_call_edge(&self, edge_id: String, to_symbol_id: String) -> Result<()> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::resolve_call_edge(&conn, &edge_id, &to_symbol_id) - .map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::resolve_call_edge(conn, &edge_id, &to_symbol_id) + .map_err(|e| Error::from_reason(e.to_string())) + }) } // ── Branch Symbol methods ──────────────────────────────────────── #[napi] pub fn add_symbols_to_branch(&self, branch: String, symbol_ids: Vec) -> Result<()> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::add_symbols_to_branch(&conn, &branch, &symbol_ids) - .map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::add_symbols_to_branch(conn, &branch, &symbol_ids) + .map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] @@ -1119,53 +1053,43 @@ impl Database { branch: String, symbol_ids: Vec, ) -> Result<()> { - let mut conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::add_symbols_to_branch_batch(&mut conn, &branch, &symbol_ids) - .map_err(|e| Error::from_reason(e.to_string())) + self.with_conn_mut(|conn| { + db::add_symbols_to_branch_batch(conn, &branch, &symbol_ids) + .map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn get_branch_symbol_ids(&self, branch: String) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::get_branch_symbol_ids(&conn, &branch).map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::get_branch_symbol_ids(conn, &branch).map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn clear_branch_symbols(&self, branch: String) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = db::clear_branch_symbols(&conn, &branch) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = db::clear_branch_symbols(conn, &branch) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } #[napi] pub fn get_referenced_symbol_ids(&self, symbol_ids: Vec) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - db::get_referenced_symbol_ids(&conn, &symbol_ids) - .map_err(|e| Error::from_reason(e.to_string())) + self.with_conn(|conn| { + db::get_referenced_symbol_ids(conn, &symbol_ids) + .map_err(|e| Error::from_reason(e.to_string())) + }) } #[napi] pub fn delete_branch_symbols_by_symbol_ids(&self, symbol_ids: Vec) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = db::delete_branch_symbols_by_symbol_ids(&conn, &symbol_ids) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = db::delete_branch_symbols_by_symbol_ids(conn, &symbol_ids) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } #[napi] @@ -1174,35 +1098,30 @@ impl Database { branch: String, symbol_ids: Vec, ) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = db::delete_branch_symbols_for_branch(&conn, &branch, &symbol_ids) - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = db::delete_branch_symbols_for_branch(conn, &branch, &symbol_ids) + .map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } // ── GC methods for symbols/edges ───────────────────────────────── #[napi] pub fn gc_orphan_symbols(&self) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = db::gc_orphan_symbols(&conn).map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = + db::gc_orphan_symbols(conn).map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } #[napi] pub fn gc_orphan_call_edges(&self) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| Error::from_reason(e.to_string()))?; - let count = - db::gc_orphan_call_edges(&conn).map_err(|e| Error::from_reason(e.to_string()))?; - Ok(count as u32) + self.with_conn(|conn| { + let count = + db::gc_orphan_call_edges(conn).map_err(|e| Error::from_reason(e.to_string()))?; + Ok(count as u32) + }) } } diff --git a/src/native/index.ts b/src/native/index.ts index c69d18c..2fe7a8a 100644 --- a/src/native/index.ts +++ b/src/native/index.ts @@ -683,22 +683,37 @@ export interface DatabaseStats { export class Database { private inner: any; + private closed = false; constructor(dbPath: string) { this.inner = new native.Database(dbPath); } + private throwIfClosed(): void { + if (this.closed) { + throw new Error("Database is closed"); + } + } + close(): void { + if (this.closed) { + return; + } + if (typeof this.inner.close === "function") { this.inner.close(); } + + this.closed = true; } embeddingExists(contentHash: string): boolean { + this.throwIfClosed(); return this.inner.embeddingExists(contentHash); } getEmbedding(contentHash: string): Buffer | null { + this.throwIfClosed(); return this.inner.getEmbedding(contentHash) ?? null; } @@ -708,6 +723,7 @@ export class Database { chunkText: string, model: string ): void { + this.throwIfClosed(); this.inner.upsertEmbedding(contentHash, embedding, chunkText, model); } @@ -719,217 +735,266 @@ export class Database { model: string; }> ): void { + this.throwIfClosed(); if (items.length === 0) return; this.inner.upsertEmbeddingsBatch(items); } getMissingEmbeddings(contentHashes: string[]): string[] { + this.throwIfClosed(); return this.inner.getMissingEmbeddings(contentHashes); } upsertChunk(chunk: ChunkData): void { + this.throwIfClosed(); this.inner.upsertChunk(chunk); } upsertChunksBatch(chunks: ChunkData[]): void { + this.throwIfClosed(); if (chunks.length === 0) return; this.inner.upsertChunksBatch(chunks); } getChunk(chunkId: string): ChunkData | null { + this.throwIfClosed(); return this.inner.getChunk(chunkId) ?? null; } getChunksByFile(filePath: string): ChunkData[] { + this.throwIfClosed(); return this.inner.getChunksByFile(filePath); } getChunksByName(name: string): ChunkData[] { + this.throwIfClosed(); return this.inner.getChunksByName(name); } getChunksByNameCi(name: string): ChunkData[] { + this.throwIfClosed(); return this.inner.getChunksByNameCi(name); } deleteChunksByFile(filePath: string): number { + this.throwIfClosed(); return this.inner.deleteChunksByFile(filePath); } deleteChunksByIds(chunkIds: string[]): number { + this.throwIfClosed(); if (chunkIds.length === 0) return 0; return this.inner.deleteChunksByIds(chunkIds); } addChunksToBranch(branch: string, chunkIds: string[]): void { + this.throwIfClosed(); this.inner.addChunksToBranch(branch, chunkIds); } addChunksToBranchBatch(branch: string, chunkIds: string[]): void { + this.throwIfClosed(); if (chunkIds.length === 0) return; this.inner.addChunksToBranchBatch(branch, chunkIds); } clearBranch(branch: string): number { + this.throwIfClosed(); return this.inner.clearBranch(branch); } deleteBranchChunksByChunkIds(chunkIds: string[]): number { + this.throwIfClosed(); if (chunkIds.length === 0) return 0; return this.inner.deleteBranchChunksByChunkIds(chunkIds); } deleteBranchChunksForBranch(branch: string, chunkIds: string[]): number { + this.throwIfClosed(); if (chunkIds.length === 0) return 0; return this.inner.deleteBranchChunksForBranch(branch, chunkIds); } getBranchChunkIds(branch: string): string[] { + this.throwIfClosed(); return this.inner.getBranchChunkIds(branch); } getBranchDelta(branch: string, baseBranch: string): BranchDelta { + this.throwIfClosed(); return this.inner.getBranchDelta(branch, baseBranch); } getReferencedChunkIds(chunkIds: string[]): string[] { + this.throwIfClosed(); if (chunkIds.length === 0) return []; return this.inner.getReferencedChunkIds(chunkIds); } chunkExistsOnBranch(branch: string, chunkId: string): boolean { + this.throwIfClosed(); return this.inner.chunkExistsOnBranch(branch, chunkId); } getAllBranches(): string[] { + this.throwIfClosed(); return this.inner.getAllBranches(); } getMetadata(key: string): string | null { + this.throwIfClosed(); return this.inner.getMetadata(key) ?? null; } setMetadata(key: string, value: string): void { + this.throwIfClosed(); this.inner.setMetadata(key, value); } deleteMetadata(key: string): boolean { + this.throwIfClosed(); return this.inner.deleteMetadata(key); } clearAllIndexedData(): void { + this.throwIfClosed(); this.inner.clearAllIndexedData(); } clearCallEdgeTargetsForSymbols(symbolIds: string[]): number { + this.throwIfClosed(); if (symbolIds.length === 0) return 0; return this.inner.clearCallEdgeTargetsForSymbols(symbolIds); } gcOrphanEmbeddings(): number { + this.throwIfClosed(); return this.inner.gcOrphanEmbeddings(); } gcOrphanChunks(): number { + this.throwIfClosed(); return this.inner.gcOrphanChunks(); } getStats(): DatabaseStats { + this.throwIfClosed(); return this.inner.getStats(); } // ── Symbol methods ────────────────────────────────────────────── upsertSymbol(symbol: SymbolData): void { + this.throwIfClosed(); this.inner.upsertSymbol(symbol); } upsertSymbolsBatch(symbols: SymbolData[]): void { + this.throwIfClosed(); if (symbols.length === 0) return; this.inner.upsertSymbolsBatch(symbols); } getSymbolsByFile(filePath: string): SymbolData[] { + this.throwIfClosed(); return this.inner.getSymbolsByFile(filePath); } getSymbolByName(name: string, filePath: string): SymbolData | null { + this.throwIfClosed(); return this.inner.getSymbolByName(name, filePath) ?? null; } getSymbolsByName(name: string): SymbolData[] { + this.throwIfClosed(); return this.inner.getSymbolsByName(name); } getSymbolsByNameCi(name: string): SymbolData[] { + this.throwIfClosed(); return this.inner.getSymbolsByNameCi(name); } deleteSymbolsByFile(filePath: string): number { + this.throwIfClosed(); return this.inner.deleteSymbolsByFile(filePath); } // ── Call Edge methods ──────────────────────────────────────────── upsertCallEdge(edge: CallEdgeData): void { + this.throwIfClosed(); this.inner.upsertCallEdge(edge); } upsertCallEdgesBatch(edges: CallEdgeData[]): void { + this.throwIfClosed(); if (edges.length === 0) return; this.inner.upsertCallEdgesBatch(edges); } getCallers(targetName: string, branch: string): CallEdgeData[] { + this.throwIfClosed(); return this.inner.getCallers(targetName, branch); } getCallersWithContext(targetName: string, branch: string): CallEdgeData[] { + this.throwIfClosed(); return this.inner.getCallersWithContext(targetName, branch); } getCallees(symbolId: string, branch: string): CallEdgeData[] { + this.throwIfClosed(); return this.inner.getCallees(symbolId, branch); } deleteCallEdgesByFile(filePath: string): number { + this.throwIfClosed(); return this.inner.deleteCallEdgesByFile(filePath); } resolveCallEdge(edgeId: string, toSymbolId: string): void { + this.throwIfClosed(); this.inner.resolveCallEdge(edgeId, toSymbolId); } // ── Branch Symbol methods ──────────────────────────────────────── addSymbolsToBranch(branch: string, symbolIds: string[]): void { + this.throwIfClosed(); this.inner.addSymbolsToBranch(branch, symbolIds); } addSymbolsToBranchBatch(branch: string, symbolIds: string[]): void { + this.throwIfClosed(); if (symbolIds.length === 0) return; this.inner.addSymbolsToBranchBatch(branch, symbolIds); } getBranchSymbolIds(branch: string): string[] { + this.throwIfClosed(); return this.inner.getBranchSymbolIds(branch); } clearBranchSymbols(branch: string): number { + this.throwIfClosed(); return this.inner.clearBranchSymbols(branch); } getReferencedSymbolIds(symbolIds: string[]): string[] { + this.throwIfClosed(); if (symbolIds.length === 0) return []; return this.inner.getReferencedSymbolIds(symbolIds); } deleteBranchSymbolsBySymbolIds(symbolIds: string[]): number { + this.throwIfClosed(); if (symbolIds.length === 0) return 0; return this.inner.deleteBranchSymbolsBySymbolIds(symbolIds); } deleteBranchSymbolsForBranch(branch: string, symbolIds: string[]): number { + this.throwIfClosed(); if (symbolIds.length === 0) return 0; return this.inner.deleteBranchSymbolsForBranch(branch, symbolIds); } @@ -937,10 +1002,12 @@ export class Database { // ── GC methods for symbols/edges ───────────────────────────────── gcOrphanSymbols(): number { + this.throwIfClosed(); return this.inner.gcOrphanSymbols(); } gcOrphanCallEdges(): number { + this.throwIfClosed(); return this.inner.gcOrphanCallEdges(); } } diff --git a/tests/database.test.ts b/tests/database.test.ts index 00edf71..edab4c2 100644 --- a/tests/database.test.ts +++ b/tests/database.test.ts @@ -18,6 +18,43 @@ describe("Database", () => { fs.rmSync(tempDir, { recursive: true, force: true }); }); + describe("close semantics", () => { + it("should allow repeated close calls", () => { + expect(() => db.close()).not.toThrow(); + expect(() => db.close()).not.toThrow(); + }); + + it("should fail fast after close", () => { + db.close(); + + expect(() => db.getStats()).toThrow("Database is closed"); + expect(() => db.embeddingExists("hash123")).toThrow("Database is closed"); + expect(() => db.setMetadata("key", "value")).toThrow("Database is closed"); + }); + + it("should fail fast for wrapper no-op batch helpers after close", () => { + db.close(); + + expect(() => db.upsertEmbeddingsBatch([])).toThrow("Database is closed"); + expect(() => db.upsertChunksBatch([])).toThrow("Database is closed"); + expect(() => db.addChunksToBranchBatch("main", [])).toThrow("Database is closed"); + expect(() => db.getReferencedChunkIds([])).toThrow("Database is closed"); + expect(() => db.clearCallEdgeTargetsForSymbols([])).toThrow("Database is closed"); + }); + + it("should release the database file when closed", () => { + const dbPath = path.join(tempDir, "test.db"); + + db.close(); + + if (process.platform !== "win32") { + return; + } + + expect(() => fs.rmSync(dbPath, { force: true })).not.toThrow(); + }); + }); + describe("embeddings", () => { it("should check if embedding exists", () => { expect(db.embeddingExists("hash123")).toBe(false);