diff --git a/src/api/routes/scanner.py b/src/api/routes/scanner.py index 50e8807..381cf4a 100644 --- a/src/api/routes/scanner.py +++ b/src/api/routes/scanner.py @@ -3,8 +3,9 @@ Provides endpoints for: - Validating GitHub URLs and detecting public/private repos + - Pre-scan time/token/cost estimates (heuristic) - Starting Phase 1 (AST scan) + Phase 2 (LLM enrichment) pipelines - - Polling scan status + - Polling scan status (persisted in MongoDB) - Listing user repos - Chat with indexed codebases (streaming NDJSON) @@ -21,27 +22,47 @@ import re import time import urllib.error +import urllib.parse import urllib.request -from typing import Any, Dict +from typing import Any, Dict, Optional, Tuple from fastapi import APIRouter, Query from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field +from src.config import settings + logger = logging.getLogger("xmem.api.routes.scanner") router = APIRouter(prefix="/v1/scanner", tags=["scanner"]) -_scan_jobs: Dict[str, Dict[str, Any]] = {} +_code_store_singleton: Any = None + + +def _get_code_store(): + global _code_store_singleton + if _code_store_singleton is None: + from src.scanner.code_store import CodeStore + + _code_store_singleton = CodeStore( + uri=settings.mongodb_uri, + database=settings.mongodb_database, + ) + return _code_store_singleton # ═══════════════════════════════════════════════════════════════════════════ # Request schemas # ═══════════════════════════════════════════════════════════════════════════ + class ValidateUrlRequest(BaseModel): github_url: str = Field(..., min_length=1) pat: str = Field(default="") + branch: str = Field( + default="", + description="Optional branch for estimates (defaults to repo default_branch)", + ) class ScanRequest(BaseModel): @@ -60,9 +81,10 @@ class ChatRequest(BaseModel): # ═══════════════════════════════════════════════════════════════════════════ -# Helpers +# Helpers — GitHub API # ═══════════════════════════════════════════════════════════════════════════ + def _parse_github_url(url: str) -> tuple: """Extract (org, repo) from a GitHub URL.""" url = url.strip().rstrip("/") @@ -70,12 +92,31 @@ def _parse_github_url(url: str) -> tuple: if m: return m.group(1), m.group(2) raise ValueError( - f"Invalid GitHub URL. Expected format: https://github.com/org/repo" + "Invalid GitHub URL. Expected format: https://github.com/org/repo" ) +def _github_headers(pat: str = "") -> Dict[str, str]: + h = { + "Accept": "application/vnd.github.v3+json", + "User-Agent": "XMem-Scanner/1.0", + } + if pat: + h["Authorization"] = f"token {pat}" + return h + + +def _github_get_json(path: str, pat: str = "") -> Dict[str, Any]: + """GET api.github.com path (leading slash optional).""" + path = path if path.startswith("/") else f"/{path}" + url = f"https://api.github.com{path}" + req = urllib.request.Request(url, headers=_github_headers(pat)) + with urllib.request.urlopen(req, timeout=15) as resp: + return json.loads(resp.read()) + + def _check_github_repo(org: str, repo: str, pat: str = "") -> dict: - """Hit the GitHub API to check repo accessibility.""" + """Hit the GitHub API to check repo accessibility and metadata.""" api_url = f"https://api.github.com/repos/{org}/{repo}" req = urllib.request.Request(api_url) req.add_header("Accept", "application/vnd.github.v3+json") @@ -85,12 +126,14 @@ def _check_github_repo(org: str, repo: str, pat: str = "") -> dict: try: with urllib.request.urlopen(req, timeout=10) as resp: data = json.loads(resp.read()) + size_kb = int(data.get("size") or 0) return { "accessible": True, "is_private": data.get("private", False), "default_branch": data.get("default_branch", "main"), "description": data.get("description") or "", "language": data.get("language") or "", + "size_kb": size_kb, } except urllib.error.HTTPError as e: if e.code == 404: @@ -102,10 +145,93 @@ def _check_github_repo(org: str, repo: str, pat: str = "") -> dict: return {"accessible": False, "error": str(e)} +def _get_branch_tip_sha( + org: str, repo: str, branch: str, pat: str = "", +) -> Optional[str]: + """Resolve the commit SHA at the tip of ``branch``.""" + ref = urllib.parse.quote(branch, safe="") + path = f"/repos/{org}/{repo}/commits/{ref}" + try: + data = _github_get_json(path, pat) + sha = data.get("sha") + if isinstance(sha, str) and len(sha) >= 7: + return sha + except urllib.error.HTTPError as e: + logger.warning("Could not resolve branch tip %s/%s@%s: %s", org, repo, branch, e) + except Exception as e: + logger.warning("Branch tip resolution failed: %s", e) + return None + + +def _compute_scan_estimates(size_kb: int, branch_label: str) -> Dict[str, Any]: + """Heuristic pre-scan estimates; all values are approximate.""" + s = settings + size_mb = max(size_kb / 1024.0, 0.01) + + phase1_sec = ( + s.scanner_estimate_phase1_base_seconds + + size_mb * s.scanner_estimate_phase1_seconds_per_mb + ) + emb_calls = size_mb * s.scanner_estimate_embedding_calls_per_mb + emb_tokens = int(emb_calls * s.scanner_estimate_avg_tokens_per_embedding) + llm_tokens = int(size_mb * s.scanner_estimate_llm_tokens_per_mb) + + cost: Optional[float] = None + if ( + s.scanner_estimate_embedding_cost_per_1m_tokens is not None + and s.scanner_estimate_llm_cost_per_1m_tokens is not None + ): + cost = ( + (emb_tokens / 1_000_000.0) + * float(s.scanner_estimate_embedding_cost_per_1m_tokens) + + (llm_tokens / 1_000_000.0) + * float(s.scanner_estimate_llm_cost_per_1m_tokens) + ) + + return { + "estimate_disclaimer": ( + "Estimates are approximate. Actual time and token use depend on " + "file types, parsers, and API pricing." + ), + "branch_used_for_label": branch_label, + "repo_size_kb": size_kb, + "estimated_phase1_seconds": round(phase1_sec, 1), + "estimated_embedding_api_calls": int(emb_calls), + "estimated_embedding_tokens": emb_tokens, + "estimated_phase2_llm_tokens": llm_tokens, + "estimated_cost_usd": round(cost, 4) if cost is not None else None, + } + + +def _can_reuse_index( + org: str, repo: str, remote_sha: Optional[str], +) -> Tuple[bool, bool]: + """Return (fully_reusable, phase2_only). + + fully_reusable: Phase 1+2 done for this commit. + phase2_only: Phase 1 data exists at this commit but enrichment pending. + """ + if not remote_sha: + return False, False + + store = _get_code_store() + last = store.get_last_scan(org, repo) + if not last or last.get("status") != "completed": + return False, False + if last.get("last_commit_sha") != remote_sha: + return False, False + + pending = store.count_unenriched(org, repo) + if pending["symbols"] == 0 and pending["files"] == 0: + return True, False + return False, True + + # ═══════════════════════════════════════════════════════════════════════════ # Background scan execution # ═══════════════════════════════════════════════════════════════════════════ + def _run_phase1(org: str, repo: str, url: str, branch: str, pat: str) -> dict: from src.scanner.indexer import Indexer @@ -132,43 +258,137 @@ def _run_phase2(org: str, repo: str) -> dict: enricher.close() +def _persist_job( + job_id: str, + username: str, + org: str, + repo: str, + branch: str, + url: str, + started_at: float, + phase1_status: str, + phase2_status: str, + error: Optional[str] = None, + phase1_result: Optional[dict] = None, + phase2_result: Optional[dict] = None, +) -> None: + store = _get_code_store() + store.upsert_scanner_job( + job_id=job_id, + username=username, + org=org, + repo=repo, + branch=branch, + url=url, + phase1_status=phase1_status, + phase2_status=phase2_status, + started_at=started_at, + error=error, + phase1_result=phase1_result, + phase2_result=phase2_result, + ) + + async def _run_scan_pipeline( - job_id: str, org: str, repo: str, url: str, branch: str, pat: str, + job_id: str, + username: str, + org: str, + repo: str, + url: str, + branch: str, + pat: str, ): """Run Phase 1 then Phase 2 in a background thread.""" loop = asyncio.get_running_loop() + store = _get_code_store() + started = store.get_scanner_job(job_id) + started_at = started["started_at"] if started else time.time() try: result = await loop.run_in_executor( None, lambda: _run_phase1(org, repo, url, branch, pat), ) - _scan_jobs[job_id]["phase1_status"] = "complete" - _scan_jobs[job_id]["phase1_result"] = result - _scan_jobs[job_id]["phase2_status"] = "running" + _persist_job( + job_id, username, org, repo, branch, url, started_at, + phase1_status="complete", + phase2_status="running", + phase1_result=result, + ) logger.info("Phase 1 complete for %s/%s", org, repo) except Exception as e: logger.error("Phase 1 failed for %s/%s: %s", org, repo, e) - _scan_jobs[job_id]["phase1_status"] = "failed" - _scan_jobs[job_id]["error"] = str(e) + _persist_job( + job_id, username, org, repo, branch, url, started_at, + phase1_status="failed", + phase2_status="pending", + error=str(e), + ) return try: enrich_result = await loop.run_in_executor( None, lambda: _run_phase2(org, repo), ) - _scan_jobs[job_id]["phase2_status"] = "complete" - _scan_jobs[job_id]["phase2_result"] = enrich_result + _persist_job( + job_id, username, org, repo, branch, url, started_at, + phase1_status="complete", + phase2_status="complete", + phase1_result=result, + phase2_result=enrich_result, + ) logger.info("Phase 2 complete for %s/%s", org, repo) except Exception as e: logger.error("Phase 2 failed for %s/%s: %s", org, repo, e) - _scan_jobs[job_id]["phase2_status"] = "failed" - _scan_jobs[job_id]["error"] = str(e) + _persist_job( + job_id, username, org, repo, branch, url, started_at, + phase1_status="complete", + phase2_status="failed", + error=str(e), + phase1_result=result, + ) + + +async def _run_phase2_pipeline_only( + job_id: str, + username: str, + org: str, + repo: str, + url: str, + branch: str, +): + loop = asyncio.get_running_loop() + store = _get_code_store() + started = store.get_scanner_job(job_id) + started_at = started["started_at"] if started else time.time() + + try: + enrich_result = await loop.run_in_executor( + None, lambda: _run_phase2(org, repo), + ) + _persist_job( + job_id, username, org, repo, branch, url, started_at, + phase1_status="complete", + phase2_status="complete", + phase1_result=started.get("phase1_result") if started else None, + phase2_result=enrich_result, + ) + logger.info("Phase 2-only complete for %s/%s", org, repo) + except Exception as e: + logger.error("Phase 2-only failed for %s/%s: %s", org, repo, e) + _persist_job( + job_id, username, org, repo, branch, url, started_at, + phase1_status="complete", + phase2_status="failed", + error=str(e), + phase1_result=started.get("phase1_result") if started else None, + ) # ═══════════════════════════════════════════════════════════════════════════ # Routes # ═══════════════════════════════════════════════════════════════════════════ + @router.post( "/validate-url", summary="Validate a GitHub URL and check accessibility", @@ -186,7 +406,14 @@ async def validate_url(req: ValidateUrlRequest): None, lambda: _check_github_repo(org, repo, req.pat), ) - return JSONResponse({"status": "ok", "org": org, "repo": repo, **info}) + payload: Dict[str, Any] = {"status": "ok", "org": org, "repo": repo, **info} + + if info.get("accessible") and info.get("size_kb") is not None: + branch_label = (req.branch or info.get("default_branch") or "main").strip() + estimates = _compute_scan_estimates(int(info["size_kb"]), branch_label) + payload["estimates"] = estimates + + return JSONResponse(payload) @router.post("/scan", summary="Start scanning a GitHub repository") @@ -199,8 +426,9 @@ async def start_scan(req: ScanRequest): ) job_id = f"{req.username}:{org}:{repo}" + store = _get_code_store() - existing = _scan_jobs.get(job_id) + existing = store.get_scanner_job(job_id) if existing and existing.get("phase1_status") == "running": return JSONResponse({ "status": "ok", @@ -209,29 +437,93 @@ async def start_scan(req: ScanRequest): "repo": repo, "message": "Scan already in progress", "phase1_status": "running", - "phase2_status": "pending", + "phase2_status": existing.get("phase2_status", "pending"), }) clone_url = req.github_url.strip().rstrip("/") if not clone_url.endswith(".git"): clone_url += ".git" - _scan_jobs[job_id] = { - "username": req.username, - "org": org, - "repo": repo, - "url": clone_url, - "branch": req.branch, - "phase1_status": "running", - "phase2_status": "pending", - "started_at": time.time(), - "phase1_result": None, - "phase2_result": None, - "error": None, - } + branch = (req.branch or "main").strip() + loop = asyncio.get_running_loop() + remote_sha = await loop.run_in_executor( + None, lambda: _get_branch_tip_sha(org, repo, branch, req.pat), + ) + + full_reuse, phase2_only = _can_reuse_index(org, repo, remote_sha) + + if full_reuse: + now = time.time() + store.upsert_scanner_job( + job_id=job_id, + username=req.username, + org=org, + repo=repo, + branch=branch, + url=clone_url, + phase1_status="complete", + phase2_status="complete", + started_at=now, + error=None, + phase1_result=None, + phase2_result=None, + ) + store.upsert_user_repo_entry( + req.username, org, repo, branch, + last_seen_commit=remote_sha, + ) + return JSONResponse({ + "status": "ok", + "job_id": job_id, + "org": org, + "repo": repo, + "reused": True, + "message": ( + "This revision is already indexed in the shared catalog. " + "Connected without re-scanning." + ), + "commit_sha": remote_sha, + "phase1_status": "complete", + "phase2_status": "complete", + }) + + started_at = time.time() + store.upsert_scanner_job( + job_id=job_id, + username=req.username, + org=org, + repo=repo, + branch=branch, + url=clone_url, + phase1_status="running" if not phase2_only else "complete", + phase2_status="running" if phase2_only else "pending", + started_at=started_at, + error=None, + ) + store.upsert_user_repo_entry(req.username, org, repo, branch) + + if phase2_only: + asyncio.create_task( + _run_phase2_pipeline_only( + job_id, req.username, org, repo, clone_url, branch, + ), + ) + return JSONResponse({ + "status": "ok", + "job_id": job_id, + "org": org, + "repo": repo, + "reused": False, + "phase2_only": True, + "message": "Index exists; running Phase 2 (LLM enrichment) only.", + "phase1_status": "complete", + "phase2_status": "running", + }) asyncio.create_task( - _run_scan_pipeline(job_id, org, repo, clone_url, req.branch, req.pat), + _run_scan_pipeline( + job_id, req.username, org, repo, clone_url, branch, req.pat, + ), ) return JSONResponse({ @@ -239,6 +531,7 @@ async def start_scan(req: ScanRequest): "job_id": job_id, "org": org, "repo": repo, + "reused": False, "phase1_status": "running", "phase2_status": "pending", }) @@ -251,7 +544,8 @@ async def scan_status( repo: str = Query(...), ): job_id = f"{username}:{org_id}:{repo}" - job = _scan_jobs.get(job_id) + store = _get_code_store() + job = store.get_scanner_job(job_id) if not job: return JSONResponse({ @@ -260,32 +554,52 @@ async def scan_status( "phase2_status": "not_started", }) - elapsed = time.time() - job["started_at"] + elapsed = time.time() - float(job.get("started_at", time.time())) resp: Dict[str, Any] = { "status": "ok", - "phase1_status": job["phase1_status"], - "phase2_status": job["phase2_status"], + "phase1_status": job.get("phase1_status", "not_started"), + "phase2_status": job.get("phase2_status", "not_started"), "elapsed_seconds": round(elapsed, 1), "error": job.get("error"), } - if job.get("phase1_result"): - resp["stats"] = job["phase1_result"].get("stats") + pr = job.get("phase1_result") + if isinstance(pr, dict) and pr.get("stats"): + resp["stats"] = pr["stats"] return JSONResponse(resp) @router.get("/repos", summary="List scanned repositories for a user") async def list_repos(username: str = Query(...)): - repos = [] - for _, job in _scan_jobs.items(): - if job["username"] == username: - repos.append({ - "org": job["org"], - "repo": job["repo"], - "phase1_status": job["phase1_status"], - "phase2_status": job["phase2_status"], - }) + store = _get_code_store() + jobs = store.list_scanner_jobs_for_user(username) + seen: set[tuple[str, str]] = set() + repos: list[Dict[str, Any]] = [] + for j in jobs: + o, r = j.get("org"), j.get("repo") + if not o or not r: + continue + seen.add((o, r)) + repos.append({ + "org": o, + "repo": r, + "phase1_status": j.get("phase1_status", "not_started"), + "phase2_status": j.get("phase2_status", "not_started"), + }) + for row in store.list_user_repos_for_user(username): + o, r = row.get("github_org"), row.get("repo") + if not o or not r: + continue + if (o, r) in seen: + continue + seen.add((o, r)) + repos.append({ + "org": o, + "repo": r, + "phase1_status": "not_started", + "phase2_status": "not_started", + }) return JSONResponse({"status": "ok", "repos": repos}) diff --git a/src/config/settings.py b/src/config/settings.py index ef5098b..10be468 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -132,7 +132,23 @@ class Settings(BaseSettings): default="us-east-1", description="Pinecone region for embeddings" ) - + + ssl_ca_bundle: Optional[str] = Field( + default=None, + description=( + "Path to PEM CA bundle for TLS (SSL_CA_BUNDLE). Required behind corporate " + "SSL inspection; Pinecone SDK ignores SSL_CERT_FILE unless this is set." + ), + ) + + pinecone_ssl_verify: bool = Field( + default=True, + description=( + "Verify TLS for Pinecone API (PINECONE_SSL_VERIFY). Set false only to debug " + "proxy or cert issues — never in production." + ), + ) + embedding_model: str = Field( default="gemini-embedding-001", description="Embedding model name (e.g. gemini-embedding-001, amazon.nova-2-multimodal-embeddings-v1:0)" @@ -159,6 +175,10 @@ class Settings(BaseSettings): ..., description="Neo4j password (required)" ) + neo4j_connection_timeout: float = Field( + default=60.0, + description="Neo4j driver TCP+handshake timeout in seconds (NEO4J_CONNECTION_TIMEOUT)", + ) api_host: str = Field( default="0.0.0.0", @@ -176,6 +196,36 @@ class Settings(BaseSettings): default=60, description="Rate limit (requests per minute)" ) + + # Scanner dashboard — heuristic pre-scan estimates (tunable) + scanner_estimate_phase1_base_seconds: float = Field( + default=45.0, + description="Base seconds added to Phase 1 time estimate", + ) + scanner_estimate_phase1_seconds_per_mb: float = Field( + default=8.0, + description="Additional Phase 1 seconds per MB of GitHub repo size", + ) + scanner_estimate_embedding_calls_per_mb: float = Field( + default=15.0, + description="Rough embedding API calls per MB for Phase 1", + ) + scanner_estimate_avg_tokens_per_embedding: int = Field( + default=256, + description="Assumed average tokens billed per embedding request", + ) + scanner_estimate_llm_tokens_per_mb: float = Field( + default=12000.0, + description="Rough upper-bound LLM tokens for Phase 2 per MB of repo", + ) + scanner_estimate_embedding_cost_per_1m_tokens: Optional[float] = Field( + default=None, + description="USD per 1M embedding tokens (optional; enables cost estimate)", + ) + scanner_estimate_llm_cost_per_1m_tokens: Optional[float] = Field( + default=None, + description="USD per 1M LLM tokens for Phase 2 (optional)", + ) max_request_body_bytes: int = Field( default=10 * 1024 * 1024, description="Maximum request body size in bytes (default 10MB)" diff --git a/src/scanner/code_store.py b/src/scanner/code_store.py index afbf35b..ddc73d2 100644 --- a/src/scanner/code_store.py +++ b/src/scanner/code_store.py @@ -5,6 +5,8 @@ raw_symbols — raw source code for each function/class (keyed by content hash) raw_files — raw content for each file scan_runs — tracks nightly scan state (last SHA, timestamps, stats) + scanner_jobs — dashboard scan job state (persists across API restarts) + scanner_user_repos — per-user repo rows for listing (shared index; key by user + org + repo) The raw code is stored here so the retrieval pipeline can fetch exact function bodies via ``get_symbol_code()`` without hitting the LLM or @@ -51,6 +53,8 @@ def __init__( self.symbols = self._db["raw_symbols"] self.files = self._db["raw_files"] self.scan_runs = self._db["scan_runs"] + self.scanner_jobs = self._db["scanner_jobs"] + self.scanner_user_repos = self._db["scanner_user_repos"] self._ensure_indexes() @@ -71,6 +75,109 @@ def _ensure_indexes(self) -> None: unique=True, ) + self.scanner_jobs.create_index([("job_id", 1)], unique=True) + self.scanner_jobs.create_index([("username", 1), ("updated_at", -1)]) + + self.scanner_user_repos.create_index( + [("username", 1), ("github_org", 1), ("repo", 1)], + unique=True, + ) + self.scanner_user_repos.create_index([("username", 1)]) + + # ====================================================================== + # SCANNER DASHBOARD — job + per-user repo listing + # ====================================================================== + + def upsert_scanner_job( + self, + job_id: str, + username: str, + org: str, + repo: str, + branch: str, + url: str, + phase1_status: str, + phase2_status: str, + started_at: float, + error: Optional[str] = None, + phase1_result: Optional[Dict[str, Any]] = None, + phase2_result: Optional[Dict[str, Any]] = None, + ) -> None: + """Persist or update scanner dashboard job state.""" + doc: Dict[str, Any] = { + "job_id": job_id, + "username": username, + "org": org, + "repo": repo, + "branch": branch, + "url": url, + "phase1_status": phase1_status, + "phase2_status": phase2_status, + "started_at": started_at, + "updated_at": datetime.now(timezone.utc), + "error": error, + } + if phase1_result is not None: + doc["phase1_result"] = phase1_result + if phase2_result is not None: + doc["phase2_result"] = phase2_result + + self.scanner_jobs.update_one( + {"job_id": job_id}, + {"$set": doc}, + upsert=True, + ) + + def get_scanner_job(self, job_id: str) -> Optional[Dict[str, Any]]: + return self.scanner_jobs.find_one({"job_id": job_id}) + + def list_scanner_jobs_for_user(self, username: str) -> List[Dict[str, Any]]: + cursor = self.scanner_jobs.find({"username": username}).sort( + "updated_at", -1, + ) + return list(cursor) + + def upsert_user_repo_entry( + self, + username: str, + github_org: str, + repo: str, + branch: str, + last_seen_commit: Optional[str] = None, + ) -> None: + """Record that this user has this repo in their list (shared index).""" + set_doc: Dict[str, Any] = { + "username": username, + "github_org": github_org, + "repo": repo, + "branch": branch, + "updated_at": datetime.now(timezone.utc), + } + if last_seen_commit: + set_doc["last_seen_commit"] = last_seen_commit + + self.scanner_user_repos.update_one( + { + "username": username, + "github_org": github_org, + "repo": repo, + }, + { + "$set": set_doc, + "$setOnInsert": { + "created_at": datetime.now(timezone.utc), + }, + }, + upsert=True, + ) + + def list_user_repos_for_user(self, username: str) -> List[Dict[str, Any]]: + """All repo rows bookmarked for this user (for listing merge).""" + cursor = self.scanner_user_repos.find({"username": username}).sort( + "updated_at", -1, + ) + return list(cursor) + # ====================================================================== # SYMBOL CRUD # ======================================================================