From 81f3bdae97ac939a900d30d93755b3e93e71a59e Mon Sep 17 00:00:00 2001 From: Luke Mainwaring Date: Fri, 20 Mar 2026 14:58:37 -0400 Subject: [PATCH 1/2] one shot vs. loop --- backend/pyproject.toml | 2 +- .../src/samplespace/agents/sample_agent.py | 9 +++ .../agents/tools/analysis_tools.py | 18 +++-- .../samplespace/agents/tools/clap_tools.py | 10 ++- .../src/samplespace/agents/tools/cnn_tools.py | 10 ++- backend/src/samplespace/core/config.py | 3 +- .../c1e3a6c19709_add_is_loop_column.py | 30 ++++++++ backend/src/samplespace/models/sample.py | 6 +- backend/src/samplespace/schemas/sample.py | 2 + backend/src/samplespace/scripts/__init__.py | 22 ++++++ backend/src/samplespace/scripts/embed_cnn.py | 76 +++++++------------ .../src/samplespace/scripts/embed_samples.py | 68 ++++++----------- backend/src/samplespace/scripts/seed.py | 41 ++++------ .../samplespace/services/audio_analysis.py | 62 ++++++++++++++- backend/src/samplespace/services/sample.py | 8 +- backend/uv.lock | 7 +- frontend/api/generated/types.gen.ts | 10 +++ frontend/api/scripts/outputs/openapi.json | 17 +++++ frontend/components/sample-browser.tsx | 30 +++++++- 19 files changed, 287 insertions(+), 144 deletions(-) create mode 100644 backend/src/samplespace/migrations/versions/c1e3a6c19709_add_is_loop_column.py diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 9aaeaea..afa4e98 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "psycopg[binary,pool]>=3.3.3", "pydantic-ai>=1.70.0", "pydantic>=2.12.5", - "sqlalchemy[mypy]>=2.0.48", + "sqlalchemy[asyncio,mypy]>=2.0.48", "torch>=2.10.0", "torchaudio>=2.10.0", "transformers>=5.3.0", diff --git a/backend/src/samplespace/agents/sample_agent.py b/backend/src/samplespace/agents/sample_agent.py index 40fcadd..fc09895 100644 --- a/backend/src/samplespace/agents/sample_agent.py +++ b/backend/src/samplespace/agents/sample_agent.py @@ -57,6 +57,15 @@ - Be concise but informative — mention key, BPM, and type when relevant - If the user references a sample by name rather than ID, search for it first - NEVER generate URLs or markdown links — just use plain text and bold for emphasis + +## One-Shots vs Loops + +Samples are classified as either **one-shots** (single hits like a kick, snare, or chord stab) or **loops** (repeating patterns like a drum loop, bassline, or melodic phrase). Any sample type can be either — a kick can be a single hit or a kick pattern loop. + +- One-shots do NOT have meaningful key or BPM — never mention key/BPM for one-shots +- Do not use check_key_compatibility for one-shots +- When suggesting complements for a one-shot, focus on sonic character rather than key compatibility +- Check the is_loop field on each sample rather than assuming from the type """.strip() _model = OpenAIResponsesModel(config.AGENT_MODEL) diff --git a/backend/src/samplespace/agents/tools/analysis_tools.py b/backend/src/samplespace/agents/tools/analysis_tools.py index dac99f6..8fa4796 100644 --- a/backend/src/samplespace/agents/tools/analysis_tools.py +++ b/backend/src/samplespace/agents/tools/analysis_tools.py @@ -61,10 +61,12 @@ async def analyze_sample(ctx: RunContext[AgentDeps], sample_id: str) -> str: lines = [ f"**{sample.filename}**", f"- Type: {sample.sample_type or 'unknown'}", - f"- Key: {sample.key or 'unknown'}", - f"- BPM: {sample.bpm or 'unknown'}", - f"- Duration: {sample.duration:.1f}s" if sample.duration else "- Duration: unknown", + f"- Category: {'loop' if sample.is_loop else 'one-shot'}", ] + if sample.is_loop: + lines.append(f"- Key: {sample.key or 'unknown'}") + lines.append(f"- BPM: {sample.bpm or 'unknown'}") + lines.append(f"- Duration: {sample.duration:.1f}s" if sample.duration else "- Duration: unknown") return "\n".join(lines) except Exception: logger.exception("Error analyzing sample") @@ -161,10 +163,14 @@ async def suggest_complement( if not filtered: return "No complementary samples found." - lines = [f"Samples that complement **{source.filename}** (key: {source.key or 'unknown'}):\n"] + if source.is_loop: + header = f"Samples that complement **{source.filename}** (key: {source.key or 'unknown'}):\n" + else: + header = f"Samples that complement **{source.filename}** (one-shot):\n" + lines = [header] for i, s in enumerate(filtered[:8], 1): compat = "" - if source.key and s.key: + if source.is_loop and source.key and s.key: if source.key == s.key: compat = " ✓ same key" elif RELATIVE_PAIRS.get(source.key) == s.key: @@ -172,7 +178,7 @@ async def suggest_complement( parts = [f"{i}. **{s.filename}**"] if s.sample_type: parts.append(f"type={s.sample_type}") - if s.key: + if s.is_loop and s.key: parts.append(f"key={s.key}{compat}") parts.append(f"id={s.id}") lines.append(" | ".join(parts)) diff --git a/backend/src/samplespace/agents/tools/clap_tools.py b/backend/src/samplespace/agents/tools/clap_tools.py index a1cc136..96ecbd2 100644 --- a/backend/src/samplespace/agents/tools/clap_tools.py +++ b/backend/src/samplespace/agents/tools/clap_tools.py @@ -41,10 +41,12 @@ def _format_results(results: list[SampleSchema], query: str) -> str: parts = [f"{i}. **{s.filename}**"] if s.sample_type: parts.append(f"type={s.sample_type}") - if s.key: - parts.append(f"key={s.key}") - if s.bpm and s.bpm > 0: - parts.append(f"bpm={s.bpm}") + parts.append("loop" if s.is_loop else "one-shot") + if s.is_loop: + if s.key: + parts.append(f"key={s.key}") + if s.bpm and s.bpm > 0: + parts.append(f"bpm={s.bpm}") if s.duration: parts.append(f"duration={s.duration:.1f}s") parts.append(f"id={s.id}") diff --git a/backend/src/samplespace/agents/tools/cnn_tools.py b/backend/src/samplespace/agents/tools/cnn_tools.py index 79b7112..664cddc 100644 --- a/backend/src/samplespace/agents/tools/cnn_tools.py +++ b/backend/src/samplespace/agents/tools/cnn_tools.py @@ -40,10 +40,12 @@ def _format_results(results: list[SampleSchema], source_name: str) -> str: parts = [f"{i}. **{s.filename}**"] if s.sample_type: parts.append(f"type={s.sample_type}") - if s.key: - parts.append(f"key={s.key}") - if s.bpm and s.bpm > 0: - parts.append(f"bpm={s.bpm}") + parts.append("loop" if s.is_loop else "one-shot") + if s.is_loop: + if s.key: + parts.append(f"key={s.key}") + if s.bpm and s.bpm > 0: + parts.append(f"bpm={s.bpm}") parts.append(f"id={s.id}") lines.append(" | ".join(parts)) return "\n".join(lines) diff --git a/backend/src/samplespace/core/config.py b/backend/src/samplespace/core/config.py index 261bc99..b67ada8 100644 --- a/backend/src/samplespace/core/config.py +++ b/backend/src/samplespace/core/config.py @@ -41,8 +41,9 @@ class Settings( """Main application settings.""" model_config = SettingsConfigDict( - env_file=str(pathlib.Path(__file__).parent.parent.parent.parent / ".env"), + env_file=str(pathlib.Path(__file__).parent.parent.parent.parent.parent / ".env"), env_ignore_empty=True, + extra="ignore", ) ENVIRONMENT: Literal["development", "production"] = "development" diff --git a/backend/src/samplespace/migrations/versions/c1e3a6c19709_add_is_loop_column.py b/backend/src/samplespace/migrations/versions/c1e3a6c19709_add_is_loop_column.py new file mode 100644 index 0000000..701c8c3 --- /dev/null +++ b/backend/src/samplespace/migrations/versions/c1e3a6c19709_add_is_loop_column.py @@ -0,0 +1,30 @@ +"""add_is_loop_column + +Revision ID: c1e3a6c19709 +Revises: f2483c5c35f2 +Create Date: 2026-03-20 18:10:05.440003 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c1e3a6c19709" +down_revision: Union[str, None] = "f2483c5c35f2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("samples", sa.Column("is_loop", sa.Boolean(), nullable=False, server_default=sa.text("false"))) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("samples", "is_loop") + # ### end Alembic commands ### diff --git a/backend/src/samplespace/models/sample.py b/backend/src/samplespace/models/sample.py index 201e18e..20952aa 100644 --- a/backend/src/samplespace/models/sample.py +++ b/backend/src/samplespace/models/sample.py @@ -7,7 +7,7 @@ from fastapi import HTTPException from pgvector.sqlalchemy import Vector -from sqlalchemy import DateTime, Float, String, cast, func, select +from sqlalchemy import Boolean, DateTime, Float, String, cast, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, mapped_column @@ -33,6 +33,7 @@ class Sample(Base): bpm: Mapped[float | None] duration: Mapped[float | None] sample_type: Mapped[str | None] = mapped_column(String(50)) + is_loop: Mapped[bool] = mapped_column(Boolean, default=False) clap_embedding: Mapped[list[float] | None] = mapped_column(Vector(512)) cnn_embedding: Mapped[list[float] | None] = mapped_column(Vector(128)) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) @@ -72,6 +73,7 @@ async def search_by_clap( bpm_min: float | None = None, bpm_max: float | None = None, sample_type: str | None = None, + is_loop: bool | None = None, limit: int = 20, ) -> Sequence[Sample]: """Search samples by CLAP embedding using pgvector cosine distance.""" @@ -87,6 +89,8 @@ async def search_by_clap( stmt = stmt.where(cls.bpm <= bpm_max) if sample_type is not None: stmt = stmt.where(cls.sample_type == sample_type) + if is_loop is not None: + stmt = stmt.where(cls.is_loop == is_loop) stmt = stmt.order_by(distance).limit(limit) diff --git a/backend/src/samplespace/schemas/sample.py b/backend/src/samplespace/schemas/sample.py index 8d3630a..cdb5f9f 100644 --- a/backend/src/samplespace/schemas/sample.py +++ b/backend/src/samplespace/schemas/sample.py @@ -16,6 +16,7 @@ class SampleSchema(BaseSchema): bpm: float | None = None duration: float | None = None sample_type: str | None = Field(None, description="Category of the sample (e.g., kick, snare, pad, lead)") + is_loop: bool = False created_at: datetime @@ -34,4 +35,5 @@ class SampleSearchRequest(BaseModel): bpm_min: float | None = Field(None, description="Minimum BPM filter") bpm_max: float | None = Field(None, description="Maximum BPM filter") sample_type: str | None = Field(None, description="Filter by sample type") + is_loop: bool | None = Field(None, description="Filter by loop (True) or one-shot (False)") limit: int = Field(20, ge=1, le=100, description="Maximum number of results") diff --git a/backend/src/samplespace/scripts/__init__.py b/backend/src/samplespace/scripts/__init__.py index e69de29..47c55c2 100644 --- a/backend/src/samplespace/scripts/__init__.py +++ b/backend/src/samplespace/scripts/__init__.py @@ -0,0 +1,22 @@ +"""Shared utilities for CLI scripts.""" + +from pathlib import Path + +from samplespace.core.config import get_settings + + +def find_audio_file(filename: str, sample_type: str | None) -> Path | None: + """Locate an audio file by checking type subdirectory, root, then rglob.""" + samples_dir = Path(get_settings().SAMPLES_DIR) + + if sample_type: + candidate = samples_dir / sample_type / filename + if candidate.exists(): + return candidate + + candidate = samples_dir / filename + if candidate.exists(): + return candidate + + matches = list(samples_dir.rglob(filename)) + return matches[0] if matches else None diff --git a/backend/src/samplespace/scripts/embed_cnn.py b/backend/src/samplespace/scripts/embed_cnn.py index bd4465a..b25f718 100644 --- a/backend/src/samplespace/scripts/embed_cnn.py +++ b/backend/src/samplespace/scripts/embed_cnn.py @@ -9,87 +9,69 @@ """ import argparse +import asyncio import logging -from pathlib import Path -from sqlalchemy import create_engine, select, update -from sqlalchemy.orm import Session +from sqlalchemy import select, update -from samplespace.core.config import get_settings -from samplespace.ml.predict import load_model, predict +from samplespace.dependencies.db import get_async_sqlalchemy_session from samplespace.models.sample import Sample +from samplespace.scripts import find_audio_file logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logger = logging.getLogger(__name__) -def main() -> None: - parser = argparse.ArgumentParser(description="Generate CNN embeddings for all samples") - parser.add_argument( - "--force", - action="store_true", - help="Re-embed samples that already have embeddings", - ) - args = parser.parse_args() - - config = get_settings() - sync_url = ( - f"postgresql+psycopg://{config.POSTGRES_USER}:{config.POSTGRES_PASSWORD}" - f"@{config.POSTGRES_HOST}:{config.POSTGRES_PORT}/{config.POSTGRES_DB}" - ) - engine = create_engine(sync_url) - SAMPLES_DIR = Path(config.SAMPLES_DIR) +async def generate_embeddings(*, force: bool = False) -> None: + from samplespace.ml.predict import load_model, predict model = load_model() - with Session(engine) as session: + async with get_async_sqlalchemy_session() as db: stmt = select(Sample) - if not args.force: + if not force: stmt = stmt.where(Sample.cnn_embedding.is_(None)) - samples = session.execute(stmt).scalars().all() + result = await db.execute(stmt) + samples = result.scalars().all() logger.info(f"Found {len(samples)} samples to embed") embedded = 0 for sample in samples: - # Find audio file - file_path = None - if sample.sample_type: - candidate = SAMPLES_DIR / sample.sample_type / sample.filename - if candidate.exists(): - file_path = candidate - - if file_path is None: - candidate = SAMPLES_DIR / sample.filename - if candidate.exists(): - file_path = candidate - - if file_path is None: - matches = list(SAMPLES_DIR.rglob(sample.filename)) - if matches: - file_path = matches[0] - + file_path = find_audio_file(sample.filename, sample.sample_type) if file_path is None: logger.warning(f" Audio file not found: {sample.filename}") continue try: - result = predict(str(file_path), model) - session.execute(update(Sample).where(Sample.id == sample.id).values(cnn_embedding=result.embedding)) + pred_result = predict(str(file_path), model) + await db.execute( + update(Sample).where(Sample.id == sample.id).values(cnn_embedding=pred_result.embedding) + ) embedded += 1 logger.info( f" Embedded: {sample.filename} " - f"(predicted: {result.predicted_type}, " - f"confidence: {result.confidence:.2%})" + f"(predicted: {pred_result.predicted_type}, " + f"confidence: {pred_result.confidence:.2%})" ) except Exception: logger.warning(f" Failed to embed: {sample.filename}", exc_info=True) continue - session.commit() - logger.info(f"Embedded {embedded}/{len(samples)} samples") +def main() -> None: + parser = argparse.ArgumentParser(description="Generate CNN embeddings for all samples") + parser.add_argument( + "--force", + action="store_true", + help="Re-embed samples that already have embeddings", + ) + args = parser.parse_args() + + asyncio.run(generate_embeddings(force=args.force)) + + if __name__ == "__main__": main() diff --git a/backend/src/samplespace/scripts/embed_samples.py b/backend/src/samplespace/scripts/embed_samples.py index 31cd76a..b9e692c 100644 --- a/backend/src/samplespace/scripts/embed_samples.py +++ b/backend/src/samplespace/scripts/embed_samples.py @@ -9,86 +9,64 @@ """ import argparse +import asyncio import logging -from pathlib import Path -from sqlalchemy import create_engine, select, update -from sqlalchemy.orm import Session +from sqlalchemy import select, update -from samplespace.core.config import get_settings +from samplespace.dependencies.db import get_async_sqlalchemy_session from samplespace.models.sample import Sample +from samplespace.scripts import find_audio_file from samplespace.services.embedding import embed_audio, load_clap_model logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logger = logging.getLogger(__name__) -def main() -> None: - parser = argparse.ArgumentParser(description="Generate CLAP embeddings for all samples") - parser.add_argument( - "--force", - action="store_true", - help="Re-embed samples that already have embeddings", - ) - args = parser.parse_args() - - config = get_settings() - sync_url = ( - f"postgresql+psycopg://{config.POSTGRES_USER}:{config.POSTGRES_PASSWORD}" - f"@{config.POSTGRES_HOST}:{config.POSTGRES_PORT}/{config.POSTGRES_DB}" - ) - engine = create_engine(sync_url) - SAMPLES_DIR = Path(config.SAMPLES_DIR) - +async def generate_embeddings(*, force: bool = False) -> None: # Load CLAP model model, processor = load_clap_model() - with Session(engine) as session: + async with get_async_sqlalchemy_session() as db: # Query samples that need embedding stmt = select(Sample) - if not args.force: + if not force: stmt = stmt.where(Sample.clap_embedding.is_(None)) - samples = session.execute(stmt).scalars().all() + result = await db.execute(stmt) + samples = result.scalars().all() logger.info(f"Found {len(samples)} samples to embed") embedded = 0 for sample in samples: - # Find audio file — check category subdirectories - file_path = None - if sample.sample_type: - candidate = SAMPLES_DIR / sample.sample_type / sample.filename - if candidate.exists(): - file_path = candidate - - if file_path is None: - candidate = SAMPLES_DIR / sample.filename - if candidate.exists(): - file_path = candidate - - if file_path is None: - # Search all subdirectories - matches = list(SAMPLES_DIR.rglob(sample.filename)) - if matches: - file_path = matches[0] - + file_path = find_audio_file(sample.filename, sample.sample_type) if file_path is None: logger.warning(f" Audio file not found: {sample.filename}") continue try: embedding = embed_audio(str(file_path), model, processor) - session.execute(update(Sample).where(Sample.id == sample.id).values(clap_embedding=embedding)) + await db.execute(update(Sample).where(Sample.id == sample.id).values(clap_embedding=embedding)) embedded += 1 logger.info(f" Embedded: {sample.filename}") except Exception: logger.warning(f" Failed to embed: {sample.filename}", exc_info=True) continue - session.commit() - logger.info(f"Embedded {embedded}/{len(samples)} samples") +def main() -> None: + parser = argparse.ArgumentParser(description="Generate CLAP embeddings for all samples") + parser.add_argument( + "--force", + action="store_true", + help="Re-embed samples that already have embeddings", + ) + args = parser.parse_args() + + asyncio.run(generate_embeddings(force=args.force)) + + if __name__ == "__main__": main() diff --git a/backend/src/samplespace/scripts/seed.py b/backend/src/samplespace/scripts/seed.py index 8c59533..7cdc627 100644 --- a/backend/src/samplespace/scripts/seed.py +++ b/backend/src/samplespace/scripts/seed.py @@ -10,16 +10,17 @@ """ import argparse +import asyncio import logging import uuid from pathlib import Path -from sqlalchemy import create_engine, text -from sqlalchemy.orm import Session +from sqlalchemy import select from samplespace.core.config import get_settings -from samplespace.models import Base, Sample -from samplespace.services.audio_analysis import analyze_audio +from samplespace.dependencies.db import get_async_sqlalchemy_session +from samplespace.models.sample import Sample +from samplespace.services.audio_analysis import analyze_audio, infer_is_loop logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logger = logging.getLogger(__name__) @@ -58,24 +59,13 @@ def scan_local_samples( return found -def seed_database(samples: list[tuple[Path, str | None]]) -> int: +async def seed_database(samples: list[tuple[Path, str | None]]) -> int: """Analyze audio files and insert sample records into the database.""" - config = get_settings() - sync_url = f"postgresql+psycopg://{config.POSTGRES_USER}:{config.POSTGRES_PASSWORD}@{config.POSTGRES_HOST}:{config.POSTGRES_PORT}/{config.POSTGRES_DB}" - - engine = create_engine(sync_url) - - # Ensure pgvector extension exists - with engine.connect() as conn: - conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) - conn.commit() - - Base.metadata.create_all(engine) - inserted = 0 - with Session(engine) as session: + async with get_async_sqlalchemy_session() as db: # Get existing filenames to avoid duplicates - existing = {row[0] for row in session.execute(text("SELECT filename FROM samples")).fetchall()} + result = await db.execute(select(Sample.filename)) + existing = {row[0] for row in result.all()} for file_path, sample_type in samples: filename = file_path.name @@ -84,7 +74,8 @@ def seed_database(samples: list[tuple[Path, str | None]]) -> int: continue try: - metadata = analyze_audio(str(file_path)) + is_loop = infer_is_loop(str(file_path)) + metadata = analyze_audio(str(file_path), is_loop=is_loop) except Exception: logger.warning(f" Failed to analyze: {filename}") continue @@ -96,17 +87,17 @@ def seed_database(samples: list[tuple[Path, str | None]]) -> int: bpm=metadata.bpm, duration=metadata.duration, sample_type=sample_type, + is_loop=is_loop, ) - session.add(sample) + db.add(sample) inserted += 1 logger.info( f" Inserted: {filename} " f"(key={metadata.key}, bpm={metadata.bpm}, " - f"duration={metadata.duration:.1f}s, type={sample_type})" + f"duration={metadata.duration:.1f}s, type={sample_type}, " + f"is_loop={is_loop})" ) - session.commit() - logger.info(f"Seeded {inserted} new samples (skipped {len(samples) - inserted} duplicates)") return inserted @@ -130,7 +121,7 @@ def main() -> None: ) return - seed_database(samples) + asyncio.run(seed_database(samples)) if __name__ == "__main__": diff --git a/backend/src/samplespace/services/audio_analysis.py b/backend/src/samplespace/services/audio_analysis.py index e1f329e..aa0926b 100644 --- a/backend/src/samplespace/services/audio_analysis.py +++ b/backend/src/samplespace/services/audio_analysis.py @@ -9,17 +9,73 @@ logger = logging.getLogger(__name__) +LOOP_KEYWORDS = {"loop", "loops", "looped"} +ONE_SHOT_KEYWORDS = {"one-shot", "oneshot", "one_shot", "one shot", "hit", "hits", "single"} -def analyze_audio(file_path: str) -> AudioMetadata: + +def infer_is_loop(file_path: str) -> bool: + """Infer whether an audio file is a loop or a one-shot. + + Uses a tiered approach: + 1. Check filepath for explicit keywords (loop, one-shot, hit, etc.) + 2. Fall back to audio heuristics (duration, onset density, onset regularity) + """ + path_lower = file_path.lower() + + # Tier 1: Filepath keywords + for keyword in LOOP_KEYWORDS: + if keyword in path_lower: + logger.info(f"Inferred is_loop=True for {file_path} (keyword: {keyword})") + return True + for keyword in ONE_SHOT_KEYWORDS: + if keyword in path_lower: + logger.info(f"Inferred is_loop=False for {file_path} (keyword: {keyword})") + return False + + # Tier 2: Audio heuristics + y, sr = librosa.load(file_path, sr=22050, mono=True) + duration = float(librosa.get_duration(y=y, sr=sr)) + onsets = librosa.onset.onset_detect(y=y, sr=sr) + num_onsets = len(onsets) + + is_short = duration < 2.0 + is_long = duration > 4.0 + few_onsets = num_onsets <= 2 + + # Check onset regularity for potential loops + has_regular_rhythm = False + if num_onsets >= 4: + onset_times = librosa.frames_to_time(onsets, sr=sr) + intervals = np.diff(onset_times) + if len(intervals) > 0 and np.mean(intervals) > 0: + cv = float(np.std(intervals) / np.mean(intervals)) + has_regular_rhythm = cv < 0.25 + + if is_short and few_onsets: + result = False + elif is_long or has_regular_rhythm: + result = True + else: + result = False # Default to one-shot for ambiguous cases + + logger.info(f"Inferred is_loop={result} for {file_path} (duration={duration:.1f}s, onsets={num_onsets}, heuristic)") + return result + + +def analyze_audio(file_path: str, *, is_loop: bool = True) -> AudioMetadata: """Analyze an audio file and extract key metadata. - Returns dict with keys: key, bpm, duration. + When is_loop is False, only computes duration (key/BPM are meaningless for one-shots). """ y, sr = librosa.load(file_path, sr=22050, mono=True) - # Duration + # Duration (always computed) duration = float(librosa.get_duration(y=y, sr=sr)) + if not is_loop: + logger.info(f"Analyzed {file_path}: one-shot, duration={duration:.1f}s (skipped key/BPM)") + return AudioMetadata(key=None, bpm=None, duration=round(duration, 2)) + # BPM tempo, _ = librosa.beat.beat_track(y=y, sr=sr) tempo_val = float(tempo[0]) if isinstance(tempo, np.ndarray) else float(tempo) diff --git a/backend/src/samplespace/services/sample.py b/backend/src/samplespace/services/sample.py index 0ae360c..39812b5 100644 --- a/backend/src/samplespace/services/sample.py +++ b/backend/src/samplespace/services/sample.py @@ -7,7 +7,7 @@ from samplespace.models.sample import Sample from samplespace.schemas.sample import SampleListResponse, SampleSchema -from samplespace.services.audio_analysis import analyze_audio +from samplespace.services.audio_analysis import analyze_audio, infer_is_loop logger = logging.getLogger(__name__) @@ -20,7 +20,8 @@ async def create_sample( sample_type: str | None = None, ) -> Sample: """Create a sample from an audio file: analyze metadata and persist.""" - metadata = analyze_audio(file_path) + is_loop = infer_is_loop(file_path) + metadata = analyze_audio(file_path, is_loop=is_loop) sample = Sample( id=str(uuid.uuid4()), @@ -29,6 +30,7 @@ async def create_sample( bpm=metadata.bpm, duration=metadata.duration, sample_type=sample_type, + is_loop=is_loop, ) db.add(sample) await db.flush() @@ -65,6 +67,7 @@ async def search_by_text( bpm_min: float | None = None, bpm_max: float | None = None, sample_type: str | None = None, + is_loop: bool | None = None, limit: int = 20, ) -> list[SampleSchema]: """Search samples by CLAP text embedding with optional metadata filters. @@ -78,6 +81,7 @@ async def search_by_text( bpm_min=bpm_min, bpm_max=bpm_max, sample_type=sample_type, + is_loop=is_loop, limit=limit, ) diff --git a/backend/uv.lock b/backend/uv.lock index f3c998b..052dc60 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -2932,7 +2932,7 @@ dependencies = [ { name = "psycopg", extra = ["binary", "pool"] }, { name = "pydantic" }, { name = "pydantic-ai" }, - { name = "sqlalchemy", extra = ["mypy"] }, + { name = "sqlalchemy", extra = ["asyncio", "mypy"] }, { name = "torch" }, { name = "torchaudio" }, { name = "transformers" }, @@ -2958,7 +2958,7 @@ requires-dist = [ { name = "psycopg", extras = ["binary", "pool"], specifier = ">=3.3.3" }, { name = "pydantic", specifier = ">=2.12.5" }, { name = "pydantic-ai", specifier = ">=1.70.0" }, - { name = "sqlalchemy", extras = ["mypy"], specifier = ">=2.0.48" }, + { name = "sqlalchemy", extras = ["asyncio", "mypy"], specifier = ">=2.0.48" }, { name = "torch", specifier = ">=2.10.0" }, { name = "torchaudio", specifier = ">=2.10.0" }, { name = "transformers", specifier = ">=5.3.0" }, @@ -3154,6 +3154,9 @@ wheels = [ ] [package.optional-dependencies] +asyncio = [ + { name = "greenlet" }, +] mypy = [ { name = "mypy" }, ] diff --git a/frontend/api/generated/types.gen.ts b/frontend/api/generated/types.gen.ts index 605b2b6..ebd1876 100644 --- a/frontend/api/generated/types.gen.ts +++ b/frontend/api/generated/types.gen.ts @@ -72,6 +72,10 @@ export type SampleSchema = { * Category of the sample (e.g., kick, snare, pad, lead) */ sample_type?: string | null; + /** + * Is Loop + */ + is_loop?: boolean; /** * Created At */ @@ -114,6 +118,12 @@ export type SampleSearchRequest = { * Filter by sample type */ sample_type?: string | null; + /** + * Is Loop + * + * Filter by loop (True) or one-shot (False) + */ + is_loop?: boolean | null; /** * Limit * diff --git a/frontend/api/scripts/outputs/openapi.json b/frontend/api/scripts/outputs/openapi.json index 772f20c..9f37766 100644 --- a/frontend/api/scripts/outputs/openapi.json +++ b/frontend/api/scripts/outputs/openapi.json @@ -394,6 +394,11 @@ "title": "Sample Type", "description": "Category of the sample (e.g., kick, snare, pad, lead)" }, + "is_loop": { + "type": "boolean", + "title": "Is Loop", + "default": false + }, "created_at": { "type": "string", "format": "date-time", @@ -471,6 +476,18 @@ "title": "Sample Type", "description": "Filter by sample type" }, + "is_loop": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Loop", + "description": "Filter by loop (True) or one-shot (False)" + }, "limit": { "type": "integer", "maximum": 100, diff --git a/frontend/components/sample-browser.tsx b/frontend/components/sample-browser.tsx index c901a86..75a7f89 100644 --- a/frontend/components/sample-browser.tsx +++ b/frontend/components/sample-browser.tsx @@ -57,6 +57,9 @@ function SampleCard({ {sample.sample_type} )} + + {sample.is_loop ? "loop" : "one-shot"} + {sample.key && ( {sample.key} @@ -86,6 +89,9 @@ function SampleCard({ } export function SampleBrowser() { + const [activeCategory, setActiveCategory] = useState< + "all" | "one-shot" | "loop" + >("all"); const [activeType, setActiveType] = useState(null); const [playingId, setPlayingId] = useState(null); @@ -96,9 +102,12 @@ export function SampleBrowser() { ); const samples = data?.samples ?? []; - const filteredSamples = activeType - ? samples.filter((s) => s.sample_type === activeType) - : samples; + const filteredSamples = samples.filter((s) => { + if (activeCategory === "one-shot" && s.is_loop) return false; + if (activeCategory === "loop" && !s.is_loop) return false; + if (activeType && s.sample_type !== activeType) return false; + return true; + }); const handleTogglePlay = useCallback((sample: SampleSchema) => { setPlayingId((prev) => (prev === sample.id ? null : sample.id)); @@ -119,6 +128,21 @@ export function SampleBrowser() {

+ {/* Category filter */} +
+ {(["all", "one-shot", "loop"] as const).map((cat) => ( + + ))} +
+ {/* Type filters */}