Initial commit: AbletonMCP-AI complete system
- MCP Server with audio fallback, sample management - Song generator with bus routing - Reference listener and audio resampler - Vector-based sample search - Master chain with limiter and calibration - Fix: Audio fallback now works without M4L - Fix: Full song detection in sample loader Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
469
AbletonMCP_AI/MCP_Server/role_matcher.py
Normal file
469
AbletonMCP_AI/MCP_Server/role_matcher.py
Normal file
@@ -0,0 +1,469 @@
|
||||
"""
|
||||
role_matcher.py - Phase 4: Role validation and sample matching utilities
|
||||
|
||||
This module provides enhanced role matching for sample selection with:
|
||||
- Role validation based on audio characteristics
|
||||
- Aggressive sample detection and filtering
|
||||
- Logging of matching decisions
|
||||
- Integration with reference_listener and sample_selector
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger("RoleMatcher")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CONSTANTS
|
||||
# ============================================================================
|
||||
|
||||
# Valid roles for sample matching with their expected characteristics
|
||||
VALID_ROLES = {
|
||||
# One-shot drums
|
||||
"kick": {"max_duration": 2.0, "min_onset": 0.3, "is_loop": False, "bus": "drums"},
|
||||
"snare": {"max_duration": 2.0, "min_onset": 0.25, "is_loop": False, "bus": "drums"},
|
||||
"hat": {"max_duration": 1.5, "min_onset": 0.2, "is_loop": False, "bus": "drums"},
|
||||
"clap": {"max_duration": 2.0, "min_onset": 0.25, "is_loop": False, "bus": "drums"},
|
||||
"ride": {"max_duration": 3.0, "min_onset": 0.15, "is_loop": False, "bus": "drums"},
|
||||
"perc": {"max_duration": 2.5, "min_onset": 0.2, "is_loop": False, "bus": "drums"},
|
||||
# Loops
|
||||
"bass_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "bass"},
|
||||
"perc_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "drums"},
|
||||
"top_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "drums"},
|
||||
"synth_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "music"},
|
||||
"vocal_loop": {"min_duration": 2.0, "max_duration": 16.0, "is_loop": True, "bus": "vocal"},
|
||||
# FX
|
||||
"crash_fx": {"max_duration": 4.0, "is_loop": False, "bus": "fx"},
|
||||
"fill_fx": {"max_duration": 8.0, "is_loop": False, "bus": "fx"},
|
||||
"snare_roll": {"max_duration": 8.0, "is_loop": False, "bus": "drums"},
|
||||
"atmos_fx": {"min_duration": 4.0, "is_loop": True, "bus": "fx"},
|
||||
"vocal_shot": {"max_duration": 3.0, "is_loop": False, "bus": "vocal"},
|
||||
# Resample layers
|
||||
"resample_reverse": {"is_loop": False, "bus": "fx"},
|
||||
"resample_riser": {"is_loop": False, "bus": "fx"},
|
||||
"resample_downlifter": {"is_loop": False, "bus": "fx"},
|
||||
"resample_stutter": {"is_loop": False, "bus": "vocal"},
|
||||
}
|
||||
|
||||
# Keywords that indicate aggressive/hard samples that may be misclassified
|
||||
AGGRESSIVE_KEYWORDS = {
|
||||
# Very aggressive kick patterns
|
||||
"hard", "distorted", "industrial", "slam", "punch", "brutal",
|
||||
# Potentially misclassified
|
||||
"subdrop", "impact", "explosion", "destroy",
|
||||
}
|
||||
|
||||
# Keywords that are acceptable for aggressive genres
|
||||
GENRE_APPROPRIATE_AGGRESSIVE = {
|
||||
"industrial-techno", "hard-techno", "raw-techno", "psytrance", "dark-techno"
|
||||
}
|
||||
|
||||
# Role aliases for flexible matching
|
||||
ROLE_ALIASES = {
|
||||
"kick": ["kick", "bd", "bassdrum", "bass_drum"],
|
||||
"snare": ["snare", "sd", "snr"],
|
||||
"clap": ["clap", "cp", "handclap"],
|
||||
"hat": ["hat", "hihat", "hi_hat", "hhat", "closed_hat", "hat_closed"],
|
||||
"hat_open": ["open_hat", "hat_open", "ohat", "openhihat"],
|
||||
"ride": ["ride", "rd", "cymbal"],
|
||||
"perc": ["perc", "percussion", "percs"],
|
||||
"bass_loop": ["bass_loop", "bassloop", "bass loop", "sub_bass"],
|
||||
"perc_loop": ["perc_loop", "percloop", "percussion loop", "perc loop"],
|
||||
"top_loop": ["top_loop", "toploop", "top loop", "full_drum"],
|
||||
"synth_loop": ["synth_loop", "synthloop", "synth loop", "chord_loop", "stab"],
|
||||
"vocal_loop": ["vocal_loop", "vocalloop", "vocal loop", "vox_loop", "vox"],
|
||||
"crash_fx": ["crash", "crash_fx", "crashfx", "impact_fx"],
|
||||
"fill_fx": ["fill", "fill_fx", "fillfx", "tom_fill", "transition"],
|
||||
"snare_roll": ["snare_roll", "snareroll", "snare roll", "snr_roll"],
|
||||
"atmos_fx": ["atmos", "atmos_fx", "atmosfx", "drone", "pad_fx"],
|
||||
"vocal_shot": ["vocal_shot", "vocalshot", "vocal shot", "vocal_one_shot"],
|
||||
}
|
||||
|
||||
# Minimum score thresholds for role matching
|
||||
ROLE_SCORE_THRESHOLDS = {
|
||||
"kick": 0.35,
|
||||
"snare": 0.32,
|
||||
"hat": 0.30,
|
||||
"clap": 0.32,
|
||||
"bass_loop": 0.38,
|
||||
"perc_loop": 0.35,
|
||||
"top_loop": 0.35,
|
||||
"synth_loop": 0.36,
|
||||
"vocal_loop": 0.38,
|
||||
"crash_fx": 0.30,
|
||||
"fill_fx": 0.32,
|
||||
"snare_roll": 0.30,
|
||||
"atmos_fx": 0.32,
|
||||
"vocal_shot": 0.34,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# VALIDATION FUNCTIONS
|
||||
# ============================================================================
|
||||
|
||||
def validate_role_for_sample(
|
||||
role: str,
|
||||
sample_data: Dict[str, Any],
|
||||
genre: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validates if a sample is appropriate for a given role.
|
||||
|
||||
Args:
|
||||
role: The role to validate for (e.g., 'kick', 'bass_loop')
|
||||
sample_data: Sample metadata with keys like 'duration', 'onset_mean', 'file_name', 'rms_mean'
|
||||
genre: Optional genre for context-aware aggressive sample handling
|
||||
|
||||
Returns:
|
||||
Dict with keys:
|
||||
- 'valid' (bool): Whether the sample passes validation
|
||||
- 'score' (float): Raw validation score (0.0-1.0)
|
||||
- 'warnings' (list): List of warning messages
|
||||
- 'adjusted_score' (float): Score after penalties
|
||||
"""
|
||||
if role not in VALID_ROLES:
|
||||
return {"valid": True, "score": 0.5, "warnings": [f"Unknown role: {role}"], "adjusted_score": 0.5}
|
||||
|
||||
role_config = VALID_ROLES[role]
|
||||
warnings: List[str] = []
|
||||
score = 1.0
|
||||
|
||||
duration = float(sample_data.get("duration", 0.0) or 0.0)
|
||||
onset = float(sample_data.get("onset_mean", 0.0) or 0.0)
|
||||
file_name = str(sample_data.get("file_name", "") or "").lower()
|
||||
rms = float(sample_data.get("rms_mean", 0.0) or 0.0)
|
||||
|
||||
# Duration validation
|
||||
if role_config.get("is_loop"):
|
||||
min_dur = role_config.get("min_duration", 2.0)
|
||||
max_dur = role_config.get("max_duration", 16.0)
|
||||
if duration < min_dur:
|
||||
warnings.append(f"Duration {duration:.1f}s too short for loop role (min {min_dur}s)")
|
||||
score *= 0.7
|
||||
elif max_dur and duration > max_dur:
|
||||
warnings.append(f"Duration {duration:.1f}s too long for role (max {max_dur}s)")
|
||||
score *= 0.85
|
||||
else:
|
||||
max_dur = role_config.get("max_duration", 3.0)
|
||||
if duration > max_dur:
|
||||
warnings.append(f"Duration {duration:.1f}s too long for one-shot role (max {max_dur}s)")
|
||||
score *= 0.75
|
||||
if "loop" in file_name and role in ["kick", "snare", "hat", "clap"]:
|
||||
warnings.append("One-shot role has 'loop' in filename")
|
||||
score *= 0.65
|
||||
|
||||
# Onset validation for percussive elements
|
||||
min_onset = role_config.get("min_onset", 0.0)
|
||||
if min_onset > 0 and onset < min_onset:
|
||||
warnings.append(f"Onset {onset:.2f} below minimum {min_onset:.2f}")
|
||||
score *= 0.85
|
||||
|
||||
# Check for aggressive samples that might be misclassified
|
||||
aggressive_penalty = 1.0
|
||||
is_aggressive_genre = genre and genre.lower() in GENRE_APPROPRIATE_AGGRESSIVE
|
||||
|
||||
for keyword in AGGRESSIVE_KEYWORDS:
|
||||
if keyword in file_name:
|
||||
if not is_aggressive_genre:
|
||||
aggressive_penalty *= 0.88
|
||||
warnings.append(f"Aggressive keyword '{keyword}' found for non-aggressive genre")
|
||||
|
||||
score *= aggressive_penalty
|
||||
|
||||
# RMS validation for certain roles
|
||||
if role in ["kick", "snare", "clap"] and rms > 0.4:
|
||||
warnings.append(f"High RMS {rms:.3f} for one-shot role")
|
||||
score *= 0.9
|
||||
|
||||
adjusted_score = max(0.1, min(1.0, score))
|
||||
|
||||
return {
|
||||
"valid": score >= 0.4,
|
||||
"score": score,
|
||||
"warnings": warnings,
|
||||
"adjusted_score": adjusted_score,
|
||||
}
|
||||
|
||||
|
||||
def resolve_role_from_alias(alias: str) -> Optional[str]:
|
||||
"""
|
||||
Resolves a role name from various aliases.
|
||||
|
||||
Args:
|
||||
alias: A potential role alias (e.g., 'bd', 'hihat', 'bass loop')
|
||||
|
||||
Returns:
|
||||
The canonical role name or None if not found
|
||||
"""
|
||||
alias_lower = alias.lower().strip().replace("-", "_").replace(" ", "_")
|
||||
|
||||
# Direct match
|
||||
if alias_lower in VALID_ROLES:
|
||||
return alias_lower
|
||||
|
||||
# Check aliases
|
||||
for role, aliases in ROLE_ALIASES.items():
|
||||
normalized_aliases = [a.lower().replace("-", "_").replace(" ", "_") for a in aliases]
|
||||
if alias_lower in normalized_aliases:
|
||||
return role
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_bus_for_role(role: str) -> str:
|
||||
"""
|
||||
Gets the appropriate bus for a role.
|
||||
|
||||
Args:
|
||||
role: The role name
|
||||
|
||||
Returns:
|
||||
Bus name ('drums', 'bass', 'music', 'vocal', or 'fx')
|
||||
"""
|
||||
if role in VALID_ROLES:
|
||||
return VALID_ROLES[role].get("bus", "music")
|
||||
return "music"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LOGGING FUNCTIONS
|
||||
# ============================================================================
|
||||
|
||||
def log_matching_decision(
|
||||
role: str,
|
||||
selected_sample: Optional[Dict[str, Any]],
|
||||
candidates_count: int,
|
||||
final_score: float,
|
||||
validation_result: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Logs detailed matching decisions for debugging and analysis.
|
||||
|
||||
Args:
|
||||
role: The role being matched
|
||||
selected_sample: The selected sample dict or None
|
||||
candidates_count: Number of candidates considered
|
||||
final_score: The final matching score
|
||||
validation_result: Optional validation result dict
|
||||
"""
|
||||
if not selected_sample:
|
||||
logger.info(
|
||||
f"[MATCH] Role '{role}': No sample selected (0/{candidates_count} candidates)"
|
||||
)
|
||||
return
|
||||
|
||||
sample_name = selected_sample.get("file_name", "unknown")
|
||||
sample_tempo = selected_sample.get("tempo", 0.0)
|
||||
sample_key = selected_sample.get("key", "N/A")
|
||||
sample_dur = selected_sample.get("duration", 0.0)
|
||||
|
||||
log_parts = [
|
||||
f"[MATCH] Role '{role}':",
|
||||
f"Sample: {sample_name}",
|
||||
f"Score: {final_score:.3f}",
|
||||
f"Tempo: {sample_tempo:.1f}",
|
||||
f"Key: {sample_key}",
|
||||
f"Duration: {sample_dur:.1f}s",
|
||||
f"Candidates: {candidates_count}",
|
||||
]
|
||||
|
||||
if validation_result:
|
||||
warnings = validation_result.get("warnings", [])
|
||||
if warnings:
|
||||
log_parts.append(f"Warnings: {', '.join(warnings)}")
|
||||
log_parts.append(f"Validated: {validation_result.get('valid', True)}")
|
||||
|
||||
logger.info(" | ".join(log_parts))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ENHANCEMENT FUNCTIONS
|
||||
# ============================================================================
|
||||
|
||||
def enhance_sample_matching(
|
||||
matches: Dict[str, List[Dict[str, Any]]],
|
||||
reference: Dict[str, Any],
|
||||
genre: Optional[str] = None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Enhances sample matching results with validation and filtering.
|
||||
|
||||
This function takes raw matches from reference_listener and applies:
|
||||
1. Role validation based on audio characteristics
|
||||
2. Aggressive sample filtering
|
||||
3. Score adjustment based on validation results
|
||||
|
||||
Args:
|
||||
matches: Raw matches from reference_listener (role -> list of sample dicts)
|
||||
reference: Reference track analysis data
|
||||
genre: Target genre for context-aware filtering
|
||||
|
||||
Returns:
|
||||
Enhanced matches with validation scores and filtering applied
|
||||
"""
|
||||
enhanced: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
for role, candidates in matches.items():
|
||||
if not candidates:
|
||||
enhanced[role] = []
|
||||
continue
|
||||
|
||||
threshold = ROLE_SCORE_THRESHOLDS.get(role, 0.30)
|
||||
enhanced_candidates: List[Dict[str, Any]] = []
|
||||
|
||||
for candidate in candidates:
|
||||
# Create a copy to avoid modifying the original
|
||||
enhanced_candidate = dict(candidate)
|
||||
|
||||
# Validate the sample for this role
|
||||
validation = validate_role_for_sample(role, candidate, genre)
|
||||
enhanced_candidate["validation"] = validation
|
||||
|
||||
# Apply validation penalty to the score
|
||||
original_score = float(candidate.get("score", 0.0))
|
||||
adjusted_score = original_score * validation["adjusted_score"]
|
||||
enhanced_candidate["adjusted_score"] = round(adjusted_score, 6)
|
||||
|
||||
# Filter out samples below threshold
|
||||
if adjusted_score >= threshold:
|
||||
enhanced_candidates.append(enhanced_candidate)
|
||||
else:
|
||||
logger.debug(
|
||||
f"[FILTER] Role '{role}': Filtered out '{candidate.get('file_name', 'unknown')}' "
|
||||
f"(score {adjusted_score:.3f} < threshold {threshold})"
|
||||
)
|
||||
|
||||
# Re-sort by adjusted score
|
||||
enhanced_candidates.sort(key=lambda x: float(x.get("adjusted_score", 0.0)), reverse=True)
|
||||
enhanced[role] = enhanced_candidates
|
||||
|
||||
# Log summary
|
||||
filtered_count = len(candidates) - len(enhanced_candidates)
|
||||
if filtered_count > 0:
|
||||
logger.info(
|
||||
f"[ENHANCE] Role '{role}': {len(enhanced_candidates)}/{len(candidates)} candidates passed validation "
|
||||
f"({filtered_count} filtered out)"
|
||||
)
|
||||
|
||||
return enhanced
|
||||
|
||||
|
||||
def filter_aggressive_samples(
|
||||
candidates: List[Dict[str, Any]],
|
||||
genre: Optional[str] = None,
|
||||
strict: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Filters out samples with aggressive keywords unless appropriate for the genre.
|
||||
|
||||
Args:
|
||||
candidates: List of sample candidate dicts
|
||||
genre: Target genre
|
||||
strict: If True, apply stricter filtering
|
||||
|
||||
Returns:
|
||||
Filtered list of candidates
|
||||
"""
|
||||
is_aggressive_genre = genre and genre.lower() in GENRE_APPROPRIATE_AGGRESSIVE
|
||||
|
||||
if is_aggressive_genre:
|
||||
# For aggressive genres, don't filter aggressive samples
|
||||
return candidates
|
||||
|
||||
filtered = []
|
||||
for candidate in candidates:
|
||||
file_name = str(candidate.get("file_name", "") or "").lower()
|
||||
aggressive_count = sum(1 for kw in AGGRESSIVE_KEYWORDS if kw in file_name)
|
||||
|
||||
if strict and aggressive_count > 0:
|
||||
continue
|
||||
|
||||
# Apply penalty instead of filtering completely
|
||||
if aggressive_count > 0:
|
||||
penalty = 0.85 ** aggressive_count
|
||||
candidate_copy = dict(candidate)
|
||||
original_score = float(candidate.get("score", 0.0))
|
||||
candidate_copy["score"] = original_score * penalty
|
||||
filtered.append(candidate_copy)
|
||||
else:
|
||||
filtered.append(candidate)
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# INTEGRATION HELPERS
|
||||
# ============================================================================
|
||||
|
||||
def create_enhanced_match_report(
|
||||
role: str,
|
||||
selected_sample: Optional[Dict[str, Any]],
|
||||
all_candidates: List[Dict[str, Any]],
|
||||
validation_result: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Creates a detailed report for a matching decision.
|
||||
|
||||
Args:
|
||||
role: The role being matched
|
||||
selected_sample: The selected sample
|
||||
all_candidates: All candidates that were considered
|
||||
validation_result: Validation result for the selected sample
|
||||
|
||||
Returns:
|
||||
A dict with detailed matching report
|
||||
"""
|
||||
report = {
|
||||
"role": role,
|
||||
"selected": selected_sample is not None,
|
||||
"candidates_count": len(all_candidates),
|
||||
"threshold": ROLE_SCORE_THRESHOLDS.get(role, 0.30),
|
||||
}
|
||||
|
||||
if selected_sample:
|
||||
report["selected_sample"] = {
|
||||
"name": selected_sample.get("file_name"),
|
||||
"path": selected_sample.get("path"),
|
||||
"score": selected_sample.get("score"),
|
||||
"adjusted_score": selected_sample.get("adjusted_score"),
|
||||
"tempo": selected_sample.get("tempo"),
|
||||
"key": selected_sample.get("key"),
|
||||
"duration": selected_sample.get("duration"),
|
||||
}
|
||||
|
||||
if validation_result:
|
||||
report["validation"] = {
|
||||
"valid": validation_result.get("valid"),
|
||||
"score": validation_result.get("score"),
|
||||
"warnings": validation_result.get("warnings", []),
|
||||
}
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def get_role_info(role: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Gets comprehensive information about a role.
|
||||
|
||||
Args:
|
||||
role: The role name
|
||||
|
||||
Returns:
|
||||
Dict with role information including valid samples count, thresholds, etc.
|
||||
"""
|
||||
if role not in VALID_ROLES:
|
||||
return {"error": f"Unknown role: {role}"}
|
||||
|
||||
config = VALID_ROLES[role]
|
||||
aliases = ROLE_ALIASES.get(role, [])
|
||||
|
||||
return {
|
||||
"role": role,
|
||||
"config": config,
|
||||
"aliases": aliases,
|
||||
"threshold": ROLE_SCORE_THRESHOLDS.get(role, 0.30),
|
||||
"bus": config.get("bus", "music"),
|
||||
"is_loop": config.get("is_loop", False),
|
||||
}
|
||||
Reference in New Issue
Block a user