Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -563,19 +563,33 @@ export type {
} from './error-count-extractor'

export {
runReferenceReplay,
decideReferenceReplayRunPromotion,
inMemoryReferenceReplayStore,
jsonlReferenceReplayStore,
scoreReferenceReplay,
compareReferenceReplay,
decideReferenceReplayPromotion,
defaultReferenceReplayMatcher,
} from './reference-replay'
export type {
ReferenceReplayAggregate,
ReferenceReplayAdapter,
ReferenceReplayAdapterFn,
ReferenceReplayAdapterLike,
ReferenceReplayCase,
ReferenceReplayCaseRun,
ReferenceReplayCandidate,
ReferenceReplayExecutionScenario,
ReferenceReplayItem,
ReferenceReplayMatch,
ReferenceReplayMatcher,
ReferenceReplayPromotionDecision,
ReferenceReplayPromotionPolicy,
ReferenceReplayRun,
ReferenceReplayRunContext,
ReferenceReplayRunOptions,
ReferenceReplayRunStore,
ReferenceReplayScenario,
ReferenceReplayScenarioScore,
ReferenceReplayScore,
Expand Down
266 changes: 257 additions & 9 deletions src/reference-replay.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
* across train/dev/test/holdout splits.
*/

import { appendFileSync, existsSync, mkdirSync, readFileSync } from 'node:fs'
import { dirname } from 'node:path'

export type ReferenceReplaySplit = 'train' | 'dev' | 'test' | 'holdout'

export interface ReferenceReplayItem {
Expand Down Expand Up @@ -35,6 +38,42 @@ export interface ReferenceReplayScenario {
metadata?: Record<string, unknown>
}

export interface ReferenceReplayCase<Input = unknown> {
id: string
split?: ReferenceReplaySplit
input: Input
references: ReferenceReplayItem[]
metadata?: Record<string, unknown>
}

export interface ReferenceReplayExecutionScenario<Input = unknown> {
id: string
split: ReferenceReplaySplit
input: Input
metadata?: Record<string, unknown>
}

export interface ReferenceReplayRunContext {
runId: string
caseIndex: number
abortSignal?: AbortSignal
}

export interface ReferenceReplayAdapter<Input = unknown> {
run(
scenario: ReferenceReplayExecutionScenario<Input>,
context: ReferenceReplayRunContext,
): Promise<ReferenceReplayCandidate[]>
}

export type ReferenceReplayAdapterFn<Input = unknown> = (
scenario: ReferenceReplayExecutionScenario<Input>,
context: ReferenceReplayRunContext,
) => Promise<ReferenceReplayCandidate[]>

export type ReferenceReplayAdapterLike<Input = unknown> =
ReferenceReplayAdapter<Input> | ReferenceReplayAdapterFn<Input>

export interface ReferenceReplayMatch {
scenarioId: string
referenceId: string
Expand Down Expand Up @@ -124,9 +163,175 @@ export interface ReferenceReplayPromotionDecision {
regressions: ReferenceReplaySplitComparison[]
}

export interface ReferenceReplayCaseRun<Input = unknown> {
caseId: string
split: ReferenceReplaySplit
input: Input
metadata?: Record<string, unknown>
references: ReferenceReplayItem[]
candidates: ReferenceReplayCandidate[]
score: ReferenceReplayScenarioScore
durationMs: number
error?: string
}

export interface ReferenceReplayRun<Input = unknown> {
id: string
variantId?: string
startedAt: number
completedAt: number
durationMs: number
cases: ReferenceReplayCaseRun<Input>[]
score: ReferenceReplayScore
metadata?: Record<string, unknown>
}

export interface ReferenceReplayRunOptions<Input = unknown> extends ReferenceReplayScoreOptions {
adapter: ReferenceReplayAdapterLike<Input>
runId?: string
variantId?: string
metadata?: Record<string, unknown>
store?: ReferenceReplayRunStore<Input>
abortSignal?: AbortSignal
continueOnError?: boolean
now?: () => number
}

export interface ReferenceReplayRunStore<Input = unknown> {
save(run: ReferenceReplayRun<Input>): Promise<void>
list(): Promise<ReferenceReplayRun<Input>[]>
}

const DEFAULT_MATCH_THRESHOLD = 0.55
const ALL_SPLITS: ReferenceReplaySplit[] = ['train', 'dev', 'test', 'holdout']

export async function runReferenceReplay<Input = unknown>(
cases: ReferenceReplayCase<Input>[],
options: ReferenceReplayRunOptions<Input>,
): Promise<ReferenceReplayRun<Input>> {
const now = options.now ?? Date.now
const startedAt = now()
const runId = options.runId ?? `${options.variantId ?? 'reference-replay'}-${startedAt}`
const allowedSplits = new Set(options.splits ?? ALL_SPLITS)
const caseRuns: ReferenceReplayCaseRun<Input>[] = []

for (const [caseIndex, replayCase] of cases.entries()) {
const split = replayCase.split ?? 'train'
if (split === 'holdout' && !options.includeHoldout) continue
if (!allowedSplits.has(split)) continue

const caseStart = now()
const executionScenario: ReferenceReplayExecutionScenario<Input> = {
id: replayCase.id,
split,
input: replayCase.input,
...(replayCase.metadata !== undefined ? { metadata: replayCase.metadata } : {}),
}

let candidates: ReferenceReplayCandidate[] = []
let error: string | undefined
try {
throwIfAborted(options.abortSignal)
candidates = await runAdapter(options.adapter, executionScenario, {
runId,
caseIndex,
abortSignal: options.abortSignal,
})
throwIfAborted(options.abortSignal)
} catch (cause) {
if (options.abortSignal?.aborted) throw cause
if (!options.continueOnError) throw cause
error = cause instanceof Error ? cause.message : String(cause)
}

const scenario: ReferenceReplayScenario = {
id: replayCase.id,
split,
references: replayCase.references,
candidates,
...(replayCase.metadata !== undefined ? { metadata: replayCase.metadata } : {}),
}
const scoreOptions: ReferenceReplayScoreOptions = {
matcher: options.matcher,
matchThreshold: options.matchThreshold,
includeHoldout: true,
}
const scenarioScore = scoreReferenceReplay([scenario], scoreOptions).scenarios[0]
caseRuns.push({
caseId: replayCase.id,
split,
input: replayCase.input,
references: replayCase.references,
candidates,
score: scenarioScore,
durationMs: Math.max(0, now() - caseStart),
...(replayCase.metadata !== undefined ? { metadata: replayCase.metadata } : {}),
...(error !== undefined ? { error } : {}),
})
}

const completedAt = now()
const scoreOptions: ReferenceReplayScoreOptions = {
matcher: options.matcher,
matchThreshold: options.matchThreshold,
includeHoldout: true,
}
const run: ReferenceReplayRun<Input> = {
id: runId,
startedAt,
completedAt,
durationMs: Math.max(0, completedAt - startedAt),
cases: caseRuns,
score: scoreReferenceReplay(caseRuns.map((caseRun) => ({
id: caseRun.caseId,
split: caseRun.split,
references: caseRun.references,
candidates: caseRun.candidates,
...(caseRun.metadata !== undefined ? { metadata: caseRun.metadata } : {}),
})), scoreOptions),
...(options.variantId !== undefined ? { variantId: options.variantId } : {}),
...(options.metadata !== undefined ? { metadata: options.metadata } : {}),
}

await options.store?.save(run)
return run
}

export function decideReferenceReplayRunPromotion(
baseline: ReferenceReplayRun,
candidate: ReferenceReplayRun,
policy: ReferenceReplayPromotionPolicy = {},
): ReferenceReplayPromotionDecision {
return decideReferenceReplayPromotion(baseline.score, candidate.score, policy)
}

export function inMemoryReferenceReplayStore<Input = unknown>(
initial: ReferenceReplayRun<Input>[] = [],
): ReferenceReplayRunStore<Input> {
const runs = [...initial]
return {
async save(run) {
runs.push(run)
},
async list() {
return [...runs]
},
}
}

export function jsonlReferenceReplayStore<Input = unknown>(path: string): ReferenceReplayRunStore<Input> {
return {
async save(run) {
mkdirSync(dirname(path), { recursive: true })
appendFileSync(path, JSON.stringify(run) + '\n')
},
async list() {
if (!existsSync(path)) return []
return readJsonl(path)
},
}
}

export function scoreReferenceReplay(
scenarios: ReferenceReplayScenario[],
options: ReferenceReplayScoreOptions = {},
Expand Down Expand Up @@ -182,10 +387,21 @@ export function decideReferenceReplayPromotion(
const maxRegression = policy.maxRegression ?? 0
const requireHoldout = policy.requireHoldoutNonRegression ?? true
const comparisons = compareReferenceReplay(baseline, candidate)
const missingRequiredSplits = requiredSplits.filter((split) => !hasSplit(baseline, split) || !hasSplit(candidate, split))
const compared = comparisons.filter((item) => requiredSplits.includes(item.split))
const regressions = comparisons.filter((item) => item.f1Delta < -maxRegression)
const aggregateDelta = candidate.aggregate.f1 - baseline.aggregate.f1

if (missingRequiredSplits.length > 0) {
return {
promote: false,
reason: `Required split missing from baseline or candidate: ${missingRequiredSplits.join(', ')}`,
aggregateDelta,
comparisons,
regressions,
}
}

if (compared.length === 0) {
return {
promote: false,
Expand All @@ -206,7 +422,7 @@ export function decideReferenceReplayPromotion(
}
}

if (requireHoldout && !comparisons.some((item) => item.split === 'holdout')) {
if (requireHoldout && (!hasSplit(baseline, 'holdout') || !hasSplit(candidate, 'holdout'))) {
return {
promote: false,
reason: 'Holdout split is required for promotion',
Expand Down Expand Up @@ -256,24 +472,25 @@ function scoreScenario(
matcher: ReferenceReplayMatcher,
threshold: number,
): ReferenceReplayScenarioScore {
const candidatesLeft = new Map(scenario.candidates.map((candidate) => [candidate.id, candidate]))
const candidatesLeft = scenario.candidates.map((candidate, index) => ({ candidate, index }))
const matches: ReferenceReplayMatch[] = []

for (const reference of scenario.references) {
let best: { candidate: ReferenceReplayCandidate; score: number; reason: string } | null = null
for (const candidate of candidatesLeft.values()) {
const result = matcher(reference, candidate, scenario)
let best: { candidate: ReferenceReplayCandidate; index: number; score: number; reason: string } | null = null
for (const item of candidatesLeft) {
const result = matcher(reference, item.candidate, scenario)
if (!Number.isFinite(result.score)) {
throw new Error(`reference replay matcher returned non-finite score for ${scenario.id}:${reference.id}:${candidate.id}`)
throw new Error(`reference replay matcher returned non-finite score for ${scenario.id}:${reference.id}:${item.candidate.id}`)
}
if (!best || result.score > best.score) {
best = { candidate, score: clamp01(result.score), reason: result.reason ?? '' }
best = { ...item, score: clamp01(result.score), reason: result.reason ?? '' }
}
}

const weight = reference.weight ?? 1
if (best && best.score >= threshold) {
candidatesLeft.delete(best.candidate.id)
const matchIndex = candidatesLeft.findIndex((item) => item.index === best.index)
if (matchIndex >= 0) candidatesLeft.splice(matchIndex, 1)
matches.push({
scenarioId: scenario.id,
referenceId: reference.id,
Expand All @@ -298,7 +515,7 @@ function scoreScenario(

const matched = matches.filter((match) => match.matched).length
const total = scenario.references.length
const falsePositives = candidatesLeft.size
const falsePositives = candidatesLeft.length
const matchedWeight = matches.filter((match) => match.matched).reduce((sum, match) => sum + match.weight, 0)
const totalWeight = matches.reduce((sum, match) => sum + match.weight, 0)
const precision = ratio(matched, matched + falsePositives)
Expand Down Expand Up @@ -364,6 +581,10 @@ function emptyAggregate(): ReferenceReplayAggregate {
}
}

function hasSplit(score: ReferenceReplayScore, split: ReferenceReplaySplit): boolean {
return score.bySplit[split] !== undefined
}

function f1(precision: number, recall: number): number {
return precision + recall === 0 ? 0 : 2 * precision * recall / (precision + recall)
}
Expand Down Expand Up @@ -425,6 +646,33 @@ function bySplitOrder(a: ReferenceReplaySplit, b: ReferenceReplaySplit): number
return ALL_SPLITS.indexOf(a) - ALL_SPLITS.indexOf(b)
}

function runAdapter<Input>(
adapter: ReferenceReplayAdapterLike<Input>,
scenario: ReferenceReplayExecutionScenario<Input>,
context: ReferenceReplayRunContext,
): Promise<ReferenceReplayCandidate[]> {
return typeof adapter === 'function'
? adapter(scenario, context)
: adapter.run(scenario, context)
}

function throwIfAborted(signal: AbortSignal | undefined): void {
if (!signal?.aborted) return
if (signal.reason instanceof Error) throw signal.reason
throw new Error(signal.reason ? String(signal.reason) : 'reference replay aborted')
}

function readJsonl<Input>(path: string): ReferenceReplayRun<Input>[] {
const raw = readFileSync(path, 'utf8')
const out: ReferenceReplayRun<Input>[] = []
for (const line of raw.split('\n')) {
const trimmed = line.trim()
if (!trimmed) continue
out.push(JSON.parse(trimmed) as ReferenceReplayRun<Input>)
}
return out
}

const STOP_WORDS = new Set([
'the',
'and',
Expand Down
Loading