""" Parallel AI Provider - Race multiple providers for fastest response Implements Strategy A: Parallel Generation with Consensus """ import asyncio import logging from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from typing import Dict, List, Optional, Any from datetime import datetime from core import AIProcessingError from .base_provider import AIProvider @dataclass class ProviderResult: """Result from a single provider""" provider_name: str content: str duration_ms: int success: bool error: Optional[str] = None quality_score: float = 0.0 @dataclass class ParallelResult: """Aggregated result from parallel execution""" content: str strategy: str providers_used: List[str] total_duration_ms: int all_results: List[ProviderResult] selected_provider: str class ParallelAIProvider: """ Orchestrates multiple AI providers in parallel for faster responses. Strategies: - "race": Use first successful response (fastest) - "consensus": Wait for all, select best quality - "majority": Select most common response """ def __init__(self, providers: Dict[str, AIProvider], max_workers: int = 4): self.providers = providers self.max_workers = max_workers self.logger = logging.getLogger(__name__) self.executor = ThreadPoolExecutor(max_workers=max_workers) def _generate_sync(self, provider: AIProvider, prompt: str, **kwargs) -> ProviderResult: """Synchronous wrapper for provider generation""" start_time = datetime.now() try: content = provider.generate_text(prompt, **kwargs) duration_ms = int((datetime.now() - start_time).total_seconds() * 1000) # Calculate quality score quality_score = self._calculate_quality_score(content) return ProviderResult( provider_name=provider.name, content=content, duration_ms=duration_ms, success=True, quality_score=quality_score ) except Exception as e: duration_ms = int((datetime.now() - start_time).total_seconds() * 1000) self.logger.error(f"{provider.name} failed: {e}") return ProviderResult( provider_name=provider.name, content="", duration_ms=duration_ms, success=False, error=str(e) ) def _calculate_quality_score(self, content: str) -> float: """Calculate quality score for generated content""" score = 0.0 # Length check (comprehensive is better) if 500 < len(content) < 50000: score += 0.2 # LaTeX structure validation latex_indicators = [ r"\documentclass", r"\begin{document}", r"\section", r"\subsection", r"\begin{itemize}", r"\end{document}" ] found_indicators = sum(1 for ind in latex_indicators if ind in content) score += (found_indicators / len(latex_indicators)) * 0.4 # Bracket matching if content.count("{") == content.count("}"): score += 0.2 # Environment closure envs = ["document", "itemize", "enumerate"] for env in envs: if content.count(f"\\begin{{{env}}}") == content.count(f"\\end{{{env}}}"): score += 0.1 # Has content beyond template if len(content) > 1000: score += 0.1 return min(score, 1.0) def generate_parallel( self, prompt: str, strategy: str = "race", timeout_ms: int = 300000, # 5 minutes default **kwargs ) -> ParallelResult: """ Execute prompt across multiple providers in parallel. Args: prompt: The prompt to send to all providers strategy: "race", "consensus", or "majority" timeout_ms: Maximum time to wait for results **kwargs: Additional arguments for providers Returns: ParallelResult with selected content and metadata """ if not self.providers: raise AIProcessingError("No providers available for parallel execution") start_time = datetime.now() all_results: List[ProviderResult] = [] # Submit all providers futures = {} for name, provider in self.providers.items(): if provider.is_available(): future = self.executor.submit( self._generate_sync, provider, prompt, **kwargs ) futures[future] = name # Wait for results based on strategy if strategy == "race": all_results = self._race_strategy(futures, timeout_ms) elif strategy == "consensus": all_results = self._consensus_strategy(futures, timeout_ms) elif strategy == "majority": all_results = self._majority_strategy(futures, timeout_ms) else: raise ValueError(f"Unknown strategy: {strategy}") # Select best result selected = self._select_result(all_results, strategy) total_duration_ms = int((datetime.now() - start_time).total_seconds() * 1000) self.logger.info( f"Parallel generation complete: {strategy} strategy, " f"{len(all_results)} providers, {selected.provider_name} selected, " f"{total_duration_ms}ms" ) return ParallelResult( content=selected.content, strategy=strategy, providers_used=[r.provider_name for r in all_results if r.success], total_duration_ms=total_duration_ms, all_results=all_results, selected_provider=selected.provider_name ) def _race_strategy( self, futures: dict, timeout_ms: int ) -> List[ProviderResult]: """Return first successful response""" results = [] for future in as_completed(futures, timeout=timeout_ms / 1000): try: result = future.result() results.append(result) if result.success: # Got a successful response, cancel remaining for f in futures: f.cancel() break except Exception as e: self.logger.error(f"Future failed: {e}") return results def _consensus_strategy( self, futures: dict, timeout_ms: int ) -> List[ProviderResult]: """Wait for all, return all results""" results = [] for future in as_completed(futures, timeout=timeout_ms / 1000): try: result = future.result() results.append(result) except Exception as e: self.logger.error(f"Future failed: {e}") return results def _majority_strategy( self, futures: dict, timeout_ms: int ) -> List[ProviderResult]: """Wait for majority, select most common response""" results = [] for future in as_completed(futures, timeout=timeout_ms / 1000): try: result = future.result() results.append(result) except Exception as e: self.logger.error(f"Future failed: {e}") return results def _select_result(self, results: List[ProviderResult], strategy: str) -> ProviderResult: """Select best result based on strategy""" successful = [r for r in results if r.success] if not successful: # Return first failed result with error info return results[0] if results else ProviderResult( provider_name="none", content="", duration_ms=0, success=False, error="All providers failed" ) if strategy == "race" or len(successful) == 1: return successful[0] if strategy == "consensus": # Select by quality score return max(successful, key=lambda r: r.quality_score) if strategy == "majority": # Group by similar content (simplified - use longest) return max(successful, key=lambda r: len(r.content)) return successful[0] def fix_latex_parallel( self, latex_code: str, error_log: str, timeout_ms: int = 120000, **kwargs ) -> ParallelResult: """Try to fix LaTeX across multiple providers in parallel""" # Build fix prompt for each provider results = [] start_time = datetime.now() for name, provider in self.providers.items(): if provider.is_available(): try: start = datetime.now() fixed = provider.fix_latex(latex_code, error_log, **kwargs) duration_ms = int((datetime.now() - start).total_seconds() * 1000) # Score by checking if error patterns are reduced quality = self._score_latex_fix(fixed, error_log) results.append(ProviderResult( provider_name=name, content=fixed, duration_ms=duration_ms, success=True, quality_score=quality )) except Exception as e: self.logger.error(f"{name} fix failed: {e}") # Select best fix if results: selected = max(results, key=lambda r: r.quality_score) total_duration_ms = int((datetime.now() - start_time).total_seconds() * 1000) return ParallelResult( content=selected.content, strategy="consensus", providers_used=[r.provider_name for r in results], total_duration_ms=total_duration_ms, all_results=results, selected_provider=selected.provider_name ) raise AIProcessingError("All providers failed to fix LaTeX") def _score_latex_fix(self, fixed_latex: str, original_error: str) -> float: """Score a LaTeX fix attempt""" score = 0.5 # Base score # Check if common error patterns are addressed error_patterns = [ ("Undefined control sequence", r"\\[a-zA-Z]+"), ("Missing $ inserted", r"\$.*\$"), ("Runaway argument", r"\{.*\}"), ] for error_msg, pattern in error_patterns: if error_msg in original_error: # If error was in original, check if pattern appears better score += 0.1 # Validate bracket matching if fixed_latex.count("{") == fixed_latex.count("}"): score += 0.2 # Validate environment closure envs = ["document", "itemize", "enumerate"] for env in envs: begin_count = fixed_latex.count(f"\\begin{{{env}}}") end_count = fixed_latex.count(f"\\end{{{env}}}") if begin_count == end_count: score += 0.1 return min(score, 1.0) def shutdown(self): """Shutdown the executor""" self.executor.shutdown(wait=True) def __del__(self): self.shutdown()