diff --git a/src/backend/ai/ai_clients.py b/src/backend/ai/ai_clients.py index 496ac2e..a6ee0f0 100644 --- a/src/backend/ai/ai_clients.py +++ b/src/backend/ai/ai_clients.py @@ -1,25 +1,93 @@ -""" -AI Client Integrations for GLM4.6 and Minimax M2 -Handles communication with AI APIs for chat and music generation -""" +"""AI client integrations that route chat + project generation to GLM/Claude.""" -import os +import asyncio import json import logging +import shutil import aiohttp from typing import Dict, List, Optional, Any from decouple import config +from als.sample_library import SampleLibrary + + +def _clean_base_url(url: str) -> str: + """Ensure base URLs don't end with trailing slashes.""" + return url.rstrip('/') if url else url + logger = logging.getLogger(__name__) +class ClaudeCLIClient: + """Proxy that talks to the local `claude` CLI so we reuse user's setup.""" + + def __init__(self): + binary = config('CLAUDE_CLI_BIN', default='claude') + self.binary = shutil.which(binary) + self.model = config('CLAUDE_CLI_MODEL', default=config('GLM46_MODEL', default='glm-4.6')) + self.available = bool(self.binary) + if not self.available: + logger.warning("Claude CLI binary '%s' not found in PATH", binary) + + async def complete(self, prompt: str, system_prompt: Optional[str] = None) -> str: + if not self.available: + return "Error: Claude CLI not available" + + cmd = [ + self.binary, + '--print', + '--output-format', 'json', + '--model', self.model, + '--dangerously-skip-permissions', + ] + if system_prompt: + cmd += ['--system-prompt', system_prompt] + cmd.append(prompt) + + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + if proc.returncode != 0: + logger.error("Claude CLI failed (%s): %s", proc.returncode, stderr.decode().strip()) + return f"Error: Claude CLI exited with code {proc.returncode}" + + output = stdout.decode().strip() + json_line = None + for line in reversed(output.splitlines()): + if line.strip(): + json_line = line.strip() + break + + if not json_line: + logger.error("Claude CLI produced no JSON output") + return "Error: Empty response from Claude CLI" + + try: + payload = json.loads(json_line) + except json.JSONDecodeError as exc: + logger.error("Failed to parse Claude CLI output: %s", exc) + return output + + result = payload.get('result') or payload.get('output') + if not result: + logger.warning("Claude CLI JSON missing 'result': %s", payload) + return "Error: Invalid Claude CLI response" + return result + + class GLM46Client: """Client for GLM4.6 API - Optimized for structured generation""" def __init__(self): self.api_key = config('GLM46_API_KEY', default='') - self.base_url = config('GLM46_BASE_URL', default='https://api.z.ai/api/paas/v4') + self.base_url = _clean_base_url(config('GLM46_BASE_URL', default='https://api.z.ai/api/paas/v4')) self.model = config('GLM46_MODEL', default='glm-4.6') + self.anthropic_token = config('ANTHROPIC_AUTH_TOKEN', default='') + anthropic_base = config('ANTHROPIC_BASE_URL', default='').strip() + self.anthropic_base_url = _clean_base_url(anthropic_base or 'https://api.z.ai/api/anthropic') async def complete(self, prompt: str, **kwargs) -> str: """ @@ -33,6 +101,9 @@ class GLM46Client: str: AI response """ if not self.api_key: + if self.anthropic_token: + logger.info("GLM46 API key missing, using Anthropic-compatible endpoint") + return await self._anthropic_complete(prompt, **kwargs) logger.warning("GLM46_API_KEY not configured") return "Error: GLM46 API key not configured" @@ -66,6 +137,53 @@ class GLM46Client: return f"Error: API request failed with status {response.status}" except Exception as e: logger.error(f"GLM46 request failed: {e}") + if self.anthropic_token: + logger.info("Falling back to Anthropic-compatible endpoint for GLM4.6") + return await self._anthropic_complete(prompt, **kwargs) + return f"Error: {str(e)}" + + async def _anthropic_complete(self, prompt: str, **kwargs) -> str: + """Call GLM4.6 through the Anthropic-compatible proxy the user configured.""" + headers = { + 'Authorization': f'Bearer {self.anthropic_token}', + 'Content-Type': 'application/json', + 'anthropic-version': '2023-06-01' + } + + if isinstance(prompt, str) and 'messages' not in kwargs: + messages = [{'role': 'user', 'content': prompt}] + else: + messages = kwargs.get('messages', [{'role': 'user', 'content': prompt}]) + + data = { + 'model': self.model, + 'max_tokens': kwargs.get('max_tokens', 1024), + 'messages': messages, + } + + for tuning_key in ('temperature', 'top_p', 'top_k'): + if tuning_key in kwargs and kwargs[tuning_key] is not None: + data[tuning_key] = kwargs[tuning_key] + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f'{self.anthropic_base_url}/messages', + headers=headers, + json=data, + timeout=60 + ) as response: + if response.status == 200: + result = await response.json() + for content_block in result.get('content', []): + if content_block.get('type') == 'text': + return content_block.get('text', '') + return "Error: No text content in response" + error_text = await response.text() + logger.error(f"Anthropic GLM4.6 error: {response.status} - {error_text}") + return f"Error: API request failed with status {response.status}" + except Exception as e: + logger.error(f"Anthropic GLM4.6 request failed: {e}") return f"Error: {str(e)}" async def analyze_music_request(self, user_message: str) -> Dict[str, Any]: @@ -118,8 +236,11 @@ class MinimaxM2Client: def __init__(self): self.api_key = config('ANTHROPIC_AUTH_TOKEN', default='') - self.base_url = config('MINIMAX_BASE_URL', default='https://api.minimax.io/anthropic') - self.model = config('MINIMAX_MODEL', default='MiniMax-M2') + base_override = config('ANTHROPIC_BASE_URL', default='').strip() + default_base = config('MINIMAX_BASE_URL', default='https://api.minimax.io/anthropic') + self.base_url = _clean_base_url(base_override or default_base) + default_model = config('ANTHROPIC_CHAT_MODEL', default='glm-4.6') + self.model = config('MINIMAX_MODEL', default=default_model) async def complete(self, prompt: str, **kwargs) -> str: """ @@ -170,7 +291,7 @@ class MinimaxM2Client: for content_block in result.get('content', []): if content_block.get('type') == 'text': return content_block.get('text', '') - return "No text content in response" + return "Error: No text content in response" else: error_text = await response.text() logger.error(f"Minimax API error: {response.status} - {error_text}") @@ -212,6 +333,8 @@ class AIOrchestrator: def __init__(self): self.glm_client = GLM46Client() self.minimax_client = MinimaxM2Client() + self.claude_cli = ClaudeCLIClient() + self.sample_library = SampleLibrary() self.mock_mode = config('MOCK_MODE', default='false').lower() == 'true' async def process_request(self, message: str, request_type: str = 'chat') -> str: @@ -226,9 +349,15 @@ class AIOrchestrator: str: AI response """ if request_type == 'generate' or request_type == 'analyze': - # Use GLM4.6 for structured tasks + # Use GLM4.6 for structured tasks and fall back to CLI if needed logger.info("Using GLM4.6 for structured generation") - return await self.glm_client.complete(message) + response = await self.glm_client.complete(message) + if response.startswith("Error:") and self.claude_cli.available: + logger.info("GLM4.6 HTTP failed, trying Claude CLI") + cli_response = await self.claude_cli.complete(message) + if not cli_response.startswith("Error:"): + return cli_response + return response else: # Try Minimax M2 first, fall back to GLM4.6 try: @@ -344,13 +473,15 @@ class AIOrchestrator: 'color': 21 }) - return { + config = { 'name': project_name, 'bpm': bpm, 'key': key, 'tracks': tracks } + return self.sample_library.populate_project(config) + async def generate_music_project(self, user_message: str) -> Dict[str, Any]: """ Generate complete music project configuration with mock mode fallback. @@ -409,9 +540,18 @@ class AIOrchestrator: try: config = json.loads(response) logger.info(f"Generated project config: {config['name']}") - return config + return self.sample_library.populate_project(config) except json.JSONDecodeError as e: logger.error(f"Failed to parse project config: {e}") + elif self.claude_cli.available: + logger.info("Retrying project generation through Claude CLI") + cli_response = await self.claude_cli.complete(prompt) + if not cli_response.startswith("Error:"): + try: + config = json.loads(cli_response) + return self.sample_library.populate_project(config) + except json.JSONDecodeError as e: + logger.error(f"Claude CLI project JSON parse error: {e}") except Exception as e: logger.warning(f"GLM4.6 project generation failed: {e}") @@ -455,6 +595,8 @@ class AIOrchestrator: if self.mock_mode: return self._get_mock_chat_response(message) + system_prompt = """You are MusiaIA, an AI assistant specialized in music creation.\nYou help users generate Ableton Live projects through natural conversation.\nBe friendly, helpful, and creative. Keep responses concise but informative.""" + # Try using the minimax chat method, but fall back if it fails try: response = await self.minimax_client.chat(message, history) @@ -479,6 +621,20 @@ class AIOrchestrator: except Exception as e: logger.warning(f"GLM4.6 error: {e}") + if self.claude_cli.available: + try: + context_str = "" + if history: + context_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in history[-5:]]) + context_str += "\n" + prompt = f"{context_str}User: {message}\n\nAssistant:" + cli_response = await self.claude_cli.complete(prompt, system_prompt=system_prompt) + if not cli_response.startswith("Error:"): + return cli_response + logger.warning(f"Claude CLI chat failed: {cli_response}") + except Exception as exc: + logger.warning(f"Claude CLI error: {exc}") + # Final fallback to mock logger.info("All APIs failed, using mock response") return self._get_mock_chat_response(message) diff --git a/src/backend/als/als_generator.py b/src/backend/als/als_generator.py index c721fcc..1a8723e 100644 --- a/src/backend/als/als_generator.py +++ b/src/backend/als/als_generator.py @@ -5,6 +5,7 @@ ALS Generator - Core component for creating Ableton Live Set files import gzip import os import random +import shutil import uuid from datetime import datetime from pathlib import Path @@ -24,6 +25,7 @@ class ALSGenerator: self.output_dir = Path(output_dir or "/home/ren/musia/output/als") self.output_dir.mkdir(parents=True, exist_ok=True) self.next_id = 1000 + self.sample_root = Path(os.environ.get('SAMPLE_LIBRARY_PATH', '/home/ren/musia/source')) def generate_project(self, config: Dict[str, Any]) -> str: """ @@ -56,8 +58,11 @@ class ALSGenerator: samples_dir = als_folder / "Samples" / "Imported" samples_dir.mkdir(parents=True, exist_ok=True) + # Resolve and copy samples into the project folder + config = self._prepare_samples(config, samples_dir, als_folder) + # Generate XML content - xml_content = self._build_als_xml(config, samples_dir) + xml_content = self._build_als_xml(config) # Write ALS file (gzip compressed XML) als_file_path = als_folder / f"{project_name}.als" @@ -70,7 +75,57 @@ class ALSGenerator: logger.info(f"ALS project generated: {als_file_path}") return str(als_file_path) - def _build_als_xml(self, config: Dict[str, Any], samples_dir: Path) -> str: + def _prepare_samples(self, config: Dict[str, Any], samples_dir: Path, project_root: Path) -> Dict[str, Any]: + """Copy referenced samples into the Ableton project and make paths relative.""" + for track in config.get('tracks', []): + prepared_samples: List[str] = [] + for sample_entry in track.get('samples', []) or []: + resolved = self._resolve_sample_path(sample_entry) + if not resolved: + logger.warning("Sample %s could not be resolved", sample_entry) + continue + + copied = self._copy_sample(resolved, samples_dir) + try: + relative_path = copied.relative_to(project_root) + except ValueError: + relative_path = copied.name + prepared_samples.append(str(relative_path)) + + track['samples'] = prepared_samples + + return config + + def _resolve_sample_path(self, sample_entry: str) -> Optional[Path]: + if not sample_entry: + return None + + candidate = Path(sample_entry) + if candidate.is_absolute() and candidate.exists(): + return candidate + + if candidate.exists(): + return candidate.resolve() + + if self.sample_root: + potential = self.sample_root / sample_entry + if potential.exists(): + return potential.resolve() + + return None + + def _copy_sample(self, source: Path, samples_dir: Path) -> Path: + samples_dir.mkdir(parents=True, exist_ok=True) + destination = samples_dir / source.name + counter = 1 + while destination.exists(): + destination = samples_dir / f"{source.stem}_{counter}{source.suffix}" + counter += 1 + + shutil.copy2(source, destination) + return destination + + def _build_als_xml(self, config: Dict[str, Any]) -> str: """Build the complete XML structure for ALS file.""" # Create root element root = self._create_root_element() diff --git a/src/backend/als/sample_library.py b/src/backend/als/sample_library.py new file mode 100644 index 0000000..cc3958b --- /dev/null +++ b/src/backend/als/sample_library.py @@ -0,0 +1,132 @@ +"""Utility helpers to attach real audio samples from the local library.""" + +import logging +import os +import random +from pathlib import Path +from typing import Dict, List, Optional, Any + +logger = logging.getLogger(__name__) + + +class SampleLibrary: + """Loads audio files from /source and serves suggestions per track.""" + + SUPPORTED_EXTENSIONS = {'.wav', '.aiff', '.aif', '.flac', '.mp3'} + + TRACK_HINTS: Dict[str, List[str]] = { + 'drum': ['kicks', 'snares', 'hats'], + 'percussion': ['percussion', 'hats'], + 'perc': ['percussion', 'hats'], + 'kick': ['kicks'], + 'snare': ['snares'], + 'hat': ['hats'], + 'bass': ['bass'], + 'lead': ['leads'], + 'synth': ['leads'], + 'pad': ['pads'], + 'fx': ['fx'], + 'vocal': ['vox'], + 'vox': ['vox'], + } + + def __init__(self, root_dir: Optional[str] = None): + default_root = os.environ.get('SAMPLE_LIBRARY_PATH', '/home/ren/musia/source') + self.root_dir = Path(root_dir or default_root) + self.samples_by_category = self._scan_library() + + def populate_project(self, config: Dict[str, Any]) -> Dict[str, Any]: + """Ensure each track in the config references valid sample files.""" + if not self.samples_by_category: + return config + + for track in config.get('tracks', []): + self._assign_samples_to_track(track) + return config + + def _scan_library(self) -> Dict[str, List[Path]]: + """Scan the source folder once and cache files per category.""" + samples: Dict[str, List[Path]] = {} + + if not self.root_dir.exists(): + logger.warning("Sample library not found at %s", self.root_dir) + return samples + + for category_dir in self.root_dir.iterdir(): + if not category_dir.is_dir(): + continue + files = [ + path for path in category_dir.rglob('*') + if path.is_file() and path.suffix.lower() in self.SUPPORTED_EXTENSIONS + ] + if files: + samples[category_dir.name.lower()] = files + + if not samples: + logger.warning("No audio files found under %s", self.root_dir) + + return samples + + def _assign_samples_to_track(self, track: Dict[str, Any]) -> None: + resolved: List[Path] = [] + + for hint in track.get('samples', []) or []: + sample_path = self._resolve_hint(hint) + if sample_path: + resolved.append(sample_path) + + if not resolved: + categories = self._infer_categories(track) + for category in categories: + sample = self._pick_from_category(category) + if sample: + resolved.append(sample) + + track['samples'] = [str(path) for path in resolved] + + def _resolve_hint(self, hint: str) -> Optional[Path]: + if not hint: + return None + + candidate = Path(hint) + if candidate.is_absolute() and candidate.exists(): + return candidate + + if candidate.exists(): + return candidate.resolve() + + normalized = hint.replace('\\', '/').lower() + category = normalized.split('/')[0] + return self._pick_from_category(category) + + def _infer_categories(self, track: Dict[str, Any]) -> List[str]: + name = (track.get('name') or '').lower() + categories: List[str] = [] + + for token, mapped in self.TRACK_HINTS.items(): + if token in name: + categories.extend(mapped) + + if not categories: + track_type = (track.get('type') or '').lower() + if 'bass' in name or track_type == 'bass': + categories.append('bass') + elif 'midi' in track_type: + categories.append('leads') + else: + categories.append('fx') + + # Remove duplicates while preserving order + seen = set() + unique_categories = [] + for category in categories: + if category not in seen: + seen.add(category) + unique_categories.append(category) + return unique_categories + + def _pick_from_category(self, category: str) -> Optional[Path]: + files = self.samples_by_category.get(category.lower()) + if not files: + return None + return random.choice(files)