""" reference_stem_builder.py - Rebuild an Ableton arrangement directly from a reference track. """ from __future__ import annotations import json import logging import socket from pathlib import Path from typing import Any, Dict, List, Tuple import soundfile as sf import torch from demucs.apply import apply_model from demucs.pretrained import get_model try: import librosa except ImportError: # pragma: no cover librosa = None try: from reference_listener import ReferenceAudioListener except ImportError: # pragma: no cover from .reference_listener import ReferenceAudioListener logger = logging.getLogger("ReferenceStemBuilder") HOST = "127.0.0.1" PORT = 9877 MESSAGE_TERMINATOR = b"\n" SCRIPT_DIR = Path(__file__).resolve().parent PACKAGE_DIR = SCRIPT_DIR.parent PROJECT_SAMPLES_DIR = PACKAGE_DIR.parent / "librerias" / "reggaeton" SAMPLES_DIR = str(PROJECT_SAMPLES_DIR) TRACK_LAYOUT = ( ("REFERENCE FULL", 59, 0.72, True), ("REF DRUMS", 10, 0.84, False), ("REF BASS", 30, 0.82, False), ("REF OTHER", 50, 0.68, False), ("REF VOCALS", 40, 0.70, False), ) SECTION_BLUEPRINTS = { "club": [ ("INTRO DJ", 16), ("GROOVE A", 16), ("VOCAL BUILD", 8), ("DROP A", 16), ("BREAKDOWN", 8), ("BUILD B", 8), ("DROP B", 16), ("PEAK", 8), ("OUTRO DJ", 16), ], "standard": [ ("INTRO", 8), ("BUILD", 8), ("DROP A", 16), ("BREAK", 8), ("DROP B", 16), ("OUTRO", 8), ], } class AbletonSocketClient: def __init__(self, host: str = HOST, port: int = PORT): self.host = host self.port = port def send(self, command_type: str, params: Dict[str, Any] | None = None, timeout: float = 30.0) -> Dict[str, Any]: payload = json.dumps({"type": command_type, "params": params or {}}, separators=(",", ":")).encode("utf-8") + MESSAGE_TERMINATOR with socket.create_connection((self.host, self.port), timeout=timeout) as sock: sock.sendall(payload) data = b"" while not data.endswith(MESSAGE_TERMINATOR): chunk = sock.recv(65536) if not chunk: break data += chunk if not data: raise RuntimeError(f"Sin respuesta para {command_type}") return json.loads(data.decode("utf-8", errors="replace").strip()) def _resolve_reference_profile(reference_path: Path) -> Dict[str, Any]: listener = ReferenceAudioListener(SAMPLES_DIR) analysis = listener.analyze_reference(str(reference_path)) structure = "club" if analysis.get("duration", 0.0) >= 180 else "standard" return { "tempo": float(analysis.get("tempo", 128.0) or 128.0), "key": str(analysis.get("key", "") or ""), "duration": float(analysis.get("duration", 0.0) or 0.0), "structure": structure, "listener_device": analysis.get("device", "cpu"), } def ensure_reference_wav(reference_path: Path) -> Path: if reference_path.suffix.lower() == ".wav": return reference_path if librosa is None: raise RuntimeError("librosa no está disponible para convertir la referencia a WAV") wav_path = reference_path.with_suffix(".wav") if wav_path.exists() and wav_path.stat().st_size > 0: return wav_path y, sr = librosa.load(str(reference_path), sr=44100, mono=False) if y.ndim == 1: y = y.reshape(1, -1) sf.write(str(wav_path), y.T, sr, subtype="PCM_16") return wav_path def separate_stems(reference_wav: Path, output_dir: Path) -> Dict[str, Path]: output_dir.mkdir(parents=True, exist_ok=True) stem_root = output_dir / reference_wav.stem expected = { "reference": reference_wav, "drums": stem_root / "drums.wav", "bass": stem_root / "bass.wav", "other": stem_root / "other.wav", "vocals": stem_root / "vocals.wav", } if all(path.exists() and path.stat().st_size > 0 for path in expected.values()): return expected audio, sr = sf.read(str(reference_wav), always_2d=True) if sr != 44100: raise RuntimeError(f"Sample rate inesperado en referencia WAV: {sr}") model = get_model("htdemucs") model.cpu() model.eval() waveform = torch.tensor(audio.T, dtype=torch.float32) separated = apply_model(model, waveform[None], device="cpu", progress=False)[0] stem_root.mkdir(parents=True, exist_ok=True) for stem_name, tensor in zip(model.sources, separated): stem_path = stem_root / f"{stem_name}.wav" sf.write(str(stem_path), tensor.detach().cpu().numpy().T, sr, subtype="PCM_16") return expected def _sections_for_structure(structure: str) -> List[Tuple[str, int]]: return list(SECTION_BLUEPRINTS.get(structure.lower(), SECTION_BLUEPRINTS["standard"])) def _create_track(client: AbletonSocketClient, name: str, color: int, volume: float) -> int: response = client.send("create_track", {"type": "audio", "index": -1}) if response.get("status") != "success": raise RuntimeError(response.get("message", f"No se pudo crear {name}")) track_index = int(response.get("result", {}).get("index")) client.send("set_track_name", {"index": track_index, "name": name}) client.send("set_track_color", {"index": track_index, "color": color}) client.send("set_track_volume", {"index": track_index, "volume": volume}) return track_index def _import_full_length_audio(client: AbletonSocketClient, track_index: int, file_path: Path, name: str) -> None: response = client.send("create_arrangement_audio_pattern", { "track_index": track_index, "file_path": str(file_path), "positions": [0.0], "name": name, }, timeout=120.0) if response.get("status") != "success": raise RuntimeError(response.get("message", f"No se pudo importar {name}")) def _prepare_navigation_scenes(client: AbletonSocketClient, structure: str) -> None: sections = _sections_for_structure(structure) session_info = client.send("get_session_info") if session_info.get("status") != "success": return scene_count = int(session_info.get("result", {}).get("num_scenes", 0) or 0) target_count = len(sections) while scene_count < target_count: create_response = client.send("create_scene", {"index": -1}) if create_response.get("status") != "success": break scene_count += 1 while scene_count > target_count and scene_count > 1: delete_response = client.send("delete_scene", {"index": scene_count - 1}) if delete_response.get("status") != "success": break scene_count -= 1 for scene_index, (section_name, _) in enumerate(sections): client.send("set_scene_name", {"index": scene_index, "name": section_name}) def rebuild_project_from_reference(reference_path: Path) -> Dict[str, Any]: reference_path = reference_path.resolve() if not reference_path.exists(): raise FileNotFoundError(reference_path) profile = _resolve_reference_profile(reference_path) reference_wav = ensure_reference_wav(reference_path) stems = separate_stems(reference_wav, reference_path.parent / "stems") client = AbletonSocketClient() clear_response = client.send("clear_project", {"keep_tracks": 0}, timeout=120.0) if clear_response.get("status") != "success": raise RuntimeError(clear_response.get("message", "No se pudo limpiar el proyecto")) client.send("stop", {}) client.send("set_tempo", {"tempo": round(profile["tempo"], 3)}) client.send("show_arrangement_view", {}) client.send("jump_to", {"time": 0}) created = [] for (track_name, color, volume, muted), stem_key in zip(TRACK_LAYOUT, ("reference", "drums", "bass", "other", "vocals")): track_index = _create_track(client, track_name, color, volume) _import_full_length_audio(client, track_index, stems[stem_key], track_name) if muted: client.send("set_track_mute", {"index": track_index, "mute": True}) created.append({ "track_index": track_index, "name": track_name, "file_path": str(stems[stem_key]), }) _prepare_navigation_scenes(client, profile["structure"]) client.send("loop_selection", {"start": 0, "length": max(32.0, round(profile["duration"] * profile["tempo"] / 60.0, 3)), "enable": False}) client.send("jump_to", {"time": 0}) client.send("show_arrangement_view", {}) session_info = client.send("get_session_info") return { "reference": str(reference_path), "tempo": profile["tempo"], "key": profile["key"], "structure": profile["structure"], "listener_device": profile["listener_device"], "stems": created, "session_info": session_info.get("result", {}), } def main() -> int: import argparse parser = argparse.ArgumentParser(description="Rebuild an Ableton project directly from a reference track.") parser.add_argument("reference_path", help="Absolute or relative path to the reference audio file") args = parser.parse_args() result = rebuild_project_from_reference(Path(args.reference_path)) print(json.dumps(result, indent=2, ensure_ascii=False)) return 0 if __name__ == "__main__": raise SystemExit(main())