"""Unit tests for src/composer/melody_engine.py — hook-based melody generation.""" import pytest from src.composer.melody_engine import ( build_motif, apply_variation, build_call_response, _resolve_chord_tones, _resolve_tension_notes, _resolve_tonic, _get_pentatonic, _get_diatonic, ) # --------------------------------------------------------------------------- # Phase 3.1 — Determinism # --------------------------------------------------------------------------- class TestMotifDeterministic: """Same seed → identical output.""" def test_motif_deterministic_hook(self): a = build_motif("A", True, "hook", 4, 42) b = build_motif("A", True, "hook", 4, 42) assert a == b, "Same seed must produce identical motfs" assert len(a) > 0, "Hook should produce notes" def test_motif_deterministic_stabs(self): a = build_motif("A", True, "stabs", 2, 1) b = build_motif("A", True, "stabs", 2, 1) assert a == b def test_motif_deterministic_smooth(self): a = build_motif("A", True, "smooth", 4, 7) b = build_motif("A", True, "smooth", 4, 7) assert a == b def test_motif_deterministic_different_style(self): hook = build_motif("A", True, "hook", 4, 42) stabs = build_motif("A", True, "stabs", 4, 42) assert hook != stabs, "Different styles should produce different output" # --------------------------------------------------------------------------- # Phase 3.2 — Different seeds # --------------------------------------------------------------------------- class TestMotifDifferentSeeds: """Different seeds → different output.""" def test_motif_different_seeds_hook(self): a = build_motif("A", True, "hook", 4, 42) b = build_motif("A", True, "hook", 4, 99) assert a != b, "Different seeds must produce different output" def test_motif_different_seeds_stabs(self): a = build_motif("A", True, "stabs", 4, 1) b = build_motif("A", True, "stabs", 4, 2) assert a != b def test_motif_different_seeds_smooth(self): a = build_motif("A", True, "smooth", 4, 1) b = build_motif("A", True, "smooth", 4, 2) assert a != b # --------------------------------------------------------------------------- # Phase 3.3 — Invalid style # --------------------------------------------------------------------------- class TestInvalidStyle: """Invalid style raises ValueError.""" def test_invalid_style_raises_value_error(self): with pytest.raises(ValueError, match="Invalid style"): build_motif("A", True, "invalid", 4, 42) def test_value_error_mentions_valid_styles(self): with pytest.raises(ValueError) as exc: build_motif("A", True, "xyz", 4, 42) msg = str(exc.value) assert "hook" in msg assert "stabs" in msg assert "smooth" in msg # --------------------------------------------------------------------------- # Phase 3.4 — Chord tones on strong beats (hook) # --------------------------------------------------------------------------- class TestHookChordTones: """Hook style: ≥70% of quarter-position notes are chord tones.""" @staticmethod def _quarter_position_notes(notes): """Return notes whose start time is on a quarter-beat boundary.""" return [n for n in notes if abs(n.start % 1.0) < 0.001] @staticmethod def _is_chord_tone(pitch, key_root, key_minor, bar, bar_offset=0): """Check if pitch belongs to the active chord at the given bar.""" chord_tones = _resolve_chord_tones(key_root, key_minor, bar) return any(abs(pitch - ct) % 12 == 0 for ct in chord_tones) def test_hook_chord_tones_on_strong_beats(self): """≥70% of notes on quarter positions are chord tones.""" notes = build_motif("A", True, "hook", 8, 42) quarter_notes = self._quarter_position_notes(notes) assert len(quarter_notes) > 0, "Hook must have notes on quarter positions" chord_count = 0 for note in quarter_notes: bar = int(note.start // 4.0) if self._is_chord_tone(note.pitch, "A", True, bar): chord_count += 1 ratio = chord_count / len(quarter_notes) assert ratio >= 0.70, ( f"Chord tone ratio on strong beats: {ratio:.1%}, need ≥ 70%\n" f"Pitches: {[n.pitch for n in quarter_notes]}" ) def test_hook_produces_notes(self): """Hook should produce a reasonable number of notes.""" notes = build_motif("A", True, "hook", 4, 42) assert 16 <= len(notes) <= 24, ( f"Expected 16-24 notes for 4-bar hook, got {len(notes)}" ) # --------------------------------------------------------------------------- # Phase 3.5 — Stabs grid alignment # --------------------------------------------------------------------------- class TestStabsGridAlignment: """Stabs: all notes on dembow positions [1.0, 2.5, 3.0, 3.5] per bar.""" DEMBOW = {1.0, 2.5, 3.0, 3.5} def test_stabs_grid_alignment(self): notes = build_motif("A", True, "stabs", 4, 1) assert len(notes) > 0, "Stabs must produce notes" for note in notes: bar_start = int(note.start // 4.0) * 4.0 pos_in_bar = note.start - bar_start # Allow tiny floating-point tolerance assert any( abs(pos_in_bar - dp) < 0.001 for dp in self.DEMBOW ), f"Note at {note.start} (pos_in_bar={pos_in_bar}) not on dembow grid" def test_stabs_duration_16th(self): """All stabs should be 16th notes (≤ 0.25 beats).""" notes = build_motif("A", True, "stabs", 4, 1) for note in notes: assert note.duration <= 0.25, ( f"Stab duration {note.duration} > 16th note" ) # --------------------------------------------------------------------------- # Phase 3.6 — Smooth stepwise motion # --------------------------------------------------------------------------- class TestSmoothStepwise: """Smooth style: consecutive notes differ by ≤ 2 semitones.""" def test_smooth_stepwise_motion(self): notes = build_motif("A", True, "smooth", 4, 7) assert len(notes) >= 8, f"Expected at least 8 notes, got {len(notes)}" sorted_notes = sorted(notes, key=lambda n: n.start) for i in range(len(sorted_notes) - 1): diff = abs(sorted_notes[i + 1].pitch - sorted_notes[i].pitch) assert diff <= 2, ( f"Step from pitch {sorted_notes[i].pitch} to {sorted_notes[i + 1].pitch} " f"at beat {sorted_notes[i + 1].start}: diff={diff} > 2" ) def test_smooth_eighth_note_density(self): """Smooth style should produce notes at roughly eighth-note spacing.""" notes = build_motif("A", True, "smooth", 4, 7) sorted_notes = sorted(notes, key=lambda n: n.start) # Each note should be ~0.5 beats apart (eighth note) # Check that most gaps are close to 0.5 gaps = [] for i in range(len(sorted_notes) - 1): gap = sorted_notes[i + 1].start - sorted_notes[i].start gaps.append(gap) avg_gap = sum(gaps) / len(gaps) if gaps else 0 assert 0.4 < avg_gap < 0.6, ( f"Expected eighth-note spacing (~0.5), got avg gap {avg_gap:.3f}" ) # --------------------------------------------------------------------------- # Phase 3.7 — Variation preserves structure # --------------------------------------------------------------------------- class TestVariation: """apply_variation() preserves note count, durations, and IOIs.""" def test_variation_preserves_note_count(self): motif = build_motif("A", True, "hook", 4, 42) variant = apply_variation(motif, shift_beats=0.25) assert len(variant) == len(motif) def test_variation_preserves_durations(self): motif = build_motif("A", True, "hook", 4, 42) variant = apply_variation(motif, shift_beats=0.25, transpose_semitones=3) for orig, var in zip(motif, variant): assert var.duration == orig.duration, ( f"Duration mismatch: {var.duration} != {orig.duration}" ) def test_variation_preserves_iois(self): """Inter-onset intervals are preserved after shift.""" motif = sorted(build_motif("A", True, "hook", 4, 42), key=lambda n: n.start) variant = sorted( apply_variation(motif, shift_beats=0.25), key=lambda n: n.start, ) for i in range(len(motif) - 1): orig_ioi = motif[i + 1].start - motif[i].start var_ioi = variant[i + 1].start - variant[i].start assert abs(orig_ioi - var_ioi) < 0.001, ( f"IOI mismatch at index {i}: {var_ioi:.4f} != {orig_ioi:.4f}" ) def test_variation_shifts_start_times(self): motif = build_motif("A", True, "hook", 4, 42) variant = apply_variation(motif, shift_beats=0.25) for orig, var in zip( sorted(motif, key=lambda n: n.start), sorted(variant, key=lambda n: n.start), ): assert abs(var.start - orig.start - 0.25) < 0.001 def test_variation_transposes_pitches(self): motif = build_motif("A", True, "hook", 4, 42) variant = apply_variation(motif, transpose_semitones=3) for orig, var in zip(motif, variant): assert var.pitch == orig.pitch + 3 def test_variation_empty_motif(self): result = apply_variation([], shift_beats=1.0) assert result == [] def test_variation_defaults(self): motif = build_motif("A", True, "hook", 4, 42) variant = apply_variation(motif) assert len(variant) == len(motif) # --------------------------------------------------------------------------- # Phase 3.8 — Call ends on tension, response on tonic # --------------------------------------------------------------------------- class TestCallResponseResolution: """build_call_response(): call → V/VII, response → tonic.""" def test_call_ends_on_tension_response_ends_on_tonic(self): """Call (first half) last note = V or VII; response last note = tonic.""" # Am: tonic = A(69), V = E(76), VII = G(79) motif = build_motif("A", True, "hook", 4, 42) result = build_call_response(motif, bars=8, key_root="A", key_minor=True, seed=42) assert len(result) > 0 # Sort by start time sorted_notes = sorted(result, key=lambda n: n.start) # Find last note of first 4 bars (call) call_cutoff = 4.0 * 4 # 4 bars * 4 beats call_notes = [n for n in sorted_notes if n.start < call_cutoff] assert len(call_notes) > 0, "No notes in call half" last_call_pitch = call_notes[-1].pitch % 12 # V of Am is E (pitch%12=4), VII is G (pitch%12=7) assert last_call_pitch in (4, 7), ( f"Last call note pitch class {last_call_pitch} must be V(4=E) or VII(7=G)" ) # Last note overall (response) must be tonic A (pitch%12=9) last_note = sorted_notes[-1].pitch % 12 assert last_note == 9, ( f"Last note pitch class {last_note} must be tonic A(9)" ) # --------------------------------------------------------------------------- # Phase 3.9 — Call-response fills bars with motif repetition # --------------------------------------------------------------------------- class TestCallResponseFillsBars: """build_call_response() fills section with motif repetition.""" def test_call_response_fills_bars(self): """A 2-bar motif repeated to fill 8 bars.""" motif = build_motif("A", True, "hook", 2, 42) result = build_call_response(motif, bars=8, key_root="A", key_minor=True, seed=42) # Total span should be ~8 bars (32 beats) max_end = max(n.start + n.duration for n in result) if result else 0 min_start = min(n.start for n in result) if result else 0 span = max_end - min_start assert span >= 28, f"Notes should span ~32 beats (8 bars), got {span}" # Motif content should repeat at least 2 times within 8 bars assert len(result) >= len(motif) * 2, ( f"Motif repeats: expected ≥{len(motif)*2} notes, got {len(result)}" ) def test_call_response_empty_motif(self): result = build_call_response([], bars=8, key_root="A", key_minor=True) assert result == [] def test_call_response_length_matches_bars(self): """Result should not exceed `bars` worth of material.""" motif = build_motif("A", True, "hook", 4, 42) for test_bars in (2, 4, 8): result = build_call_response(motif, bars=test_bars, key_root="A", key_minor=True, seed=42) max_end = max((n.start + n.duration for n in result), default=0) assert max_end <= test_bars * 4.0 + 0.001, ( f"For {test_bars} bars, max_end={max_end} exceeds {test_bars * 4.0}" ) # --------------------------------------------------------------------------- # Internal helpers tests # --------------------------------------------------------------------------- class TestResolveChordTones: """_resolve_chord_tones returns correct pitches.""" def test_chord_tones_am_bar0(self): """Bar 0 of Am should return Am chord tones (A, C, E).""" tones = _resolve_chord_tones("A", True, 0, 4) # Check for pitch classes 9(A), 0(C), 4(E) pitch_classes = {p % 12 for p in tones} assert 9 in pitch_classes, "A missing" assert 0 in pitch_classes, "C missing" assert 4 in pitch_classes, "E missing" def test_chord_tones_am_bar2(self): """Bar 2 of Am should return F major (F, A, C) — the VI chord.""" tones = _resolve_chord_tones("A", True, 2, 4) pitch_classes = {p % 12 for p in tones} assert 5 in pitch_classes, "F missing" # F = 5 assert 9 in pitch_classes, "A missing" # A = 9 assert 0 in pitch_classes, "C missing" # C = 0 def test_chord_tones_wraps(self): """Bar 8 wraps back to chord i.""" tones0 = _resolve_chord_tones("A", True, 0, 4) tones8 = _resolve_chord_tones("A", True, 8, 4) p0 = {p % 12 for p in tones0} p8 = {p % 12 for p in tones8} assert p0 == p8, "Bar 0 and bar 8 should have same chord tones (wrapped)" class TestResolveTensionNotes: """_resolve_tension_notes returns correct V and VII.""" def test_tension_notes_am(self): v_pitch, vii_pitch = _resolve_tension_notes("A", True, 4) # V of A = E (MIDI 69 + 7 = 76) assert v_pitch == 76, f"V of Am should be E (76), got {v_pitch}" # VII of Am minor = G (MIDI 69 + 10 = 79) assert vii_pitch == 79, f"VII of Am should be G (79), got {vii_pitch}" def test_tension_notes_cm(self): v_pitch, vii_pitch = _resolve_tension_notes("C", True, 4) # C4=60, V=G4=60+7=67 assert v_pitch == 67, f"V of Cm should be G4 (67), got {v_pitch}" # VII of C minor = Bb4 = 60+10=70 assert vii_pitch == 70, f"VII of Cm should be Bb4 (70), got {vii_pitch}" class TestResolveTonic: """_resolve_tonic returns correct pitch.""" def test_tonic_am(self): assert _resolve_tonic("A", 4) == 69 # A4 def test_tonic_dm(self): assert _resolve_tonic("D", 4) == 62 # D4 class TestScaleHelpers: """_get_pentatonic and _get_diatonic.""" def test_pentatonic_am(self): notes = _get_pentatonic("A", True, 4) pitch_classes = {n % 12 for n in notes} assert pitch_classes == {9, 0, 2, 4, 7}, ( f"Am pentatonic: A C D E G, got {pitch_classes}" ) def test_pentatonic_c_major(self): notes = _get_pentatonic("C", False, 4) pitch_classes = {n % 12 for n in notes} assert pitch_classes == {0, 2, 4, 7, 9}, ( f"C major pentatonic: C D E G A, got {pitch_classes}" ) def test_diatonic_am(self): notes = _get_diatonic("A", True, 4) pitch_classes = {n % 12 for n in notes} assert pitch_classes == {9, 11, 0, 2, 4, 5, 7}, ( f"Am natural minor: A B C D E F G, got {pitch_classes}" ) def test_diatonic_c_major(self): notes = _get_diatonic("C", False, 4) pitch_classes = {n % 12 for n in notes} assert pitch_classes == {0, 2, 4, 5, 7, 9, 11}, ( f"C major: C D E F G A B, got {pitch_classes}" ) # --------------------------------------------------------------------------- # Cross-style tests # --------------------------------------------------------------------------- class TestCrossStyle: """Tests covering all three styles.""" def test_all_styles_return_midi_notes(self): for style in ("hook", "stabs", "smooth"): notes = build_motif("A", True, style, 4, 0) assert isinstance(notes, list) assert len(notes) > 0, f"Style '{style}' returned empty list" assert all(hasattr(n, "pitch") for n in notes) assert all(hasattr(n, "start") for n in notes) def test_different_key_produces_different_output(self): am = build_motif("A", True, "hook", 4, 42) dm = build_motif("D", True, "hook", 4, 42) assert am != dm, "Different keys should produce different motifs" def test_major_key_produces_notes(self): notes = build_motif("C", False, "hook", 4, 42) assert len(notes) > 0