diff --git a/src/embeddings/provider.ts b/src/embeddings/provider.ts index bb0480c..2d2c7af 100644 --- a/src/embeddings/provider.ts +++ b/src/embeddings/provider.ts @@ -270,6 +270,8 @@ class GoogleEmbeddingProvider implements EmbeddingProviderInterface { } class OllamaEmbeddingProvider implements EmbeddingProviderInterface { + private static readonly MIN_TRUNCATION_CHARS = 512; + constructor( private credentials: ProviderCredentials, private modelInfo: EmbeddingProviderModelInfo['ollama'] @@ -295,8 +297,7 @@ class OllamaEmbeddingProvider implements EmbeddingProviderInterface { return Math.ceil(text.length / 4); } - private truncateToTokenLimit(text: string, maxTokens: number): string { - const maxChars = Math.max(1, maxTokens * 4); + private truncateToCharLimit(text: string, maxChars: number): string { if (text.length <= maxChars) { return text; } @@ -304,6 +305,82 @@ class OllamaEmbeddingProvider implements EmbeddingProviderInterface { return `${text.slice(0, Math.max(0, maxChars - 17))}\n... [truncated]`; } + private isContextLengthError(error: unknown): boolean { + const message = (error instanceof Error ? error.message : String(error)).toLowerCase(); + return (message.includes("context length") && (message.includes("exceed") || message.includes("exceeded") || message.includes("too long"))) + || message.includes("input length exceeds the context length") + || message.includes("context length exceeded"); + } + + private buildTruncationCandidates(text: string): string[] { + const baseMaxChars = Math.max(1, this.modelInfo.maxTokens * 4); + const candidateLimits = new Set(); + const baselineLimit = text.length > baseMaxChars + ? baseMaxChars + : Math.max( + OllamaEmbeddingProvider.MIN_TRUNCATION_CHARS, + Math.floor(text.length * 0.9) + ); + + if (baselineLimit < text.length) { + candidateLimits.add(baselineLimit); + } + + for (const factor of [0.75, 0.6, 0.45, 0.35, 0.25]) { + const scaledLimit = Math.max( + OllamaEmbeddingProvider.MIN_TRUNCATION_CHARS, + Math.floor(baselineLimit * factor) + ); + if (scaledLimit < text.length) { + candidateLimits.add(scaledLimit); + } + } + + candidateLimits.add(Math.min(text.length - 1, OllamaEmbeddingProvider.MIN_TRUNCATION_CHARS)); + + const candidates: string[] = []; + const seen = new Set(); + for (const limit of [...candidateLimits].sort((a, b) => b - a)) { + if (limit <= 0 || limit >= text.length) { + continue; + } + + const truncated = this.truncateToCharLimit(text, limit); + if (truncated === text || seen.has(truncated)) { + continue; + } + + seen.add(truncated); + candidates.push(truncated); + } + + return candidates; + } + + private async embedSingleWithFallback(text: string): Promise<{ embedding: number[]; tokensUsed: number }> { + try { + return await this.embedSingle(text); + } catch (error) { + if (!this.isContextLengthError(error)) { + throw error; + } + + let lastError: unknown = error; + for (const truncated of this.buildTruncationCandidates(text)) { + try { + return await this.embedSingle(truncated); + } catch (retryError) { + if (!this.isContextLengthError(retryError)) { + throw retryError; + } + lastError = retryError; + } + } + + throw lastError; + } + } + private async embedSingle(text: string): Promise<{ embedding: number[]; tokensUsed: number }> { const response = await fetch(`${this.credentials.baseUrl}/api/embeddings`, { method: "POST", @@ -336,23 +413,7 @@ class OllamaEmbeddingProvider implements EmbeddingProviderInterface { const results: Array<{ embedding: number[]; tokensUsed: number }> = []; for (const text of texts) { - try { - results.push(await this.embedSingle(text)); - } catch (error) { - const message = error instanceof Error ? error.message : String(error); - const shouldRetryWithTruncation = message.includes("input length exceeds the context length"); - - if (!shouldRetryWithTruncation) { - throw error; - } - - const truncated = this.truncateToTokenLimit(text, this.modelInfo.maxTokens); - if (truncated === text) { - throw error; - } - - results.push(await this.embedSingle(truncated)); - } + results.push(await this.embedSingleWithFallback(text)); } return { diff --git a/tests/custom-provider.test.ts b/tests/custom-provider.test.ts index bd04566..441421d 100644 --- a/tests/custom-provider.test.ts +++ b/tests/custom-provider.test.ts @@ -476,6 +476,73 @@ describe("OllamaEmbeddingProvider", () => { expect(result.embeddings).toHaveLength(1); }); + it("backs off on context errors even when the prompt is below the estimated char limit", async () => { + const prompts: string[] = []; + fetchSpy.mockImplementation(async (_url, init) => { + const body = JSON.parse(String(init?.body ?? "{}")) as { prompt?: string; truncate?: boolean }; + prompts.push(body.prompt ?? ""); + + if (prompts.length === 1) { + return new Response(JSON.stringify({ error: "the input length exceeds the context length" }), { status: 500 }); + } + + return new Response(JSON.stringify({ embedding: new Array(768).fill(0.1) }), { status: 200 }); + }); + + const provider = createOllamaProvider(); + const nearLimit = "x".repeat(7000); + const result = await provider.embedBatch([nearLimit]); + + expect(prompts).toHaveLength(2); + expect(prompts[1].length).toBeLessThan(prompts[0].length); + expect(result.embeddings).toHaveLength(1); + }); + + it("keeps shrinking ollama prompts until a context-length retry succeeds", async () => { + const prompts: string[] = []; + fetchSpy.mockImplementation(async (_url, init) => { + const body = JSON.parse(String(init?.body ?? "{}")) as { prompt?: string; truncate?: boolean }; + prompts.push(body.prompt ?? ""); + + if (prompts.length < 3) { + return new Response(JSON.stringify({ error: "the input length exceeds the context length" }), { status: 500 }); + } + + return new Response(JSON.stringify({ embedding: new Array(768).fill(0.1) }), { status: 200 }); + }); + + const provider = createOllamaProvider(); + const oversized = "x".repeat(9000); + const result = await provider.embedBatch([oversized]); + + expect(prompts).toHaveLength(3); + expect(prompts[1].length).toBeLessThan(prompts[0].length); + expect(prompts[2].length).toBeLessThan(prompts[1].length); + expect(result.embeddings).toHaveLength(1); + }); + + it("matches alternate Ollama context-length error wording", async () => { + const prompts: string[] = []; + fetchSpy.mockImplementation(async (_url, init) => { + const body = JSON.parse(String(init?.body ?? "{}")) as { prompt?: string; truncate?: boolean }; + prompts.push(body.prompt ?? ""); + + if (prompts.length === 1) { + return new Response(JSON.stringify({ error: "Context length exceeded for this embedding request" }), { status: 500 }); + } + + return new Response(JSON.stringify({ embedding: new Array(768).fill(0.1) }), { status: 200 }); + }); + + const provider = createOllamaProvider(); + const oversized = "x".repeat(9000); + const result = await provider.embedBatch([oversized]); + + expect(prompts).toHaveLength(2); + expect(prompts[1].length).toBeLessThan(prompts[0].length); + expect(result.embeddings).toHaveLength(1); + }); + it("processes ollama embedBatch requests sequentially", async () => { let activeRequests = 0; let maxActiveRequests = 0; diff --git a/tests/indexer-failed-batches.test.ts b/tests/indexer-failed-batches.test.ts index 4ec010a..095eb1a 100644 --- a/tests/indexer-failed-batches.test.ts +++ b/tests/indexer-failed-batches.test.ts @@ -304,6 +304,53 @@ describe("indexer failed batch recovery", () => { expect(status.failedBatchesCount).toBe(0); }); + it("recovers split ollama chunks when provider retries need extra truncation below the estimated limit", async () => { + const embedPrompts: string[] = []; + fetchSpy.mockImplementation(async (url: string | URL | Request, init?: RequestInit) => { + if (String(url).endsWith("/api/tags")) { + return new Response(JSON.stringify({ + models: [{ name: "nomic-embed-text" }], + }), { status: 200 }); + } + + const body = JSON.parse(String(init?.body ?? "{}")) as { prompt?: string }; + const prompt = body.prompt ?? ""; + embedPrompts.push(prompt); + + if (prompt.length > 6000) { + return new Response(JSON.stringify({ error: "the input length exceeds the context length" }), { status: 500 }); + } + + const seed = prompt.length % 19; + return new Response(JSON.stringify({ + embedding: Array.from({ length: 768 }, (_, idx) => seed + idx / 1000), + }), { status: 200 }); + }); + + fs.writeFileSync( + sourceFile, + [ + "export function denseOversizedChunk() {", + ` const blob = ${JSON.stringify("dense<>symbols{}[]() ".repeat(1400))};`, + " return blob.length;", + "}", + ].join("\n"), + "utf-8" + ); + + const indexer = createOllamaIndexer(); + const stats = await indexer.index(); + + expect(stats.failedChunks).toBe(0); + expect(stats.indexedChunks).toBeGreaterThan(0); + expect(embedPrompts.length).toBeGreaterThan(1); + expect(embedPrompts.some((prompt) => prompt.includes("Part 1/"))).toBe(true); + expect(embedPrompts.some((prompt, idx) => idx > 0 && prompt.length < embedPrompts[0]!.length)).toBe(true); + + const status = await indexer.getStatus(); + expect(status.failedBatchesCount).toBe(0); + }); + it("rebuilds legacy failed-batch prompts with the current split strategy", async () => { const indexer = createOllamaIndexer(); await indexer.initialize();