-
Notifications
You must be signed in to change notification settings - Fork 0
feat(indexing): port BM25 enrichment + index from semble #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| import { mkdtemp, rm } from 'node:fs/promises' | ||
| import { tmpdir } from 'node:os' | ||
| import path from 'node:path' | ||
| import { afterEach, beforeEach, describe, expect, test } from 'bun:test' | ||
|
|
||
| import { Bm25Index, type Chunk, enrichForBm25, selectorToMask } from './sparse.ts' | ||
|
|
||
| function makeChunk(overrides: Partial<Chunk> & { filePath: string, content?: string }): Chunk { | ||
| return { | ||
| content: overrides.content ?? '', | ||
| filePath: overrides.filePath, | ||
| startLine: overrides.startLine ?? 1, | ||
| endLine: overrides.endLine ?? 1, | ||
| language: overrides.language ?? null, | ||
| } | ||
| } | ||
|
|
||
| describe('enrichForBm25', () => { | ||
| test('appends repeated stem and last 3 dir parts (2-part dir)', () => { | ||
| // Mirrors upstream Python: Path('src/utils/format.ts').parent.parts == ('src', 'utils'), | ||
| // so last-3 is the full ['src', 'utils']. | ||
| const out = enrichForBm25(makeChunk({ filePath: 'src/utils/format.ts', content: 'hello world' })) | ||
| expect(out).toBe('hello world format format src utils') | ||
| }) | ||
|
|
||
| test('trims to the last 3 dir parts (4-part dir)', () => { | ||
| const out = enrichForBm25(makeChunk({ filePath: 'a/b/c/d/foo.py', content: 'x' })) | ||
| expect(out).toBe('x foo foo b c d') | ||
| }) | ||
|
|
||
| test('handles a top-level file with no directory components', () => { | ||
| const out = enrichForBm25(makeChunk({ filePath: 'foo.py', content: 'x' })) | ||
| expect(out).toBe('x foo foo ') | ||
| }) | ||
|
|
||
| test('drops "." pseudo-segments from relative paths', () => { | ||
| const out = enrichForBm25(makeChunk({ filePath: './a/b/foo.ts', content: 'x' })) | ||
| expect(out).toBe('x foo foo a b') | ||
| }) | ||
|
|
||
| test('normalizes backslashes for cross-platform consistency', () => { | ||
| // Repo-relative paths must produce the same enrichment regardless of | ||
| // host OS — Windows hosts may surface back-slashes if a caller forgets | ||
| // to normalize before passing the chunk through. | ||
| const out = enrichForBm25(makeChunk({ filePath: 'src\\utils\\format.ts', content: 'hello world' })) | ||
| expect(out).toBe('hello world format format src utils') | ||
| }) | ||
| }) | ||
|
|
||
| describe('selectorToMask', () => { | ||
| test('builds a 0/1 mask the same length as `size`', () => { | ||
| const mask = selectorToMask(new Uint32Array([0, 2, 5]), 6) | ||
| expect(mask).not.toBeNull() | ||
| expect(Array.from(mask!)).toEqual([1, 0, 1, 0, 0, 1]) | ||
| }) | ||
|
|
||
| test('returns null for a null selector', () => { | ||
| expect(selectorToMask(null, 6)).toBeNull() | ||
| }) | ||
|
|
||
| test('returns null for an undefined selector', () => { | ||
| expect(selectorToMask(undefined, 6)).toBeNull() | ||
| }) | ||
|
|
||
| test('ignores indices outside the mask bounds', () => { | ||
| // Out-of-bounds indices are silently dropped rather than crashing — | ||
| // upstream relies on the selector being well-formed but we want to be | ||
| // defensive in the TS port. | ||
| const mask = selectorToMask(new Uint32Array([0, 10]), 3) | ||
| expect(Array.from(mask!)).toEqual([1, 0, 0]) | ||
| }) | ||
| }) | ||
|
|
||
| describe('Bm25Index.build / getScores', () => { | ||
| test('ranks documents containing the query term higher', () => { | ||
| const index = Bm25Index.build([ | ||
| ['hello', 'world'], | ||
| ['hello'], | ||
| ['world'], | ||
| ]) | ||
| const scores = index.getScores(['hello']) | ||
| expect(scores).toHaveLength(3) | ||
| expect(scores[0]).toBeGreaterThan(0) | ||
| expect(scores[1]).toBeGreaterThan(0) | ||
| expect(scores[2]).toBe(0) | ||
| }) | ||
|
|
||
| test('returns zero scores for unknown query tokens', () => { | ||
| const index = Bm25Index.build([['hello'], ['world']]) | ||
| const scores = index.getScores(['unknown']) | ||
| expect(Array.from(scores)).toEqual([0, 0]) | ||
| }) | ||
|
|
||
| test('returns an empty-array-equivalent for an empty corpus', () => { | ||
| const index = Bm25Index.build([]) | ||
| const scores = index.getScores(['anything']) | ||
| expect(scores).toHaveLength(0) | ||
| }) | ||
|
|
||
| test('returns zero scores when query tokens are empty', () => { | ||
| const index = Bm25Index.build([['hello'], ['world']]) | ||
| const scores = index.getScores([]) | ||
| expect(Array.from(scores)).toEqual([0, 0]) | ||
| }) | ||
|
|
||
| test('weightMask zeros out masked-out documents', () => { | ||
| const index = Bm25Index.build([ | ||
| ['hello', 'world'], | ||
| ['hello'], | ||
| ['world'], | ||
| ]) | ||
| // Mask in docs 0 and 2 only. | ||
| const mask = new Uint8Array([1, 0, 1]) | ||
| const scores = index.getScores(['hello'], mask) | ||
| expect(scores[0]).toBeGreaterThan(0) | ||
| expect(scores[1]).toBe(0) | ||
| expect(scores[2]).toBe(0) // doc 2 doesn't contain 'hello' | ||
| }) | ||
|
|
||
| test('weightMask only suppresses scores; matched-in docs are unchanged', () => { | ||
| const index = Bm25Index.build([ | ||
| ['hello', 'world'], | ||
| ['hello'], | ||
| ['world'], | ||
| ]) | ||
| const baseline = index.getScores(['hello']) | ||
| const masked = index.getScores(['hello'], new Uint8Array([1, 1, 1])) | ||
| expect(Array.from(masked)).toEqual(Array.from(baseline)) | ||
| }) | ||
|
|
||
| test('repeated query tokens do not compound scores', () => { | ||
| const index = Bm25Index.build([['hello']]) | ||
| const single = index.getScores(['hello']) | ||
| const repeated = index.getScores(['hello', 'hello', 'hello']) | ||
| expect(Array.from(repeated)).toEqual(Array.from(single)) | ||
| }) | ||
| }) | ||
|
|
||
| describe('Bm25Index.save / load', () => { | ||
| let tmp: string | ||
|
|
||
| beforeEach(async () => { | ||
| tmp = await mkdtemp(path.join(tmpdir(), 'csp-bm25-')) | ||
| }) | ||
|
|
||
| afterEach(async () => { | ||
| await rm(tmp, { recursive: true, force: true }) | ||
| }) | ||
|
|
||
| test('round-trips an index and preserves scores', async () => { | ||
| const index = Bm25Index.build([ | ||
| ['alpha', 'beta'], | ||
| ['alpha'], | ||
| ['beta', 'gamma'], | ||
| ]) | ||
| await index.save(tmp) | ||
| const loaded = await Bm25Index.load(tmp) | ||
| const original = index.getScores(['alpha']) | ||
| const restored = loaded.getScores(['alpha']) | ||
| expect(Array.from(restored)).toEqual(Array.from(original)) | ||
| }) | ||
| }) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,219 @@ | ||
| // Port of src/semble/index/sparse.py | ||
| // | ||
| // Implements the two helpers from the upstream module plus a minimal BM25 | ||
| // index (Bm25Index) that stands in for Python's `bm25s` library. | ||
| // | ||
| // BM25 backend choice (see PR body for full discussion): | ||
| // Option B (inline minimal BM25+ with k1=1.5, b=0.75) was chosen over a | ||
| // third-party npm such as wink-bm25-text-search because: | ||
| // - The dependency tree stays self-contained while the project is still | ||
| // a scaffold (no other indexing deps are pinned yet). | ||
| // - The required surface is tiny (build / getScores / save / load) and | ||
| // getScores must respect a weight_mask that maps cleanly to BM25's | ||
| // per-document scoring loop. | ||
| // - Replacing this backend later is a localized change because all | ||
| // callers go through the Bm25Index class. | ||
|
|
||
| import { mkdir, readFile, writeFile } from 'node:fs/promises' | ||
| import path from 'node:path' | ||
|
|
||
| // Stopgap structural type until ./types.ts lands from Unit 1. | ||
| // Mirrors semble.types.Chunk with camelCase field names per | ||
| // @pleaseai/csp public-API conventions. | ||
| export interface Chunk { | ||
| content: string | ||
| filePath: string | ||
| startLine: number | ||
| endLine: number | ||
| language?: string | null | ||
| } | ||
|
|
||
| /** | ||
| * Append file path components to BM25 content to boost path-based queries. | ||
| * | ||
| * Assumes `chunk.filePath` is already repo-relative (set during indexing) so | ||
| * machine-specific directory components are never indexed. The stem is | ||
| * repeated twice to up-weight file-path matches in BM25. | ||
| * | ||
| * Repo-relative paths are normalized to POSIX (forward slashes) before | ||
| * parsing so Windows-host indexes produce the same enriched text as POSIX | ||
| * hosts. Without this, `path.parse` on Windows would split on `\\` while | ||
| * the indexer stores forward-slash paths, leading to inconsistent BM25 | ||
| * tokenization across platforms. | ||
| */ | ||
| export function enrichForBm25(chunk: Chunk): string { | ||
| const normalized = chunk.filePath.replace(/\\/g, '/') | ||
| const parsed = path.posix.parse(normalized) | ||
| const stem = parsed.name | ||
| const dirParts = parsed.dir | ||
| .split('/') | ||
| .filter(part => part !== '' && part !== '.') | ||
| const dirText = dirParts.slice(-3).join(' ') | ||
| return `${chunk.content} ${stem} ${stem} ${dirText}` | ||
| } | ||
|
|
||
| /** | ||
| * Convert a selector array of indices into a boolean mask of length `size`. | ||
| * | ||
| * Returns `null` when `selector` is null/undefined so callers can skip mask | ||
| * application entirely (matching the upstream semantics). | ||
| */ | ||
| export function selectorToMask( | ||
| selector: Uint32Array | null | undefined, | ||
| size: number, | ||
| ): Uint8Array | null { | ||
| if (selector === null || selector === undefined) | ||
| return null | ||
| const mask = new Uint8Array(size) | ||
| for (const idx of selector) { | ||
| if (idx < size) | ||
| mask[idx] = 1 | ||
| } | ||
| return mask | ||
| } | ||
|
|
||
| // --------------------------------------------------------------------------- | ||
| // Minimal BM25 index | ||
| // --------------------------------------------------------------------------- | ||
|
|
||
| // Standard Okapi BM25 hyperparameters used by bm25s' default Lucene scorer. | ||
| const K1 = 1.5 | ||
| const B = 0.75 | ||
|
|
||
| interface Bm25State { | ||
| // Number of documents indexed. | ||
| numDocs: number | ||
| // Document length (token count) per document, in doc order. | ||
| docLengths: Float32Array | ||
| // Average document length across the corpus. | ||
| avgDocLength: number | ||
| // Term -> array of [docId, termFreq] entries (postings list). | ||
| postings: Map<string, Array<[number, number]>> | ||
| // Term -> document frequency (count of docs containing the term). | ||
| docFreq: Map<string, number> | ||
| } | ||
|
|
||
| /** | ||
| * Minimal BM25 index supporting build / getScores / save / load. | ||
| * | ||
| * Documents are passed pre-tokenized (callers use `tokenize(enrichForBm25(...))`). | ||
| * `getScores` returns a Float32Array of per-document scores in doc order, | ||
| * matching the bm25s.BM25.get_scores contract used by upstream. | ||
| */ | ||
| export class Bm25Index { | ||
| // Exposed only for save() — kept private to consumers. | ||
| readonly #state: Bm25State | ||
|
|
||
| private constructor(state: Bm25State) { | ||
| this.#state = state | ||
| } | ||
|
|
||
| /** Build an index from an array of pre-tokenized documents. */ | ||
| static build(documents: string[][]): Bm25Index { | ||
| const numDocs = documents.length | ||
| const docLengths = new Float32Array(numDocs) | ||
| const postings = new Map<string, Array<[number, number]>>() | ||
| const docFreq = new Map<string, number>() | ||
|
|
||
| let totalLen = 0 | ||
| for (let docId = 0; docId < numDocs; docId++) { | ||
| const tokens = documents[docId] ?? [] | ||
| docLengths[docId] = tokens.length | ||
| totalLen += tokens.length | ||
|
|
||
| // Term frequencies for this document. | ||
| const tf = new Map<string, number>() | ||
| for (const token of tokens) | ||
| tf.set(token, (tf.get(token) ?? 0) + 1) | ||
|
|
||
| for (const [term, freq] of tf) { | ||
| let list = postings.get(term) | ||
| if (list === undefined) { | ||
| list = [] | ||
| postings.set(term, list) | ||
| } | ||
| list.push([docId, freq]) | ||
| docFreq.set(term, (docFreq.get(term) ?? 0) + 1) | ||
| } | ||
| } | ||
|
|
||
| const avgDocLength = numDocs > 0 ? totalLen / numDocs : 0 | ||
|
|
||
| return new Bm25Index({ numDocs, docLengths, avgDocLength, postings, docFreq }) | ||
| } | ||
|
|
||
| /** | ||
| * Compute BM25 scores for the given query tokens. | ||
| * | ||
| * Returns a Float32Array of length numDocs, in document order. When | ||
| * `weightMask` is provided, documents with mask[i] === 0 receive a score | ||
| * of 0 (matching bm25s.BM25.get_scores(..., weight_mask=mask) semantics). | ||
| */ | ||
| getScores(queryTokens: string[], weightMask?: Uint8Array | null): Float32Array { | ||
| const { numDocs, docLengths, avgDocLength, postings, docFreq } = this.#state | ||
| const scores = new Float32Array(numDocs) | ||
| if (queryTokens.length === 0 || numDocs === 0) | ||
| return scores | ||
|
|
||
| // De-duplicate query tokens — repeated terms shouldn't compound BM25 scores. | ||
| const uniqueTerms = new Set(queryTokens) | ||
|
|
||
| for (const term of uniqueTerms) { | ||
| const list = postings.get(term) | ||
| if (list === undefined) | ||
| continue | ||
| const df = docFreq.get(term) ?? 0 | ||
| // Lucene/Robertson IDF: log(1 + (N - df + 0.5) / (df + 0.5)). | ||
| const idf = Math.log(1 + (numDocs - df + 0.5) / (df + 0.5)) | ||
|
|
||
| for (const [docId, freq] of list) { | ||
| // Skip masked-out documents inside the posting-list iteration so we | ||
| // avoid the work entirely; Float32Array entries default to 0 so the | ||
| // final scores match the post-loop zeroing approach. | ||
| if (weightMask && !weightMask[docId]) | ||
| continue | ||
| const dl = docLengths[docId] ?? 0 | ||
| const denom = freq + K1 * (1 - B + (B * dl) / (avgDocLength || 1)) | ||
| const contrib = (idf * (freq * (K1 + 1))) / (denom || 1) | ||
| scores[docId] = (scores[docId] ?? 0) + contrib | ||
| } | ||
| } | ||
|
|
||
| return scores | ||
| } | ||
|
amondnet marked this conversation as resolved.
|
||
|
|
||
| /** Persist the index to `dir`. Creates the directory if it doesn't exist. */ | ||
| async save(dir: string): Promise<void> { | ||
| await mkdir(dir, { recursive: true }) | ||
| const { numDocs, docLengths, avgDocLength, postings, docFreq } = this.#state | ||
| const serialized = { | ||
| version: 1, | ||
| numDocs, | ||
| avgDocLength, | ||
| docLengths: Array.from(docLengths), | ||
| postings: Array.from(postings.entries()), | ||
| docFreq: Array.from(docFreq.entries()), | ||
| } | ||
| await writeFile(path.join(dir, 'bm25.json'), JSON.stringify(serialized)) | ||
| } | ||
|
|
||
| /** Load an index previously persisted with `save`. */ | ||
| static async load(dir: string): Promise<Bm25Index> { | ||
| const raw = await readFile(path.join(dir, 'bm25.json'), 'utf8') | ||
| const parsed = JSON.parse(raw) as { | ||
| version: number | ||
| numDocs: number | ||
| avgDocLength: number | ||
| docLengths: number[] | ||
| postings: Array<[string, Array<[number, number]>]> | ||
| docFreq: Array<[string, number]> | ||
| } | ||
| return new Bm25Index({ | ||
| numDocs: parsed.numDocs, | ||
| docLengths: Float32Array.from(parsed.docLengths), | ||
| avgDocLength: parsed.avgDocLength, | ||
| postings: new Map(parsed.postings), | ||
| docFreq: new Map(parsed.docFreq), | ||
| }) | ||
| } | ||
| } | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.