diff --git a/backend/pyproject.toml b/backend/pyproject.toml
index 9aaeaea..afa4e98 100644
--- a/backend/pyproject.toml
+++ b/backend/pyproject.toml
@@ -18,7 +18,7 @@ dependencies = [
"psycopg[binary,pool]>=3.3.3",
"pydantic-ai>=1.70.0",
"pydantic>=2.12.5",
- "sqlalchemy[mypy]>=2.0.48",
+ "sqlalchemy[asyncio,mypy]>=2.0.48",
"torch>=2.10.0",
"torchaudio>=2.10.0",
"transformers>=5.3.0",
diff --git a/backend/src/samplespace/agents/sample_agent.py b/backend/src/samplespace/agents/sample_agent.py
index 40fcadd..fc09895 100644
--- a/backend/src/samplespace/agents/sample_agent.py
+++ b/backend/src/samplespace/agents/sample_agent.py
@@ -57,6 +57,15 @@
- Be concise but informative — mention key, BPM, and type when relevant
- If the user references a sample by name rather than ID, search for it first
- NEVER generate URLs or markdown links — just use plain text and bold for emphasis
+
+## One-Shots vs Loops
+
+Samples are classified as either **one-shots** (single hits like a kick, snare, or chord stab) or **loops** (repeating patterns like a drum loop, bassline, or melodic phrase). Any sample type can be either — a kick can be a single hit or a kick pattern loop.
+
+- One-shots do NOT have meaningful key or BPM — never mention key/BPM for one-shots
+- Do not use check_key_compatibility for one-shots
+- When suggesting complements for a one-shot, focus on sonic character rather than key compatibility
+- Check the is_loop field on each sample rather than assuming from the type
""".strip()
_model = OpenAIResponsesModel(config.AGENT_MODEL)
diff --git a/backend/src/samplespace/agents/tools/analysis_tools.py b/backend/src/samplespace/agents/tools/analysis_tools.py
index dac99f6..28b6b8b 100644
--- a/backend/src/samplespace/agents/tools/analysis_tools.py
+++ b/backend/src/samplespace/agents/tools/analysis_tools.py
@@ -61,10 +61,12 @@ async def analyze_sample(ctx: RunContext[AgentDeps], sample_id: str) -> str:
lines = [
f"**{sample.filename}**",
f"- Type: {sample.sample_type or 'unknown'}",
- f"- Key: {sample.key or 'unknown'}",
- f"- BPM: {sample.bpm or 'unknown'}",
- f"- Duration: {sample.duration:.1f}s" if sample.duration else "- Duration: unknown",
+ f"- Category: {'loop' if sample.is_loop else 'one-shot'}",
]
+ if sample.is_loop:
+ lines.append(f"- Key: {sample.key or 'unknown'}")
+ lines.append(f"- BPM: {sample.bpm or 'unknown'}")
+ lines.append(f"- Duration: {sample.duration:.1f}s" if sample.duration is not None else "- Duration: unknown")
return "\n".join(lines)
except Exception:
logger.exception("Error analyzing sample")
@@ -161,10 +163,14 @@ async def suggest_complement(
if not filtered:
return "No complementary samples found."
- lines = [f"Samples that complement **{source.filename}** (key: {source.key or 'unknown'}):\n"]
+ if source.is_loop:
+ header = f"Samples that complement **{source.filename}** (key: {source.key or 'unknown'}):\n"
+ else:
+ header = f"Samples that complement **{source.filename}** (one-shot):\n"
+ lines = [header]
for i, s in enumerate(filtered[:8], 1):
compat = ""
- if source.key and s.key:
+ if source.is_loop and source.key and s.key:
if source.key == s.key:
compat = " ✓ same key"
elif RELATIVE_PAIRS.get(source.key) == s.key:
@@ -172,7 +178,7 @@ async def suggest_complement(
parts = [f"{i}. **{s.filename}**"]
if s.sample_type:
parts.append(f"type={s.sample_type}")
- if s.key:
+ if s.is_loop and s.key:
parts.append(f"key={s.key}{compat}")
parts.append(f"id={s.id}")
lines.append(" | ".join(parts))
diff --git a/backend/src/samplespace/agents/tools/clap_tools.py b/backend/src/samplespace/agents/tools/clap_tools.py
index a1cc136..96ecbd2 100644
--- a/backend/src/samplespace/agents/tools/clap_tools.py
+++ b/backend/src/samplespace/agents/tools/clap_tools.py
@@ -41,10 +41,12 @@ def _format_results(results: list[SampleSchema], query: str) -> str:
parts = [f"{i}. **{s.filename}**"]
if s.sample_type:
parts.append(f"type={s.sample_type}")
- if s.key:
- parts.append(f"key={s.key}")
- if s.bpm and s.bpm > 0:
- parts.append(f"bpm={s.bpm}")
+ parts.append("loop" if s.is_loop else "one-shot")
+ if s.is_loop:
+ if s.key:
+ parts.append(f"key={s.key}")
+ if s.bpm and s.bpm > 0:
+ parts.append(f"bpm={s.bpm}")
if s.duration:
parts.append(f"duration={s.duration:.1f}s")
parts.append(f"id={s.id}")
diff --git a/backend/src/samplespace/agents/tools/cnn_tools.py b/backend/src/samplespace/agents/tools/cnn_tools.py
index 79b7112..664cddc 100644
--- a/backend/src/samplespace/agents/tools/cnn_tools.py
+++ b/backend/src/samplespace/agents/tools/cnn_tools.py
@@ -40,10 +40,12 @@ def _format_results(results: list[SampleSchema], source_name: str) -> str:
parts = [f"{i}. **{s.filename}**"]
if s.sample_type:
parts.append(f"type={s.sample_type}")
- if s.key:
- parts.append(f"key={s.key}")
- if s.bpm and s.bpm > 0:
- parts.append(f"bpm={s.bpm}")
+ parts.append("loop" if s.is_loop else "one-shot")
+ if s.is_loop:
+ if s.key:
+ parts.append(f"key={s.key}")
+ if s.bpm and s.bpm > 0:
+ parts.append(f"bpm={s.bpm}")
parts.append(f"id={s.id}")
lines.append(" | ".join(parts))
return "\n".join(lines)
diff --git a/backend/src/samplespace/core/config.py b/backend/src/samplespace/core/config.py
index 261bc99..b67ada8 100644
--- a/backend/src/samplespace/core/config.py
+++ b/backend/src/samplespace/core/config.py
@@ -41,8 +41,9 @@ class Settings(
"""Main application settings."""
model_config = SettingsConfigDict(
- env_file=str(pathlib.Path(__file__).parent.parent.parent.parent / ".env"),
+ env_file=str(pathlib.Path(__file__).parent.parent.parent.parent.parent / ".env"),
env_ignore_empty=True,
+ extra="ignore",
)
ENVIRONMENT: Literal["development", "production"] = "development"
diff --git a/backend/src/samplespace/migrations/versions/c1e3a6c19709_add_is_loop_column.py b/backend/src/samplespace/migrations/versions/c1e3a6c19709_add_is_loop_column.py
new file mode 100644
index 0000000..701c8c3
--- /dev/null
+++ b/backend/src/samplespace/migrations/versions/c1e3a6c19709_add_is_loop_column.py
@@ -0,0 +1,30 @@
+"""add_is_loop_column
+
+Revision ID: c1e3a6c19709
+Revises: f2483c5c35f2
+Create Date: 2026-03-20 18:10:05.440003
+
+"""
+
+from typing import Sequence, Union
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision: str = "c1e3a6c19709"
+down_revision: Union[str, None] = "f2483c5c35f2"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column("samples", sa.Column("is_loop", sa.Boolean(), nullable=False, server_default=sa.text("false")))
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column("samples", "is_loop")
+ # ### end Alembic commands ###
diff --git a/backend/src/samplespace/models/sample.py b/backend/src/samplespace/models/sample.py
index 201e18e..5dd2353 100644
--- a/backend/src/samplespace/models/sample.py
+++ b/backend/src/samplespace/models/sample.py
@@ -7,7 +7,7 @@
from fastapi import HTTPException
from pgvector.sqlalchemy import Vector
-from sqlalchemy import DateTime, Float, String, cast, func, select
+from sqlalchemy import Boolean, DateTime, Float, String, cast, func, select, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column
@@ -33,6 +33,7 @@ class Sample(Base):
bpm: Mapped[float | None]
duration: Mapped[float | None]
sample_type: Mapped[str | None] = mapped_column(String(50))
+ is_loop: Mapped[bool] = mapped_column(Boolean, server_default=text("false"))
clap_embedding: Mapped[list[float] | None] = mapped_column(Vector(512))
cnn_embedding: Mapped[list[float] | None] = mapped_column(Vector(128))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
@@ -72,6 +73,7 @@ async def search_by_clap(
bpm_min: float | None = None,
bpm_max: float | None = None,
sample_type: str | None = None,
+ is_loop: bool | None = None,
limit: int = 20,
) -> Sequence[Sample]:
"""Search samples by CLAP embedding using pgvector cosine distance."""
@@ -87,6 +89,8 @@ async def search_by_clap(
stmt = stmt.where(cls.bpm <= bpm_max)
if sample_type is not None:
stmt = stmt.where(cls.sample_type == sample_type)
+ if is_loop is not None:
+ stmt = stmt.where(cls.is_loop == is_loop)
stmt = stmt.order_by(distance).limit(limit)
diff --git a/backend/src/samplespace/routers/samples.py b/backend/src/samplespace/routers/samples.py
index d38436b..d5c7356 100644
--- a/backend/src/samplespace/routers/samples.py
+++ b/backend/src/samplespace/routers/samples.py
@@ -46,6 +46,7 @@ async def search_samples(
bpm_min=body.bpm_min,
bpm_max=body.bpm_max,
sample_type=body.sample_type,
+ is_loop=body.is_loop,
limit=body.limit,
)
diff --git a/backend/src/samplespace/schemas/sample.py b/backend/src/samplespace/schemas/sample.py
index 8d3630a..cdb5f9f 100644
--- a/backend/src/samplespace/schemas/sample.py
+++ b/backend/src/samplespace/schemas/sample.py
@@ -16,6 +16,7 @@ class SampleSchema(BaseSchema):
bpm: float | None = None
duration: float | None = None
sample_type: str | None = Field(None, description="Category of the sample (e.g., kick, snare, pad, lead)")
+ is_loop: bool = False
created_at: datetime
@@ -34,4 +35,5 @@ class SampleSearchRequest(BaseModel):
bpm_min: float | None = Field(None, description="Minimum BPM filter")
bpm_max: float | None = Field(None, description="Maximum BPM filter")
sample_type: str | None = Field(None, description="Filter by sample type")
+ is_loop: bool | None = Field(None, description="Filter by loop (True) or one-shot (False)")
limit: int = Field(20, ge=1, le=100, description="Maximum number of results")
diff --git a/backend/src/samplespace/scripts/__init__.py b/backend/src/samplespace/scripts/__init__.py
index e69de29..47c55c2 100644
--- a/backend/src/samplespace/scripts/__init__.py
+++ b/backend/src/samplespace/scripts/__init__.py
@@ -0,0 +1,22 @@
+"""Shared utilities for CLI scripts."""
+
+from pathlib import Path
+
+from samplespace.core.config import get_settings
+
+
+def find_audio_file(filename: str, sample_type: str | None) -> Path | None:
+ """Locate an audio file by checking type subdirectory, root, then rglob."""
+ samples_dir = Path(get_settings().SAMPLES_DIR)
+
+ if sample_type:
+ candidate = samples_dir / sample_type / filename
+ if candidate.exists():
+ return candidate
+
+ candidate = samples_dir / filename
+ if candidate.exists():
+ return candidate
+
+ matches = list(samples_dir.rglob(filename))
+ return matches[0] if matches else None
diff --git a/backend/src/samplespace/scripts/embed_cnn.py b/backend/src/samplespace/scripts/embed_cnn.py
index bd4465a..b25f718 100644
--- a/backend/src/samplespace/scripts/embed_cnn.py
+++ b/backend/src/samplespace/scripts/embed_cnn.py
@@ -9,87 +9,69 @@
"""
import argparse
+import asyncio
import logging
-from pathlib import Path
-from sqlalchemy import create_engine, select, update
-from sqlalchemy.orm import Session
+from sqlalchemy import select, update
-from samplespace.core.config import get_settings
-from samplespace.ml.predict import load_model, predict
+from samplespace.dependencies.db import get_async_sqlalchemy_session
from samplespace.models.sample import Sample
+from samplespace.scripts import find_audio_file
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
-def main() -> None:
- parser = argparse.ArgumentParser(description="Generate CNN embeddings for all samples")
- parser.add_argument(
- "--force",
- action="store_true",
- help="Re-embed samples that already have embeddings",
- )
- args = parser.parse_args()
-
- config = get_settings()
- sync_url = (
- f"postgresql+psycopg://{config.POSTGRES_USER}:{config.POSTGRES_PASSWORD}"
- f"@{config.POSTGRES_HOST}:{config.POSTGRES_PORT}/{config.POSTGRES_DB}"
- )
- engine = create_engine(sync_url)
- SAMPLES_DIR = Path(config.SAMPLES_DIR)
+async def generate_embeddings(*, force: bool = False) -> None:
+ from samplespace.ml.predict import load_model, predict
model = load_model()
- with Session(engine) as session:
+ async with get_async_sqlalchemy_session() as db:
stmt = select(Sample)
- if not args.force:
+ if not force:
stmt = stmt.where(Sample.cnn_embedding.is_(None))
- samples = session.execute(stmt).scalars().all()
+ result = await db.execute(stmt)
+ samples = result.scalars().all()
logger.info(f"Found {len(samples)} samples to embed")
embedded = 0
for sample in samples:
- # Find audio file
- file_path = None
- if sample.sample_type:
- candidate = SAMPLES_DIR / sample.sample_type / sample.filename
- if candidate.exists():
- file_path = candidate
-
- if file_path is None:
- candidate = SAMPLES_DIR / sample.filename
- if candidate.exists():
- file_path = candidate
-
- if file_path is None:
- matches = list(SAMPLES_DIR.rglob(sample.filename))
- if matches:
- file_path = matches[0]
-
+ file_path = find_audio_file(sample.filename, sample.sample_type)
if file_path is None:
logger.warning(f" Audio file not found: {sample.filename}")
continue
try:
- result = predict(str(file_path), model)
- session.execute(update(Sample).where(Sample.id == sample.id).values(cnn_embedding=result.embedding))
+ pred_result = predict(str(file_path), model)
+ await db.execute(
+ update(Sample).where(Sample.id == sample.id).values(cnn_embedding=pred_result.embedding)
+ )
embedded += 1
logger.info(
f" Embedded: {sample.filename} "
- f"(predicted: {result.predicted_type}, "
- f"confidence: {result.confidence:.2%})"
+ f"(predicted: {pred_result.predicted_type}, "
+ f"confidence: {pred_result.confidence:.2%})"
)
except Exception:
logger.warning(f" Failed to embed: {sample.filename}", exc_info=True)
continue
- session.commit()
-
logger.info(f"Embedded {embedded}/{len(samples)} samples")
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Generate CNN embeddings for all samples")
+ parser.add_argument(
+ "--force",
+ action="store_true",
+ help="Re-embed samples that already have embeddings",
+ )
+ args = parser.parse_args()
+
+ asyncio.run(generate_embeddings(force=args.force))
+
+
if __name__ == "__main__":
main()
diff --git a/backend/src/samplespace/scripts/embed_samples.py b/backend/src/samplespace/scripts/embed_samples.py
index 31cd76a..b9e692c 100644
--- a/backend/src/samplespace/scripts/embed_samples.py
+++ b/backend/src/samplespace/scripts/embed_samples.py
@@ -9,86 +9,64 @@
"""
import argparse
+import asyncio
import logging
-from pathlib import Path
-from sqlalchemy import create_engine, select, update
-from sqlalchemy.orm import Session
+from sqlalchemy import select, update
-from samplespace.core.config import get_settings
+from samplespace.dependencies.db import get_async_sqlalchemy_session
from samplespace.models.sample import Sample
+from samplespace.scripts import find_audio_file
from samplespace.services.embedding import embed_audio, load_clap_model
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
-def main() -> None:
- parser = argparse.ArgumentParser(description="Generate CLAP embeddings for all samples")
- parser.add_argument(
- "--force",
- action="store_true",
- help="Re-embed samples that already have embeddings",
- )
- args = parser.parse_args()
-
- config = get_settings()
- sync_url = (
- f"postgresql+psycopg://{config.POSTGRES_USER}:{config.POSTGRES_PASSWORD}"
- f"@{config.POSTGRES_HOST}:{config.POSTGRES_PORT}/{config.POSTGRES_DB}"
- )
- engine = create_engine(sync_url)
- SAMPLES_DIR = Path(config.SAMPLES_DIR)
-
+async def generate_embeddings(*, force: bool = False) -> None:
# Load CLAP model
model, processor = load_clap_model()
- with Session(engine) as session:
+ async with get_async_sqlalchemy_session() as db:
# Query samples that need embedding
stmt = select(Sample)
- if not args.force:
+ if not force:
stmt = stmt.where(Sample.clap_embedding.is_(None))
- samples = session.execute(stmt).scalars().all()
+ result = await db.execute(stmt)
+ samples = result.scalars().all()
logger.info(f"Found {len(samples)} samples to embed")
embedded = 0
for sample in samples:
- # Find audio file — check category subdirectories
- file_path = None
- if sample.sample_type:
- candidate = SAMPLES_DIR / sample.sample_type / sample.filename
- if candidate.exists():
- file_path = candidate
-
- if file_path is None:
- candidate = SAMPLES_DIR / sample.filename
- if candidate.exists():
- file_path = candidate
-
- if file_path is None:
- # Search all subdirectories
- matches = list(SAMPLES_DIR.rglob(sample.filename))
- if matches:
- file_path = matches[0]
-
+ file_path = find_audio_file(sample.filename, sample.sample_type)
if file_path is None:
logger.warning(f" Audio file not found: {sample.filename}")
continue
try:
embedding = embed_audio(str(file_path), model, processor)
- session.execute(update(Sample).where(Sample.id == sample.id).values(clap_embedding=embedding))
+ await db.execute(update(Sample).where(Sample.id == sample.id).values(clap_embedding=embedding))
embedded += 1
logger.info(f" Embedded: {sample.filename}")
except Exception:
logger.warning(f" Failed to embed: {sample.filename}", exc_info=True)
continue
- session.commit()
-
logger.info(f"Embedded {embedded}/{len(samples)} samples")
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Generate CLAP embeddings for all samples")
+ parser.add_argument(
+ "--force",
+ action="store_true",
+ help="Re-embed samples that already have embeddings",
+ )
+ args = parser.parse_args()
+
+ asyncio.run(generate_embeddings(force=args.force))
+
+
if __name__ == "__main__":
main()
diff --git a/backend/src/samplespace/scripts/seed.py b/backend/src/samplespace/scripts/seed.py
index 8c59533..c69ef75 100644
--- a/backend/src/samplespace/scripts/seed.py
+++ b/backend/src/samplespace/scripts/seed.py
@@ -10,16 +10,17 @@
"""
import argparse
+import asyncio
import logging
import uuid
from pathlib import Path
-from sqlalchemy import create_engine, text
-from sqlalchemy.orm import Session
+from sqlalchemy import select
from samplespace.core.config import get_settings
-from samplespace.models import Base, Sample
-from samplespace.services.audio_analysis import analyze_audio
+from samplespace.dependencies.db import get_async_sqlalchemy_session
+from samplespace.models.sample import Sample
+from samplespace.services.audio_analysis import analyze_and_classify
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
@@ -58,24 +59,13 @@ def scan_local_samples(
return found
-def seed_database(samples: list[tuple[Path, str | None]]) -> int:
+async def seed_database(samples: list[tuple[Path, str | None]]) -> int:
"""Analyze audio files and insert sample records into the database."""
- config = get_settings()
- sync_url = f"postgresql+psycopg://{config.POSTGRES_USER}:{config.POSTGRES_PASSWORD}@{config.POSTGRES_HOST}:{config.POSTGRES_PORT}/{config.POSTGRES_DB}"
-
- engine = create_engine(sync_url)
-
- # Ensure pgvector extension exists
- with engine.connect() as conn:
- conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
- conn.commit()
-
- Base.metadata.create_all(engine)
-
inserted = 0
- with Session(engine) as session:
+ async with get_async_sqlalchemy_session() as db:
# Get existing filenames to avoid duplicates
- existing = {row[0] for row in session.execute(text("SELECT filename FROM samples")).fetchall()}
+ result = await db.execute(select(Sample.filename))
+ existing = {row[0] for row in result.all()}
for file_path, sample_type in samples:
filename = file_path.name
@@ -84,7 +74,7 @@ def seed_database(samples: list[tuple[Path, str | None]]) -> int:
continue
try:
- metadata = analyze_audio(str(file_path))
+ analysis = analyze_and_classify(str(file_path))
except Exception:
logger.warning(f" Failed to analyze: {filename}")
continue
@@ -92,21 +82,21 @@ def seed_database(samples: list[tuple[Path, str | None]]) -> int:
sample = Sample(
id=str(uuid.uuid4()),
filename=filename,
- key=metadata.key,
- bpm=metadata.bpm,
- duration=metadata.duration,
+ key=analysis.metadata.key,
+ bpm=analysis.metadata.bpm,
+ duration=analysis.metadata.duration,
sample_type=sample_type,
+ is_loop=analysis.is_loop,
)
- session.add(sample)
+ db.add(sample)
inserted += 1
logger.info(
f" Inserted: {filename} "
- f"(key={metadata.key}, bpm={metadata.bpm}, "
- f"duration={metadata.duration:.1f}s, type={sample_type})"
+ f"(key={analysis.metadata.key}, bpm={analysis.metadata.bpm}, "
+ f"duration={analysis.metadata.duration:.1f}s, type={sample_type}, "
+ f"is_loop={analysis.is_loop})"
)
- session.commit()
-
logger.info(f"Seeded {inserted} new samples (skipped {len(samples) - inserted} duplicates)")
return inserted
@@ -130,7 +120,7 @@ def main() -> None:
)
return
- seed_database(samples)
+ asyncio.run(seed_database(samples))
if __name__ == "__main__":
diff --git a/backend/src/samplespace/services/audio_analysis.py b/backend/src/samplespace/services/audio_analysis.py
index e1f329e..4f33698 100644
--- a/backend/src/samplespace/services/audio_analysis.py
+++ b/backend/src/samplespace/services/audio_analysis.py
@@ -1,6 +1,8 @@
"""Audio analysis service using librosa for key, BPM, and duration extraction."""
import logging
+import re
+from dataclasses import dataclass
import librosa
import numpy as np
@@ -9,27 +11,112 @@
logger = logging.getLogger(__name__)
+_LOOP_PATTERN = re.compile(r"\b(?:loop|loops|looped)\b", re.IGNORECASE)
+_ONE_SHOT_PATTERN = re.compile(r"\b(?:one[-_ ]?shot|oneshot|hit|hits|single)\b", re.IGNORECASE)
-def analyze_audio(file_path: str) -> AudioMetadata:
- """Analyze an audio file and extract key metadata.
- Returns dict with keys: key, bpm, duration.
+@dataclass
+class AnalysisResult:
+ """Combined result of loop inference and audio analysis."""
+
+ is_loop: bool
+ metadata: AudioMetadata
+
+
+def analyze_and_classify(file_path: str) -> AnalysisResult:
+ """Infer one-shot/loop classification and extract audio metadata in a single pass.
+
+ Uses a tiered approach for loop inference:
+ 1. Check filepath for explicit keywords (loop, one-shot, hit, etc.)
+ 2. Fall back to audio heuristics (duration, onset density, onset regularity)
+
+ When classified as a one-shot, skips key/BPM extraction (meaningless for single hits).
"""
+ # Tier 1: Filepath keywords (word-boundary matching to avoid false positives)
+ loop_match = _LOOP_PATTERN.search(file_path)
+ if loop_match:
+ logger.info(f"Inferred is_loop=True for {file_path} (keyword: {loop_match.group()})")
+ return AnalysisResult(
+ is_loop=True,
+ metadata=_analyze_audio(file_path),
+ )
+ one_shot_match = _ONE_SHOT_PATTERN.search(file_path)
+ if one_shot_match:
+ logger.info(f"Inferred is_loop=False for {file_path} (keyword: {one_shot_match.group()})")
+ return AnalysisResult(
+ is_loop=False,
+ metadata=_analyze_duration_only(file_path),
+ )
+
+ # Tier 2: Audio heuristics (load once, reuse for analysis)
+ y, sr = librosa.load(file_path, sr=22050, mono=True)
+ is_loop = _infer_from_audio(file_path, y, int(sr))
+
+ if is_loop:
+ metadata = _extract_full_metadata(file_path, y, int(sr))
+ else:
+ duration = round(float(librosa.get_duration(y=y, sr=sr)), 2)
+ metadata = AudioMetadata(key=None, bpm=None, duration=duration)
+ logger.info(f"Analyzed {file_path}: one-shot, duration={duration:.1f}s (skipped key/BPM)")
+
+ return AnalysisResult(is_loop=is_loop, metadata=metadata)
+
+
+def _infer_from_audio(file_path: str, y: np.ndarray, sr: int) -> bool:
+ """Infer one-shot/loop from audio features (tier 2 heuristic)."""
+ duration = float(librosa.get_duration(y=y, sr=sr))
+ onsets = librosa.onset.onset_detect(y=y, sr=sr)
+ num_onsets = len(onsets)
+
+ is_short = duration < 2.0
+ is_long = duration > 4.0
+ few_onsets = num_onsets <= 2
+
+ # Check onset regularity for potential loops
+ has_regular_rhythm = False
+ if num_onsets >= 4:
+ onset_times = librosa.frames_to_time(onsets, sr=sr)
+ intervals = np.diff(onset_times)
+ if np.mean(intervals) > 0:
+ cv = float(np.std(intervals) / np.mean(intervals))
+ has_regular_rhythm = cv < 0.25
+
+ if is_short and few_onsets:
+ result = False
+ elif is_long or has_regular_rhythm:
+ result = True
+ else:
+ result = False # Default to one-shot for ambiguous cases
+
+ logger.info(f"Inferred is_loop={result} for {file_path} (duration={duration:.1f}s, onsets={num_onsets}, heuristic)")
+ return result
+
+
+def _analyze_audio(file_path: str) -> AudioMetadata:
+ """Full audio analysis (key, BPM, duration) — loads file from disk."""
+ y, sr = librosa.load(file_path, sr=22050, mono=True)
+ return _extract_full_metadata(file_path, y, int(sr))
+
+
+def _analyze_duration_only(file_path: str) -> AudioMetadata:
+ """Duration-only analysis for one-shots — loads file from disk."""
y, sr = librosa.load(file_path, sr=22050, mono=True)
+ duration = round(float(librosa.get_duration(y=y, sr=sr)), 2)
+ logger.info(f"Analyzed {file_path}: one-shot, duration={duration:.1f}s (skipped key/BPM)")
+ return AudioMetadata(key=None, bpm=None, duration=duration)
- # Duration
+
+def _extract_full_metadata(file_path: str, y: np.ndarray, sr: int) -> AudioMetadata:
+ """Extract key, BPM, and duration from already-loaded audio."""
duration = float(librosa.get_duration(y=y, sr=sr))
- # BPM
tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
tempo_val = float(tempo[0]) if isinstance(tempo, np.ndarray) else float(tempo)
bpm = round(tempo_val, 1)
- # Key detection via chroma features
- key = _detect_key(y, int(sr))
+ key = _detect_key(y, sr)
logger.info(f"Analyzed {file_path}: key={key}, bpm={bpm}, duration={duration:.1f}s")
-
return AudioMetadata(key=key, bpm=bpm, duration=round(duration, 2))
diff --git a/backend/src/samplespace/services/sample.py b/backend/src/samplespace/services/sample.py
index 0ae360c..fabb28c 100644
--- a/backend/src/samplespace/services/sample.py
+++ b/backend/src/samplespace/services/sample.py
@@ -7,7 +7,7 @@
from samplespace.models.sample import Sample
from samplespace.schemas.sample import SampleListResponse, SampleSchema
-from samplespace.services.audio_analysis import analyze_audio
+from samplespace.services.audio_analysis import analyze_and_classify
logger = logging.getLogger(__name__)
@@ -20,15 +20,16 @@ async def create_sample(
sample_type: str | None = None,
) -> Sample:
"""Create a sample from an audio file: analyze metadata and persist."""
- metadata = analyze_audio(file_path)
+ result = analyze_and_classify(file_path)
sample = Sample(
id=str(uuid.uuid4()),
filename=filename,
- key=metadata.key,
- bpm=metadata.bpm,
- duration=metadata.duration,
+ key=result.metadata.key,
+ bpm=result.metadata.bpm,
+ duration=result.metadata.duration,
sample_type=sample_type,
+ is_loop=result.is_loop,
)
db.add(sample)
await db.flush()
@@ -65,6 +66,7 @@ async def search_by_text(
bpm_min: float | None = None,
bpm_max: float | None = None,
sample_type: str | None = None,
+ is_loop: bool | None = None,
limit: int = 20,
) -> list[SampleSchema]:
"""Search samples by CLAP text embedding with optional metadata filters.
@@ -78,6 +80,7 @@ async def search_by_text(
bpm_min=bpm_min,
bpm_max=bpm_max,
sample_type=sample_type,
+ is_loop=is_loop,
limit=limit,
)
diff --git a/backend/uv.lock b/backend/uv.lock
index f3c998b..052dc60 100644
--- a/backend/uv.lock
+++ b/backend/uv.lock
@@ -2932,7 +2932,7 @@ dependencies = [
{ name = "psycopg", extra = ["binary", "pool"] },
{ name = "pydantic" },
{ name = "pydantic-ai" },
- { name = "sqlalchemy", extra = ["mypy"] },
+ { name = "sqlalchemy", extra = ["asyncio", "mypy"] },
{ name = "torch" },
{ name = "torchaudio" },
{ name = "transformers" },
@@ -2958,7 +2958,7 @@ requires-dist = [
{ name = "psycopg", extras = ["binary", "pool"], specifier = ">=3.3.3" },
{ name = "pydantic", specifier = ">=2.12.5" },
{ name = "pydantic-ai", specifier = ">=1.70.0" },
- { name = "sqlalchemy", extras = ["mypy"], specifier = ">=2.0.48" },
+ { name = "sqlalchemy", extras = ["asyncio", "mypy"], specifier = ">=2.0.48" },
{ name = "torch", specifier = ">=2.10.0" },
{ name = "torchaudio", specifier = ">=2.10.0" },
{ name = "transformers", specifier = ">=5.3.0" },
@@ -3154,6 +3154,9 @@ wheels = [
]
[package.optional-dependencies]
+asyncio = [
+ { name = "greenlet" },
+]
mypy = [
{ name = "mypy" },
]
diff --git a/frontend/api/generated/types.gen.ts b/frontend/api/generated/types.gen.ts
index 605b2b6..ebd1876 100644
--- a/frontend/api/generated/types.gen.ts
+++ b/frontend/api/generated/types.gen.ts
@@ -72,6 +72,10 @@ export type SampleSchema = {
* Category of the sample (e.g., kick, snare, pad, lead)
*/
sample_type?: string | null;
+ /**
+ * Is Loop
+ */
+ is_loop?: boolean;
/**
* Created At
*/
@@ -114,6 +118,12 @@ export type SampleSearchRequest = {
* Filter by sample type
*/
sample_type?: string | null;
+ /**
+ * Is Loop
+ *
+ * Filter by loop (True) or one-shot (False)
+ */
+ is_loop?: boolean | null;
/**
* Limit
*
diff --git a/frontend/api/scripts/outputs/openapi.json b/frontend/api/scripts/outputs/openapi.json
index 772f20c..9f37766 100644
--- a/frontend/api/scripts/outputs/openapi.json
+++ b/frontend/api/scripts/outputs/openapi.json
@@ -394,6 +394,11 @@
"title": "Sample Type",
"description": "Category of the sample (e.g., kick, snare, pad, lead)"
},
+ "is_loop": {
+ "type": "boolean",
+ "title": "Is Loop",
+ "default": false
+ },
"created_at": {
"type": "string",
"format": "date-time",
@@ -471,6 +476,18 @@
"title": "Sample Type",
"description": "Filter by sample type"
},
+ "is_loop": {
+ "anyOf": [
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Is Loop",
+ "description": "Filter by loop (True) or one-shot (False)"
+ },
"limit": {
"type": "integer",
"maximum": 100,
diff --git a/frontend/components/sample-browser.tsx b/frontend/components/sample-browser.tsx
index c901a86..75a7f89 100644
--- a/frontend/components/sample-browser.tsx
+++ b/frontend/components/sample-browser.tsx
@@ -57,6 +57,9 @@ function SampleCard({
{sample.sample_type}
)}
+
+ {sample.is_loop ? "loop" : "one-shot"}
+
{sample.key && (
{sample.key}
@@ -86,6 +89,9 @@ function SampleCard({
}
export function SampleBrowser() {
+ const [activeCategory, setActiveCategory] = useState<
+ "all" | "one-shot" | "loop"
+ >("all");
const [activeType, setActiveType] = useState