diff --git a/src/backend/als/als_generator.py b/src/backend/als/als_generator.py index 592eed3..b9a63bc 100644 --- a/src/backend/als/als_generator.py +++ b/src/backend/als/als_generator.py @@ -10,12 +10,16 @@ import shutil import uuid from collections import defaultdict from datetime import datetime -from pathlib import Path +from pathlib import Path, PurePosixPath from typing import Dict, List, Any, Optional from xml.etree.ElementTree import Element, SubElement, tostring, ElementTree import logging from decouple import config +try: + from .pattern_library import get_genre_pattern +except ImportError: + from pattern_library import get_genre_pattern logger = logging.getLogger(__name__) @@ -93,8 +97,17 @@ class ALSGenerator: """Copy referenced samples into the Ableton project and make paths relative.""" for track in config.get('tracks', []): prepared_samples: List[str] = [] - for sample_entry in track.get('samples', []) or []: + sample_entries = track.get('samples') or [] + if not sample_entries: + fallback = self._select_default_sample(track) + if fallback is not None: + sample_entries = [str(fallback)] + + for sample_entry in sample_entries: resolved = self._resolve_sample_path(sample_entry) + if resolved is None: + fallback = self._select_default_sample(track) + resolved = fallback if not resolved: logger.warning("Sample %s could not be resolved", sample_entry) continue @@ -125,6 +138,34 @@ class ALSGenerator: return None + def _select_default_sample(self, track: Dict[str, Any]) -> Optional[Path]: + if not self.sample_root or not self.sample_root.exists(): + return None + + name = (track.get('name') or track.get('type') or '').lower() + folder = None + if 'kick' in name: + folder = 'Kicks' + elif 'snare' in name or 'clap' in name: + folder = 'Snares' + elif 'hat' in name: + folder = 'Hi Hats' + elif 'tom' in name: + folder = 'Toms' + elif 'perc' in name: + folder = 'Percussions' + elif 'fx' in name or 'sfx' in name: + folder = 'FX & Rolls' + + search_root = self.sample_root / folder if folder else self.sample_root + candidates = list(search_root.glob('**/*.wav')) + if not candidates: + return None + return random.choice(candidates).resolve() + + def _infer_genre(self, config: Dict[str, Any]) -> Optional[str]: + return config.get('genre') or config.get('style') or config.get('name') + def _copy_sample(self, source: Path, samples_dir: Path) -> Path: samples_dir.mkdir(parents=True, exist_ok=True) destination = samples_dir / source.name @@ -204,14 +245,15 @@ class ALSGenerator: return requested_tracks = config.get('tracks', []) + genre = self._infer_genre(config) available_tracks = [ track for track in tracks_element - if track.tag in ('GroupTrack', 'MidiTrack', 'AudioTrack') + if track.tag == 'MidiTrack' ] for idx, track_element in enumerate(available_tracks): if idx < len(requested_tracks): - self._configure_template_track(track_element, requested_tracks[idx], idx) + self._configure_template_track(track_element, requested_tracks[idx], idx, genre) else: self._reset_template_track(track_element, idx) @@ -227,7 +269,7 @@ class ALSGenerator: if scene_names is not None: self._populate_scene_names(scene_names, config) - def _configure_template_track(self, track_element: Element, track_config: Dict[str, Any], index: int) -> None: + def _configure_template_track(self, track_element: Element, track_config: Dict[str, Any], index: int, genre: Optional[str]) -> None: track_name = track_config.get('name') or f"Track {index + 1}" self._set_track_name(track_element, track_name) @@ -237,8 +279,11 @@ class ALSGenerator: self._clear_arranger_automation(track_element) + if track_config.get('samples'): + self._assign_samples_to_simpler(track_element, track_config['samples']) + if track_element.tag == 'MidiTrack': - midi_config = track_config.get('midi') or {} + midi_config = track_config.get('midi') or self._build_sample_midi_config(track_config, genre) self._populate_template_midi_clip(track_element, midi_config, track_name, color) def _reset_template_track(self, track_element: Element, index: int) -> None: @@ -279,10 +324,86 @@ class ALSGenerator: events = arranger.find('Events') if events is not None: + for midi_clip in list(events.findall('MidiClip')): + events.remove(midi_clip) + else: + events = SubElement(arranger, 'Events') events.clear() - for midi_clip in list(arranger.findall('MidiClip')): - arranger.remove(midi_clip) + def _assign_samples_to_simpler(self, track_element: Element, samples: List[str]) -> None: + if not samples: + return + + simpler = track_element.find('.//OriginalSimpler') + if simpler is None: + return + + file_ref = simpler.find('.//SampleRef/FileRef') + if file_ref is None: + sample_ref = simpler.find('.//SampleRef') + if sample_ref is None: + return + file_ref = SubElement(sample_ref, 'FileRef') + + while list(file_ref): + file_ref.remove(list(file_ref)[0]) + + SubElement(file_ref, 'HasRelativePath', Value='true') + SubElement(file_ref, 'RelativePathType', Value='1') + relative_path = SubElement(file_ref, 'RelativePath') + + sample_path = PurePosixPath(samples[0]) + dirs = list(sample_path.parent.parts) + for idx, part in enumerate(dirs, start=1): + SubElement(relative_path, 'RelativePathElement', Id=str(idx), Dir=part) + + file_name = sample_path.name or os.path.basename(samples[0]) + SubElement(file_ref, 'Name', Value=file_name) + SubElement(file_ref, 'RefersToFolder', Value='false') + SubElement(file_ref, 'LivePackName', Value='') + SubElement(file_ref, 'LivePackId', Value='0') + + def _build_sample_midi_config(self, track_config: Dict[str, Any], genre: Optional[str]) -> Dict[str, Any]: + track_name = track_config.get('name') or '' + name = track_name.lower() + pattern = get_genre_pattern(genre, track_name) + if pattern: + return { + 'notes': pattern, + 'length': 4, + 'velocity': 110, + 'duration': 0.4, + 'spacing': 0.5, + 'offset': pattern[0]['time'] if pattern else 0.0 + } + + if 'kick' in name: + midi_note = 36 + hits = [0, 1, 2, 3] + elif any(keyword in name for keyword in ('snare', 'clap')): + midi_note = 38 + hits = [1, 3] + elif 'hat' in name or 'perc' in name: + midi_note = 42 + hits = [i * 0.5 for i in range(8)] + else: + midi_note = 48 + hits = [0, 0.5, 1.5, 2, 3] + + notes = [{ + 'note': midi_note, + 'time': t, + 'duration': 0.4, + 'velocity': 100 + } for t in hits] + + return { + 'notes': notes, + 'length': max(hits) + 1 if hits else 4, + 'velocity': 100, + 'duration': 0.4, + 'spacing': 0.5 + } def _populate_template_midi_clip( self, @@ -295,9 +416,13 @@ class ALSGenerator: if arranger is None: return + events_container = arranger.find('Events') + if events_container is None: + events_container = SubElement(arranger, 'Events') + clip_color = str(color) if color is not None else '36' midi_clip = self._create_template_midi_clip(track_name, midi_config, clip_color) - arranger.append(midi_clip) + events_container.append(midi_clip) def _create_template_midi_clip(self, track_name: str, midi_config: Dict[str, Any], clip_color: str) -> Element: clip = Element('MidiClip', { @@ -319,7 +444,7 @@ class ALSGenerator: clip_end = self._format_clip_value(clip_length) - SubElement(clip, 'CurrentStart', Value='0') + SubElement(clip, 'CurrentStart', Value=self._format_clip_value(midi_config.get('offset', 0))) SubElement(clip, 'CurrentEnd', Value=clip_end) loop = SubElement(clip, 'Loop') diff --git a/src/backend/als/pattern_library.py b/src/backend/als/pattern_library.py new file mode 100644 index 0000000..c9e34c2 --- /dev/null +++ b/src/backend/als/pattern_library.py @@ -0,0 +1,76 @@ +"""Pattern library for predefined groove templates.""" + +from typing import List, Dict, Optional + +SalsaPattern = List[Dict[str, float]] + +SALSA_PATTERNS: Dict[str, SalsaPattern] = { + 'kick': [ + {'note': 36, 'time': 0.0, 'duration': 0.25, 'velocity': 118}, + {'note': 36, 'time': 1.5, 'duration': 0.25, 'velocity': 112}, + {'note': 36, 'time': 2.5, 'duration': 0.25, 'velocity': 120}, + {'note': 36, 'time': 3.75, 'duration': 0.25, 'velocity': 115}, + ], + 'clap': [ + {'note': 39, 'time': 1.0, 'duration': 0.25, 'velocity': 105}, + {'note': 39, 'time': 2.5, 'duration': 0.25, 'velocity': 110}, + ], + 'snare': [ + {'note': 38, 'time': 0.75, 'duration': 0.25, 'velocity': 108}, + {'note': 38, 'time': 1.25, 'duration': 0.2, 'velocity': 116}, + {'note': 38, 'time': 2.25, 'duration': 0.25, 'velocity': 110}, + {'note': 38, 'time': 3.0, 'duration': 0.3, 'velocity': 118}, + ], + 'hihat': [ + {'note': 42, 'time': beat, 'duration': 0.15, 'velocity': 96 + (idx % 2) * 10} + for idx, beat in enumerate([i * 0.5 for i in range(8)]) + ], + 'perc': [ + {'note': 64, 'time': 0.5, 'duration': 0.25, 'velocity': 108}, + {'note': 65, 'time': 1.0, 'duration': 0.25, 'velocity': 120}, + {'note': 64, 'time': 1.75, 'duration': 0.25, 'velocity': 110}, + {'note': 65, 'time': 2.25, 'duration': 0.25, 'velocity': 118}, + {'note': 64, 'time': 3.0, 'duration': 0.25, 'velocity': 112}, + {'note': 65, 'time': 3.5, 'duration': 0.25, 'velocity': 124}, + ], + 'fx': [ + {'note': 81, 'time': 0.0, 'duration': 0.5, 'velocity': 90}, + {'note': 83, 'time': 3.75, 'duration': 0.5, 'velocity': 105}, + ], + 'bass': [ + {'note': 35, 'time': 0.0, 'duration': 0.4, 'velocity': 110}, + {'note': 42, 'time': 0.5, 'duration': 0.4, 'velocity': 100}, + {'note': 47, 'time': 1.25, 'duration': 0.4, 'velocity': 112}, + {'note': 42, 'time': 1.75, 'duration': 0.4, 'velocity': 108}, + {'note': 40, 'time': 2.5, 'duration': 0.4, 'velocity': 115}, + {'note': 47, 'time': 3.0, 'duration': 0.4, 'velocity': 118}, + {'note': 42, 'time': 3.5, 'duration': 0.4, 'velocity': 100}, + {'note': 35, 'time': 3.75, 'duration': 0.4, 'velocity': 120}, + ], +} + + +def get_genre_pattern(genre: Optional[str], track_name: str) -> Optional[SalsaPattern]: + if not genre: + return None + + genre = genre.lower() + if 'salsa' not in genre: + return None + + key = 'perc' + name = track_name.lower() + if 'kick' in name: + key = 'kick' + elif 'snare' in name or 'rim' in name: + key = 'snare' + elif 'hat' in name or 'ride' in name or 'bell' in name: + key = 'hihat' + elif 'clap' in name: + key = 'clap' + elif 'fx' in name or 'timb' in name: + key = 'fx' + elif 'bass' in name: + key = 'bass' + + return SALSA_PATTERNS.get(key, SALSA_PATTERNS['perc'])