""" role_matcher.py - Phase 4: Role validation and sample matching utilities This module provides enhanced role matching for sample selection with: - Role validation based on audio characteristics - Aggressive sample detection and filtering - Logging of matching decisions - Integration with reference_listener and sample_selector """ import logging from typing import Any, Dict, List, Optional logger = logging.getLogger("RoleMatcher") # ============================================================================ # CONSTANTS # ============================================================================ # Valid roles for sample matching with their expected characteristics VALID_ROLES = { # One-shot drums "kick": {"max_duration": 2.0, "min_onset": 0.3, "is_loop": False, "bus": "drums"}, "snare": {"max_duration": 2.0, "min_onset": 0.25, "is_loop": False, "bus": "drums"}, "hat": {"max_duration": 1.5, "min_onset": 0.2, "is_loop": False, "bus": "drums"}, "clap": {"max_duration": 2.0, "min_onset": 0.25, "is_loop": False, "bus": "drums"}, "ride": {"max_duration": 3.0, "min_onset": 0.15, "is_loop": False, "bus": "drums"}, "perc": {"max_duration": 2.5, "min_onset": 0.2, "is_loop": False, "bus": "drums"}, # Loops "bass_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "bass"}, "perc_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "drums"}, "top_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "drums"}, "synth_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "music"}, "vocal_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "vocal"}, # FX "crash_fx": {"max_duration": 4.0, "is_loop": False, "bus": "fx"}, "fill_fx": {"max_duration": 8.0, "is_loop": False, "bus": "fx"}, "snare_roll": {"max_duration": 8.0, "is_loop": False, "bus": "drums"}, "atmos_fx": {"min_duration": 4.0, "is_loop": True, "bus": "fx"}, "vocal_shot": {"max_duration": 3.0, "is_loop": False, "bus": "vocal"}, # Resample layers "resample_reverse": {"is_loop": False, "bus": "fx"}, "resample_riser": {"is_loop": False, "bus": "fx"}, "resample_downlifter": {"is_loop": False, "bus": "fx"}, "resample_stutter": {"is_loop": False, "bus": "vocal"}, } # Keywords that indicate aggressive/hard samples that may be misclassified AGGRESSIVE_KEYWORDS = { # Very aggressive kick patterns "hard", "distorted", "industrial", "slam", "punch", "brutal", # Potentially misclassified "subdrop", "impact", "explosion", "destroy", } # Keywords that are acceptable for aggressive genres GENRE_APPROPRIATE_AGGRESSIVE = { "industrial-techno", "hard-techno", "raw-techno", "psytrance", "dark-techno" } # Role aliases for flexible matching ROLE_ALIASES = { "kick": ["kick", "bd", "bassdrum", "bass_drum"], "snare": ["snare", "sd", "snr"], "clap": ["clap", "cp", "handclap"], "hat": ["hat", "hihat", "hi_hat", "hhat", "closed_hat", "hat_closed"], "hat_open": ["open_hat", "hat_open", "ohat", "openhihat"], "ride": ["ride", "rd", "cymbal"], "perc": ["perc", "percussion", "percs"], "bass_loop": ["bass_loop", "bassloop", "bass loop", "sub_bass"], "perc_loop": ["perc_loop", "percloop", "percussion loop", "perc loop"], "top_loop": ["top_loop", "toploop", "top loop", "full_drum"], "synth_loop": ["synth_loop", "synthloop", "synth loop", "chord_loop", "stab"], "vocal_loop": ["vocal_loop", "vocalloop", "vocal loop", "vox_loop", "vox"], "crash_fx": ["crash", "crash_fx", "crashfx", "impact_fx"], "fill_fx": ["fill", "fill_fx", "fillfx", "tom_fill", "transition"], "snare_roll": ["snare_roll", "snareroll", "snare roll", "snr_roll"], "atmos_fx": ["atmos", "atmos_fx", "atmosfx", "drone", "pad_fx"], "vocal_shot": ["vocal_shot", "vocalshot", "vocal shot", "vocal_one_shot"], } # Minimum score thresholds for role matching ROLE_SCORE_THRESHOLDS = { "kick": 0.35, "snare": 0.32, "hat": 0.30, "clap": 0.32, "bass_loop": 0.38, "perc_loop": 0.35, "top_loop": 0.35, "synth_loop": 0.36, "vocal_loop": 0.38, "crash_fx": 0.30, "fill_fx": 0.32, "snare_roll": 0.30, "atmos_fx": 0.32, "vocal_shot": 0.34, } # ============================================================================ # VALIDATION FUNCTIONS # ============================================================================ def validate_role_for_sample( role: str, sample_data: Dict[str, Any], genre: Optional[str] = None, ) -> Dict[str, Any]: """ Validates if a sample is appropriate for a given role. Args: role: The role to validate for (e.g., 'kick', 'bass_loop') sample_data: Sample metadata with keys like 'duration', 'onset_mean', 'file_name', 'rms_mean' genre: Optional genre for context-aware aggressive sample handling Returns: Dict with keys: - 'valid' (bool): Whether the sample passes validation - 'score' (float): Raw validation score (0.0-1.0) - 'warnings' (list): List of warning messages - 'adjusted_score' (float): Score after penalties """ if role not in VALID_ROLES: return {"valid": True, "score": 0.5, "warnings": [f"Unknown role: {role}"], "adjusted_score": 0.5} role_config = VALID_ROLES[role] warnings: List[str] = [] score = 1.0 duration = float(sample_data.get("duration", 0.0) or 0.0) onset = float(sample_data.get("onset_mean", 0.0) or 0.0) file_name = str(sample_data.get("file_name", "") or "").lower() rms = float(sample_data.get("rms_mean", 0.0) or 0.0) # Duration validation if role_config.get("is_loop"): min_dur = role_config.get("min_duration", 2.0) max_dur = role_config.get("max_duration", 16.0) if duration < min_dur: warnings.append(f"Duration {duration:.1f}s too short for loop role (min {min_dur}s)") score *= 0.7 elif max_dur and duration > max_dur: warnings.append(f"Duration {duration:.1f}s too long for role (max {max_dur}s)") score *= 0.85 else: max_dur = role_config.get("max_duration", 3.0) if duration > max_dur: warnings.append(f"Duration {duration:.1f}s too long for one-shot role (max {max_dur}s)") score *= 0.75 if "loop" in file_name and role in ["kick", "snare", "hat", "clap"]: warnings.append("One-shot role has 'loop' in filename") score *= 0.65 # Onset validation for percussive elements min_onset = role_config.get("min_onset", 0.0) if min_onset > 0 and onset < min_onset: warnings.append(f"Onset {onset:.2f} below minimum {min_onset:.2f}") score *= 0.85 # Check for aggressive samples that might be misclassified aggressive_penalty = 1.0 is_aggressive_genre = genre and genre.lower() in GENRE_APPROPRIATE_AGGRESSIVE for keyword in AGGRESSIVE_KEYWORDS: if keyword in file_name: if not is_aggressive_genre: aggressive_penalty *= 0.88 warnings.append(f"Aggressive keyword '{keyword}' found for non-aggressive genre") score *= aggressive_penalty # RMS validation for certain roles if role in ["kick", "snare", "clap"] and rms > 0.4: warnings.append(f"High RMS {rms:.3f} for one-shot role") score *= 0.9 adjusted_score = max(0.1, min(1.0, score)) return { "valid": score >= 0.4, "score": score, "warnings": warnings, "adjusted_score": adjusted_score, } def resolve_role_from_alias(alias: str) -> Optional[str]: """ Resolves a role name from various aliases. Args: alias: A potential role alias (e.g., 'bd', 'hihat', 'bass loop') Returns: The canonical role name or None if not found """ alias_lower = alias.lower().strip().replace("-", "_").replace(" ", "_") # Direct match if alias_lower in VALID_ROLES: return alias_lower # Check aliases for role, aliases in ROLE_ALIASES.items(): normalized_aliases = [a.lower().replace("-", "_").replace(" ", "_") for a in aliases] if alias_lower in normalized_aliases: return role return None def get_bus_for_role(role: str) -> str: """ Gets the appropriate bus for a role. Args: role: The role name Returns: Bus name ('drums', 'bass', 'music', 'vocal', or 'fx') """ if role in VALID_ROLES: return VALID_ROLES[role].get("bus", "music") return "music" # ============================================================================ # LOGGING FUNCTIONS # ============================================================================ def log_matching_decision( role: str, selected_sample: Optional[Dict[str, Any]], candidates_count: int, final_score: float, validation_result: Optional[Dict[str, Any]] = None, ) -> None: """ Logs detailed matching decisions for debugging and analysis. Args: role: The role being matched selected_sample: The selected sample dict or None candidates_count: Number of candidates considered final_score: The final matching score validation_result: Optional validation result dict """ if not selected_sample: logger.info( f"[MATCH] Role '{role}': No sample selected (0/{candidates_count} candidates)" ) return sample_name = selected_sample.get("file_name", "unknown") sample_tempo = selected_sample.get("tempo", 0.0) sample_key = selected_sample.get("key", "N/A") sample_dur = selected_sample.get("duration", 0.0) log_parts = [ f"[MATCH] Role '{role}':", f"Sample: {sample_name}", f"Score: {final_score:.3f}", f"Tempo: {sample_tempo:.1f}", f"Key: {sample_key}", f"Duration: {sample_dur:.1f}s", f"Candidates: {candidates_count}", ] if validation_result: warnings = validation_result.get("warnings", []) if warnings: log_parts.append(f"Warnings: {', '.join(warnings)}") log_parts.append(f"Validated: {validation_result.get('valid', True)}") logger.info(" | ".join(log_parts)) # ============================================================================ # ENHANCEMENT FUNCTIONS # ============================================================================ def enhance_sample_matching( matches: Dict[str, List[Dict[str, Any]]], reference: Dict[str, Any], genre: Optional[str] = None, ) -> Dict[str, List[Dict[str, Any]]]: """ Enhances sample matching results with validation and filtering. This function takes raw matches from reference_listener and applies: 1. Role validation based on audio characteristics 2. Aggressive sample filtering 3. Score adjustment based on validation results Args: matches: Raw matches from reference_listener (role -> list of sample dicts) reference: Reference track analysis data genre: Target genre for context-aware filtering Returns: Enhanced matches with validation scores and filtering applied """ enhanced: Dict[str, List[Dict[str, Any]]] = {} for role, candidates in matches.items(): if not candidates: enhanced[role] = [] continue threshold = ROLE_SCORE_THRESHOLDS.get(role, 0.30) enhanced_candidates: List[Dict[str, Any]] = [] for candidate in candidates: # Create a copy to avoid modifying the original enhanced_candidate = dict(candidate) # Validate the sample for this role validation = validate_role_for_sample(role, candidate, genre) enhanced_candidate["validation"] = validation # Apply validation penalty to the score original_score = float(candidate.get("score", 0.0)) adjusted_score = original_score * validation["adjusted_score"] enhanced_candidate["adjusted_score"] = round(adjusted_score, 6) # Filter out samples below threshold if adjusted_score >= threshold: enhanced_candidates.append(enhanced_candidate) else: logger.debug( f"[FILTER] Role '{role}': Filtered out '{candidate.get('file_name', 'unknown')}' " f"(score {adjusted_score:.3f} < threshold {threshold})" ) # Re-sort by adjusted score enhanced_candidates.sort(key=lambda x: float(x.get("adjusted_score", 0.0)), reverse=True) enhanced[role] = enhanced_candidates # Log summary filtered_count = len(candidates) - len(enhanced_candidates) if filtered_count > 0: logger.info( f"[ENHANCE] Role '{role}': {len(enhanced_candidates)}/{len(candidates)} candidates passed validation " f"({filtered_count} filtered out)" ) return enhanced def filter_aggressive_samples( candidates: List[Dict[str, Any]], genre: Optional[str] = None, strict: bool = False, ) -> List[Dict[str, Any]]: """ Filters out samples with aggressive keywords unless appropriate for the genre. Args: candidates: List of sample candidate dicts genre: Target genre strict: If True, apply stricter filtering Returns: Filtered list of candidates """ is_aggressive_genre = genre and genre.lower() in GENRE_APPROPRIATE_AGGRESSIVE if is_aggressive_genre: # For aggressive genres, don't filter aggressive samples return candidates filtered = [] for candidate in candidates: file_name = str(candidate.get("file_name", "") or "").lower() aggressive_count = sum(1 for kw in AGGRESSIVE_KEYWORDS if kw in file_name) if strict and aggressive_count > 0: continue # Apply penalty instead of filtering completely if aggressive_count > 0: penalty = 0.85 ** aggressive_count candidate_copy = dict(candidate) original_score = float(candidate.get("score", 0.0)) candidate_copy["score"] = original_score * penalty filtered.append(candidate_copy) else: filtered.append(candidate) return filtered # ============================================================================ # INTEGRATION HELPERS # ============================================================================ def create_enhanced_match_report( role: str, selected_sample: Optional[Dict[str, Any]], all_candidates: List[Dict[str, Any]], validation_result: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """ Creates a detailed report for a matching decision. Args: role: The role being matched selected_sample: The selected sample all_candidates: All candidates that were considered validation_result: Validation result for the selected sample Returns: A dict with detailed matching report """ report = { "role": role, "selected": selected_sample is not None, "candidates_count": len(all_candidates), "threshold": ROLE_SCORE_THRESHOLDS.get(role, 0.30), } if selected_sample: report["selected_sample"] = { "name": selected_sample.get("file_name"), "path": selected_sample.get("path"), "score": selected_sample.get("score"), "adjusted_score": selected_sample.get("adjusted_score"), "tempo": selected_sample.get("tempo"), "key": selected_sample.get("key"), "duration": selected_sample.get("duration"), } if validation_result: report["validation"] = { "valid": validation_result.get("valid"), "score": validation_result.get("score"), "warnings": validation_result.get("warnings", []), } return report def get_role_info(role: str) -> Dict[str, Any]: """ Gets comprehensive information about a role. Args: role: The role name Returns: Dict with role information including valid samples count, thresholds, etc. """ if role not in VALID_ROLES: return {"error": f"Unknown role: {role}"} config = VALID_ROLES[role] aliases = ROLE_ALIASES.get(role, []) return { "role": role, "config": config, "aliases": aliases, "threshold": ROLE_SCORE_THRESHOLDS.get(role, 0.30), "bus": config.get("bus", "music"), "is_loop": config.get("is_loop", False), }