Files
ableton-mcp-ai/AbletonMCP_AI/MCP_Server/role_matcher.py
renato97 6ec8663954 Initial commit: AbletonMCP-AI complete system
- MCP Server with audio fallback, sample management
- Song generator with bus routing
- Reference listener and audio resampler
- Vector-based sample search
- Master chain with limiter and calibration
- Fix: Audio fallback now works without M4L
- Fix: Full song detection in sample loader

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-28 22:53:10 -03:00

469 lines
16 KiB
Python

"""
role_matcher.py - Phase 4: Role validation and sample matching utilities
This module provides enhanced role matching for sample selection with:
- Role validation based on audio characteristics
- Aggressive sample detection and filtering
- Logging of matching decisions
- Integration with reference_listener and sample_selector
"""
import logging
from typing import Any, Dict, List, Optional
logger = logging.getLogger("RoleMatcher")
# ============================================================================
# CONSTANTS
# ============================================================================
# Valid roles for sample matching with their expected characteristics
VALID_ROLES = {
# One-shot drums
"kick": {"max_duration": 2.0, "min_onset": 0.3, "is_loop": False, "bus": "drums"},
"snare": {"max_duration": 2.0, "min_onset": 0.25, "is_loop": False, "bus": "drums"},
"hat": {"max_duration": 1.5, "min_onset": 0.2, "is_loop": False, "bus": "drums"},
"clap": {"max_duration": 2.0, "min_onset": 0.25, "is_loop": False, "bus": "drums"},
"ride": {"max_duration": 3.0, "min_onset": 0.15, "is_loop": False, "bus": "drums"},
"perc": {"max_duration": 2.5, "min_onset": 0.2, "is_loop": False, "bus": "drums"},
# Loops
"bass_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "bass"},
"perc_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "drums"},
"top_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "drums"},
"synth_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "music"},
"vocal_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "vocal"},
# FX
"crash_fx": {"max_duration": 4.0, "is_loop": False, "bus": "fx"},
"fill_fx": {"max_duration": 8.0, "is_loop": False, "bus": "fx"},
"snare_roll": {"max_duration": 8.0, "is_loop": False, "bus": "drums"},
"atmos_fx": {"min_duration": 4.0, "is_loop": True, "bus": "fx"},
"vocal_shot": {"max_duration": 3.0, "is_loop": False, "bus": "vocal"},
# Resample layers
"resample_reverse": {"is_loop": False, "bus": "fx"},
"resample_riser": {"is_loop": False, "bus": "fx"},
"resample_downlifter": {"is_loop": False, "bus": "fx"},
"resample_stutter": {"is_loop": False, "bus": "vocal"},
}
# Keywords that indicate aggressive/hard samples that may be misclassified
AGGRESSIVE_KEYWORDS = {
# Very aggressive kick patterns
"hard", "distorted", "industrial", "slam", "punch", "brutal",
# Potentially misclassified
"subdrop", "impact", "explosion", "destroy",
}
# Keywords that are acceptable for aggressive genres
GENRE_APPROPRIATE_AGGRESSIVE = {
"industrial-techno", "hard-techno", "raw-techno", "psytrance", "dark-techno"
}
# Role aliases for flexible matching
ROLE_ALIASES = {
"kick": ["kick", "bd", "bassdrum", "bass_drum"],
"snare": ["snare", "sd", "snr"],
"clap": ["clap", "cp", "handclap"],
"hat": ["hat", "hihat", "hi_hat", "hhat", "closed_hat", "hat_closed"],
"hat_open": ["open_hat", "hat_open", "ohat", "openhihat"],
"ride": ["ride", "rd", "cymbal"],
"perc": ["perc", "percussion", "percs"],
"bass_loop": ["bass_loop", "bassloop", "bass loop", "sub_bass"],
"perc_loop": ["perc_loop", "percloop", "percussion loop", "perc loop"],
"top_loop": ["top_loop", "toploop", "top loop", "full_drum"],
"synth_loop": ["synth_loop", "synthloop", "synth loop", "chord_loop", "stab"],
"vocal_loop": ["vocal_loop", "vocalloop", "vocal loop", "vox_loop", "vox"],
"crash_fx": ["crash", "crash_fx", "crashfx", "impact_fx"],
"fill_fx": ["fill", "fill_fx", "fillfx", "tom_fill", "transition"],
"snare_roll": ["snare_roll", "snareroll", "snare roll", "snr_roll"],
"atmos_fx": ["atmos", "atmos_fx", "atmosfx", "drone", "pad_fx"],
"vocal_shot": ["vocal_shot", "vocalshot", "vocal shot", "vocal_one_shot"],
}
# Minimum score thresholds for role matching
ROLE_SCORE_THRESHOLDS = {
"kick": 0.35,
"snare": 0.32,
"hat": 0.30,
"clap": 0.32,
"bass_loop": 0.38,
"perc_loop": 0.35,
"top_loop": 0.35,
"synth_loop": 0.36,
"vocal_loop": 0.38,
"crash_fx": 0.30,
"fill_fx": 0.32,
"snare_roll": 0.30,
"atmos_fx": 0.32,
"vocal_shot": 0.34,
}
# ============================================================================
# VALIDATION FUNCTIONS
# ============================================================================
def validate_role_for_sample(
role: str,
sample_data: Dict[str, Any],
genre: Optional[str] = None,
) -> Dict[str, Any]:
"""
Validates if a sample is appropriate for a given role.
Args:
role: The role to validate for (e.g., 'kick', 'bass_loop')
sample_data: Sample metadata with keys like 'duration', 'onset_mean', 'file_name', 'rms_mean'
genre: Optional genre for context-aware aggressive sample handling
Returns:
Dict with keys:
- 'valid' (bool): Whether the sample passes validation
- 'score' (float): Raw validation score (0.0-1.0)
- 'warnings' (list): List of warning messages
- 'adjusted_score' (float): Score after penalties
"""
if role not in VALID_ROLES:
return {"valid": True, "score": 0.5, "warnings": [f"Unknown role: {role}"], "adjusted_score": 0.5}
role_config = VALID_ROLES[role]
warnings: List[str] = []
score = 1.0
duration = float(sample_data.get("duration", 0.0) or 0.0)
onset = float(sample_data.get("onset_mean", 0.0) or 0.0)
file_name = str(sample_data.get("file_name", "") or "").lower()
rms = float(sample_data.get("rms_mean", 0.0) or 0.0)
# Duration validation
if role_config.get("is_loop"):
min_dur = role_config.get("min_duration", 2.0)
max_dur = role_config.get("max_duration", 16.0)
if duration < min_dur:
warnings.append(f"Duration {duration:.1f}s too short for loop role (min {min_dur}s)")
score *= 0.7
elif max_dur and duration > max_dur:
warnings.append(f"Duration {duration:.1f}s too long for role (max {max_dur}s)")
score *= 0.85
else:
max_dur = role_config.get("max_duration", 3.0)
if duration > max_dur:
warnings.append(f"Duration {duration:.1f}s too long for one-shot role (max {max_dur}s)")
score *= 0.75
if "loop" in file_name and role in ["kick", "snare", "hat", "clap"]:
warnings.append("One-shot role has 'loop' in filename")
score *= 0.65
# Onset validation for percussive elements
min_onset = role_config.get("min_onset", 0.0)
if min_onset > 0 and onset < min_onset:
warnings.append(f"Onset {onset:.2f} below minimum {min_onset:.2f}")
score *= 0.85
# Check for aggressive samples that might be misclassified
aggressive_penalty = 1.0
is_aggressive_genre = genre and genre.lower() in GENRE_APPROPRIATE_AGGRESSIVE
for keyword in AGGRESSIVE_KEYWORDS:
if keyword in file_name:
if not is_aggressive_genre:
aggressive_penalty *= 0.88
warnings.append(f"Aggressive keyword '{keyword}' found for non-aggressive genre")
score *= aggressive_penalty
# RMS validation for certain roles
if role in ["kick", "snare", "clap"] and rms > 0.4:
warnings.append(f"High RMS {rms:.3f} for one-shot role")
score *= 0.9
adjusted_score = max(0.1, min(1.0, score))
return {
"valid": score >= 0.4,
"score": score,
"warnings": warnings,
"adjusted_score": adjusted_score,
}
def resolve_role_from_alias(alias: str) -> Optional[str]:
"""
Resolves a role name from various aliases.
Args:
alias: A potential role alias (e.g., 'bd', 'hihat', 'bass loop')
Returns:
The canonical role name or None if not found
"""
alias_lower = alias.lower().strip().replace("-", "_").replace(" ", "_")
# Direct match
if alias_lower in VALID_ROLES:
return alias_lower
# Check aliases
for role, aliases in ROLE_ALIASES.items():
normalized_aliases = [a.lower().replace("-", "_").replace(" ", "_") for a in aliases]
if alias_lower in normalized_aliases:
return role
return None
def get_bus_for_role(role: str) -> str:
"""
Gets the appropriate bus for a role.
Args:
role: The role name
Returns:
Bus name ('drums', 'bass', 'music', 'vocal', or 'fx')
"""
if role in VALID_ROLES:
return VALID_ROLES[role].get("bus", "music")
return "music"
# ============================================================================
# LOGGING FUNCTIONS
# ============================================================================
def log_matching_decision(
role: str,
selected_sample: Optional[Dict[str, Any]],
candidates_count: int,
final_score: float,
validation_result: Optional[Dict[str, Any]] = None,
) -> None:
"""
Logs detailed matching decisions for debugging and analysis.
Args:
role: The role being matched
selected_sample: The selected sample dict or None
candidates_count: Number of candidates considered
final_score: The final matching score
validation_result: Optional validation result dict
"""
if not selected_sample:
logger.info(
f"[MATCH] Role '{role}': No sample selected (0/{candidates_count} candidates)"
)
return
sample_name = selected_sample.get("file_name", "unknown")
sample_tempo = selected_sample.get("tempo", 0.0)
sample_key = selected_sample.get("key", "N/A")
sample_dur = selected_sample.get("duration", 0.0)
log_parts = [
f"[MATCH] Role '{role}':",
f"Sample: {sample_name}",
f"Score: {final_score:.3f}",
f"Tempo: {sample_tempo:.1f}",
f"Key: {sample_key}",
f"Duration: {sample_dur:.1f}s",
f"Candidates: {candidates_count}",
]
if validation_result:
warnings = validation_result.get("warnings", [])
if warnings:
log_parts.append(f"Warnings: {', '.join(warnings)}")
log_parts.append(f"Validated: {validation_result.get('valid', True)}")
logger.info(" | ".join(log_parts))
# ============================================================================
# ENHANCEMENT FUNCTIONS
# ============================================================================
def enhance_sample_matching(
matches: Dict[str, List[Dict[str, Any]]],
reference: Dict[str, Any],
genre: Optional[str] = None,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Enhances sample matching results with validation and filtering.
This function takes raw matches from reference_listener and applies:
1. Role validation based on audio characteristics
2. Aggressive sample filtering
3. Score adjustment based on validation results
Args:
matches: Raw matches from reference_listener (role -> list of sample dicts)
reference: Reference track analysis data
genre: Target genre for context-aware filtering
Returns:
Enhanced matches with validation scores and filtering applied
"""
enhanced: Dict[str, List[Dict[str, Any]]] = {}
for role, candidates in matches.items():
if not candidates:
enhanced[role] = []
continue
threshold = ROLE_SCORE_THRESHOLDS.get(role, 0.30)
enhanced_candidates: List[Dict[str, Any]] = []
for candidate in candidates:
# Create a copy to avoid modifying the original
enhanced_candidate = dict(candidate)
# Validate the sample for this role
validation = validate_role_for_sample(role, candidate, genre)
enhanced_candidate["validation"] = validation
# Apply validation penalty to the score
original_score = float(candidate.get("score", 0.0))
adjusted_score = original_score * validation["adjusted_score"]
enhanced_candidate["adjusted_score"] = round(adjusted_score, 6)
# Filter out samples below threshold
if adjusted_score >= threshold:
enhanced_candidates.append(enhanced_candidate)
else:
logger.debug(
f"[FILTER] Role '{role}': Filtered out '{candidate.get('file_name', 'unknown')}' "
f"(score {adjusted_score:.3f} < threshold {threshold})"
)
# Re-sort by adjusted score
enhanced_candidates.sort(key=lambda x: float(x.get("adjusted_score", 0.0)), reverse=True)
enhanced[role] = enhanced_candidates
# Log summary
filtered_count = len(candidates) - len(enhanced_candidates)
if filtered_count > 0:
logger.info(
f"[ENHANCE] Role '{role}': {len(enhanced_candidates)}/{len(candidates)} candidates passed validation "
f"({filtered_count} filtered out)"
)
return enhanced
def filter_aggressive_samples(
candidates: List[Dict[str, Any]],
genre: Optional[str] = None,
strict: bool = False,
) -> List[Dict[str, Any]]:
"""
Filters out samples with aggressive keywords unless appropriate for the genre.
Args:
candidates: List of sample candidate dicts
genre: Target genre
strict: If True, apply stricter filtering
Returns:
Filtered list of candidates
"""
is_aggressive_genre = genre and genre.lower() in GENRE_APPROPRIATE_AGGRESSIVE
if is_aggressive_genre:
# For aggressive genres, don't filter aggressive samples
return candidates
filtered = []
for candidate in candidates:
file_name = str(candidate.get("file_name", "") or "").lower()
aggressive_count = sum(1 for kw in AGGRESSIVE_KEYWORDS if kw in file_name)
if strict and aggressive_count > 0:
continue
# Apply penalty instead of filtering completely
if aggressive_count > 0:
penalty = 0.85 ** aggressive_count
candidate_copy = dict(candidate)
original_score = float(candidate.get("score", 0.0))
candidate_copy["score"] = original_score * penalty
filtered.append(candidate_copy)
else:
filtered.append(candidate)
return filtered
# ============================================================================
# INTEGRATION HELPERS
# ============================================================================
def create_enhanced_match_report(
role: str,
selected_sample: Optional[Dict[str, Any]],
all_candidates: List[Dict[str, Any]],
validation_result: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Creates a detailed report for a matching decision.
Args:
role: The role being matched
selected_sample: The selected sample
all_candidates: All candidates that were considered
validation_result: Validation result for the selected sample
Returns:
A dict with detailed matching report
"""
report = {
"role": role,
"selected": selected_sample is not None,
"candidates_count": len(all_candidates),
"threshold": ROLE_SCORE_THRESHOLDS.get(role, 0.30),
}
if selected_sample:
report["selected_sample"] = {
"name": selected_sample.get("file_name"),
"path": selected_sample.get("path"),
"score": selected_sample.get("score"),
"adjusted_score": selected_sample.get("adjusted_score"),
"tempo": selected_sample.get("tempo"),
"key": selected_sample.get("key"),
"duration": selected_sample.get("duration"),
}
if validation_result:
report["validation"] = {
"valid": validation_result.get("valid"),
"score": validation_result.get("score"),
"warnings": validation_result.get("warnings", []),
}
return report
def get_role_info(role: str) -> Dict[str, Any]:
"""
Gets comprehensive information about a role.
Args:
role: The role name
Returns:
Dict with role information including valid samples count, thresholds, etc.
"""
if role not in VALID_ROLES:
return {"error": f"Unknown role: {role}"}
config = VALID_ROLES[role]
aliases = ROLE_ALIASES.get(role, [])
return {
"role": role,
"config": config,
"aliases": aliases,
"threshold": ROLE_SCORE_THRESHOLDS.get(role, 0.30),
"bus": config.get("bus", "music"),
"is_loop": config.get("is_loop", False),
}