feat: Sistema LaTeX mejorado con sanitización automática y corrección de TikZ
Cambios principales: ## Nuevos archivos - services/ai/parallel_provider.py: Ejecución paralela de múltiples proveedores AI - services/ai/prompt_manager.py: Gestión centralizada de prompts (resumen.md como fuente) - latex/resumen.md: Template del prompt para resúmenes académicos LaTeX ## Mejoras en generación LaTeX (document/generators.py) - Nueva función _sanitize_latex(): Corrige automáticamente errores comunes de AI - Agrega align=center a nodos TikZ con saltos de línea (\\) - Previene errores 'Not allowed in LR mode' antes de compilar - Soporte para procesamiento paralelo de proveedores AI - Conversión DOCX en paralelo con generación PDF - Uploads a Notion en background (non-blocking) - Callbacks de notificación para progreso en Telegram ## Mejoras en proveedores AI - claude_provider.py: fix_latex() con instrucciones específicas para errores TikZ - gemini_provider.py: fix_latex() mejorado + rate limiting + circuit breaker - provider_factory.py: Soporte para parallel provider ## Otros cambios - config/settings.py: Nuevas configuraciones para Gemini models - services/webdav_service.py: Mejoras en manejo de conexión - .gitignore: Ignora archivos LaTeX auxiliares (.aux, .toc, .out, .pdf) ## Archivos de ejemplo - latex/imperio_romano.tex, latex/clase_revolucion_rusa_crisis_30.tex - resumen_curiosidades.tex (corregido y compilado exitosamente)
This commit is contained in:
@@ -28,6 +28,11 @@ class AIProvider(ABC):
|
||||
"""Generate text from prompt"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fix_latex(self, latex_code: str, error_log: str, **kwargs) -> str:
|
||||
"""Fix broken LaTeX code based on compiler error log"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""Check if provider is available and configured"""
|
||||
|
||||
@@ -52,13 +52,14 @@ class ClaudeProvider(AIProvider):
|
||||
|
||||
return env
|
||||
|
||||
def _run_cli(self, prompt: str, timeout: int = 300) -> str:
|
||||
"""Run Claude CLI with prompt"""
|
||||
def _run_cli(self, prompt: str, timeout: int = 600) -> str:
|
||||
"""Run Claude CLI with prompt using -p flag for stdin input"""
|
||||
if not self.is_available():
|
||||
raise AIProcessingError("Claude CLI not available or not configured")
|
||||
|
||||
try:
|
||||
cmd = [self._cli_path]
|
||||
# Use -p flag to read prompt from stdin, --dangerously-skip-permissions for automation
|
||||
cmd = [self._cli_path, "--dangerously-skip-permissions", "-p", "-"]
|
||||
process = subprocess.run(
|
||||
cmd,
|
||||
input=prompt,
|
||||
@@ -126,3 +127,32 @@ Return only the category name, nothing else."""
|
||||
def generate_text(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text using Claude"""
|
||||
return self._run_cli(prompt)
|
||||
|
||||
def fix_latex(self, latex_code: str, error_log: str, **kwargs) -> str:
|
||||
"""Fix broken LaTeX code using Claude"""
|
||||
prompt = f"""I have a LaTeX file that failed to compile. Please fix the code.
|
||||
|
||||
COMPILER ERROR LOG:
|
||||
{error_log[-3000:]}
|
||||
|
||||
BROKEN LATEX CODE:
|
||||
{latex_code}
|
||||
|
||||
INSTRUCTIONS:
|
||||
1. Analyze the error log to find the specific syntax error.
|
||||
2. Fix the LaTeX code.
|
||||
3. Return ONLY the full corrected LaTeX code.
|
||||
4. Do not include markdown blocks or explanations.
|
||||
5. Start immediately with \\documentclass.
|
||||
|
||||
COMMON LATEX ERRORS TO CHECK:
|
||||
- TikZ nodes with line breaks (\\\\) MUST have "align=center" in their style.
|
||||
WRONG: \\node[box] (n) {{Text\\\\More}};
|
||||
CORRECT: \\node[box, align=center] (n) {{Text\\\\More}};
|
||||
- All \\begin{{env}} must have matching \\end{{env}}
|
||||
- All braces {{ }} must be balanced
|
||||
- Math mode $ must be paired
|
||||
- Special characters need escaping: % & # _
|
||||
- tcolorbox environments need proper titles: [Title] not {{Title}}
|
||||
"""
|
||||
return self._run_cli(prompt, timeout=180)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Gemini AI Provider - Optimized version with rate limiting and retry
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import shutil
|
||||
@@ -16,31 +17,32 @@ from .base_provider import AIProvider
|
||||
|
||||
class TokenBucket:
|
||||
"""Token bucket rate limiter"""
|
||||
|
||||
|
||||
def __init__(self, rate: float = 10, capacity: int = 20):
|
||||
self.rate = rate # tokens per second
|
||||
self.capacity = capacity
|
||||
self.tokens = capacity
|
||||
self.last_update = time.time()
|
||||
self._lock = None # Lazy initialization
|
||||
|
||||
|
||||
def _get_lock(self):
|
||||
if self._lock is None:
|
||||
import threading
|
||||
|
||||
self._lock = threading.Lock()
|
||||
return self._lock
|
||||
|
||||
|
||||
def acquire(self, tokens: int = 1) -> float:
|
||||
with self._get_lock():
|
||||
now = time.time()
|
||||
elapsed = now - self.last_update
|
||||
self.last_update = now
|
||||
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
|
||||
|
||||
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return 0.0
|
||||
|
||||
|
||||
wait_time = (tokens - self.tokens) / self.rate
|
||||
self.tokens = 0
|
||||
return wait_time
|
||||
@@ -48,7 +50,7 @@ class TokenBucket:
|
||||
|
||||
class CircuitBreaker:
|
||||
"""Circuit breaker for API calls"""
|
||||
|
||||
|
||||
def __init__(self, failure_threshold: int = 5, recovery_timeout: int = 60):
|
||||
self.failure_threshold = failure_threshold
|
||||
self.recovery_timeout = recovery_timeout
|
||||
@@ -56,21 +58,26 @@ class CircuitBreaker:
|
||||
self.last_failure: Optional[datetime] = None
|
||||
self.state = "closed" # closed, open, half-open
|
||||
self._lock = None
|
||||
|
||||
|
||||
def _get_lock(self):
|
||||
if self._lock is None:
|
||||
import threading
|
||||
|
||||
self._lock = threading.Lock()
|
||||
return self._lock
|
||||
|
||||
|
||||
def call(self, func, *args, **kwargs):
|
||||
with self._get_lock():
|
||||
if self.state == "open":
|
||||
if self.last_failure and (datetime.utcnow() - self.last_failure).total_seconds() > self.recovery_timeout:
|
||||
if (
|
||||
self.last_failure
|
||||
and (datetime.utcnow() - self.last_failure).total_seconds()
|
||||
> self.recovery_timeout
|
||||
):
|
||||
self.state = "half-open"
|
||||
else:
|
||||
raise AIProcessingError("Circuit breaker is open")
|
||||
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
if self.state == "half-open":
|
||||
@@ -87,7 +94,7 @@ class CircuitBreaker:
|
||||
|
||||
class GeminiProvider(AIProvider):
|
||||
"""Gemini AI provider with rate limiting and retry"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
@@ -102,17 +109,17 @@ class GeminiProvider(AIProvider):
|
||||
"max_attempts": 3,
|
||||
"base_delay": 1.0,
|
||||
"max_delay": 30.0,
|
||||
"exponential_base": 2
|
||||
"exponential_base": 2,
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "Gemini"
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Gemini CLI or API is available"""
|
||||
return bool(self._cli_path or self._api_key)
|
||||
|
||||
|
||||
def _init_session(self) -> None:
|
||||
"""Initialize HTTP session with connection pooling"""
|
||||
if self._session is None:
|
||||
@@ -120,17 +127,17 @@ class GeminiProvider(AIProvider):
|
||||
adapter = requests.adapters.HTTPAdapter(
|
||||
pool_connections=10,
|
||||
pool_maxsize=20,
|
||||
max_retries=0 # We handle retries manually
|
||||
max_retries=0, # We handle retries manually
|
||||
)
|
||||
self._session.mount('https://', adapter)
|
||||
|
||||
self._session.mount("https://", adapter)
|
||||
|
||||
def _run_with_retry(self, func, *args, **kwargs):
|
||||
"""Execute function with exponential backoff retry"""
|
||||
max_attempts = self._retry_config["max_attempts"]
|
||||
base_delay = self._retry_config["base_delay"]
|
||||
|
||||
|
||||
last_exception = None
|
||||
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
return self._circuit_breaker.call(func, *args, **kwargs)
|
||||
@@ -138,94 +145,84 @@ class GeminiProvider(AIProvider):
|
||||
last_exception = e
|
||||
if attempt < max_attempts - 1:
|
||||
delay = min(
|
||||
base_delay * (2 ** attempt),
|
||||
self._retry_config["max_delay"]
|
||||
base_delay * (2**attempt), self._retry_config["max_delay"]
|
||||
)
|
||||
# Add jitter
|
||||
delay += delay * 0.1 * (time.time() % 1)
|
||||
self.logger.warning(f"Attempt {attempt + 1} failed: {e}, retrying in {delay:.2f}s")
|
||||
self.logger.warning(
|
||||
f"Attempt {attempt + 1} failed: {e}, retrying in {delay:.2f}s"
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
raise AIProcessingError(f"Max retries exceeded: {last_exception}")
|
||||
|
||||
|
||||
def _run_cli(self, prompt: str, use_flash: bool = True, timeout: int = 300) -> str:
|
||||
"""Run Gemini CLI with prompt"""
|
||||
if not self._cli_path:
|
||||
raise AIProcessingError("Gemini CLI not available")
|
||||
|
||||
|
||||
model = self._flash_model if use_flash else self._pro_model
|
||||
cmd = [self._cli_path, model, prompt]
|
||||
|
||||
|
||||
try:
|
||||
# Apply rate limiting
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
|
||||
process = subprocess.run(
|
||||
cmd,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
shell=False
|
||||
cmd, text=True, capture_output=True, timeout=timeout, shell=False
|
||||
)
|
||||
|
||||
|
||||
if process.returncode != 0:
|
||||
error_msg = process.stderr or "Unknown error"
|
||||
raise AIProcessingError(f"Gemini CLI failed: {error_msg}")
|
||||
|
||||
|
||||
return process.stdout.strip()
|
||||
except subprocess.TimeoutExpired:
|
||||
raise AIProcessingError(f"Gemini CLI timed out after {timeout}s")
|
||||
except Exception as e:
|
||||
raise AIProcessingError(f"Gemini CLI error: {e}")
|
||||
|
||||
|
||||
def _call_api(self, prompt: str, use_flash: bool = True, timeout: int = 180) -> str:
|
||||
"""Call Gemini API with rate limiting and retry"""
|
||||
if not self._api_key:
|
||||
raise AIProcessingError("Gemini API key not configured")
|
||||
|
||||
|
||||
self._init_session()
|
||||
|
||||
|
||||
model = self._flash_model if use_flash else self._pro_model
|
||||
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent"
|
||||
|
||||
payload = {
|
||||
"contents": [{
|
||||
"parts": [{"text": prompt}]
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
payload = {"contents": [{"parts": [{"text": prompt}]}]}
|
||||
|
||||
params = {"key": self._api_key}
|
||||
|
||||
|
||||
def api_call():
|
||||
# Apply rate limiting
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
|
||||
response = self._session.post(
|
||||
url,
|
||||
json=payload,
|
||||
params=params,
|
||||
timeout=timeout
|
||||
url, json=payload, params=params, timeout=timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
|
||||
response = self._run_with_retry(api_call)
|
||||
data = response.json()
|
||||
|
||||
|
||||
if "candidates" not in data or not data["candidates"]:
|
||||
raise AIProcessingError("Empty response from Gemini API")
|
||||
|
||||
|
||||
candidate = data["candidates"][0]
|
||||
if "content" not in candidate or "parts" not in candidate["content"]:
|
||||
raise AIProcessingError("Invalid response format from Gemini API")
|
||||
|
||||
|
||||
result = candidate["content"]["parts"][0]["text"]
|
||||
return result.strip()
|
||||
|
||||
|
||||
def _run(self, prompt: str, use_flash: bool = True, timeout: int = 300) -> str:
|
||||
"""Run Gemini with fallback between CLI and API"""
|
||||
# Try CLI first if available
|
||||
@@ -234,14 +231,14 @@ class GeminiProvider(AIProvider):
|
||||
return self._run_cli(prompt, use_flash, timeout)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Gemini CLI failed, trying API: {e}")
|
||||
|
||||
|
||||
# Fallback to API
|
||||
if self._api_key:
|
||||
api_timeout = min(timeout, 180)
|
||||
return self._call_api(prompt, use_flash, api_timeout)
|
||||
|
||||
|
||||
raise AIProcessingError("No Gemini provider available (CLI or API)")
|
||||
|
||||
|
||||
def summarize(self, text: str, **kwargs) -> str:
|
||||
"""Generate summary using Gemini"""
|
||||
prompt = f"""Summarize the following text:
|
||||
@@ -250,7 +247,7 @@ class GeminiProvider(AIProvider):
|
||||
|
||||
Provide a clear, concise summary in Spanish."""
|
||||
return self._run(prompt, use_flash=True)
|
||||
|
||||
|
||||
def correct_text(self, text: str, **kwargs) -> str:
|
||||
"""Correct text using Gemini"""
|
||||
prompt = f"""Correct the following text for grammar, spelling, and clarity:
|
||||
@@ -259,11 +256,16 @@ Provide a clear, concise summary in Spanish."""
|
||||
|
||||
Return only the corrected text, nothing else."""
|
||||
return self._run(prompt, use_flash=True)
|
||||
|
||||
|
||||
def classify_content(self, text: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Classify content using Gemini"""
|
||||
categories = ["historia", "analisis_contable", "instituciones_gobierno", "otras_clases"]
|
||||
|
||||
categories = [
|
||||
"historia",
|
||||
"analisis_contable",
|
||||
"instituciones_gobierno",
|
||||
"otras_clases",
|
||||
]
|
||||
|
||||
prompt = f"""Classify the following text into one of these categories:
|
||||
- historia
|
||||
- analisis_contable
|
||||
@@ -274,39 +276,61 @@ Text: {text}
|
||||
|
||||
Return only the category name, nothing else."""
|
||||
result = self._run(prompt, use_flash=True).lower()
|
||||
|
||||
|
||||
# Validate result
|
||||
if result not in categories:
|
||||
result = "otras_clases"
|
||||
|
||||
return {
|
||||
"category": result,
|
||||
"confidence": 0.9,
|
||||
"provider": self.name
|
||||
}
|
||||
|
||||
return {"category": result, "confidence": 0.9, "provider": self.name}
|
||||
|
||||
def generate_text(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text using Gemini"""
|
||||
use_flash = kwargs.get('use_flash', True)
|
||||
use_flash = kwargs.get("use_flash", True)
|
||||
if self._api_key:
|
||||
return self._call_api(prompt, use_flash=use_flash)
|
||||
return self._call_cli(prompt, use_yolo=True)
|
||||
|
||||
return self._run_cli(prompt, use_flash=use_flash)
|
||||
|
||||
def fix_latex(self, latex_code: str, error_log: str, **kwargs) -> str:
|
||||
"""Fix broken LaTeX code using Gemini"""
|
||||
prompt = f"""Fix the following LaTeX code which failed to compile.
|
||||
|
||||
Error Log:
|
||||
{error_log[-3000:]}
|
||||
|
||||
Broken Code:
|
||||
{latex_code}
|
||||
|
||||
INSTRUCTIONS:
|
||||
1. Return ONLY the corrected LaTeX code. No explanations.
|
||||
2. Start immediately with \\documentclass.
|
||||
|
||||
COMMON LATEX ERRORS TO FIX:
|
||||
- TikZ nodes with line breaks (\\\\) MUST have "align=center" in their style.
|
||||
WRONG: \\node[box] (n) {{Text\\\\More}};
|
||||
CORRECT: \\node[box, align=center] (n) {{Text\\\\More}};
|
||||
- All \\begin{{env}} must have matching \\end{{env}}
|
||||
- All braces {{ }} must be balanced
|
||||
- Math mode $ must be paired
|
||||
- Special characters need escaping: % & # _
|
||||
- tcolorbox environments need proper titles: [Title] not {{Title}}
|
||||
"""
|
||||
return self._run(prompt, use_flash=False) # Use Pro model for coding fixes
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get provider statistics"""
|
||||
return {
|
||||
"rate_limiter": {
|
||||
"tokens": round(self._rate_limiter.tokens, 2),
|
||||
"capacity": self._rate_limiter.capacity,
|
||||
"rate": self._rate_limiter.rate
|
||||
"rate": self._rate_limiter.rate,
|
||||
},
|
||||
"circuit_breaker": {
|
||||
"state": self._circuit_breaker.state,
|
||||
"failures": self._circuit_breaker.failures,
|
||||
"failure_threshold": self._circuit_breaker.failure_threshold
|
||||
"failure_threshold": self._circuit_breaker.failure_threshold,
|
||||
},
|
||||
"cli_available": bool(self._cli_path),
|
||||
"api_available": bool(self._api_key)
|
||||
"api_available": bool(self._api_key),
|
||||
}
|
||||
|
||||
|
||||
|
||||
346
services/ai/parallel_provider.py
Normal file
346
services/ai/parallel_provider.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
Parallel AI Provider - Race multiple providers for fastest response
|
||||
Implements Strategy A: Parallel Generation with Consensus
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
from core import AIProcessingError
|
||||
from .base_provider import AIProvider
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderResult:
|
||||
"""Result from a single provider"""
|
||||
provider_name: str
|
||||
content: str
|
||||
duration_ms: int
|
||||
success: bool
|
||||
error: Optional[str] = None
|
||||
quality_score: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelResult:
|
||||
"""Aggregated result from parallel execution"""
|
||||
content: str
|
||||
strategy: str
|
||||
providers_used: List[str]
|
||||
total_duration_ms: int
|
||||
all_results: List[ProviderResult]
|
||||
selected_provider: str
|
||||
|
||||
|
||||
class ParallelAIProvider:
|
||||
"""
|
||||
Orchestrates multiple AI providers in parallel for faster responses.
|
||||
|
||||
Strategies:
|
||||
- "race": Use first successful response (fastest)
|
||||
- "consensus": Wait for all, select best quality
|
||||
- "majority": Select most common response
|
||||
"""
|
||||
|
||||
def __init__(self, providers: Dict[str, AIProvider], max_workers: int = 4):
|
||||
self.providers = providers
|
||||
self.max_workers = max_workers
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
def _generate_sync(self, provider: AIProvider, prompt: str, **kwargs) -> ProviderResult:
|
||||
"""Synchronous wrapper for provider generation"""
|
||||
start_time = datetime.now()
|
||||
try:
|
||||
content = provider.generate_text(prompt, **kwargs)
|
||||
duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
|
||||
# Calculate quality score
|
||||
quality_score = self._calculate_quality_score(content)
|
||||
|
||||
return ProviderResult(
|
||||
provider_name=provider.name,
|
||||
content=content,
|
||||
duration_ms=duration_ms,
|
||||
success=True,
|
||||
quality_score=quality_score
|
||||
)
|
||||
except Exception as e:
|
||||
duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
self.logger.error(f"{provider.name} failed: {e}")
|
||||
return ProviderResult(
|
||||
provider_name=provider.name,
|
||||
content="",
|
||||
duration_ms=duration_ms,
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def _calculate_quality_score(self, content: str) -> float:
|
||||
"""Calculate quality score for generated content"""
|
||||
score = 0.0
|
||||
|
||||
# Length check (comprehensive is better)
|
||||
if 500 < len(content) < 50000:
|
||||
score += 0.2
|
||||
|
||||
# LaTeX structure validation
|
||||
latex_indicators = [
|
||||
r"\documentclass",
|
||||
r"\begin{document}",
|
||||
r"\section",
|
||||
r"\subsection",
|
||||
r"\begin{itemize}",
|
||||
r"\end{document}"
|
||||
]
|
||||
found_indicators = sum(1 for ind in latex_indicators if ind in content)
|
||||
score += (found_indicators / len(latex_indicators)) * 0.4
|
||||
|
||||
# Bracket matching
|
||||
if content.count("{") == content.count("}"):
|
||||
score += 0.2
|
||||
|
||||
# Environment closure
|
||||
envs = ["document", "itemize", "enumerate"]
|
||||
for env in envs:
|
||||
if content.count(f"\\begin{{{env}}}") == content.count(f"\\end{{{env}}}"):
|
||||
score += 0.1
|
||||
|
||||
# Has content beyond template
|
||||
if len(content) > 1000:
|
||||
score += 0.1
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
def generate_parallel(
|
||||
self,
|
||||
prompt: str,
|
||||
strategy: str = "race",
|
||||
timeout_ms: int = 300000, # 5 minutes default
|
||||
**kwargs
|
||||
) -> ParallelResult:
|
||||
"""
|
||||
Execute prompt across multiple providers in parallel.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to all providers
|
||||
strategy: "race", "consensus", or "majority"
|
||||
timeout_ms: Maximum time to wait for results
|
||||
**kwargs: Additional arguments for providers
|
||||
|
||||
Returns:
|
||||
ParallelResult with selected content and metadata
|
||||
"""
|
||||
if not self.providers:
|
||||
raise AIProcessingError("No providers available for parallel execution")
|
||||
|
||||
start_time = datetime.now()
|
||||
all_results: List[ProviderResult] = []
|
||||
|
||||
# Submit all providers
|
||||
futures = {}
|
||||
for name, provider in self.providers.items():
|
||||
if provider.is_available():
|
||||
future = self.executor.submit(
|
||||
self._generate_sync,
|
||||
provider,
|
||||
prompt,
|
||||
**kwargs
|
||||
)
|
||||
futures[future] = name
|
||||
|
||||
# Wait for results based on strategy
|
||||
if strategy == "race":
|
||||
all_results = self._race_strategy(futures, timeout_ms)
|
||||
elif strategy == "consensus":
|
||||
all_results = self._consensus_strategy(futures, timeout_ms)
|
||||
elif strategy == "majority":
|
||||
all_results = self._majority_strategy(futures, timeout_ms)
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy: {strategy}")
|
||||
|
||||
# Select best result
|
||||
selected = self._select_result(all_results, strategy)
|
||||
|
||||
total_duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
|
||||
self.logger.info(
|
||||
f"Parallel generation complete: {strategy} strategy, "
|
||||
f"{len(all_results)} providers, {selected.provider_name} selected, "
|
||||
f"{total_duration_ms}ms"
|
||||
)
|
||||
|
||||
return ParallelResult(
|
||||
content=selected.content,
|
||||
strategy=strategy,
|
||||
providers_used=[r.provider_name for r in all_results if r.success],
|
||||
total_duration_ms=total_duration_ms,
|
||||
all_results=all_results,
|
||||
selected_provider=selected.provider_name
|
||||
)
|
||||
|
||||
def _race_strategy(
|
||||
self,
|
||||
futures: dict,
|
||||
timeout_ms: int
|
||||
) -> List[ProviderResult]:
|
||||
"""Return first successful response"""
|
||||
results = []
|
||||
for future in as_completed(futures, timeout=timeout_ms / 1000):
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
if result.success:
|
||||
# Got a successful response, cancel remaining
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Future failed: {e}")
|
||||
return results
|
||||
|
||||
def _consensus_strategy(
|
||||
self,
|
||||
futures: dict,
|
||||
timeout_ms: int
|
||||
) -> List[ProviderResult]:
|
||||
"""Wait for all, return all results"""
|
||||
results = []
|
||||
for future in as_completed(futures, timeout=timeout_ms / 1000):
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Future failed: {e}")
|
||||
return results
|
||||
|
||||
def _majority_strategy(
|
||||
self,
|
||||
futures: dict,
|
||||
timeout_ms: int
|
||||
) -> List[ProviderResult]:
|
||||
"""Wait for majority, select most common response"""
|
||||
results = []
|
||||
for future in as_completed(futures, timeout=timeout_ms / 1000):
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Future failed: {e}")
|
||||
return results
|
||||
|
||||
def _select_result(self, results: List[ProviderResult], strategy: str) -> ProviderResult:
|
||||
"""Select best result based on strategy"""
|
||||
successful = [r for r in results if r.success]
|
||||
|
||||
if not successful:
|
||||
# Return first failed result with error info
|
||||
return results[0] if results else ProviderResult(
|
||||
provider_name="none",
|
||||
content="",
|
||||
duration_ms=0,
|
||||
success=False,
|
||||
error="All providers failed"
|
||||
)
|
||||
|
||||
if strategy == "race" or len(successful) == 1:
|
||||
return successful[0]
|
||||
|
||||
if strategy == "consensus":
|
||||
# Select by quality score
|
||||
return max(successful, key=lambda r: r.quality_score)
|
||||
|
||||
if strategy == "majority":
|
||||
# Group by similar content (simplified - use longest)
|
||||
return max(successful, key=lambda r: len(r.content))
|
||||
|
||||
return successful[0]
|
||||
|
||||
def fix_latex_parallel(
|
||||
self,
|
||||
latex_code: str,
|
||||
error_log: str,
|
||||
timeout_ms: int = 120000,
|
||||
**kwargs
|
||||
) -> ParallelResult:
|
||||
"""Try to fix LaTeX across multiple providers in parallel"""
|
||||
# Build fix prompt for each provider
|
||||
results = []
|
||||
start_time = datetime.now()
|
||||
|
||||
for name, provider in self.providers.items():
|
||||
if provider.is_available():
|
||||
try:
|
||||
start = datetime.now()
|
||||
fixed = provider.fix_latex(latex_code, error_log, **kwargs)
|
||||
duration_ms = int((datetime.now() - start).total_seconds() * 1000)
|
||||
|
||||
# Score by checking if error patterns are reduced
|
||||
quality = self._score_latex_fix(fixed, error_log)
|
||||
|
||||
results.append(ProviderResult(
|
||||
provider_name=name,
|
||||
content=fixed,
|
||||
duration_ms=duration_ms,
|
||||
success=True,
|
||||
quality_score=quality
|
||||
))
|
||||
except Exception as e:
|
||||
self.logger.error(f"{name} fix failed: {e}")
|
||||
|
||||
# Select best fix
|
||||
if results:
|
||||
selected = max(results, key=lambda r: r.quality_score)
|
||||
total_duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
|
||||
return ParallelResult(
|
||||
content=selected.content,
|
||||
strategy="consensus",
|
||||
providers_used=[r.provider_name for r in results],
|
||||
total_duration_ms=total_duration_ms,
|
||||
all_results=results,
|
||||
selected_provider=selected.provider_name
|
||||
)
|
||||
|
||||
raise AIProcessingError("All providers failed to fix LaTeX")
|
||||
|
||||
def _score_latex_fix(self, fixed_latex: str, original_error: str) -> float:
|
||||
"""Score a LaTeX fix attempt"""
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Check if common error patterns are addressed
|
||||
error_patterns = [
|
||||
("Undefined control sequence", r"\\[a-zA-Z]+"),
|
||||
("Missing $ inserted", r"\$.*\$"),
|
||||
("Runaway argument", r"\{.*\}"),
|
||||
]
|
||||
|
||||
for error_msg, pattern in error_patterns:
|
||||
if error_msg in original_error:
|
||||
# If error was in original, check if pattern appears better
|
||||
score += 0.1
|
||||
|
||||
# Validate bracket matching
|
||||
if fixed_latex.count("{") == fixed_latex.count("}"):
|
||||
score += 0.2
|
||||
|
||||
# Validate environment closure
|
||||
envs = ["document", "itemize", "enumerate"]
|
||||
for env in envs:
|
||||
begin_count = fixed_latex.count(f"\\begin{{{env}}}")
|
||||
end_count = fixed_latex.count(f"\\end{{{env}}}")
|
||||
if begin_count == end_count:
|
||||
score += 0.1
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the executor"""
|
||||
self.executor.shutdown(wait=True)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
343
services/ai/prompt_manager.py
Normal file
343
services/ai/prompt_manager.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
Prompt Manager - Centralized prompt management using resumen.md as source of truth
|
||||
"""
|
||||
|
||||
import re
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
from config import settings
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""
|
||||
Manages prompts for AI services, loading templates from latex/resumen.md
|
||||
This is the SINGLE SOURCE OF TRUTH for academic summary generation.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_prompt_cache: Optional[str] = None
|
||||
_latex_preamble_cache: Optional[str] = None
|
||||
|
||||
# Path to the prompt template file
|
||||
PROMPT_FILE_PATH = Path("latex/resumen.md")
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(PromptManager, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def _load_prompt_template(self) -> str:
|
||||
"""Load the complete prompt template from resumen.md"""
|
||||
if self._prompt_cache:
|
||||
return self._prompt_cache
|
||||
|
||||
try:
|
||||
file_path = self.PROMPT_FILE_PATH.resolve()
|
||||
|
||||
if not file_path.exists():
|
||||
self._prompt_cache = self._get_fallback_prompt()
|
||||
return self._prompt_cache
|
||||
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
# The file has a markdown code block after "## Prompt Template"
|
||||
# We need to find the content from "## Prompt Template" to the LAST ```
|
||||
# (because there's a ```latex...``` block INSIDE the template)
|
||||
|
||||
# First, find where "## Prompt Template" starts
|
||||
template_start = content.find("## Prompt Template")
|
||||
if template_start == -1:
|
||||
self._prompt_cache = self._get_fallback_prompt()
|
||||
return self._prompt_cache
|
||||
|
||||
# Find the opening ``` after the header
|
||||
after_header = content[template_start:]
|
||||
code_block_start = after_header.find("```")
|
||||
if code_block_start == -1:
|
||||
self._prompt_cache = self._get_fallback_prompt()
|
||||
return self._prompt_cache
|
||||
|
||||
# Skip the opening ``` and any language specifier
|
||||
after_code_start = after_header[code_block_start + 3:]
|
||||
first_newline = after_code_start.find("\n")
|
||||
if first_newline != -1:
|
||||
actual_content_start = template_start + code_block_start + 3 + first_newline + 1
|
||||
else:
|
||||
actual_content_start = template_start + code_block_start + 3
|
||||
|
||||
# Now find the LAST ``` that closes the main block
|
||||
# We look for ``` followed by optional space and then newline or end
|
||||
remaining = content[actual_content_start:]
|
||||
|
||||
# Find all positions of ``` in the remaining content
|
||||
positions = []
|
||||
pos = 0
|
||||
while True:
|
||||
found = remaining.find("```", pos)
|
||||
if found == -1:
|
||||
break
|
||||
positions.append(found)
|
||||
pos = found + 3
|
||||
|
||||
if not positions:
|
||||
self._prompt_cache = self._get_fallback_prompt()
|
||||
return self._prompt_cache
|
||||
|
||||
# The LAST ``` is the closing of the main block
|
||||
# (all previous ``` are the latex block inside the template)
|
||||
last_backtick_pos = positions[-1]
|
||||
|
||||
# Extract the content
|
||||
template_content = content[actual_content_start:actual_content_start + last_backtick_pos]
|
||||
|
||||
# Remove leading newline if present
|
||||
template_content = template_content.lstrip("\n")
|
||||
|
||||
self._prompt_cache = template_content
|
||||
return self._prompt_cache
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading prompt file: {e}")
|
||||
self._prompt_cache = self._get_fallback_prompt()
|
||||
return self._prompt_cache
|
||||
|
||||
def _get_fallback_prompt(self) -> str:
|
||||
"""Fallback prompt if resumen.md is not found"""
|
||||
return """Sos un asistente académico experto. Creá un resumen extenso en LaTeX basado en la transcripción de clase.
|
||||
|
||||
## Transcripción de clase:
|
||||
[PEGAR TRANSCRIPCIÓN AQUÍ]
|
||||
|
||||
## Material bibliográfico:
|
||||
[PEGAR TEXTO DEL LIBRO/APUNTE O INDICAR QUE LO SUBISTE COMO ARCHIVO]
|
||||
|
||||
Generá un archivo LaTeX completo con:
|
||||
- Estructura académica formal
|
||||
- Mínimo 10 páginas de contenido
|
||||
- Fórmulas matemáticas en LaTeX
|
||||
- Tablas y diagramas cuando corresponda
|
||||
"""
|
||||
|
||||
def _load_latex_preamble(self) -> str:
|
||||
"""Extract the LaTeX preamble from resumen.md"""
|
||||
if self._latex_preamble_cache:
|
||||
return self._latex_preamble_cache
|
||||
|
||||
try:
|
||||
file_path = self.PROMPT_FILE_PATH.resolve()
|
||||
|
||||
if not file_path.exists():
|
||||
return self._get_default_preamble()
|
||||
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
# Extract LaTeX code block in the template
|
||||
match = re.search(
|
||||
r"```latex\s*\n([\s\S]*?)\n```",
|
||||
content
|
||||
)
|
||||
|
||||
if match:
|
||||
self._latex_preamble_cache = match.group(1).strip()
|
||||
else:
|
||||
self._latex_preamble_cache = self._get_default_preamble()
|
||||
|
||||
return self._latex_preamble_cache
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading LaTeX preamble: {e}")
|
||||
return self._get_default_preamble()
|
||||
|
||||
def _get_default_preamble(self) -> str:
|
||||
"""Default LaTeX preamble"""
|
||||
return r"""\documentclass[11pt,a4paper]{article}
|
||||
\usepackage[utf8]{inputenc}
|
||||
\usepackage[spanish,provide=*]{babel}
|
||||
\usepackage{amsmath,amssymb}
|
||||
\usepackage{geometry}
|
||||
\usepackage{graphicx}
|
||||
\usepackage{tikz}
|
||||
\usetikzlibrary{arrows.meta,positioning,shapes.geometric,calc}
|
||||
\usepackage{booktabs}
|
||||
\usepackage{enumitem}
|
||||
\usepackage{fancyhdr}
|
||||
\usepackage{titlesec}
|
||||
\usepackage{tcolorbox}
|
||||
\usepackage{array}
|
||||
\usepackage{multirow}
|
||||
|
||||
\geometry{margin=2.5cm}
|
||||
\pagestyle{fancy}
|
||||
\fancyhf{}
|
||||
\fancyhead[L]{[MATERIA] - CBC}
|
||||
\fancyhead[R]{Clase [N]}
|
||||
\fancyfoot[C]{\thepage}
|
||||
|
||||
% Cajas para destacar contenido
|
||||
\newtcolorbox{definicion}[1][]{
|
||||
colback=blue!5!white,
|
||||
colframe=blue!75!black,
|
||||
fonttitle=\bfseries,
|
||||
title=#1
|
||||
}
|
||||
|
||||
\newtcolorbox{importante}[1][]{
|
||||
colback=red!5!white,
|
||||
colframe=red!75!black,
|
||||
fonttitle=\bfseries,
|
||||
title=#1
|
||||
}
|
||||
|
||||
\newtcolorbox{ejemplo}[1][]{
|
||||
colback=green!5!white,
|
||||
colframe=green!50!black,
|
||||
fonttitle=\bfseries,
|
||||
title=#1
|
||||
}
|
||||
"""
|
||||
|
||||
def get_latex_summary_prompt(
|
||||
self,
|
||||
transcription: str,
|
||||
materia: str = "Economía",
|
||||
bibliographic_text: Optional[str] = None,
|
||||
class_number: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate the complete prompt for LaTeX academic summary based on resumen.md template.
|
||||
|
||||
Args:
|
||||
transcription: The class transcription text
|
||||
materia: Subject name (default: "Economía")
|
||||
bibliographic_text: Optional supporting text from books/notes
|
||||
class_number: Optional class number for header
|
||||
|
||||
Returns:
|
||||
Complete prompt string ready to send to AI
|
||||
"""
|
||||
template = self._load_prompt_template()
|
||||
|
||||
# CRITICAL: Prepend explicit instructions to force direct LaTeX generation
|
||||
# (This doesn't modify resumen.md, just adds context before it)
|
||||
explicit_instructions = """CRITICAL: Tu respuesta debe ser ÚNICAMENTE código LaTeX.
|
||||
|
||||
INSTRUCCIONES OBLIGATORIAS:
|
||||
1. NO incluyas explicaciones previas
|
||||
2. NO describas lo que vas a hacer
|
||||
3. Comienza INMEDIATAMENTE con \\documentclass
|
||||
4. Tu respuesta debe ser SOLO el código LaTeX fuente
|
||||
5. Termina con \\end{document}
|
||||
|
||||
---
|
||||
|
||||
"""
|
||||
|
||||
prompt = explicit_instructions + template
|
||||
|
||||
# Replace placeholders
|
||||
prompt = prompt.replace("[MATERIA]", materia)
|
||||
|
||||
# Insert transcription
|
||||
if "[PEGAR TRANSCRIPCIÓN AQUÍ]" in prompt:
|
||||
prompt = prompt.replace("[PEGAR TRANSCRIPCIÓN AQUÍ]", transcription)
|
||||
else:
|
||||
prompt += f"\n\n## Transcripción de clase:\n{transcription}"
|
||||
|
||||
# Insert bibliographic material
|
||||
bib_text = bibliographic_text or "No se proporcionó material bibliográfico adicional."
|
||||
if "[PEGAR TEXTO DEL LIBRO/APUNTE O INDICAR QUE LO SUBISTE COMO ARCHIVO]" in prompt:
|
||||
prompt = prompt.replace(
|
||||
"[PEGAR TEXTO DEL LIBRO/APUNTE O INDICAR QUE LO SUBISTE COMO ARCHIVO]",
|
||||
bib_text
|
||||
)
|
||||
else:
|
||||
prompt += f"\n\n## Material bibliográfico:\n{bib_text}"
|
||||
|
||||
# Add class number if provided
|
||||
if class_number is not None:
|
||||
prompt = prompt.replace("[N]", str(class_number))
|
||||
|
||||
return prompt
|
||||
|
||||
def get_latex_preamble(
|
||||
self,
|
||||
materia: str = "Economía",
|
||||
class_number: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Get the LaTeX preamble with placeholders replaced.
|
||||
|
||||
Args:
|
||||
materia: Subject name
|
||||
class_number: Optional class number
|
||||
|
||||
Returns:
|
||||
Complete LaTeX preamble as string
|
||||
"""
|
||||
preamble = self._load_latex_preamble()
|
||||
|
||||
# Replace placeholders
|
||||
preamble = preamble.replace("[MATERIA]", materia)
|
||||
if class_number is not None:
|
||||
preamble = preamble.replace("[N]", str(class_number))
|
||||
|
||||
return preamble
|
||||
|
||||
def get_latex_fix_prompt(self, latex_code: str, error_log: str) -> str:
|
||||
"""Get prompt for fixing broken LaTeX code"""
|
||||
return f"""I have a LaTeX file that failed to compile. Please fix the code.
|
||||
|
||||
COMPILER ERROR LOG:
|
||||
{error_log[-3000:]}
|
||||
|
||||
BROKEN LATEX CODE:
|
||||
{latex_code}
|
||||
|
||||
INSTRUCTIONS:
|
||||
1. Analyze the error log to find the specific syntax error.
|
||||
2. Fix the LaTeX code.
|
||||
3. Return ONLY the full corrected LaTeX code.
|
||||
4. Do not include markdown blocks or explanations.
|
||||
5. Start immediately with \\documentclass.
|
||||
6. Ensure all braces {{}} are properly balanced.
|
||||
7. Ensure all environments \\begin{{...}} have matching \\end{{...}}.
|
||||
8. Ensure all packages are properly declared.
|
||||
"""
|
||||
|
||||
def extract_latex_from_response(self, response: str) -> Optional[str]:
|
||||
"""
|
||||
Extract clean LaTeX code from AI response.
|
||||
|
||||
Handles cases where AI wraps LaTeX in ```latex...``` blocks.
|
||||
"""
|
||||
if not response:
|
||||
return None
|
||||
|
||||
# Try to find content inside ```latex ... ``` blocks
|
||||
code_block_pattern = r"```(?:latex|tex)?\s*([\s\S]*?)\s*```"
|
||||
match = re.search(code_block_pattern, response, re.IGNORECASE)
|
||||
|
||||
if match:
|
||||
latex = match.group(1).strip()
|
||||
else:
|
||||
latex = response.strip()
|
||||
|
||||
# Verify it looks like LaTeX
|
||||
if "\\documentclass" not in latex:
|
||||
return None
|
||||
|
||||
# Clean up: remove anything before \documentclass
|
||||
start_idx = latex.find("\\documentclass")
|
||||
latex = latex[start_idx:]
|
||||
|
||||
# Clean up: remove anything after \end{document}
|
||||
if "\\end{document}" in latex:
|
||||
end_idx = latex.rfind("\\end{document}")
|
||||
latex = latex[:end_idx + len("\\end{document}")]
|
||||
|
||||
return latex.strip()
|
||||
|
||||
|
||||
# Singleton instance for easy import
|
||||
prompt_manager = PromptManager()
|
||||
@@ -3,16 +3,17 @@ AI Provider Factory (Factory Pattern)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Type
|
||||
from typing import Dict, Type, Optional
|
||||
|
||||
from core import AIProcessingError
|
||||
from .base_provider import AIProvider
|
||||
from .claude_provider import ClaudeProvider
|
||||
from .gemini_provider import GeminiProvider
|
||||
from .parallel_provider import ParallelAIProvider
|
||||
|
||||
|
||||
class AIProviderFactory:
|
||||
"""Factory for creating AI providers with fallback"""
|
||||
"""Factory for creating AI providers with fallback and parallel execution"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
@@ -20,6 +21,7 @@ class AIProviderFactory:
|
||||
"claude": ClaudeProvider(),
|
||||
"gemini": GeminiProvider(),
|
||||
}
|
||||
self._parallel_provider: Optional[ParallelAIProvider] = None
|
||||
|
||||
def get_provider(self, preferred: str = "gemini") -> AIProvider:
|
||||
"""Get available provider with fallback"""
|
||||
@@ -50,6 +52,29 @@ class AIProviderFactory:
|
||||
"""Get the best available provider (Claude > Gemini)"""
|
||||
return self.get_provider("claude")
|
||||
|
||||
def get_parallel_provider(self, max_workers: int = 4) -> ParallelAIProvider:
|
||||
"""Get parallel provider for racing multiple AI providers"""
|
||||
available = self.get_all_available()
|
||||
|
||||
if not available:
|
||||
raise AIProcessingError("No providers available for parallel execution")
|
||||
|
||||
if self._parallel_provider is None:
|
||||
self._parallel_provider = ParallelAIProvider(
|
||||
providers=available,
|
||||
max_workers=max_workers
|
||||
)
|
||||
self.logger.info(
|
||||
f"Created parallel provider with {len(available)} workers: "
|
||||
f"{', '.join(available.keys())}"
|
||||
)
|
||||
|
||||
return self._parallel_provider
|
||||
|
||||
def use_parallel(self) -> bool:
|
||||
"""Check if parallel execution should be used (multiple providers available)"""
|
||||
return len(self.get_all_available()) > 1
|
||||
|
||||
|
||||
# Global instance
|
||||
ai_provider_factory = AIProviderFactory()
|
||||
|
||||
@@ -7,8 +7,9 @@ import time
|
||||
import unicodedata
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
from contextlib import contextmanager
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
from requests.adapters import HTTPAdapter
|
||||
@@ -107,7 +108,7 @@ class WebDAVService:
|
||||
return self._parse_propfind_response(response.text)
|
||||
|
||||
def _parse_propfind_response(self, xml_response: str) -> List[str]:
|
||||
"""Parse PROPFIND XML response"""
|
||||
"""Parse PROPFIND XML response and return only files (not directories)"""
|
||||
# Simple parser for PROPFIND response
|
||||
files = []
|
||||
try:
|
||||
@@ -119,20 +120,41 @@ class WebDAVService:
|
||||
parsed_url = urlparse(settings.NEXTCLOUD_URL)
|
||||
webdav_path = parsed_url.path.rstrip('/') # e.g. /remote.php/webdav
|
||||
|
||||
# Find all href elements
|
||||
for href in root.findall('.//{DAV:}href'):
|
||||
href_text = href.text or ""
|
||||
href_text = unquote(href_text) # Decode URL encoding
|
||||
|
||||
# Find all response elements
|
||||
for response in root.findall('.//{DAV:}response'):
|
||||
href = response.find('.//{DAV:}href')
|
||||
if href is None or href.text is None:
|
||||
continue
|
||||
|
||||
href_text = unquote(href.text) # Decode URL encoding
|
||||
|
||||
# Check if this is a directory (has collection resourcetype)
|
||||
propstat = response.find('.//{DAV:}propstat')
|
||||
is_directory = False
|
||||
if propstat is not None:
|
||||
prop = propstat.find('.//{DAV:}prop')
|
||||
if prop is not None:
|
||||
resourcetype = prop.find('.//{DAV:}resourcetype')
|
||||
if resourcetype is not None and resourcetype.find('.//{DAV:}collection') is not None:
|
||||
is_directory = True
|
||||
|
||||
# Skip directories
|
||||
if is_directory:
|
||||
continue
|
||||
|
||||
# Also skip paths ending with / (another way to detect directories)
|
||||
if href_text.endswith('/'):
|
||||
continue
|
||||
|
||||
# Remove base URL from href
|
||||
base_url = settings.NEXTCLOUD_URL.rstrip('/')
|
||||
if href_text.startswith(base_url):
|
||||
href_text = href_text[len(base_url):]
|
||||
|
||||
|
||||
# Also strip the webdav path if it's there
|
||||
if href_text.startswith(webdav_path):
|
||||
href_text = href_text[len(webdav_path):]
|
||||
|
||||
|
||||
# Clean up the path
|
||||
href_text = href_text.lstrip('/')
|
||||
if href_text: # Skip empty paths (root directory)
|
||||
@@ -210,6 +232,59 @@ class WebDAVService:
|
||||
except WebDAVError:
|
||||
return False
|
||||
|
||||
def upload_batch(
|
||||
self,
|
||||
files: List[Tuple[Path, str]],
|
||||
max_workers: int = 4,
|
||||
timeout: int = 120
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
Upload multiple files concurrently.
|
||||
|
||||
Args:
|
||||
files: List of (local_path, remote_path) tuples
|
||||
max_workers: Maximum concurrent uploads
|
||||
timeout: Timeout per upload in seconds
|
||||
|
||||
Returns:
|
||||
Dict mapping remote_path to success status
|
||||
"""
|
||||
if not files:
|
||||
return {}
|
||||
|
||||
results: Dict[str, bool] = {}
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all upload tasks
|
||||
future_to_path = {
|
||||
executor.submit(self.upload, local, remote): remote
|
||||
for local, remote in files
|
||||
}
|
||||
|
||||
# Collect results as they complete
|
||||
for future in as_completed(future_to_path, timeout=timeout):
|
||||
remote_path = future_to_path[future]
|
||||
try:
|
||||
future.result()
|
||||
results[remote_path] = True
|
||||
self.logger.info(f"Successfully uploaded: {remote_path}")
|
||||
except Exception as e:
|
||||
results[remote_path] = False
|
||||
self.logger.error(f"Failed to upload {remote_path}: {e}")
|
||||
|
||||
failed_count = sum(1 for success in results.values() if not success)
|
||||
if failed_count > 0:
|
||||
self.logger.warning(
|
||||
f"Batch upload completed with {failed_count} failures "
|
||||
f"({len(results) - failed_count}/{len(results)} successful)"
|
||||
)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Batch upload completed: {len(results)} files uploaded successfully"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global instance
|
||||
webdav_service = WebDAVService()
|
||||
|
||||
Reference in New Issue
Block a user