Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions src/ranking/penalties.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
// Tests for src/ranking/penalties.ts — parity checked against the Python source.
import { describe, expect, it } from 'bun:test'

import {
_filePathPenalty,
FILE_SATURATION_DECAY,
MILD_PENALTY,
MODERATE_PENALTY,
rerankTopK,
STRONG_PENALTY,
} from './penalties.ts'

type Chunk = {
content: string
filePath: string
startLine: number
endLine: number
language?: string
}

function makeChunk(filePath: string, idx = 0): Chunk {
return {
content: `chunk ${idx}`,
filePath,
startLine: idx,
endLine: idx + 1,
}
}

describe('_filePathPenalty', () => {
it('penalises JS/TS test files with STRONG_PENALTY', () => {
expect(_filePathPenalty('src/foo.test.ts')).toBe(STRONG_PENALTY)
})

it('penalises .spec.tsx files with STRONG_PENALTY', () => {
expect(_filePathPenalty('src/foo.spec.tsx')).toBe(STRONG_PENALTY)
})

it('penalises __init__.py with MODERATE_PENALTY (re-export barrel)', () => {
expect(_filePathPenalty('src/__init__.py')).toBe(MODERATE_PENALTY)
})

it('penalises .d.ts type stubs with MILD_PENALTY', () => {
expect(_filePathPenalty('src/foo.d.ts')).toBe(MILD_PENALTY)
})

it('penalises files under tests/ — TEST_DIR + TEST_FILE share one STRONG branch', () => {
// Python parity: only one STRONG_PENALTY multiplication regardless of how
// many of {TEST_FILE_RE, TEST_DIR_RE} match (they are OR'd in one branch).
expect(_filePathPenalty('tests/test_foo.py')).toBeCloseTo(STRONG_PENALTY, 10)
})

it('returns 1.0 for ordinary source files', () => {
expect(_filePathPenalty('src/foo.ts')).toBe(1.0)
})

it('compounds STRONG (examples/) and STRONG (.test.ts) penalties', () => {
// Python: examples/foo.test.ts -> 0.09
expect(_filePathPenalty('examples/foo.test.ts')).toBeCloseTo(STRONG_PENALTY * STRONG_PENALTY, 10)
})

it('compounds MILD (.d.ts) and MODERATE (__init__) penalties', () => {
// Python: src/__init__.d.ts -> 0.7 (only .d.ts matches; basename is __init__.d.ts)
expect(_filePathPenalty('src/__init__.d.ts')).toBe(MILD_PENALTY)
})

it('penalises compat dirs with STRONG_PENALTY', () => {
expect(_filePathPenalty('compat/foo.ts')).toBe(STRONG_PENALTY)
})

it('penalises examples dirs with STRONG_PENALTY', () => {
expect(_filePathPenalty('examples/foo.ts')).toBe(STRONG_PENALTY)
})

it('normalises backslashes to forward slashes before matching', () => {
expect(_filePathPenalty('src\\foo.test.ts')).toBe(STRONG_PENALTY)
})

it('handles bare __init__.py basename without path', () => {
expect(_filePathPenalty('__init__.py')).toBe(MODERATE_PENALTY)
})

it('penalises Go _test.go files', () => {
expect(_filePathPenalty('pkg/foo_test.go')).toBe(STRONG_PENALTY)
})

it('penalises Java FooTests.java files', () => {
expect(_filePathPenalty('src/FooTests.java')).toBe(STRONG_PENALTY)
})

it('penalises legacy dirs with STRONG_PENALTY', () => {
expect(_filePathPenalty('legacy/foo.ts')).toBe(STRONG_PENALTY)
})
})

describe('rerankTopK', () => {
it('returns an empty list for empty input', () => {
expect(rerankTopK(new Map(), 5)).toEqual([])
})

it('returns an empty list for non-positive topK', () => {
const a = makeChunk('a.ts', 0)
const scores = new Map<Chunk, number>([[a, 1.0]])
expect(rerankTopK(scores, 0)).toEqual([])
expect(rerankTopK(scores, -1)).toEqual([])
expect(rerankTopK(scores, -5)).toEqual([])
})

it('applies saturation decay to chunks from the same file', () => {
// 4 chunks from the same file, all initial score 1.0, no path penalty.
const a = makeChunk('src/foo.ts', 0)
const b = makeChunk('src/foo.ts', 1)
const c = makeChunk('src/foo.ts', 2)
const d = makeChunk('src/foo.ts', 3)
const scores = new Map<Chunk, number>([
[a, 1.0],
[b, 1.0],
[c, 1.0],
[d, 1.0],
])
const result = rerankTopK(scores, 4, { penalisePaths: false })
expect(result).toHaveLength(4)
// Sorted descending after decay; ties preserved by sort stability of computation.
const finalScores = result.map(([, s]) => s)
// First chunk picked: 1.0 (no decay)
// Second: 1.0 * 0.5 = 0.5
// Third: 1.0 * 0.25 = 0.25
// Fourth: 1.0 * 0.125 = 0.125
expect(finalScores[0]).toBeCloseTo(1.0, 10)
expect(finalScores[1]).toBeCloseTo(FILE_SATURATION_DECAY, 10)
expect(finalScores[2]).toBeCloseTo(FILE_SATURATION_DECAY ** 2, 10)
expect(finalScores[3]).toBeCloseTo(FILE_SATURATION_DECAY ** 3, 10)
})

it('truncates to topK after sorting', () => {
const a = makeChunk('a.ts', 0)
const b = makeChunk('b.ts', 1)
const c = makeChunk('c.ts', 2)
const scores = new Map<Chunk, number>([
[a, 0.5],
[b, 0.9],
[c, 0.1],
])
const result = rerankTopK(scores, 2, { penalisePaths: false })
expect(result).toHaveLength(2)
expect(result[0]![0]).toBe(b)
expect(result[1]![0]).toBe(a)
})

it('applies path penalties before sorting when enabled', () => {
// a is a test file (penalty 0.3), b is normal. a wins pre-penalty, b wins post.
const a = makeChunk('src/foo.test.ts', 0)
const b = makeChunk('src/foo.ts', 1)
const scores = new Map<Chunk, number>([
[a, 0.9],
[b, 0.5],
])
const result = rerankTopK(scores, 2)
expect(result[0]![0]).toBe(b)
expect(result[1]![0]).toBe(a)
expect(result[0]![1]).toBeCloseTo(0.5, 10)
expect(result[1]![1]).toBeCloseTo(0.9 * STRONG_PENALTY, 10)
})

it('does not apply path penalties when penalisePaths is false', () => {
const a = makeChunk('src/foo.test.ts', 0)
const b = makeChunk('src/foo.ts', 1)
const scores = new Map<Chunk, number>([
[a, 0.9],
[b, 0.5],
])
const result = rerankTopK(scores, 2, { penalisePaths: false })
expect(result[0]![0]).toBe(a)
expect(result[0]![1]).toBeCloseTo(0.9, 10)
expect(result[1]![0]).toBe(b)
expect(result[1]![1]).toBeCloseTo(0.5, 10)
})

it('mixes saturation decay across multiple files', () => {
// Two files, two chunks each. All score 1.0. topK = 4.
const a1 = makeChunk('a.ts', 0)
const a2 = makeChunk('a.ts', 1)
const b1 = makeChunk('b.ts', 2)
const b2 = makeChunk('b.ts', 3)
const scores = new Map<Chunk, number>([
[a1, 1.0],
[a2, 1.0],
[b1, 1.0],
[b2, 1.0],
])
const result = rerankTopK(scores, 4, { penalisePaths: false })
expect(result).toHaveLength(4)
// First two picked at 1.0 (first of each file), next two at 0.5.
const top = result.map(([, s]) => s)
expect(top[0]).toBeCloseTo(1.0, 10)
expect(top[1]).toBeCloseTo(1.0, 10)
expect(top[2]).toBeCloseTo(FILE_SATURATION_DECAY, 10)
expect(top[3]).toBeCloseTo(FILE_SATURATION_DECAY, 10)
})
})
187 changes: 187 additions & 0 deletions src/ranking/penalties.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
// Port of src/semble/ranking/penalties.py
// Inlined Chunk type until src/types.ts lands (Unit 1).
type Chunk = {
content: string
filePath: string
startLine: number
endLine: number
language?: string
}

// Patterns that identify test files across common languages.
// Grouped by language for readability; combined into a single regex.
export const TEST_FILE_RE = new RegExp(
'(?:^|/)'
+ '(?:'
// Python
+ 'test_[^/]*\\.py' // test_foo.py
+ '|[^/]*_test\\.py' // foo_test.py
// Go
+ '|[^/]*_test\\.go' // foo_test.go
// Java
+ '|[^/]*Tests?\\.java' // FooTest.java / FooTests.java
// PHP
+ '|[^/]*Test\\.php' // FooTest.php
// Ruby
+ '|[^/]*_spec\\.rb' // foo_spec.rb
+ '|[^/]*_test\\.rb' // foo_test.rb
// JavaScript / TypeScript
+ '|[^/]*\\.test\\.[jt]sx?' // foo.test.js/ts/jsx/tsx
+ '|[^/]*\\.spec\\.[jt]sx?' // foo.spec.js/ts/jsx/tsx
// Kotlin
+ '|[^/]*Tests?\\.kt' // FooTest.kt / FooTests.kt
+ '|[^/]*Spec\\.kt' // FooSpec.kt (Kotest)
// Swift
+ '|[^/]*Tests?\\.swift' // FooTests.swift (XCTest)
+ '|[^/]*Spec\\.swift' // FooSpec.swift (Quick)
// C#
+ '|[^/]*Tests?\\.cs' // FooTest.cs / FooTests.cs
// C / C++
+ '|test_[^/]*\\.cpp' // test_foo.cpp (Google Test)
+ '|[^/]*_test\\.cpp' // foo_test.cpp (Google Test)
+ '|test_[^/]*\\.c' // test_foo.c
+ '|[^/]*_test\\.c' // foo_test.c
// Scala
+ '|[^/]*Spec\\.scala' // FooSpec.scala (ScalaTest)
+ '|[^/]*Suite\\.scala' // FooSuite.scala (MUnit)
+ '|[^/]*Test\\.scala' // FooTest.scala
// Dart
+ '|[^/]*_test\\.dart' // foo_test.dart
+ '|test_[^/]*\\.dart' // test_foo.dart
// Lua
+ '|[^/]*_spec\\.lua' // foo_spec.lua (busted)
+ '|[^/]*_test\\.lua' // foo_test.lua
+ '|test_[^/]*\\.lua' // test_foo.lua (luaunit)
// Shared helper patterns (all languages)
+ '|test_helpers?[^/]*\\.\\w+' // test_helpers.go, test_helper.rb, etc.
+ ')$',
)

// Test/spec directories.
export const TEST_DIR_RE = /(?:^|\/)(?:tests?|__tests__|spec|testing)(?:\/|$)/

// Compat/legacy path components.
export const COMPAT_DIR_RE = /(?:^|\/)(?:compat|_compat|legacy)(?:\/|$)/

// Examples/docs path components.
export const EXAMPLES_DIR_RE = /(?:^|\/)(?:_?examples?|docs?_src)(?:\/|$)/

// TypeScript declaration files (.d.ts stubs).
export const TYPE_DEFS_RE = /\.d\.ts$/

export const STRONG_PENALTY = 0.3 // test files, compat shims, example/doc code
export const MODERATE_PENALTY = 0.5 // re-export / metadata files
export const MILD_PENALTY = 0.7 // .d.ts declaration stubs (still carry useful type info)

// Filenames that are re-export barrels or package-level metadata.
export const REEXPORT_FILENAMES = new Set(['__init__.py', 'package-info.java'])

// Maximum chunks from the same file before a saturation penalty is applied.
export const FILE_SATURATION_THRESHOLD = 1

// Multiplicative penalty per extra chunk from the same file beyond the threshold.
export const FILE_SATURATION_DECAY = 0.5

/**
* Select top-k results with optional file-path penalties and file-saturation decay.
*
* When `penalisePaths` is true, path penalties are applied before sorting.
* Saturation decay is applied greedily during the greedy pass; because decay
* only reduces scores and candidates are pre-sorted descending, early exit is
* safe once the remaining scores cannot beat the current k-th best.
*/
export function rerankTopK(
scores: Map<Chunk, number>,
topK: number,
options: { penalisePaths?: boolean } = {},
): Array<[Chunk, number]> {
const penalisePaths = options.penalisePaths ?? true

if (scores.size === 0 || topK <= 0) {
return []
}

// Apply file-path penalties.
const penaltyCache = new Map<string, number>()
const penalised = new Map<Chunk, number>()
for (const [chunk, score] of scores) {
if (penalisePaths) {
let cached = penaltyCache.get(chunk.filePath)
if (cached === undefined) {
cached = _filePathPenalty(chunk.filePath)
penaltyCache.set(chunk.filePath, cached)
}
penalised.set(chunk, score * cached)
}
else {
penalised.set(chunk, score)
}
}

// Sort by penalised score (highest first) — single sort.
const ranked = [...penalised.keys()].sort((a, b) => {
const sa = penalised.get(a) as number
const sb = penalised.get(b) as number
return sb - sa
})

const fileSelected = new Map<string, number>()
const selected: Array<[number, Chunk]> = []
let minSelected = Number.POSITIVE_INFINITY

for (const chunk of ranked) {
const penScore = penalised.get(chunk) as number

if (selected.length >= topK && penScore <= minSelected) {
break
}

const alreadySelected = fileSelected.get(chunk.filePath) ?? 0
let effScore = penScore
if (alreadySelected >= FILE_SATURATION_THRESHOLD) {
const excess = alreadySelected - FILE_SATURATION_THRESHOLD + 1
effScore *= FILE_SATURATION_DECAY ** excess
}

selected.push([effScore, chunk])
fileSelected.set(chunk.filePath, alreadySelected + 1)

if (selected.length >= topK) {
let m = Number.POSITIVE_INFINITY
for (const [s] of selected) {
if (s < m) m = s
}
minSelected = m
}
}

selected.sort((a, b) => b[0] - a[0])
return selected.slice(0, topK).map(([score, chunk]) => [chunk, score])
}
Comment thread
amondnet marked this conversation as resolved.

/**
* Return a combined multiplicative penalty for all applicable path patterns.
*/
export function _filePathPenalty(filePath: string): number {
const normalised = filePath.replace(/\\/g, '/')
let penalty = 1.0
if (TEST_FILE_RE.test(normalised) || TEST_DIR_RE.test(normalised)) {
penalty *= STRONG_PENALTY
}
// Match Python's Path(file_path).name (POSIX semantics): only forward-slash
// is a separator. Backslashes in the raw path are part of the filename.
const basename = filePath.slice(filePath.lastIndexOf('/') + 1)
if (REEXPORT_FILENAMES.has(basename)) {
penalty *= MODERATE_PENALTY
}
if (COMPAT_DIR_RE.test(normalised)) {
penalty *= STRONG_PENALTY
}
if (EXAMPLES_DIR_RE.test(normalised)) {
penalty *= STRONG_PENALTY
}
if (TYPE_DEFS_RE.test(normalised)) {
penalty *= MILD_PENALTY
}
return penalty
}