"""Sample Selector — queries the forensic sample index by musical criteria. Loads data/sample_index.json and provides scored, ranked queries: - Role matching (exact) - Key compatibility (exact, relative major/minor, dominant/subdominant) - BPM tolerance (±5%, half/double time) - Character similarity (grouped characters) - Tonal/atonal filtering Usage: selector = SampleSelector() results = selector.select(role="kick", bpm=95, limit=5) results = selector.select(role="bass", key="Am", bpm=92, character="deep") """ from __future__ import annotations import json import os import random from pathlib import Path from typing import Optional from dataclasses import dataclass, field # --------------------------------------------------------------------------- # Key Compatibility # --------------------------------------------------------------------------- CIRCLE_OF_FIFTHS = ["C", "G", "D", "A", "E", "B", "F#", "C#", "G#", "D#", "A#", "F"] # Relative major/minor pairs (each minor → its relative major) RELATIVE_MAJOR = { "Am": "C", "Em": "G", "Bm": "D", "F#m": "A", "C#m": "E", "G#m": "B", "D#m": "F#", "A#m": "C#", "Fm": "G#", "Cm": "Eb", "Gm": "Bb", "Dm": "F", # Enharmonic equivalents "Bbm": "Db", "Ebm": "Gb", "Abm": "B", "Bbm": "Cb", } # Build reverse: major → relative minor RELATIVE_MINOR = {v: k for k, v in RELATIVE_MAJOR.items()} # Dominant (V) and subdominant (IV) relationships DOMINANT = {"C": "G", "G": "D", "D": "A", "A": "E", "E": "B", "B": "F#", "F#": "C#", "C#": "G#", "G#": "D#", "D#": "A#", "A#": "F", "F": "C"} SUBDOMINANT = {v: k for k, v in DOMINANT.items()} # Character similarity groups CHARACTER_GROUPS = [ {"warm", "soft", "lush"}, {"boomy", "deep", "dark"}, {"sharp", "crisp", "bright"}, {"aggressive", "tight"}, {"ethereal", "neutral"}, {"impact", "short"}, {"hollow", "full"}, ] # All roles the classifier produces KNOWN_ROLES = { "kick", "snare", "hihat", "bass", "lead", "pad", "pluck", "vocal", "arp", "guitar", "keys", "synth", "brass", "perc", "drumloop", "fx", "fill", "oneshot", } # Roles that are typically atonal (key doesn't matter) ATONAL_ROLES = {"kick", "snare", "hihat", "perc", "fx", "fill", "oneshot"} def _normalize_key(key: str) -> str: """Normalize key names: Eb→D#, Bb→A#, Db→C#, Gb→F#, Ab→G#.""" enharmonics = {"Eb": "D#", "Bb": "A#", "Db": "C#", "Gb": "F#", "Ab": "G#", "Cb": "B"} return enharmonics.get(key, key) def _key_compatibility(query_key: str, sample_key: str) -> float: """Score how compatible a sample's key is with the query key. Returns: 1.0 = exact match 0.9 = same root, different mode (C ↔ Cm) 0.8 = relative major/minor (Am ↔ C) 0.7 = dominant/subdominant (C ↔ G or C ↔ F) 0.5 = compatible (nearby in circle of fifths) 0.0 = atonal or no match """ if query_key == "X" or sample_key == "X": return 0.0 # Atonal, no key compatibility q = _normalize_key(query_key) s = _normalize_key(sample_key) # Exact match if q == s: return 1.0 # Separate root and mode q_root = q.rstrip("m") q_minor = q.endswith("m") s_root = s.rstrip("m") s_minor = s.endswith("m") # Same root, different mode (C ↔ Cm) if q_root == s_root: return 0.9 # Relative major/minor (Am ↔ C) if q_minor and not s_minor: rel = RELATIVE_MAJOR.get(q, "") if s_root == _normalize_key(rel): return 0.8 if not q_minor and s_minor: rel = RELATIVE_MINOR.get(q, "") if s_root == _normalize_key(rel.rstrip("m")): return 0.8 # Dominant/subdominant q_root_norm = _normalize_key(q_root) s_root_norm = _normalize_key(s_root) if DOMINANT.get(q_root_norm) == s_root_norm or SUBDOMINANT.get(q_root_norm) == s_root_norm: return 0.7 # Circle of fifths proximity try: q_idx = CIRCLE_OF_FIFTHS.index(q_root_norm) s_idx = CIRCLE_OF_FIFTHS.index(s_root_norm) distance = min(abs(q_idx - s_idx), 12 - abs(q_idx - s_idx)) if distance <= 2: return 0.5 except ValueError: pass return 0.3 def _bpm_compatibility(query_bpm: float, sample_bpm: float) -> float: """Score BPM compatibility. Handles half/double time.""" if query_bpm <= 0 or sample_bpm <= 0: return 0.5 # Unknown BPM, neutral score ratio = sample_bpm / query_bpm tolerance = 0.05 # ±5% # Direct match if abs(ratio - 1.0) <= tolerance: return 1.0 # Half time if abs(ratio - 0.5) <= tolerance: return 0.8 # Double time if abs(ratio - 2.0) <= tolerance: return 0.8 # Near match (±10%) if abs(ratio - 1.0) <= 0.10: return 0.6 return 0.3 def _character_compatibility(query_char: Optional[str], sample_char: str) -> float: """Score character compatibility using similarity groups.""" if not query_char: return 0.5 # No preference if query_char == sample_char: return 1.0 # Check if in same group for group in CHARACTER_GROUPS: if query_char in group and sample_char in group: return 0.7 return 0.3 @dataclass class SampleMatch: """A scored sample match from the selector.""" score: float sample: dict score_breakdown: dict = field(default_factory=dict) class SampleSelector: """Query the forensic sample index with musical criteria.""" def __init__(self, index_path: Optional[str] = None): if index_path is None: project = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) index_path = os.path.join(project, "data", "sample_index.json") self.index_path = index_path self._samples: list[dict] = [] self._by_role: dict[str, list[dict]] = {} self._loaded = False def _load(self): """Lazy-load the index.""" if self._loaded: return with open(self.index_path, "r", encoding="utf-8") as f: data = json.load(f) self._samples = [s for s in data.get("samples", []) if "error" not in s] # Index by role for fast lookup self._by_role = {} for s in self._samples: role = s.get("role", "unknown") if role not in self._by_role: self._by_role[role] = [] self._by_role[role].append(s) self._loaded = True def select( self, role: str, key: Optional[str] = None, bpm: Optional[float] = None, character: Optional[str] = None, is_tonal: Optional[bool] = None, limit: int = 10, path_prefix: Optional[str] = None, ) -> list[SampleMatch]: """Select samples matching criteria, ranked by compatibility score. Args: role: Required. Production role (kick, bass, lead, etc.) key: Musical key for compatibility (e.g. "Am", "C") bpm: Target BPM for tempo matching character: Timbre character preference (e.g. "warm", "boomy") is_tonal: Filter by tonal/atonal status limit: Maximum results to return path_prefix: Filter by file path prefix Returns: List of SampleMatch objects sorted by score (descending) """ self._load() if role not in KNOWN_ROLES: # Try fuzzy match role_lower = role.lower() for known in KNOWN_ROLES: if known in role_lower: role = known break candidates = self._by_role.get(role, []) if not candidates: return [] # Score each candidate matches: list[SampleMatch] = [] for s in candidates: # Path prefix filter if path_prefix: if path_prefix.lower() not in s.get("original_path", "").lower(): continue # Tonal filter if is_tonal is not None: sample_tonal = s.get("musical", {}).get("is_tonal", False) if sample_tonal != is_tonal: continue breakdown = {} total = 0.0 # Role score (always 1.0 since we filtered by role) breakdown["role"] = 1.0 total += 1.0 # Key compatibility if key and role not in ATONAL_ROLES: sample_key = s.get("musical", {}).get("key", "X") kc = _key_compatibility(key, sample_key) breakdown["key"] = kc total += kc * 2.0 # Weight key heavily else: breakdown["key"] = 0.5 # BPM compatibility if bpm: sample_bpm = s.get("perceptual", {}).get("tempo", 0) bc = _bpm_compatibility(bpm, sample_bpm) breakdown["bpm"] = bc total += bc * 1.5 else: breakdown["bpm"] = 0.5 # Character compatibility cc = _character_compatibility(character, s.get("character", "")) breakdown["character"] = cc total += cc * 0.5 # Duration preference: shorter samples get slight bonus for flexibility dur = s.get("signal", {}).get("duration", 0) if dur > 0 and dur < 5.0: total += 0.1 # Short bonus breakdown["duration"] = dur matches.append(SampleMatch( score=round(total, 4), sample=s, score_breakdown=breakdown, )) # Sort by score descending matches.sort(key=lambda m: m.score, reverse=True) return matches[:limit] def select_one( self, role: str, seed: Optional[int] = None, **kwargs, ) -> Optional[dict]: """Select one sample using weighted random from top-5 candidates. The top-5 candidates are selected with weights [5, 4, 3, 2, 1], favoring higher-scored results while allowing variation across calls. Pass seed for reproducible output. """ if seed is not None: random.seed(seed) results = self.select(role=role, limit=5, **kwargs) if not results: return None candidates = results[:5] weights = [5, 4, 3, 2, 1][: len(candidates)] selected = random.choices(candidates, weights=weights, k=1)[0] return selected.sample def get_roles(self) -> list[str]: """Get all available roles and their counts.""" self._load() return sorted(self._by_role.keys()) def get_stats(self) -> dict[str, int]: """Get count per role.""" self._load() return {role: len(samples) for role, samples in sorted(self._by_role.items())} def random_sample(self, role: str, **kwargs) -> Optional[dict]: """Select a random sample from the top candidates for variation.""" import random results = self.select(role=role, limit=5, **kwargs) if not results: return None return random.choice(results).sample def select_diverse( self, role: str, n: int = 1, exclude: Optional[list[str]] = None, **kwargs, ) -> list[dict]: """Return n different samples for role, excluding known IDs. Uses randomized scoring to ensure diversity across calls. Returns fewer than n if not enough candidates available after exclusion. Args: role: Required. Production role (kick, bass, lead, etc.) n: Number of different samples to return exclude: List of sample IDs (file_hash) to exclude from results **kwargs: Passed to select() (key, bpm, character, etc.) Returns: List of sample dicts (length <= n, never includes excluded IDs) """ import random exclude = exclude or [] results: list[dict] = [] # Keep trying until we have n samples or run out of candidates remaining = self.select(role=role, limit=100, **kwargs) # Get enough candidates for match in remaining: sample = match.sample sample_id = sample.get("file_hash", "") if sample_id in exclude: continue # Add small random noise to score for diversity # This way repeated calls with same params can return different results scored_sample = (match.score + random.uniform(-0.05, 0.05), sample) results.append(scored_sample) if len(results) >= n: break # Sort by randomized score (descending) and extract samples results.sort(key=lambda x: x[0], reverse=True) return [sample for _, sample in results[:n]]