"""Tests for DrumLoopAnalyzer.""" from __future__ import annotations import json import sys from pathlib import Path import numpy as np import pytest import soundfile as sf sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from src.composer.drum_analyzer import BeatGrid, DrumLoopAnalyzer, DrumLoopAnalysis, Transient @pytest.fixture def synthetic_kick(tmp_path): sr = 44100 dur = 2.0 t = np.linspace(0, dur, int(sr * dur), endpoint=False) y = np.zeros_like(t) for pos in [0.0, 0.5, 1.0, 1.5]: idx = int(pos * sr) freq_sweep = np.exp(-np.linspace(0, 8, 800)) * np.sin( 2 * np.pi * np.linspace(150, 40, 800) * np.linspace(0, 0.02, 800) ) end = min(idx + len(freq_sweep), len(y)) y[idx:end] += freq_sweep[: end - idx] path = tmp_path / "synth_kick.wav" sf.write(str(path), y, sr) return str(path) @pytest.fixture def synthetic_drumloop(tmp_path): sr = 44100 bpm = 120 dur = 4.0 t = np.linspace(0, dur, int(sr * dur), endpoint=False) y = np.zeros_like(t) beat = 60.0 / bpm for bar in range(2): off = bar * 4 * beat for p in [0.0, 2.0 * beat, 3.5 * beat]: idx = int((off + p) * sr) n = 600 kick = np.exp(-np.linspace(0, 10, n)) * np.sin( 2 * np.pi * np.linspace(160, 35, n) * np.linspace(0, 0.03, n) ) end = min(idx + n, len(y)) y[idx:end] += kick[: end - idx] * 0.8 for p in [1.0 * beat, 2.5 * beat]: idx = int((off + p) * sr) n = 1200 noise = np.random.RandomState(42).randn(n) * np.exp(-np.linspace(0, 6, n)) snare = np.sin(2 * np.pi * 200 * np.linspace(0, 0.05, n)) * np.exp(-np.linspace(0, 5, n)) end = min(idx + n, len(y)) y[idx:end] += (noise + snare)[: end - idx] * 0.5 for i in range(8): p = i * beat / 2 idx = int((off + p) * sr) n = 200 hh = np.random.RandomState(i).randn(n) * np.exp(-np.linspace(0, 20, n)) end = min(idx + n, len(y)) y[idx:end] += hh[: end - idx] * 0.15 y = y / (np.max(np.abs(y)) + 1e-10) * 0.9 path = tmp_path / "synth_drumloop.wav" sf.write(str(path), y, sr) return str(path) class TestDrumLoopAnalyzer: def test_analyze_returns_result(self, synthetic_drumloop): analyzer = DrumLoopAnalyzer(synthetic_drumloop) result = analyzer.analyze() assert isinstance(result, DrumLoopAnalysis) assert result.bpm > 0 assert result.duration > 0 assert len(result.beats) > 0 assert len(result.transients) > 0 assert isinstance(result.beat_grid, BeatGrid) assert len(result.beat_grid.quarter) > 0 def test_bpm_reasonable(self, synthetic_drumloop): result = DrumLoopAnalyzer(synthetic_drumloop).analyze() assert 60 <= result.bpm <= 200, f"BPM {result.bpm} out of range" def test_transient_classification(self, synthetic_drumloop): result = DrumLoopAnalyzer(synthetic_drumloop).analyze() types = {t.type for t in result.transients} valid = {"kick", "snare", "hihat", "other"} assert types <= valid, f"Unexpected types: {types - valid}" def test_beat_grid_populated(self, synthetic_drumloop): result = DrumLoopAnalyzer(synthetic_drumloop).analyze() grid = result.beat_grid assert len(grid.quarter) > 0 assert len(grid.eighth) >= len(grid.quarter) assert len(grid.sixteenth) >= len(grid.eighth) def test_key_detection(self, synthetic_drumloop): result = DrumLoopAnalyzer(synthetic_drumloop).analyze() assert result.key is not None assert result.key_confidence >= 0 def test_energy_profile(self, synthetic_drumloop): result = DrumLoopAnalyzer(synthetic_drumloop).analyze() assert len(result.energy_profile) > 0 assert all(e >= 0 for e in result.energy_profile) def test_to_dict_roundtrip(self, synthetic_drumloop): result = DrumLoopAnalyzer(synthetic_drumloop).analyze() d = result.to_dict() assert d["bpm"] == round(result.bpm, 2) assert d["duration"] == round(result.duration, 4) assert len(d["transients"]) == len(result.transients) assert "summary" in d json.dumps(d) def test_kick_free_zones(self, synthetic_drumloop): result = DrumLoopAnalyzer(synthetic_drumloop).analyze() zones = result.kick_free_zones(margin_beats=0.2) assert isinstance(zones, list) for start, end in zones: assert end > start def test_transient_positions(self, synthetic_drumloop): result = DrumLoopAnalyzer(synthetic_drumloop).analyze() all_pos = result.transient_positions() kick_pos = result.transient_positions("kick") assert len(all_pos) >= len(kick_pos) def test_real_drumloop_if_exists(self): path = Path( r"C:\Users\Administrator\Documents\fl_control\libreria\samples\drumloop" r"\drumloop_E3_120_boomy_accb48.wav" ) if not path.exists(): pytest.skip("Real drumloop not available") result = DrumLoopAnalyzer(str(path)).analyze() assert 100 <= result.bpm <= 140, f"BPM {result.bpm} unexpected" assert result.bar_count > 0 kicks = result.transients_of_type("kick") snares = result.transients_of_type("snare") assert len(kicks) > 0, "No kicks detected" assert len(snares) >= 0 class TestTransient: def test_transient_creation(self): t = Transient(time=0.5, type="kick", energy=0.8, spectral_centroid=120.0) assert t.time == 0.5 assert t.type == "kick"