Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"psycopg[binary,pool]>=3.3.3",
"pydantic-ai>=1.70.0",
"pydantic>=2.12.5",
"sqlalchemy[mypy]>=2.0.48",
"sqlalchemy[asyncio,mypy]>=2.0.48",
"torch>=2.10.0",
"torchaudio>=2.10.0",
"transformers>=5.3.0",
Expand Down
9 changes: 9 additions & 0 deletions backend/src/samplespace/agents/sample_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@
- Be concise but informative — mention key, BPM, and type when relevant
- If the user references a sample by name rather than ID, search for it first
- NEVER generate URLs or markdown links — just use plain text and bold for emphasis

## One-Shots vs Loops

Samples are classified as either **one-shots** (single hits like a kick, snare, or chord stab) or **loops** (repeating patterns like a drum loop, bassline, or melodic phrase). Any sample type can be either — a kick can be a single hit or a kick pattern loop.

- One-shots do NOT have meaningful key or BPM — never mention key/BPM for one-shots
- Do not use check_key_compatibility for one-shots
- When suggesting complements for a one-shot, focus on sonic character rather than key compatibility
- Check the is_loop field on each sample rather than assuming from the type
""".strip()

_model = OpenAIResponsesModel(config.AGENT_MODEL)
Expand Down
18 changes: 12 additions & 6 deletions backend/src/samplespace/agents/tools/analysis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@ async def analyze_sample(ctx: RunContext[AgentDeps], sample_id: str) -> str:
lines = [
f"**{sample.filename}**",
f"- Type: {sample.sample_type or 'unknown'}",
f"- Key: {sample.key or 'unknown'}",
f"- BPM: {sample.bpm or 'unknown'}",
f"- Duration: {sample.duration:.1f}s" if sample.duration else "- Duration: unknown",
f"- Category: {'loop' if sample.is_loop else 'one-shot'}",
]
if sample.is_loop:
lines.append(f"- Key: {sample.key or 'unknown'}")
lines.append(f"- BPM: {sample.bpm or 'unknown'}")
lines.append(f"- Duration: {sample.duration:.1f}s" if sample.duration is not None else "- Duration: unknown")
return "\n".join(lines)
except Exception:
logger.exception("Error analyzing sample")
Expand Down Expand Up @@ -161,18 +163,22 @@ async def suggest_complement(
if not filtered:
return "No complementary samples found."

lines = [f"Samples that complement **{source.filename}** (key: {source.key or 'unknown'}):\n"]
if source.is_loop:
header = f"Samples that complement **{source.filename}** (key: {source.key or 'unknown'}):\n"
else:
header = f"Samples that complement **{source.filename}** (one-shot):\n"
lines = [header]
for i, s in enumerate(filtered[:8], 1):
compat = ""
if source.key and s.key:
if source.is_loop and source.key and s.key:
if source.key == s.key:
compat = " ✓ same key"
elif RELATIVE_PAIRS.get(source.key) == s.key:
compat = " ✓ relative key"
parts = [f"{i}. **{s.filename}**"]
if s.sample_type:
parts.append(f"type={s.sample_type}")
if s.key:
if s.is_loop and s.key:
parts.append(f"key={s.key}{compat}")
parts.append(f"id={s.id}")
lines.append(" | ".join(parts))
Expand Down
10 changes: 6 additions & 4 deletions backend/src/samplespace/agents/tools/clap_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ def _format_results(results: list[SampleSchema], query: str) -> str:
parts = [f"{i}. **{s.filename}**"]
if s.sample_type:
parts.append(f"type={s.sample_type}")
if s.key:
parts.append(f"key={s.key}")
if s.bpm and s.bpm > 0:
parts.append(f"bpm={s.bpm}")
parts.append("loop" if s.is_loop else "one-shot")
if s.is_loop:
if s.key:
parts.append(f"key={s.key}")
if s.bpm and s.bpm > 0:
parts.append(f"bpm={s.bpm}")
if s.duration:
parts.append(f"duration={s.duration:.1f}s")
parts.append(f"id={s.id}")
Expand Down
10 changes: 6 additions & 4 deletions backend/src/samplespace/agents/tools/cnn_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ def _format_results(results: list[SampleSchema], source_name: str) -> str:
parts = [f"{i}. **{s.filename}**"]
if s.sample_type:
parts.append(f"type={s.sample_type}")
if s.key:
parts.append(f"key={s.key}")
if s.bpm and s.bpm > 0:
parts.append(f"bpm={s.bpm}")
parts.append("loop" if s.is_loop else "one-shot")
if s.is_loop:
if s.key:
parts.append(f"key={s.key}")
if s.bpm and s.bpm > 0:
parts.append(f"bpm={s.bpm}")
parts.append(f"id={s.id}")
lines.append(" | ".join(parts))
return "\n".join(lines)
Expand Down
3 changes: 2 additions & 1 deletion backend/src/samplespace/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ class Settings(
"""Main application settings."""

model_config = SettingsConfigDict(
env_file=str(pathlib.Path(__file__).parent.parent.parent.parent / ".env"),
env_file=str(pathlib.Path(__file__).parent.parent.parent.parent.parent / ".env"),
env_ignore_empty=True,
extra="ignore",
)

ENVIRONMENT: Literal["development", "production"] = "development"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""add_is_loop_column

Revision ID: c1e3a6c19709
Revises: f2483c5c35f2
Create Date: 2026-03-20 18:10:05.440003

"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "c1e3a6c19709"
down_revision: Union[str, None] = "f2483c5c35f2"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("samples", sa.Column("is_loop", sa.Boolean(), nullable=False, server_default=sa.text("false")))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("samples", "is_loop")
# ### end Alembic commands ###
6 changes: 5 additions & 1 deletion backend/src/samplespace/models/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fastapi import HTTPException
from pgvector.sqlalchemy import Vector
from sqlalchemy import DateTime, Float, String, cast, func, select
from sqlalchemy import Boolean, DateTime, Float, String, cast, func, select, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column

Expand All @@ -33,6 +33,7 @@ class Sample(Base):
bpm: Mapped[float | None]
duration: Mapped[float | None]
sample_type: Mapped[str | None] = mapped_column(String(50))
is_loop: Mapped[bool] = mapped_column(Boolean, server_default=text("false"))
clap_embedding: Mapped[list[float] | None] = mapped_column(Vector(512))
cnn_embedding: Mapped[list[float] | None] = mapped_column(Vector(128))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
Expand Down Expand Up @@ -72,6 +73,7 @@ async def search_by_clap(
bpm_min: float | None = None,
bpm_max: float | None = None,
sample_type: str | None = None,
is_loop: bool | None = None,
limit: int = 20,
) -> Sequence[Sample]:
"""Search samples by CLAP embedding using pgvector cosine distance."""
Expand All @@ -87,6 +89,8 @@ async def search_by_clap(
stmt = stmt.where(cls.bpm <= bpm_max)
if sample_type is not None:
stmt = stmt.where(cls.sample_type == sample_type)
if is_loop is not None:
stmt = stmt.where(cls.is_loop == is_loop)

stmt = stmt.order_by(distance).limit(limit)

Expand Down
1 change: 1 addition & 0 deletions backend/src/samplespace/routers/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def search_samples(
bpm_min=body.bpm_min,
bpm_max=body.bpm_max,
sample_type=body.sample_type,
is_loop=body.is_loop,
limit=body.limit,
)

Expand Down
2 changes: 2 additions & 0 deletions backend/src/samplespace/schemas/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class SampleSchema(BaseSchema):
bpm: float | None = None
duration: float | None = None
sample_type: str | None = Field(None, description="Category of the sample (e.g., kick, snare, pad, lead)")
is_loop: bool = False
created_at: datetime


Expand All @@ -34,4 +35,5 @@ class SampleSearchRequest(BaseModel):
bpm_min: float | None = Field(None, description="Minimum BPM filter")
bpm_max: float | None = Field(None, description="Maximum BPM filter")
sample_type: str | None = Field(None, description="Filter by sample type")
is_loop: bool | None = Field(None, description="Filter by loop (True) or one-shot (False)")
limit: int = Field(20, ge=1, le=100, description="Maximum number of results")
22 changes: 22 additions & 0 deletions backend/src/samplespace/scripts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Shared utilities for CLI scripts."""

from pathlib import Path

from samplespace.core.config import get_settings


def find_audio_file(filename: str, sample_type: str | None) -> Path | None:
"""Locate an audio file by checking type subdirectory, root, then rglob."""
samples_dir = Path(get_settings().SAMPLES_DIR)

if sample_type:
candidate = samples_dir / sample_type / filename
if candidate.exists():
return candidate

candidate = samples_dir / filename
if candidate.exists():
return candidate

matches = list(samples_dir.rglob(filename))
return matches[0] if matches else None
76 changes: 29 additions & 47 deletions backend/src/samplespace/scripts/embed_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,87 +9,69 @@
"""

import argparse
import asyncio
import logging
from pathlib import Path

from sqlalchemy import create_engine, select, update
from sqlalchemy.orm import Session
from sqlalchemy import select, update

from samplespace.core.config import get_settings
from samplespace.ml.predict import load_model, predict
from samplespace.dependencies.db import get_async_sqlalchemy_session
from samplespace.models.sample import Sample
from samplespace.scripts import find_audio_file

logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)


def main() -> None:
parser = argparse.ArgumentParser(description="Generate CNN embeddings for all samples")
parser.add_argument(
"--force",
action="store_true",
help="Re-embed samples that already have embeddings",
)
args = parser.parse_args()

config = get_settings()
sync_url = (
f"postgresql+psycopg://{config.POSTGRES_USER}:{config.POSTGRES_PASSWORD}"
f"@{config.POSTGRES_HOST}:{config.POSTGRES_PORT}/{config.POSTGRES_DB}"
)
engine = create_engine(sync_url)
SAMPLES_DIR = Path(config.SAMPLES_DIR)
async def generate_embeddings(*, force: bool = False) -> None:
from samplespace.ml.predict import load_model, predict

model = load_model()

with Session(engine) as session:
async with get_async_sqlalchemy_session() as db:
stmt = select(Sample)
if not args.force:
if not force:
stmt = stmt.where(Sample.cnn_embedding.is_(None))

samples = session.execute(stmt).scalars().all()
result = await db.execute(stmt)
samples = result.scalars().all()
logger.info(f"Found {len(samples)} samples to embed")

embedded = 0
for sample in samples:
# Find audio file
file_path = None
if sample.sample_type:
candidate = SAMPLES_DIR / sample.sample_type / sample.filename
if candidate.exists():
file_path = candidate

if file_path is None:
candidate = SAMPLES_DIR / sample.filename
if candidate.exists():
file_path = candidate

if file_path is None:
matches = list(SAMPLES_DIR.rglob(sample.filename))
if matches:
file_path = matches[0]

file_path = find_audio_file(sample.filename, sample.sample_type)
if file_path is None:
logger.warning(f" Audio file not found: {sample.filename}")
continue

try:
result = predict(str(file_path), model)
session.execute(update(Sample).where(Sample.id == sample.id).values(cnn_embedding=result.embedding))
pred_result = predict(str(file_path), model)
await db.execute(
update(Sample).where(Sample.id == sample.id).values(cnn_embedding=pred_result.embedding)
)
embedded += 1
logger.info(
f" Embedded: {sample.filename} "
f"(predicted: {result.predicted_type}, "
f"confidence: {result.confidence:.2%})"
f"(predicted: {pred_result.predicted_type}, "
f"confidence: {pred_result.confidence:.2%})"
)
except Exception:
logger.warning(f" Failed to embed: {sample.filename}", exc_info=True)
continue

session.commit()

logger.info(f"Embedded {embedded}/{len(samples)} samples")


def main() -> None:
parser = argparse.ArgumentParser(description="Generate CNN embeddings for all samples")
parser.add_argument(
"--force",
action="store_true",
help="Re-embed samples that already have embeddings",
)
args = parser.parse_args()

asyncio.run(generate_embeddings(force=args.force))


if __name__ == "__main__":
main()
Loading
Loading