Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix linting and type checking issues
  • Loading branch information
openhands-agent committed Dec 22, 2024
commit 4ff41e8a64222efd02c46518417ce55c7aa39880
43 changes: 43 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
exclude: docs/modules/python
- id: end-of-file-fixer
exclude: docs/modules/python
- id: check-yaml
- id: debug-statements

- repo: https://github.com/tox-dev/pyproject-fmt
rev: 1.7.0
hooks:
- id: pyproject-fmt
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.16
hooks:
- id: validate-pyproject

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.1
hooks:
# Run the linter.
- id: ruff
entry: ruff check --config dev_config/python/ruff.toml
types_or: [python, pyi, jupyter]
args: [--fix]
# Run the formatter.
- id: ruff-format
entry: ruff format --config dev_config/python/ruff.toml
types_or: [python, pyi, jupyter]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
hooks:
- id: mypy
additional_dependencies:
[types-requests, types-setuptools, types-pyyaml, types-toml]
entry: mypy --config-file dev_config/python/mypy.ini openhands_aci/
always_run: true
pass_filenames: false
Binary file added code_search_index/documents.pkl
Binary file not shown.
Binary file added code_search_index/index.faiss
Binary file not shown.
2 changes: 1 addition & 1 deletion openhands_aci/code_search/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .core import CodeSearchIndex
from .tools import initialize_code_search, search_code

__all__ = ['CodeSearchIndex', 'initialize_code_search', 'search_code']
__all__ = ['CodeSearchIndex', 'initialize_code_search', 'search_code']
81 changes: 45 additions & 36 deletions openhands_aci/code_search/core.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
import os
import pickle
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from typing import Any, Dict, List, Optional

import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer


class CodeSearchIndex:
def __init__(self, embedding_model: Optional[str] = None):
"""Initialize the code search index.

Args:
embedding_model: Name or path of the sentence transformer model to use.
If None, will use the model specified in EMBEDDING_MODEL env var.
"""
self.embedding_model = embedding_model or os.getenv('EMBEDDING_MODEL', 'BAAI/bge-base-en-v1.5')
self.embedding_model = embedding_model or os.getenv(
'EMBEDDING_MODEL', 'BAAI/bge-base-en-v1.5'
)
self.model = SentenceTransformer(self.embedding_model)
self.index = None
self.documents = []
self.doc_ids = []
self.index: Optional[faiss.IndexFlatIP] = None
self.documents: List[Dict[str, Any]] = []
self.doc_ids: List[str] = []

def _embed_text(self, text: str) -> np.ndarray:
"""Embed a single text string."""
Expand All @@ -32,94 +35,100 @@ def _embed_batch(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Embed a batch of text strings."""
embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
batch = texts[i : i + batch_size]
with torch.no_grad():
batch_embeddings = self.model.encode(batch, convert_to_tensor=True)
embeddings.append(batch_embeddings.cpu().numpy())
return np.vstack(embeddings)

def add_documents(self, documents: List[Dict[str, Any]], batch_size: int = 32):
"""Add documents to the index.

Args:
documents: List of document dictionaries with 'id' and 'content' keys
batch_size: Batch size for embedding generation
"""
texts = [doc['content'] for doc in documents]
embeddings = self._embed_batch(texts, batch_size)

if self.index is None:
self.index = faiss.IndexFlatIP(embeddings.shape[1])

self.index.add(embeddings)
self.documents.extend(documents)
self.doc_ids.extend([doc['id'] for doc in documents])

def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
"""Search the index with a query string.

Args:
query: The search query
k: Number of results to return

Returns:
List of document dictionaries with scores
"""
query_embedding = self._embed_text(query)
if self.index is None:
raise ValueError('Index is not initialized. Add documents first.')
scores, indices = self.index.search(query_embedding.reshape(1, -1), k)

results = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0 or idx >= len(self.documents):
continue
doc = self.documents[idx].copy()
doc['score'] = float(score)
results.append(doc)

return results

def save(self, directory: str):
"""Save the index and documents to disk.

Args:
directory: Directory to save the index in
"""
directory = Path(directory)
directory.mkdir(parents=True, exist_ok=True)
dir_path = Path(directory)
dir_path.mkdir(parents=True, exist_ok=True)

# Save the Faiss index
faiss.write_index(self.index, str(directory / 'index.faiss'))

if self.index is not None:
faiss.write_index(self.index, str(dir_path / 'index.faiss'))

# Save documents and metadata
with open(directory / 'documents.pkl', 'wb') as f:
pickle.dump({
'documents': self.documents,
'doc_ids': self.doc_ids,
'embedding_model': self.embedding_model
}, f)
with open(dir_path / 'documents.pkl', 'wb') as f:
pickle.dump(
{
'documents': self.documents,
'doc_ids': self.doc_ids,
'embedding_model': self.embedding_model,
},
f,
)

@classmethod
def load(cls, directory: str) -> 'CodeSearchIndex':
"""Load an index from disk.

Args:
directory: Directory containing the saved index

Returns:
Loaded CodeSearchIndex instance
"""
directory = Path(directory)
dir_path = Path(directory)

# Load metadata
with open(directory / 'documents.pkl', 'rb') as f:
with open(dir_path / 'documents.pkl', 'rb') as f:
data = pickle.load(f)

# Create instance with same model
instance = cls(embedding_model=data['embedding_model'])
instance.documents = data['documents']
instance.doc_ids = data['doc_ids']

# Load Faiss index
instance.index = faiss.read_index(str(directory / 'index.faiss'))
return instance
instance.index = faiss.read_index(str(dir_path / 'index.faiss'))

return instance
76 changes: 37 additions & 39 deletions openhands_aci/code_search/tools.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
import os
from pathlib import Path
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional

from git import Repo
from git.exc import InvalidGitRepositoryError

from .core import CodeSearchIndex

def _get_files_from_repo(repo_path: str, extensions: Optional[List[str]] = None) -> List[Dict[str, Any]]:

def _get_files_from_repo(
repo_path: str, extensions: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""Get all files from a git repository.

Args:
repo_path: Path to the git repository
extensions: List of file extensions to include (e.g. ['.py', '.js'])
If None, includes all files

Returns:
List of document dictionaries with 'id' and 'content' keys
"""
try:
repo = Repo(repo_path)
# Verify it's a git repo
_ = Repo(repo_path)
except InvalidGitRepositoryError:
raise ValueError(f"{repo_path} is not a valid git repository")
raise ValueError(f'{repo_path} is not a valid git repository')

documents = []
repo_path = Path(repo_path)
repo_path_obj = Path(repo_path)

for root, _, files in os.walk(repo_path):
for root, _, files in os.walk(repo_path_obj):
if '.git' in root:
continue

Expand All @@ -41,29 +45,30 @@ def _get_files_from_repo(repo_path: str, extensions: Optional[List[str]] = None)
except (UnicodeDecodeError, IOError):
continue

rel_path = file_path.relative_to(repo_path)
documents.append({
'id': str(rel_path),
'content': content,
'path': str(rel_path)
})
rel_path = file_path.relative_to(repo_path_obj)
documents.append(
{'id': str(rel_path), 'content': content, 'path': str(rel_path)}
)

return documents

def initialize_code_search(repo_path: str,
save_dir: str,
extensions: Optional[List[str]] = None,
embedding_model: Optional[str] = None,
batch_size: int = 32) -> Dict[str, Any]:

def initialize_code_search(
repo_path: str,
save_dir: str,
extensions: Optional[List[str]] = None,
embedding_model: Optional[str] = None,
batch_size: int = 32,
) -> Dict[str, Any]:
"""Initialize code search for a repository.

Args:
repo_path: Path to the git repository
save_dir: Directory to save the search index
extensions: List of file extensions to include
embedding_model: Name or path of the embedding model to use
batch_size: Batch size for embedding generation

Returns:
Dictionary with status and message
"""
Expand All @@ -73,7 +78,7 @@ def initialize_code_search(repo_path: str,
if not documents:
return {
'status': 'error',
'message': f'No files found in repository {repo_path}'
'message': f'No files found in repository {repo_path}',
}

# Create and save index
Expand All @@ -84,42 +89,35 @@ def initialize_code_search(repo_path: str,
return {
'status': 'success',
'message': f'Successfully indexed {len(documents)} files from {repo_path}',
'num_documents': len(documents)
'num_documents': len(documents),
}

except Exception as e:
return {
'status': 'error',
'message': f'Error initializing code search: {str(e)}'
'message': f'Error initializing code search: {str(e)}',
}

def search_code(save_dir: str,
query: str,
k: int = 5) -> Dict[str, Any]:

def search_code(save_dir: str, query: str, k: int = 5) -> Dict[str, Any]:
"""Search code in an indexed repository.

Args:
save_dir: Directory containing the search index
query: Search query
k: Number of results to return

Returns:
Dictionary with status and search results
"""
try:
# Load index
index = CodeSearchIndex.load(save_dir)

# Search
results = index.search(query, k=k)

return {
'status': 'success',
'results': results
}

return {'status': 'success', 'results': results}

except Exception as e:
return {
'status': 'error',
'message': f'Error searching code: {str(e)}'
}
return {'status': 'error', 'message': f'Error searching code: {str(e)}'}
Loading