""" retrieval_benchmark.py - Offline benchmark harness for retrieval quality inspection. Analyzes reference tracks and outputs top-N candidates per role to help spot role contamination and evaluate retrieval quality. Usage: python retrieval_benchmark.py --reference "path/to/track.mp3" python retrieval_benchmark.py --reference "track1.mp3" "track2.mp3" --top-n 10 python retrieval_benchmark.py --reference "track.mp3" --output results.json --format json python retrieval_benchmark.py --reference "track.mp3" --output results.md --format markdown """ from __future__ import annotations import argparse import json import logging import sys import time from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Optional # Add parent directory to path for imports when running as script sys.path.insert(0, str(Path(__file__).parent)) from reference_listener import ReferenceAudioListener, ROLE_SEGMENT_SETTINGS logger = logging.getLogger(__name__) def _default_library_dir() -> Path: """Get the default library directory.""" return Path(__file__).resolve().parents[2] / "librerias" / "all_tracks" def run_benchmark( reference_paths: List[str], library_dir: Path, top_n: int = 10, roles: Optional[List[str]] = None, duration_limit: Optional[float] = None, ) -> Dict[str, Any]: """ Run retrieval benchmark on one or more reference tracks. Args: reference_paths: List of paths to reference audio files library_dir: Path to the sample library top_n: Number of top candidates to show per role roles: Optional list of specific roles to analyze duration_limit: Optional duration limit for analysis Returns: Dict containing benchmark results for each reference """ listener = ReferenceAudioListener(str(library_dir)) all_roles = list(ROLE_SEGMENT_SETTINGS.keys()) target_roles = [r for r in (roles or all_roles) if r in all_roles] results = { "benchmark_info": { "library_dir": str(library_dir), "top_n": top_n, "roles": target_roles, "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "device": listener.device_name, }, "references": [], } for ref_path in reference_paths: ref_path = Path(ref_path) if not ref_path.exists(): logger.warning("Reference file not found: %s", ref_path) continue logger.info("Analyzing reference: %s", ref_path.name) try: start_time = time.time() # Run match_assets to get candidates per role match_result = listener.match_assets(str(ref_path)) reference_info = match_result.get("reference", {}) matches = match_result.get("matches", {}) elapsed = time.time() - start_time ref_result = { "file_name": ref_path.name, "path": str(ref_path), "analysis_time_seconds": round(elapsed, 2), "reference_info": { "tempo": reference_info.get("tempo"), "key": reference_info.get("key"), "duration": reference_info.get("duration"), "rms_mean": reference_info.get("rms_mean"), "onset_mean": reference_info.get("onset_mean"), "spectral_centroid": reference_info.get("spectral_centroid"), }, "sections": [ { "kind": s.get("kind"), "start": s.get("start"), "end": s.get("end"), "bars": s.get("bars"), } for s in match_result.get("reference_sections", []) ], "role_candidates": {}, } # Process each role for role in target_roles: role_matches = matches.get(role, []) top_candidates = role_matches[:top_n] ref_result["role_candidates"][role] = { "total_available": len(role_matches), "top_candidates": [ { "rank": i + 1, "file_name": c.get("file_name"), "path": c.get("path"), "score": c.get("score"), "cosine": c.get("cosine"), "segment_score": c.get("segment_score"), "catalog_score": c.get("catalog_score"), "tempo": c.get("tempo"), "key": c.get("key"), "duration": c.get("duration"), } for i, c in enumerate(top_candidates) ], } results["references"].append(ref_result) logger.info("Completed analysis in %.2fs", elapsed) except Exception as e: logger.error("Failed to analyze %s: %s", ref_path, e, exc_info=True) results["references"].append({ "file_name": ref_path.name, "path": str(ref_path), "error": str(e), }) return results def analyze_role_contamination(results: Dict[str, Any]) -> Dict[str, Any]: """ Analyze results for potential role contamination issues. Returns a dict with contamination analysis: - files appearing in multiple roles - misnamed files (e.g., "bass" appearing in "kick" role) - score distribution anomalies """ contamination = { "cross_role_files": [], "potential_mismatches": [], "role_score_stats": {}, } # Track files appearing in multiple roles file_to_roles: Dict[str, List[Dict[str, Any]]] = defaultdict(list) for ref in results.get("references", []): ref_name = ref.get("file_name", "unknown") for role, role_data in ref.get("role_candidates", {}).items(): for candidate in role_data.get("top_candidates", []): file_name = candidate.get("file_name", "") if file_name: file_to_roles[file_name].append({ "reference": ref_name, "role": role, "rank": candidate.get("rank"), "score": candidate.get("score"), }) # Find files appearing in multiple roles for file_name, appearances in file_to_roles.items(): unique_roles = set(a["role"] for a in appearances) if len(unique_roles) > 1: contamination["cross_role_files"].append({ "file_name": file_name, "roles": list(unique_roles), "appearances": appearances, }) # Check for potential mismatches (filename suggests different role) role_keywords = { "kick": ["kick"], "snare": ["snare", "clap"], "hat": ["hat", "hihat", "hi-hat"], "bass_loop": ["bass", "sub", "808"], "perc_loop": ["perc", "percussion", "conga", "bongo"], "top_loop": ["top", "drum loop", "full drum"], "synth_loop": ["synth", "lead", "pad", "chord", "arp"], "vocal_loop": ["vocal", "vox", "acapella"], "crash_fx": ["crash", "cymbal", "impact"], "fill_fx": ["fill", "transition", "tom"], "snare_roll": ["roll", "snareroll"], "atmos_fx": ["atmos", "drone", "ambient", "texture"], "vocal_shot": ["shot", "vocal shot", "chop"], } for ref in results.get("references", []): for role, role_data in ref.get("role_candidates", {}).items(): for candidate in role_data.get("top_candidates", []): file_name = candidate.get("file_name", "").lower() if not file_name: continue # Check if file name suggests a different role expected_keywords = role_keywords.get(role, []) other_role_matches = [] for other_role, keywords in role_keywords.items(): if other_role == role: continue if any(kw in file_name for kw in keywords): other_role_matches.append(other_role) if other_role_matches and expected_keywords: # File name matches another role but not this one if not any(kw in file_name for kw in expected_keywords): contamination["potential_mismatches"].append({ "file_name": candidate.get("file_name"), "assigned_role": role, "rank": candidate.get("rank"), "score": candidate.get("score"), "suggested_roles": other_role_matches, }) # Calculate score distribution per role for ref in results.get("references", []): for role, role_data in ref.get("role_candidates", {}).items(): scores = [ c.get("score", 0) for c in role_data.get("top_candidates", []) if c.get("score") is not None ] if scores: contamination["role_score_stats"][role] = { "min": round(min(scores), 4), "max": round(max(scores), 4), "avg": round(sum(scores) / len(scores), 4), "count": len(scores), } return contamination def format_output_json(results: Dict[str, Any]) -> str: """Format results as JSON string.""" return json.dumps(results, indent=2, ensure_ascii=False) def format_output_markdown(results: Dict[str, Any]) -> str: """Format results as markdown string.""" lines = [] # Header lines.append("# Retrieval Benchmark Report") lines.append("") lines.append(f"**Generated:** {results['benchmark_info']['timestamp']}") lines.append(f"**Library:** `{results['benchmark_info']['library_dir']}`") lines.append(f"**Top N:** {results['benchmark_info']['top_n']}") lines.append(f"**Device:** {results['benchmark_info']['device']}") lines.append("") # Process each reference for ref in results.get("references", []): lines.append(f"## Reference: {ref.get('file_name', 'unknown')}") lines.append("") # Error case if "error" in ref: lines.append(f"**Error:** {ref['error']}") lines.append("") continue # Reference info ref_info = ref.get("reference_info", {}) lines.append("### Reference Analysis") lines.append("") lines.append("| Property | Value |") lines.append("|----------|-------|") lines.append(f"| Tempo | {ref_info.get('tempo', 'N/A')} BPM |") lines.append(f"| Key | {ref_info.get('key', 'N/A')} |") lines.append(f"| Duration | {ref_info.get('duration', 'N/A')}s |") lines.append(f"| RMS Mean | {ref_info.get('rms_mean', 'N/A')} |") lines.append(f"| Onset Mean | {ref_info.get('onset_mean', 'N/A')} |") lines.append(f"| Spectral Centroid | {ref_info.get('spectral_centroid', 'N/A')} Hz |") lines.append("") # Sections sections = ref.get("sections", []) if sections: lines.append("### Detected Sections") lines.append("") lines.append("| Type | Start | End | Bars |") lines.append("|------|-------|-----|------|") for s in sections: lines.append(f"| {s.get('kind', 'N/A')} | {s.get('start', 'N/A')}s | {s.get('end', 'N/A')}s | {s.get('bars', 'N/A')} |") lines.append("") # Role candidates lines.append("### Top Candidates per Role") lines.append("") for role, role_data in ref.get("role_candidates", {}).items(): total = role_data.get("total_available", 0) lines.append(f"#### {role} ({total} available)") lines.append("") candidates = role_data.get("top_candidates", []) if not candidates: lines.append("*No candidates found*") lines.append("") continue lines.append("| Rank | File | Score | Cosine | Seg | Catalog | Tempo | Key | Duration |") lines.append("|------|------|-------|--------|-----|---------|-------|-----|----------|") for c in candidates: lines.append( f"| {c.get('rank', 'N/A')} | " f"`{c.get('file_name', 'N/A')[:40]}` | " f"{c.get('score', 0):.4f} | " f"{c.get('cosine', 0):.4f} | " f"{c.get('segment_score', 0):.4f} | " f"{c.get('catalog_score', 0):.4f} | " f"{c.get('tempo', 'N/A')} | " f"{c.get('key', 'N/A')} | " f"{c.get('duration', 'N/A'):.2f}s |" ) lines.append("") # Contamination analysis if "contamination_analysis" in results: contam = results["contamination_analysis"] lines.append("## Role Contamination Analysis") lines.append("") # Cross-role files cross_role = contam.get("cross_role_files", []) if cross_role: lines.append("### Files Appearing in Multiple Roles") lines.append("") for item in cross_role: lines.append(f"- **{item['file_name']}**") lines.append(f" - Roles: {', '.join(item['roles'])}") for app in item["appearances"]: lines.append(f" - {app['role']}: rank {app['rank']}, score {app['score']:.4f}") lines.append("") # Potential mismatches mismatches = contam.get("potential_mismatches", []) if mismatches: lines.append("### Potential Role Mismatches") lines.append("") lines.append("Files whose names suggest a different role than assigned:") lines.append("") for item in mismatches: lines.append(f"- **{item['file_name']}**") lines.append(f" - Assigned: {item['assigned_role']} (rank {item['rank']}, score {item['score']:.4f})") lines.append(f" - Suggested: {', '.join(item['suggested_roles'])}") lines.append("") # Score stats score_stats = contam.get("role_score_stats", {}) if score_stats: lines.append("### Score Distribution per Role") lines.append("") lines.append("| Role | Min | Max | Avg | Count |") lines.append("|------|-----|-----|-----|-------|") for role, stats in sorted(score_stats.items()): lines.append( f"| {role} | {stats['min']:.4f} | {stats['max']:.4f} | " f"{stats['avg']:.4f} | {stats['count']} |" ) lines.append("") return "\n".join(lines) def main() -> int: parser = argparse.ArgumentParser( description="Offline benchmark harness for retrieval quality inspection.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: %(prog)s --reference "track.mp3" %(prog)s --reference "track1.mp3" "track2.mp3" --top-n 15 %(prog)s --reference "track.mp3" --output results.md --format markdown %(prog)s --reference "track.mp3" --roles kick snare hat --top-n 20 """, ) parser.add_argument( "--reference", "-r", nargs="+", required=True, help="One or more reference audio files to analyze", ) parser.add_argument( "--library-dir", default=str(_default_library_dir()), help="Audio library directory (default: ../librerias/all_tracks)", ) parser.add_argument( "--top-n", "-n", type=int, default=10, help="Number of top candidates to show per role (default: 10)", ) parser.add_argument( "--roles", nargs="*", default=None, help="Specific roles to analyze (default: all roles)", ) parser.add_argument( "--output", "-o", type=str, default=None, help="Output file path for results", ) parser.add_argument( "--format", "-f", choices=["json", "markdown", "md"], default=None, help="Output format (json or markdown). Auto-detected from output file extension if not specified.", ) parser.add_argument( "--analyze-contamination", action="store_true", help="Include role contamination analysis in output", ) parser.add_argument( "--verbose", "-v", action="store_true", help="Enable verbose logging", ) parser.add_argument( "--duration-limit", type=float, default=None, help="Optional duration limit for audio analysis", ) args = parser.parse_args() # Configure logging if args.verbose: logging.basicConfig(level=logging.DEBUG, format="%(levelname)s: %(message)s") else: logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") # Validate reference files reference_paths = [] for ref in args.reference: ref_path = Path(ref) if ref_path.exists(): reference_paths.append(str(ref_path)) else: logger.warning("Reference file not found: %s", ref) if not reference_paths: logger.error("No valid reference files provided") return 1 # Run benchmark logger.info("Running retrieval benchmark on %d reference(s)", len(reference_paths)) results = run_benchmark( reference_paths=reference_paths, library_dir=Path(args.library_dir), top_n=args.top_n, roles=args.roles, duration_limit=args.duration_limit, ) # Add contamination analysis if requested if args.analyze_contamination: logger.info("Analyzing role contamination...") results["contamination_analysis"] = analyze_role_contamination(results) # Determine output format output_format = args.format if output_format is None and args.output: output_format = "markdown" if args.output.endswith(".md") else "json" output_format = output_format or "text" # Format output if output_format in ("markdown", "md"): output_text = format_output_markdown(results) elif output_format == "json": output_text = format_output_json(results) else: # Plain text summary output_text = format_output_markdown(results) # Write to file or stdout if args.output: output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(output_text, encoding="utf-8") logger.info("Results written to: %s", output_path) else: print(output_text) return 0 if __name__ == "__main__": raise SystemExit(main())