Skip to content

Commit 6d64b0d

Browse files
committed
implemented hybrid retreival
1 parent 7892ceb commit 6d64b0d

File tree

5 files changed

+359
-66
lines changed

5 files changed

+359
-66
lines changed

README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,35 @@ nl2sql chat
3939
- `/help` - Show commands
4040
- `/schema` - Display schema
4141
- `/models` - List models
42+
- `/explain <query>` - Explain similarity scores (e.g., /explain show users)
4243
- `/history` - View history
4344
- `/clear` - Clear chat
4445
- `/model <name>` - Switch AI model
4546
- `/exit` - Quit
4647

48+
## Improving Model Discovery
49+
50+
The CLI uses semantic search to find relevant tables. Help it work better by:
51+
52+
1. **Add comments to your models:**
53+
54+
```javascript
55+
// User accounts and authentication data
56+
const User = sequelize.define('User', { ... });
57+
```
58+
59+
2. **Use descriptive column names:**
60+
-`email`, `phone_number`, `created_at`
61+
-`col1`, `data`, `field_x`
62+
63+
3. **Document relationships clearly:**
64+
65+
```javascript
66+
User.hasMany(Order, { as: 'purchases', foreignKey: 'customer_id' });
67+
```
68+
69+
No special configuration needed - the system learns from your schema structure!
70+
4771
### Example Queries
4872

4973
- Show all users

src/cli/chat.ts

Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ import chalk from 'chalk';
22
import { WatchableSchema } from '../models/analyzer';
33
import { VectorSearch } from '../rag/vector-search';
44
import { OpenRouterClient, ChatMessage } from '../llm/openrouter';
5+
import { HybridRetrieval } from '../rag/hybrid-retrieval';
56
import * as readline from 'readline';
67

78
const vectorSearch = new VectorSearch();
9+
let hybridRetrieval: HybridRetrieval;
810

911
export async function runChat(watchableSchema: WatchableSchema): Promise<void> {
1012
console.log(
@@ -17,6 +19,7 @@ export async function runChat(watchableSchema: WatchableSchema): Promise<void> {
1719
console.log();
1820

1921
await vectorSearch.initialize(watchableSchema.models);
22+
hybridRetrieval = new HybridRetrieval(vectorSearch, watchableSchema.models);
2023

2124
const apiKey = process.env.OPENROUTER_API_KEY;
2225
if (!apiKey) {
@@ -100,9 +103,17 @@ export async function runChat(watchableSchema: WatchableSchema): Promise<void> {
100103
} catch (error) {
101104
console.log(chalk.red('❌ Failed to reload schema\n'));
102105
}
103-
} else if (userInput.startsWith('/debug ')) {
104-
const query = userInput.substring(7).trim();
105-
await debugSimilarity(query, vectorSearch);
106+
} else if (userInput.startsWith('/explain ')) {
107+
const query = userInput.substring(9).trim();
108+
const explanation = await hybridRetrieval.explainSelection(query);
109+
110+
console.log(chalk.yellow('\n🔍 Selection Explanation:'));
111+
console.log(chalk.gray('─'.repeat(60)));
112+
explanation.slice(0, 10).forEach((item, i) => {
113+
console.log(chalk.cyan(`${i + 1}. ${item.model}`));
114+
console.log(chalk.gray(` ${item.reason}`));
115+
});
116+
console.log(chalk.gray('─'.repeat(60) + '\n'));
106117
} else if (userInput.startsWith('/model ')) {
107118
const modelName = userInput.substring(7).trim();
108119
try {
@@ -118,13 +129,7 @@ export async function runChat(watchableSchema: WatchableSchema): Promise<void> {
118129
console.log(chalk.red(`❌ Unknown command: ${userInput}`));
119130
console.log(chalk.gray('Type /help for available commands.\n'));
120131
} else {
121-
await handleQuery(
122-
userInput,
123-
llm,
124-
watchableSchema,
125-
chatHistory,
126-
vectorSearch
127-
);
132+
await handleQuery(userInput, llm, watchableSchema, chatHistory);
128133
}
129134

130135
rl.prompt();
@@ -147,34 +152,36 @@ async function handleQuery(
147152
query: string,
148153
llm: OpenRouterClient,
149154
watchableSchema: WatchableSchema,
150-
chatHistory: ChatMessage[],
151-
vectorSearch: VectorSearch
155+
chatHistory: ChatMessage[]
152156
): Promise<void> {
153157
try {
154158
process.stdout.write(chalk.blue('🤔 Thinking... '));
155159

156-
// Find relevant models using vector search
157-
const relevantModels = await vectorSearch.findRelevant(query, 5, 0.25);
160+
const relevantModels = await hybridRetrieval.findRelevant(query, {
161+
topK: 5,
162+
threshold: 0.25,
163+
includeRelated: true,
164+
});
158165

159166
readline.clearLine(process.stdout, 0);
160167
readline.cursorTo(process.stdout, 0);
161168

162-
// Show which models were selected
163169
if (relevantModels.length > 0) {
164170
console.log(
165171
chalk.magenta(
166172
`📊 Using ${relevantModels.length}/${watchableSchema.models.length} relevant table(s):`
167173
)
168174
);
169-
relevantModels.forEach((m) => console.log(chalk.gray(` • ${m.name}`)));
175+
relevantModels.forEach((m) =>
176+
console.log(chalk.gray(` • ${m.name} (${m.tableName})`))
177+
);
170178
console.log();
171179
} else {
172180
console.log(
173181
chalk.yellow('⚠️ No relevant tables found, using all models\n')
174182
);
175183
}
176184

177-
// Generate SQL with filtered models
178185
const modelsToUse =
179186
relevantModels.length > 0 ? relevantModels : watchableSchema.models;
180187
const response = await llm.generateSQL(query, modelsToUse, chatHistory);
@@ -203,33 +210,6 @@ async function handleQuery(
203210
}
204211
}
205212

206-
async function debugSimilarity(
207-
query: string,
208-
vectorSearch: VectorSearch
209-
): Promise<void> {
210-
try {
211-
const scores = await vectorSearch.getScores(query);
212-
console.log(
213-
chalk.yellow('\n🔍 Similarity scores for:'),
214-
chalk.white(query)
215-
);
216-
console.log(chalk.gray('─'.repeat(60)));
217-
scores.slice(0, 10).forEach((s, i) => {
218-
const percentage = (s.score * 100).toFixed(1);
219-
const bar = '█'.repeat(Math.floor(s.score * 20));
220-
console.log(
221-
`${i + 1}. ${chalk.cyan(s.name.padEnd(20))} ${bar} ${percentage}%`
222-
);
223-
});
224-
console.log(chalk.gray('─'.repeat(60) + '\n'));
225-
} catch (error) {
226-
console.log(
227-
chalk.red('❌ Error:'),
228-
error instanceof Error ? error.message : 'Failed to get scores'
229-
);
230-
}
231-
}
232-
233213
function displayHelp(): void {
234214
console.log(chalk.yellow('📚 Available commands:'));
235215
console.log(chalk.gray('─'.repeat(60)));
@@ -244,8 +224,8 @@ function displayHelp(): void {
244224
chalk.gray('- List all available models with details')
245225
);
246226
console.log(
247-
chalk.cyan(' /debug ') +
248-
chalk.gray('- Debug similarity scores (e.g., /debug show users)')
227+
chalk.cyan(' /explain ') +
228+
chalk.gray('- Explain similarity scores (e.g., /explain show users)')
249229
);
250230
console.log(chalk.cyan(' /history ') + chalk.gray('- Show chat history'));
251231
console.log(

src/models/analyzer.ts

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import * as chokidar from 'chokidar';
55
import { parse } from '@babel/parser';
66
import traverse, { NodePath } from '@babel/traverse';
77

8-
// At the top of analyzer.ts (around line 7)
98
interface ColumnInfo {
109
name: string;
1110
type?: string;
@@ -23,6 +22,7 @@ interface ColumnInfo {
2322
export interface ModelInfo {
2423
name: string;
2524
tableName: string;
25+
description?: string;
2626
columns: ColumnInfo[];
2727
associations: Array<{
2828
type: string;
@@ -324,33 +324,99 @@ function mergeCentralizedAssociations(
324324
}
325325
}
326326

327+
/**
328+
* Extract ALL forms of documentation (universal approach)
329+
*/
330+
function extractModelDescription(
331+
content: string,
332+
modelName: string
333+
): string | undefined {
334+
const descriptions: string[] = [];
335+
336+
// 1. JSDoc comments
337+
const jsDocPattern = /\/\*\*\s*([\s\S]*?)\s*\*\//g;
338+
let jsDocMatch;
339+
while ((jsDocMatch = jsDocPattern.exec(content)) !== null) {
340+
const comment = jsDocMatch[1]
341+
.split('\n')
342+
.map((line) => line.replace(/^\s*\*\s?/, '').trim())
343+
.filter((line) => line && !line.startsWith('@'))
344+
.join(' ');
345+
if (comment) descriptions.push(comment);
346+
}
347+
348+
// 2. Single-line comments above model
349+
const singleLinePattern = new RegExp(
350+
`(?://\\s*(.+?)\\s*\\n)+[\\s\\S]{0,50}(class|const|export)\\s+${modelName}`,
351+
'i'
352+
);
353+
const singleLineMatch = content.match(singleLinePattern);
354+
if (singleLineMatch) {
355+
descriptions.push(singleLineMatch[1].trim());
356+
}
357+
358+
// 3. Inline comments in define()
359+
const defineCommentPattern = /\.define\s*\([^)]*\)\s*\/\/\s*(.+?)$/m;
360+
const defineMatch = content.match(defineCommentPattern);
361+
if (defineMatch) {
362+
descriptions.push(defineMatch[1].trim());
363+
}
364+
365+
// 4. Comments in table definition
366+
const tableNamePattern =
367+
/tableName:\s*['"`]([^'"`]+)['"`]\s*,?\s*(?:\/\/\s*(.+?)$)?/m;
368+
const tableMatch = content.match(tableNamePattern);
369+
if (tableMatch && tableMatch[2]) {
370+
descriptions.push(tableMatch[2].trim());
371+
}
372+
373+
// 5. Block comments anywhere in file mentioning the table
374+
const blockCommentPattern = /\/\*[\s\S]*?\*\//g;
375+
let blockMatch;
376+
while ((blockMatch = blockCommentPattern.exec(content)) !== null) {
377+
const comment = blockMatch[0];
378+
if (
379+
comment.toLowerCase().includes(modelName.toLowerCase()) ||
380+
comment.toLowerCase().includes('table') ||
381+
comment.toLowerCase().includes('model')
382+
) {
383+
const cleaned = comment
384+
.replace(/\/\*|\*\//g, '')
385+
.replace(/\*/g, '')
386+
.trim();
387+
if (cleaned) descriptions.push(cleaned);
388+
}
389+
}
390+
391+
return descriptions.length > 0 ? descriptions.join('. ') : undefined;
392+
}
393+
327394
/**
328395
* Parse a Sequelize model file to extract model information
329396
*/
330397
function parseModelFile(content: string, filename: string): ModelInfo | null {
331-
// Extract model name from filename
332398
const modelName = path.basename(filename, path.extname(filename));
333399

334-
// Try to extract table name
335400
let tableName = modelName.toLowerCase();
336401
const tableNameMatch = content.match(/tableName:\s*['"`]([^'"`]+)['"`]/);
337402
if (tableNameMatch) {
338403
tableName = tableNameMatch[1];
339404
}
340405

341-
// Extract columns from the model definition
406+
const description = extractModelDescription(content, modelName);
407+
342408
const columns = parseColumns(content);
343409

344410
if (columns.length === 0) {
345411
return null;
346412
}
347413

348-
// Extract in-file associations
349414
const associations = parseAssociations(content);
350415

351416
return {
352417
name: modelName,
353418
tableName,
419+
description,
354420
columns,
355421
associations,
356422
};

0 commit comments

Comments
 (0)