Cambios principales: ## Nuevos archivos - services/ai/parallel_provider.py: Ejecución paralela de múltiples proveedores AI - services/ai/prompt_manager.py: Gestión centralizada de prompts (resumen.md como fuente) - latex/resumen.md: Template del prompt para resúmenes académicos LaTeX ## Mejoras en generación LaTeX (document/generators.py) - Nueva función _sanitize_latex(): Corrige automáticamente errores comunes de AI - Agrega align=center a nodos TikZ con saltos de línea (\\) - Previene errores 'Not allowed in LR mode' antes de compilar - Soporte para procesamiento paralelo de proveedores AI - Conversión DOCX en paralelo con generación PDF - Uploads a Notion en background (non-blocking) - Callbacks de notificación para progreso en Telegram ## Mejoras en proveedores AI - claude_provider.py: fix_latex() con instrucciones específicas para errores TikZ - gemini_provider.py: fix_latex() mejorado + rate limiting + circuit breaker - provider_factory.py: Soporte para parallel provider ## Otros cambios - config/settings.py: Nuevas configuraciones para Gemini models - services/webdav_service.py: Mejoras en manejo de conexión - .gitignore: Ignora archivos LaTeX auxiliares (.aux, .toc, .out, .pdf) ## Archivos de ejemplo - latex/imperio_romano.tex, latex/clase_revolucion_rusa_crisis_30.tex - resumen_curiosidades.tex (corregido y compilado exitosamente)
347 lines
11 KiB
Python
347 lines
11 KiB
Python
"""
|
|
Parallel AI Provider - Race multiple providers for fastest response
|
|
Implements Strategy A: Parallel Generation with Consensus
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Any
|
|
from datetime import datetime
|
|
|
|
from core import AIProcessingError
|
|
from .base_provider import AIProvider
|
|
|
|
|
|
@dataclass
|
|
class ProviderResult:
|
|
"""Result from a single provider"""
|
|
provider_name: str
|
|
content: str
|
|
duration_ms: int
|
|
success: bool
|
|
error: Optional[str] = None
|
|
quality_score: float = 0.0
|
|
|
|
|
|
@dataclass
|
|
class ParallelResult:
|
|
"""Aggregated result from parallel execution"""
|
|
content: str
|
|
strategy: str
|
|
providers_used: List[str]
|
|
total_duration_ms: int
|
|
all_results: List[ProviderResult]
|
|
selected_provider: str
|
|
|
|
|
|
class ParallelAIProvider:
|
|
"""
|
|
Orchestrates multiple AI providers in parallel for faster responses.
|
|
|
|
Strategies:
|
|
- "race": Use first successful response (fastest)
|
|
- "consensus": Wait for all, select best quality
|
|
- "majority": Select most common response
|
|
"""
|
|
|
|
def __init__(self, providers: Dict[str, AIProvider], max_workers: int = 4):
|
|
self.providers = providers
|
|
self.max_workers = max_workers
|
|
self.logger = logging.getLogger(__name__)
|
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
|
|
|
def _generate_sync(self, provider: AIProvider, prompt: str, **kwargs) -> ProviderResult:
|
|
"""Synchronous wrapper for provider generation"""
|
|
start_time = datetime.now()
|
|
try:
|
|
content = provider.generate_text(prompt, **kwargs)
|
|
duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
|
|
|
# Calculate quality score
|
|
quality_score = self._calculate_quality_score(content)
|
|
|
|
return ProviderResult(
|
|
provider_name=provider.name,
|
|
content=content,
|
|
duration_ms=duration_ms,
|
|
success=True,
|
|
quality_score=quality_score
|
|
)
|
|
except Exception as e:
|
|
duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
|
self.logger.error(f"{provider.name} failed: {e}")
|
|
return ProviderResult(
|
|
provider_name=provider.name,
|
|
content="",
|
|
duration_ms=duration_ms,
|
|
success=False,
|
|
error=str(e)
|
|
)
|
|
|
|
def _calculate_quality_score(self, content: str) -> float:
|
|
"""Calculate quality score for generated content"""
|
|
score = 0.0
|
|
|
|
# Length check (comprehensive is better)
|
|
if 500 < len(content) < 50000:
|
|
score += 0.2
|
|
|
|
# LaTeX structure validation
|
|
latex_indicators = [
|
|
r"\documentclass",
|
|
r"\begin{document}",
|
|
r"\section",
|
|
r"\subsection",
|
|
r"\begin{itemize}",
|
|
r"\end{document}"
|
|
]
|
|
found_indicators = sum(1 for ind in latex_indicators if ind in content)
|
|
score += (found_indicators / len(latex_indicators)) * 0.4
|
|
|
|
# Bracket matching
|
|
if content.count("{") == content.count("}"):
|
|
score += 0.2
|
|
|
|
# Environment closure
|
|
envs = ["document", "itemize", "enumerate"]
|
|
for env in envs:
|
|
if content.count(f"\\begin{{{env}}}") == content.count(f"\\end{{{env}}}"):
|
|
score += 0.1
|
|
|
|
# Has content beyond template
|
|
if len(content) > 1000:
|
|
score += 0.1
|
|
|
|
return min(score, 1.0)
|
|
|
|
def generate_parallel(
|
|
self,
|
|
prompt: str,
|
|
strategy: str = "race",
|
|
timeout_ms: int = 300000, # 5 minutes default
|
|
**kwargs
|
|
) -> ParallelResult:
|
|
"""
|
|
Execute prompt across multiple providers in parallel.
|
|
|
|
Args:
|
|
prompt: The prompt to send to all providers
|
|
strategy: "race", "consensus", or "majority"
|
|
timeout_ms: Maximum time to wait for results
|
|
**kwargs: Additional arguments for providers
|
|
|
|
Returns:
|
|
ParallelResult with selected content and metadata
|
|
"""
|
|
if not self.providers:
|
|
raise AIProcessingError("No providers available for parallel execution")
|
|
|
|
start_time = datetime.now()
|
|
all_results: List[ProviderResult] = []
|
|
|
|
# Submit all providers
|
|
futures = {}
|
|
for name, provider in self.providers.items():
|
|
if provider.is_available():
|
|
future = self.executor.submit(
|
|
self._generate_sync,
|
|
provider,
|
|
prompt,
|
|
**kwargs
|
|
)
|
|
futures[future] = name
|
|
|
|
# Wait for results based on strategy
|
|
if strategy == "race":
|
|
all_results = self._race_strategy(futures, timeout_ms)
|
|
elif strategy == "consensus":
|
|
all_results = self._consensus_strategy(futures, timeout_ms)
|
|
elif strategy == "majority":
|
|
all_results = self._majority_strategy(futures, timeout_ms)
|
|
else:
|
|
raise ValueError(f"Unknown strategy: {strategy}")
|
|
|
|
# Select best result
|
|
selected = self._select_result(all_results, strategy)
|
|
|
|
total_duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
|
|
|
self.logger.info(
|
|
f"Parallel generation complete: {strategy} strategy, "
|
|
f"{len(all_results)} providers, {selected.provider_name} selected, "
|
|
f"{total_duration_ms}ms"
|
|
)
|
|
|
|
return ParallelResult(
|
|
content=selected.content,
|
|
strategy=strategy,
|
|
providers_used=[r.provider_name for r in all_results if r.success],
|
|
total_duration_ms=total_duration_ms,
|
|
all_results=all_results,
|
|
selected_provider=selected.provider_name
|
|
)
|
|
|
|
def _race_strategy(
|
|
self,
|
|
futures: dict,
|
|
timeout_ms: int
|
|
) -> List[ProviderResult]:
|
|
"""Return first successful response"""
|
|
results = []
|
|
for future in as_completed(futures, timeout=timeout_ms / 1000):
|
|
try:
|
|
result = future.result()
|
|
results.append(result)
|
|
if result.success:
|
|
# Got a successful response, cancel remaining
|
|
for f in futures:
|
|
f.cancel()
|
|
break
|
|
except Exception as e:
|
|
self.logger.error(f"Future failed: {e}")
|
|
return results
|
|
|
|
def _consensus_strategy(
|
|
self,
|
|
futures: dict,
|
|
timeout_ms: int
|
|
) -> List[ProviderResult]:
|
|
"""Wait for all, return all results"""
|
|
results = []
|
|
for future in as_completed(futures, timeout=timeout_ms / 1000):
|
|
try:
|
|
result = future.result()
|
|
results.append(result)
|
|
except Exception as e:
|
|
self.logger.error(f"Future failed: {e}")
|
|
return results
|
|
|
|
def _majority_strategy(
|
|
self,
|
|
futures: dict,
|
|
timeout_ms: int
|
|
) -> List[ProviderResult]:
|
|
"""Wait for majority, select most common response"""
|
|
results = []
|
|
for future in as_completed(futures, timeout=timeout_ms / 1000):
|
|
try:
|
|
result = future.result()
|
|
results.append(result)
|
|
except Exception as e:
|
|
self.logger.error(f"Future failed: {e}")
|
|
return results
|
|
|
|
def _select_result(self, results: List[ProviderResult], strategy: str) -> ProviderResult:
|
|
"""Select best result based on strategy"""
|
|
successful = [r for r in results if r.success]
|
|
|
|
if not successful:
|
|
# Return first failed result with error info
|
|
return results[0] if results else ProviderResult(
|
|
provider_name="none",
|
|
content="",
|
|
duration_ms=0,
|
|
success=False,
|
|
error="All providers failed"
|
|
)
|
|
|
|
if strategy == "race" or len(successful) == 1:
|
|
return successful[0]
|
|
|
|
if strategy == "consensus":
|
|
# Select by quality score
|
|
return max(successful, key=lambda r: r.quality_score)
|
|
|
|
if strategy == "majority":
|
|
# Group by similar content (simplified - use longest)
|
|
return max(successful, key=lambda r: len(r.content))
|
|
|
|
return successful[0]
|
|
|
|
def fix_latex_parallel(
|
|
self,
|
|
latex_code: str,
|
|
error_log: str,
|
|
timeout_ms: int = 120000,
|
|
**kwargs
|
|
) -> ParallelResult:
|
|
"""Try to fix LaTeX across multiple providers in parallel"""
|
|
# Build fix prompt for each provider
|
|
results = []
|
|
start_time = datetime.now()
|
|
|
|
for name, provider in self.providers.items():
|
|
if provider.is_available():
|
|
try:
|
|
start = datetime.now()
|
|
fixed = provider.fix_latex(latex_code, error_log, **kwargs)
|
|
duration_ms = int((datetime.now() - start).total_seconds() * 1000)
|
|
|
|
# Score by checking if error patterns are reduced
|
|
quality = self._score_latex_fix(fixed, error_log)
|
|
|
|
results.append(ProviderResult(
|
|
provider_name=name,
|
|
content=fixed,
|
|
duration_ms=duration_ms,
|
|
success=True,
|
|
quality_score=quality
|
|
))
|
|
except Exception as e:
|
|
self.logger.error(f"{name} fix failed: {e}")
|
|
|
|
# Select best fix
|
|
if results:
|
|
selected = max(results, key=lambda r: r.quality_score)
|
|
total_duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
|
|
|
return ParallelResult(
|
|
content=selected.content,
|
|
strategy="consensus",
|
|
providers_used=[r.provider_name for r in results],
|
|
total_duration_ms=total_duration_ms,
|
|
all_results=results,
|
|
selected_provider=selected.provider_name
|
|
)
|
|
|
|
raise AIProcessingError("All providers failed to fix LaTeX")
|
|
|
|
def _score_latex_fix(self, fixed_latex: str, original_error: str) -> float:
|
|
"""Score a LaTeX fix attempt"""
|
|
score = 0.5 # Base score
|
|
|
|
# Check if common error patterns are addressed
|
|
error_patterns = [
|
|
("Undefined control sequence", r"\\[a-zA-Z]+"),
|
|
("Missing $ inserted", r"\$.*\$"),
|
|
("Runaway argument", r"\{.*\}"),
|
|
]
|
|
|
|
for error_msg, pattern in error_patterns:
|
|
if error_msg in original_error:
|
|
# If error was in original, check if pattern appears better
|
|
score += 0.1
|
|
|
|
# Validate bracket matching
|
|
if fixed_latex.count("{") == fixed_latex.count("}"):
|
|
score += 0.2
|
|
|
|
# Validate environment closure
|
|
envs = ["document", "itemize", "enumerate"]
|
|
for env in envs:
|
|
begin_count = fixed_latex.count(f"\\begin{{{env}}}")
|
|
end_count = fixed_latex.count(f"\\end{{{env}}}")
|
|
if begin_count == end_count:
|
|
score += 0.1
|
|
|
|
return min(score, 1.0)
|
|
|
|
def shutdown(self):
|
|
"""Shutdown the executor"""
|
|
self.executor.shutdown(wait=True)
|
|
|
|
def __del__(self):
|
|
self.shutdown()
|