feat: Sistema CBCFacil completo con cola secuencial
- Implementa ProcessingMonitor singleton para procesamiento secuencial de archivos - Agrega AI summary service con soporte para MiniMax API - Agrega PDF generator para resúmenes - Agrega watchers para monitoreo de carpeta remota - Mejora sistema de notificaciones Telegram - Implementa gestión de VRAM para GPU - Configuración mediante variables de entorno (sin hardcoded secrets) - .env y transcriptions/ agregados a .gitignore Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,17 +1,4 @@
|
||||
"""
|
||||
Services package for CBCFacil
|
||||
"""
|
||||
from .webdav_service import WebDAVService, webdav_service
|
||||
from .vram_manager import VRAMManager, vram_manager
|
||||
from .telegram_service import TelegramService, telegram_service
|
||||
from .gpu_detector import GPUDetector, GPUType, gpu_detector
|
||||
from .ai import ai_service
|
||||
|
||||
__all__ = [
|
||||
'WebDAVService', 'webdav_service',
|
||||
'VRAMManager', 'vram_manager',
|
||||
'TelegramService', 'telegram_service',
|
||||
'GPUDetector', 'GPUType', 'gpu_detector',
|
||||
'ai_service'
|
||||
]
|
||||
"""Export de servicios."""
|
||||
from .webdav_service import WebDAVService
|
||||
|
||||
__all__ = ["WebDAVService"]
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
"""
|
||||
AI Providers package for CBCFacil
|
||||
"""
|
||||
|
||||
from .base_provider import AIProvider
|
||||
from .claude_provider import ClaudeProvider
|
||||
from .gemini_provider import GeminiProvider
|
||||
from .provider_factory import AIProviderFactory, ai_provider_factory
|
||||
|
||||
# Alias for backwards compatibility
|
||||
ai_service = ai_provider_factory
|
||||
|
||||
__all__ = [
|
||||
'AIProvider',
|
||||
'ClaudeProvider',
|
||||
'GeminiProvider',
|
||||
'AIProviderFactory',
|
||||
'ai_provider_factory',
|
||||
'ai_service'
|
||||
]
|
||||
@@ -1,45 +0,0 @@
|
||||
"""
|
||||
Base AI Provider interface (Strategy pattern)
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
"""Abstract base class for AI providers"""
|
||||
|
||||
@abstractmethod
|
||||
def summarize(self, text: str, **kwargs) -> str:
|
||||
"""Generate summary of text"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def correct_text(self, text: str, **kwargs) -> str:
|
||||
"""Correct grammar and spelling in text"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def classify_content(self, text: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Classify content into categories"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate_text(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text from prompt"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fix_latex(self, latex_code: str, error_log: str, **kwargs) -> str:
|
||||
"""Fix broken LaTeX code based on compiler error log"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""Check if provider is available and configured"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Provider name"""
|
||||
pass
|
||||
@@ -1,158 +0,0 @@
|
||||
"""
|
||||
Claude AI Provider implementation
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import shutil
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from config import settings
|
||||
from core import AIProcessingError
|
||||
from .base_provider import AIProvider
|
||||
|
||||
|
||||
class ClaudeProvider(AIProvider):
|
||||
"""Claude AI provider using CLI"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._cli_path = settings.CLAUDE_CLI_PATH or shutil.which("claude")
|
||||
self._token = settings.ZAI_AUTH_TOKEN
|
||||
self._base_url = settings.ZAI_BASE_URL
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "Claude"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Claude CLI is available"""
|
||||
return bool(self._cli_path and self._token)
|
||||
|
||||
def _get_env(self) -> Dict[str, str]:
|
||||
"""Get environment variables for Claude"""
|
||||
# Load all user environment variables first
|
||||
import os
|
||||
|
||||
env = os.environ.copy()
|
||||
|
||||
# Override with our specific settings if available
|
||||
if self._token:
|
||||
env["ANTHROPIC_AUTH_TOKEN"] = self._token
|
||||
if self._base_url:
|
||||
env["ANTHROPIC_BASE_URL"] = self._base_url
|
||||
|
||||
# Add critical flags
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
|
||||
# Ensure model variables are picked up from env (already in os.environ)
|
||||
# but if we had explicit settings for them, we'd set them here.
|
||||
# Since we put them in .env and loaded via load_dotenv -> os.environ,
|
||||
# simply copying os.environ is sufficient.
|
||||
|
||||
return env
|
||||
|
||||
def _run_cli(self, prompt: str, timeout: int = 600) -> str:
|
||||
"""Run Claude CLI with prompt using -p flag for stdin input"""
|
||||
if not self.is_available():
|
||||
raise AIProcessingError("Claude CLI not available or not configured")
|
||||
|
||||
try:
|
||||
# Use -p flag to read prompt from stdin, --dangerously-skip-permissions for automation
|
||||
cmd = [self._cli_path, "--dangerously-skip-permissions", "-p", "-"]
|
||||
process = subprocess.run(
|
||||
cmd,
|
||||
input=prompt,
|
||||
env=self._get_env(),
|
||||
text=True,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
shell=False,
|
||||
)
|
||||
|
||||
if process.returncode != 0:
|
||||
error_msg = process.stderr or "Unknown error"
|
||||
raise AIProcessingError(f"Claude CLI failed: {error_msg}")
|
||||
|
||||
return process.stdout.strip()
|
||||
except subprocess.TimeoutExpired:
|
||||
raise AIProcessingError(f"Claude CLI timed out after {timeout}s")
|
||||
except Exception as e:
|
||||
raise AIProcessingError(f"Claude CLI error: {e}")
|
||||
|
||||
def summarize(self, text: str, **kwargs) -> str:
|
||||
"""Generate summary using Claude"""
|
||||
prompt = f"""Summarize the following text:
|
||||
|
||||
{text}
|
||||
|
||||
Provide a clear, concise summary in Spanish."""
|
||||
return self._run_cli(prompt)
|
||||
|
||||
def correct_text(self, text: str, **kwargs) -> str:
|
||||
"""Correct text using Claude"""
|
||||
prompt = f"""Correct the following text for grammar, spelling, and clarity:
|
||||
|
||||
{text}
|
||||
|
||||
Return only the corrected text, nothing else."""
|
||||
return self._run_cli(prompt)
|
||||
|
||||
def classify_content(self, text: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Classify content using Claude"""
|
||||
categories = [
|
||||
"historia",
|
||||
"analisis_contable",
|
||||
"instituciones_gobierno",
|
||||
"otras_clases",
|
||||
]
|
||||
|
||||
prompt = f"""Classify the following text into one of these categories:
|
||||
- historia
|
||||
- analisis_contable
|
||||
- instituciones_gobierno
|
||||
- otras_clases
|
||||
|
||||
Text: {text}
|
||||
|
||||
Return only the category name, nothing else."""
|
||||
result = self._run_cli(prompt).lower()
|
||||
|
||||
# Validate result
|
||||
if result not in categories:
|
||||
result = "otras_clases"
|
||||
|
||||
return {"category": result, "confidence": 0.9, "provider": self.name}
|
||||
|
||||
def generate_text(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text using Claude"""
|
||||
return self._run_cli(prompt)
|
||||
|
||||
def fix_latex(self, latex_code: str, error_log: str, **kwargs) -> str:
|
||||
"""Fix broken LaTeX code using Claude"""
|
||||
prompt = f"""I have a LaTeX file that failed to compile. Please fix the code.
|
||||
|
||||
COMPILER ERROR LOG:
|
||||
{error_log[-3000:]}
|
||||
|
||||
BROKEN LATEX CODE:
|
||||
{latex_code}
|
||||
|
||||
INSTRUCTIONS:
|
||||
1. Analyze the error log to find the specific syntax error.
|
||||
2. Fix the LaTeX code.
|
||||
3. Return ONLY the full corrected LaTeX code.
|
||||
4. Do not include markdown blocks or explanations.
|
||||
5. Start immediately with \\documentclass.
|
||||
|
||||
COMMON LATEX ERRORS TO CHECK:
|
||||
- TikZ nodes with line breaks (\\\\) MUST have "align=center" in their style.
|
||||
WRONG: \\node[box] (n) {{Text\\\\More}};
|
||||
CORRECT: \\node[box, align=center] (n) {{Text\\\\More}};
|
||||
- All \\begin{{env}} must have matching \\end{{env}}
|
||||
- All braces {{ }} must be balanced
|
||||
- Math mode $ must be paired
|
||||
- Special characters need escaping: % & # _
|
||||
- tcolorbox environments need proper titles: [Title] not {{Title}}
|
||||
"""
|
||||
return self._run_cli(prompt, timeout=180)
|
||||
@@ -1,337 +0,0 @@
|
||||
"""
|
||||
Gemini AI Provider - Optimized version with rate limiting and retry
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import shutil
|
||||
import requests
|
||||
import time
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from config import settings
|
||||
from core import AIProcessingError
|
||||
from .base_provider import AIProvider
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
"""Token bucket rate limiter"""
|
||||
|
||||
def __init__(self, rate: float = 10, capacity: int = 20):
|
||||
self.rate = rate # tokens per second
|
||||
self.capacity = capacity
|
||||
self.tokens = capacity
|
||||
self.last_update = time.time()
|
||||
self._lock = None # Lazy initialization
|
||||
|
||||
def _get_lock(self):
|
||||
if self._lock is None:
|
||||
import threading
|
||||
|
||||
self._lock = threading.Lock()
|
||||
return self._lock
|
||||
|
||||
def acquire(self, tokens: int = 1) -> float:
|
||||
with self._get_lock():
|
||||
now = time.time()
|
||||
elapsed = now - self.last_update
|
||||
self.last_update = now
|
||||
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
|
||||
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return 0.0
|
||||
|
||||
wait_time = (tokens - self.tokens) / self.rate
|
||||
self.tokens = 0
|
||||
return wait_time
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""Circuit breaker for API calls"""
|
||||
|
||||
def __init__(self, failure_threshold: int = 5, recovery_timeout: int = 60):
|
||||
self.failure_threshold = failure_threshold
|
||||
self.recovery_timeout = recovery_timeout
|
||||
self.failures = 0
|
||||
self.last_failure: Optional[datetime] = None
|
||||
self.state = "closed" # closed, open, half-open
|
||||
self._lock = None
|
||||
|
||||
def _get_lock(self):
|
||||
if self._lock is None:
|
||||
import threading
|
||||
|
||||
self._lock = threading.Lock()
|
||||
return self._lock
|
||||
|
||||
def call(self, func, *args, **kwargs):
|
||||
with self._get_lock():
|
||||
if self.state == "open":
|
||||
if (
|
||||
self.last_failure
|
||||
and (datetime.utcnow() - self.last_failure).total_seconds()
|
||||
> self.recovery_timeout
|
||||
):
|
||||
self.state = "half-open"
|
||||
else:
|
||||
raise AIProcessingError("Circuit breaker is open")
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
if self.state == "half-open":
|
||||
self.state = "closed"
|
||||
self.failures = 0
|
||||
return result
|
||||
except Exception as e:
|
||||
self.failures += 1
|
||||
self.last_failure = datetime.utcnow()
|
||||
if self.failures >= self.failure_threshold:
|
||||
self.state = "open"
|
||||
raise
|
||||
|
||||
|
||||
class GeminiProvider(AIProvider):
|
||||
"""Gemini AI provider with rate limiting and retry"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._cli_path = settings.GEMINI_CLI_PATH or shutil.which("gemini")
|
||||
self._api_key = settings.GEMINI_API_KEY
|
||||
self._flash_model = settings.GEMINI_FLASH_MODEL
|
||||
self._pro_model = settings.GEMINI_PRO_MODEL
|
||||
self._session = None
|
||||
self._rate_limiter = TokenBucket(rate=15, capacity=30)
|
||||
self._circuit_breaker = CircuitBreaker(failure_threshold=5, recovery_timeout=60)
|
||||
self._retry_config = {
|
||||
"max_attempts": 3,
|
||||
"base_delay": 1.0,
|
||||
"max_delay": 30.0,
|
||||
"exponential_base": 2,
|
||||
}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "Gemini"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Gemini CLI or API is available"""
|
||||
return bool(self._cli_path or self._api_key)
|
||||
|
||||
def _init_session(self) -> None:
|
||||
"""Initialize HTTP session with connection pooling"""
|
||||
if self._session is None:
|
||||
self._session = requests.Session()
|
||||
adapter = requests.adapters.HTTPAdapter(
|
||||
pool_connections=10,
|
||||
pool_maxsize=20,
|
||||
max_retries=0, # We handle retries manually
|
||||
)
|
||||
self._session.mount("https://", adapter)
|
||||
|
||||
def _run_with_retry(self, func, *args, **kwargs):
|
||||
"""Execute function with exponential backoff retry"""
|
||||
max_attempts = self._retry_config["max_attempts"]
|
||||
base_delay = self._retry_config["base_delay"]
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
return self._circuit_breaker.call(func, *args, **kwargs)
|
||||
except requests.exceptions.RequestException as e:
|
||||
last_exception = e
|
||||
if attempt < max_attempts - 1:
|
||||
delay = min(
|
||||
base_delay * (2**attempt), self._retry_config["max_delay"]
|
||||
)
|
||||
# Add jitter
|
||||
delay += delay * 0.1 * (time.time() % 1)
|
||||
self.logger.warning(
|
||||
f"Attempt {attempt + 1} failed: {e}, retrying in {delay:.2f}s"
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
raise AIProcessingError(f"Max retries exceeded: {last_exception}")
|
||||
|
||||
def _run_cli(self, prompt: str, use_flash: bool = True, timeout: int = 300) -> str:
|
||||
"""Run Gemini CLI with prompt"""
|
||||
if not self._cli_path:
|
||||
raise AIProcessingError("Gemini CLI not available")
|
||||
|
||||
model = self._flash_model if use_flash else self._pro_model
|
||||
cmd = [self._cli_path, model, prompt]
|
||||
|
||||
try:
|
||||
# Apply rate limiting
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
process = subprocess.run(
|
||||
cmd, text=True, capture_output=True, timeout=timeout, shell=False
|
||||
)
|
||||
|
||||
if process.returncode != 0:
|
||||
error_msg = process.stderr or "Unknown error"
|
||||
raise AIProcessingError(f"Gemini CLI failed: {error_msg}")
|
||||
|
||||
return process.stdout.strip()
|
||||
except subprocess.TimeoutExpired:
|
||||
raise AIProcessingError(f"Gemini CLI timed out after {timeout}s")
|
||||
except Exception as e:
|
||||
raise AIProcessingError(f"Gemini CLI error: {e}")
|
||||
|
||||
def _call_api(self, prompt: str, use_flash: bool = True, timeout: int = 180) -> str:
|
||||
"""Call Gemini API with rate limiting and retry"""
|
||||
if not self._api_key:
|
||||
raise AIProcessingError("Gemini API key not configured")
|
||||
|
||||
self._init_session()
|
||||
|
||||
model = self._flash_model if use_flash else self._pro_model
|
||||
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent"
|
||||
|
||||
payload = {"contents": [{"parts": [{"text": prompt}]}]}
|
||||
|
||||
params = {"key": self._api_key}
|
||||
|
||||
def api_call():
|
||||
# Apply rate limiting
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
response = self._session.post(
|
||||
url, json=payload, params=params, timeout=timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
response = self._run_with_retry(api_call)
|
||||
data = response.json()
|
||||
|
||||
if "candidates" not in data or not data["candidates"]:
|
||||
raise AIProcessingError("Empty response from Gemini API")
|
||||
|
||||
candidate = data["candidates"][0]
|
||||
if "content" not in candidate or "parts" not in candidate["content"]:
|
||||
raise AIProcessingError("Invalid response format from Gemini API")
|
||||
|
||||
result = candidate["content"]["parts"][0]["text"]
|
||||
return result.strip()
|
||||
|
||||
def _run(self, prompt: str, use_flash: bool = True, timeout: int = 300) -> str:
|
||||
"""Run Gemini with fallback between CLI and API"""
|
||||
# Try CLI first if available
|
||||
if self._cli_path:
|
||||
try:
|
||||
return self._run_cli(prompt, use_flash, timeout)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Gemini CLI failed, trying API: {e}")
|
||||
|
||||
# Fallback to API
|
||||
if self._api_key:
|
||||
api_timeout = min(timeout, 180)
|
||||
return self._call_api(prompt, use_flash, api_timeout)
|
||||
|
||||
raise AIProcessingError("No Gemini provider available (CLI or API)")
|
||||
|
||||
def summarize(self, text: str, **kwargs) -> str:
|
||||
"""Generate summary using Gemini"""
|
||||
prompt = f"""Summarize the following text:
|
||||
|
||||
{text}
|
||||
|
||||
Provide a clear, concise summary in Spanish."""
|
||||
return self._run(prompt, use_flash=True)
|
||||
|
||||
def correct_text(self, text: str, **kwargs) -> str:
|
||||
"""Correct text using Gemini"""
|
||||
prompt = f"""Correct the following text for grammar, spelling, and clarity:
|
||||
|
||||
{text}
|
||||
|
||||
Return only the corrected text, nothing else."""
|
||||
return self._run(prompt, use_flash=True)
|
||||
|
||||
def classify_content(self, text: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Classify content using Gemini"""
|
||||
categories = [
|
||||
"historia",
|
||||
"analisis_contable",
|
||||
"instituciones_gobierno",
|
||||
"otras_clases",
|
||||
]
|
||||
|
||||
prompt = f"""Classify the following text into one of these categories:
|
||||
- historia
|
||||
- analisis_contable
|
||||
- instituciones_gobierno
|
||||
- otras_clases
|
||||
|
||||
Text: {text}
|
||||
|
||||
Return only the category name, nothing else."""
|
||||
result = self._run(prompt, use_flash=True).lower()
|
||||
|
||||
# Validate result
|
||||
if result not in categories:
|
||||
result = "otras_clases"
|
||||
|
||||
return {"category": result, "confidence": 0.9, "provider": self.name}
|
||||
|
||||
def generate_text(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text using Gemini"""
|
||||
use_flash = kwargs.get("use_flash", True)
|
||||
if self._api_key:
|
||||
return self._call_api(prompt, use_flash=use_flash)
|
||||
return self._run_cli(prompt, use_flash=use_flash)
|
||||
|
||||
def fix_latex(self, latex_code: str, error_log: str, **kwargs) -> str:
|
||||
"""Fix broken LaTeX code using Gemini"""
|
||||
prompt = f"""Fix the following LaTeX code which failed to compile.
|
||||
|
||||
Error Log:
|
||||
{error_log[-3000:]}
|
||||
|
||||
Broken Code:
|
||||
{latex_code}
|
||||
|
||||
INSTRUCTIONS:
|
||||
1. Return ONLY the corrected LaTeX code. No explanations.
|
||||
2. Start immediately with \\documentclass.
|
||||
|
||||
COMMON LATEX ERRORS TO FIX:
|
||||
- TikZ nodes with line breaks (\\\\) MUST have "align=center" in their style.
|
||||
WRONG: \\node[box] (n) {{Text\\\\More}};
|
||||
CORRECT: \\node[box, align=center] (n) {{Text\\\\More}};
|
||||
- All \\begin{{env}} must have matching \\end{{env}}
|
||||
- All braces {{ }} must be balanced
|
||||
- Math mode $ must be paired
|
||||
- Special characters need escaping: % & # _
|
||||
- tcolorbox environments need proper titles: [Title] not {{Title}}
|
||||
"""
|
||||
return self._run(prompt, use_flash=False) # Use Pro model for coding fixes
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get provider statistics"""
|
||||
return {
|
||||
"rate_limiter": {
|
||||
"tokens": round(self._rate_limiter.tokens, 2),
|
||||
"capacity": self._rate_limiter.capacity,
|
||||
"rate": self._rate_limiter.rate,
|
||||
},
|
||||
"circuit_breaker": {
|
||||
"state": self._circuit_breaker.state,
|
||||
"failures": self._circuit_breaker.failures,
|
||||
"failure_threshold": self._circuit_breaker.failure_threshold,
|
||||
},
|
||||
"cli_available": bool(self._cli_path),
|
||||
"api_available": bool(self._api_key),
|
||||
}
|
||||
|
||||
|
||||
# Global instance is created in __init__.py
|
||||
@@ -1,346 +0,0 @@
|
||||
"""
|
||||
Parallel AI Provider - Race multiple providers for fastest response
|
||||
Implements Strategy A: Parallel Generation with Consensus
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
from core import AIProcessingError
|
||||
from .base_provider import AIProvider
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderResult:
|
||||
"""Result from a single provider"""
|
||||
provider_name: str
|
||||
content: str
|
||||
duration_ms: int
|
||||
success: bool
|
||||
error: Optional[str] = None
|
||||
quality_score: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelResult:
|
||||
"""Aggregated result from parallel execution"""
|
||||
content: str
|
||||
strategy: str
|
||||
providers_used: List[str]
|
||||
total_duration_ms: int
|
||||
all_results: List[ProviderResult]
|
||||
selected_provider: str
|
||||
|
||||
|
||||
class ParallelAIProvider:
|
||||
"""
|
||||
Orchestrates multiple AI providers in parallel for faster responses.
|
||||
|
||||
Strategies:
|
||||
- "race": Use first successful response (fastest)
|
||||
- "consensus": Wait for all, select best quality
|
||||
- "majority": Select most common response
|
||||
"""
|
||||
|
||||
def __init__(self, providers: Dict[str, AIProvider], max_workers: int = 4):
|
||||
self.providers = providers
|
||||
self.max_workers = max_workers
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
def _generate_sync(self, provider: AIProvider, prompt: str, **kwargs) -> ProviderResult:
|
||||
"""Synchronous wrapper for provider generation"""
|
||||
start_time = datetime.now()
|
||||
try:
|
||||
content = provider.generate_text(prompt, **kwargs)
|
||||
duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
|
||||
# Calculate quality score
|
||||
quality_score = self._calculate_quality_score(content)
|
||||
|
||||
return ProviderResult(
|
||||
provider_name=provider.name,
|
||||
content=content,
|
||||
duration_ms=duration_ms,
|
||||
success=True,
|
||||
quality_score=quality_score
|
||||
)
|
||||
except Exception as e:
|
||||
duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
self.logger.error(f"{provider.name} failed: {e}")
|
||||
return ProviderResult(
|
||||
provider_name=provider.name,
|
||||
content="",
|
||||
duration_ms=duration_ms,
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def _calculate_quality_score(self, content: str) -> float:
|
||||
"""Calculate quality score for generated content"""
|
||||
score = 0.0
|
||||
|
||||
# Length check (comprehensive is better)
|
||||
if 500 < len(content) < 50000:
|
||||
score += 0.2
|
||||
|
||||
# LaTeX structure validation
|
||||
latex_indicators = [
|
||||
r"\documentclass",
|
||||
r"\begin{document}",
|
||||
r"\section",
|
||||
r"\subsection",
|
||||
r"\begin{itemize}",
|
||||
r"\end{document}"
|
||||
]
|
||||
found_indicators = sum(1 for ind in latex_indicators if ind in content)
|
||||
score += (found_indicators / len(latex_indicators)) * 0.4
|
||||
|
||||
# Bracket matching
|
||||
if content.count("{") == content.count("}"):
|
||||
score += 0.2
|
||||
|
||||
# Environment closure
|
||||
envs = ["document", "itemize", "enumerate"]
|
||||
for env in envs:
|
||||
if content.count(f"\\begin{{{env}}}") == content.count(f"\\end{{{env}}}"):
|
||||
score += 0.1
|
||||
|
||||
# Has content beyond template
|
||||
if len(content) > 1000:
|
||||
score += 0.1
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
def generate_parallel(
|
||||
self,
|
||||
prompt: str,
|
||||
strategy: str = "race",
|
||||
timeout_ms: int = 300000, # 5 minutes default
|
||||
**kwargs
|
||||
) -> ParallelResult:
|
||||
"""
|
||||
Execute prompt across multiple providers in parallel.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to all providers
|
||||
strategy: "race", "consensus", or "majority"
|
||||
timeout_ms: Maximum time to wait for results
|
||||
**kwargs: Additional arguments for providers
|
||||
|
||||
Returns:
|
||||
ParallelResult with selected content and metadata
|
||||
"""
|
||||
if not self.providers:
|
||||
raise AIProcessingError("No providers available for parallel execution")
|
||||
|
||||
start_time = datetime.now()
|
||||
all_results: List[ProviderResult] = []
|
||||
|
||||
# Submit all providers
|
||||
futures = {}
|
||||
for name, provider in self.providers.items():
|
||||
if provider.is_available():
|
||||
future = self.executor.submit(
|
||||
self._generate_sync,
|
||||
provider,
|
||||
prompt,
|
||||
**kwargs
|
||||
)
|
||||
futures[future] = name
|
||||
|
||||
# Wait for results based on strategy
|
||||
if strategy == "race":
|
||||
all_results = self._race_strategy(futures, timeout_ms)
|
||||
elif strategy == "consensus":
|
||||
all_results = self._consensus_strategy(futures, timeout_ms)
|
||||
elif strategy == "majority":
|
||||
all_results = self._majority_strategy(futures, timeout_ms)
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy: {strategy}")
|
||||
|
||||
# Select best result
|
||||
selected = self._select_result(all_results, strategy)
|
||||
|
||||
total_duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
|
||||
self.logger.info(
|
||||
f"Parallel generation complete: {strategy} strategy, "
|
||||
f"{len(all_results)} providers, {selected.provider_name} selected, "
|
||||
f"{total_duration_ms}ms"
|
||||
)
|
||||
|
||||
return ParallelResult(
|
||||
content=selected.content,
|
||||
strategy=strategy,
|
||||
providers_used=[r.provider_name for r in all_results if r.success],
|
||||
total_duration_ms=total_duration_ms,
|
||||
all_results=all_results,
|
||||
selected_provider=selected.provider_name
|
||||
)
|
||||
|
||||
def _race_strategy(
|
||||
self,
|
||||
futures: dict,
|
||||
timeout_ms: int
|
||||
) -> List[ProviderResult]:
|
||||
"""Return first successful response"""
|
||||
results = []
|
||||
for future in as_completed(futures, timeout=timeout_ms / 1000):
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
if result.success:
|
||||
# Got a successful response, cancel remaining
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Future failed: {e}")
|
||||
return results
|
||||
|
||||
def _consensus_strategy(
|
||||
self,
|
||||
futures: dict,
|
||||
timeout_ms: int
|
||||
) -> List[ProviderResult]:
|
||||
"""Wait for all, return all results"""
|
||||
results = []
|
||||
for future in as_completed(futures, timeout=timeout_ms / 1000):
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Future failed: {e}")
|
||||
return results
|
||||
|
||||
def _majority_strategy(
|
||||
self,
|
||||
futures: dict,
|
||||
timeout_ms: int
|
||||
) -> List[ProviderResult]:
|
||||
"""Wait for majority, select most common response"""
|
||||
results = []
|
||||
for future in as_completed(futures, timeout=timeout_ms / 1000):
|
||||
try:
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Future failed: {e}")
|
||||
return results
|
||||
|
||||
def _select_result(self, results: List[ProviderResult], strategy: str) -> ProviderResult:
|
||||
"""Select best result based on strategy"""
|
||||
successful = [r for r in results if r.success]
|
||||
|
||||
if not successful:
|
||||
# Return first failed result with error info
|
||||
return results[0] if results else ProviderResult(
|
||||
provider_name="none",
|
||||
content="",
|
||||
duration_ms=0,
|
||||
success=False,
|
||||
error="All providers failed"
|
||||
)
|
||||
|
||||
if strategy == "race" or len(successful) == 1:
|
||||
return successful[0]
|
||||
|
||||
if strategy == "consensus":
|
||||
# Select by quality score
|
||||
return max(successful, key=lambda r: r.quality_score)
|
||||
|
||||
if strategy == "majority":
|
||||
# Group by similar content (simplified - use longest)
|
||||
return max(successful, key=lambda r: len(r.content))
|
||||
|
||||
return successful[0]
|
||||
|
||||
def fix_latex_parallel(
|
||||
self,
|
||||
latex_code: str,
|
||||
error_log: str,
|
||||
timeout_ms: int = 120000,
|
||||
**kwargs
|
||||
) -> ParallelResult:
|
||||
"""Try to fix LaTeX across multiple providers in parallel"""
|
||||
# Build fix prompt for each provider
|
||||
results = []
|
||||
start_time = datetime.now()
|
||||
|
||||
for name, provider in self.providers.items():
|
||||
if provider.is_available():
|
||||
try:
|
||||
start = datetime.now()
|
||||
fixed = provider.fix_latex(latex_code, error_log, **kwargs)
|
||||
duration_ms = int((datetime.now() - start).total_seconds() * 1000)
|
||||
|
||||
# Score by checking if error patterns are reduced
|
||||
quality = self._score_latex_fix(fixed, error_log)
|
||||
|
||||
results.append(ProviderResult(
|
||||
provider_name=name,
|
||||
content=fixed,
|
||||
duration_ms=duration_ms,
|
||||
success=True,
|
||||
quality_score=quality
|
||||
))
|
||||
except Exception as e:
|
||||
self.logger.error(f"{name} fix failed: {e}")
|
||||
|
||||
# Select best fix
|
||||
if results:
|
||||
selected = max(results, key=lambda r: r.quality_score)
|
||||
total_duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
|
||||
return ParallelResult(
|
||||
content=selected.content,
|
||||
strategy="consensus",
|
||||
providers_used=[r.provider_name for r in results],
|
||||
total_duration_ms=total_duration_ms,
|
||||
all_results=results,
|
||||
selected_provider=selected.provider_name
|
||||
)
|
||||
|
||||
raise AIProcessingError("All providers failed to fix LaTeX")
|
||||
|
||||
def _score_latex_fix(self, fixed_latex: str, original_error: str) -> float:
|
||||
"""Score a LaTeX fix attempt"""
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Check if common error patterns are addressed
|
||||
error_patterns = [
|
||||
("Undefined control sequence", r"\\[a-zA-Z]+"),
|
||||
("Missing $ inserted", r"\$.*\$"),
|
||||
("Runaway argument", r"\{.*\}"),
|
||||
]
|
||||
|
||||
for error_msg, pattern in error_patterns:
|
||||
if error_msg in original_error:
|
||||
# If error was in original, check if pattern appears better
|
||||
score += 0.1
|
||||
|
||||
# Validate bracket matching
|
||||
if fixed_latex.count("{") == fixed_latex.count("}"):
|
||||
score += 0.2
|
||||
|
||||
# Validate environment closure
|
||||
envs = ["document", "itemize", "enumerate"]
|
||||
for env in envs:
|
||||
begin_count = fixed_latex.count(f"\\begin{{{env}}}")
|
||||
end_count = fixed_latex.count(f"\\end{{{env}}}")
|
||||
if begin_count == end_count:
|
||||
score += 0.1
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the executor"""
|
||||
self.executor.shutdown(wait=True)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
@@ -1,343 +0,0 @@
|
||||
"""
|
||||
Prompt Manager - Centralized prompt management using resumen.md as source of truth
|
||||
"""
|
||||
|
||||
import re
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
from config import settings
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""
|
||||
Manages prompts for AI services, loading templates from latex/resumen.md
|
||||
This is the SINGLE SOURCE OF TRUTH for academic summary generation.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_prompt_cache: Optional[str] = None
|
||||
_latex_preamble_cache: Optional[str] = None
|
||||
|
||||
# Path to the prompt template file
|
||||
PROMPT_FILE_PATH = Path("latex/resumen.md")
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(PromptManager, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def _load_prompt_template(self) -> str:
|
||||
"""Load the complete prompt template from resumen.md"""
|
||||
if self._prompt_cache:
|
||||
return self._prompt_cache
|
||||
|
||||
try:
|
||||
file_path = self.PROMPT_FILE_PATH.resolve()
|
||||
|
||||
if not file_path.exists():
|
||||
self._prompt_cache = self._get_fallback_prompt()
|
||||
return self._prompt_cache
|
||||
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
# The file has a markdown code block after "## Prompt Template"
|
||||
# We need to find the content from "## Prompt Template" to the LAST ```
|
||||
# (because there's a ```latex...``` block INSIDE the template)
|
||||
|
||||
# First, find where "## Prompt Template" starts
|
||||
template_start = content.find("## Prompt Template")
|
||||
if template_start == -1:
|
||||
self._prompt_cache = self._get_fallback_prompt()
|
||||
return self._prompt_cache
|
||||
|
||||
# Find the opening ``` after the header
|
||||
after_header = content[template_start:]
|
||||
code_block_start = after_header.find("```")
|
||||
if code_block_start == -1:
|
||||
self._prompt_cache = self._get_fallback_prompt()
|
||||
return self._prompt_cache
|
||||
|
||||
# Skip the opening ``` and any language specifier
|
||||
after_code_start = after_header[code_block_start + 3:]
|
||||
first_newline = after_code_start.find("\n")
|
||||
if first_newline != -1:
|
||||
actual_content_start = template_start + code_block_start + 3 + first_newline + 1
|
||||
else:
|
||||
actual_content_start = template_start + code_block_start + 3
|
||||
|
||||
# Now find the LAST ``` that closes the main block
|
||||
# We look for ``` followed by optional space and then newline or end
|
||||
remaining = content[actual_content_start:]
|
||||
|
||||
# Find all positions of ``` in the remaining content
|
||||
positions = []
|
||||
pos = 0
|
||||
while True:
|
||||
found = remaining.find("```", pos)
|
||||
if found == -1:
|
||||
break
|
||||
positions.append(found)
|
||||
pos = found + 3
|
||||
|
||||
if not positions:
|
||||
self._prompt_cache = self._get_fallback_prompt()
|
||||
return self._prompt_cache
|
||||
|
||||
# The LAST ``` is the closing of the main block
|
||||
# (all previous ``` are the latex block inside the template)
|
||||
last_backtick_pos = positions[-1]
|
||||
|
||||
# Extract the content
|
||||
template_content = content[actual_content_start:actual_content_start + last_backtick_pos]
|
||||
|
||||
# Remove leading newline if present
|
||||
template_content = template_content.lstrip("\n")
|
||||
|
||||
self._prompt_cache = template_content
|
||||
return self._prompt_cache
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading prompt file: {e}")
|
||||
self._prompt_cache = self._get_fallback_prompt()
|
||||
return self._prompt_cache
|
||||
|
||||
def _get_fallback_prompt(self) -> str:
|
||||
"""Fallback prompt if resumen.md is not found"""
|
||||
return """Sos un asistente académico experto. Creá un resumen extenso en LaTeX basado en la transcripción de clase.
|
||||
|
||||
## Transcripción de clase:
|
||||
[PEGAR TRANSCRIPCIÓN AQUÍ]
|
||||
|
||||
## Material bibliográfico:
|
||||
[PEGAR TEXTO DEL LIBRO/APUNTE O INDICAR QUE LO SUBISTE COMO ARCHIVO]
|
||||
|
||||
Generá un archivo LaTeX completo con:
|
||||
- Estructura académica formal
|
||||
- Mínimo 10 páginas de contenido
|
||||
- Fórmulas matemáticas en LaTeX
|
||||
- Tablas y diagramas cuando corresponda
|
||||
"""
|
||||
|
||||
def _load_latex_preamble(self) -> str:
|
||||
"""Extract the LaTeX preamble from resumen.md"""
|
||||
if self._latex_preamble_cache:
|
||||
return self._latex_preamble_cache
|
||||
|
||||
try:
|
||||
file_path = self.PROMPT_FILE_PATH.resolve()
|
||||
|
||||
if not file_path.exists():
|
||||
return self._get_default_preamble()
|
||||
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
# Extract LaTeX code block in the template
|
||||
match = re.search(
|
||||
r"```latex\s*\n([\s\S]*?)\n```",
|
||||
content
|
||||
)
|
||||
|
||||
if match:
|
||||
self._latex_preamble_cache = match.group(1).strip()
|
||||
else:
|
||||
self._latex_preamble_cache = self._get_default_preamble()
|
||||
|
||||
return self._latex_preamble_cache
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading LaTeX preamble: {e}")
|
||||
return self._get_default_preamble()
|
||||
|
||||
def _get_default_preamble(self) -> str:
|
||||
"""Default LaTeX preamble"""
|
||||
return r"""\documentclass[11pt,a4paper]{article}
|
||||
\usepackage[utf8]{inputenc}
|
||||
\usepackage[spanish,provide=*]{babel}
|
||||
\usepackage{amsmath,amssymb}
|
||||
\usepackage{geometry}
|
||||
\usepackage{graphicx}
|
||||
\usepackage{tikz}
|
||||
\usetikzlibrary{arrows.meta,positioning,shapes.geometric,calc}
|
||||
\usepackage{booktabs}
|
||||
\usepackage{enumitem}
|
||||
\usepackage{fancyhdr}
|
||||
\usepackage{titlesec}
|
||||
\usepackage{tcolorbox}
|
||||
\usepackage{array}
|
||||
\usepackage{multirow}
|
||||
|
||||
\geometry{margin=2.5cm}
|
||||
\pagestyle{fancy}
|
||||
\fancyhf{}
|
||||
\fancyhead[L]{[MATERIA] - CBC}
|
||||
\fancyhead[R]{Clase [N]}
|
||||
\fancyfoot[C]{\thepage}
|
||||
|
||||
% Cajas para destacar contenido
|
||||
\newtcolorbox{definicion}[1][]{
|
||||
colback=blue!5!white,
|
||||
colframe=blue!75!black,
|
||||
fonttitle=\bfseries,
|
||||
title=#1
|
||||
}
|
||||
|
||||
\newtcolorbox{importante}[1][]{
|
||||
colback=red!5!white,
|
||||
colframe=red!75!black,
|
||||
fonttitle=\bfseries,
|
||||
title=#1
|
||||
}
|
||||
|
||||
\newtcolorbox{ejemplo}[1][]{
|
||||
colback=green!5!white,
|
||||
colframe=green!50!black,
|
||||
fonttitle=\bfseries,
|
||||
title=#1
|
||||
}
|
||||
"""
|
||||
|
||||
def get_latex_summary_prompt(
|
||||
self,
|
||||
transcription: str,
|
||||
materia: str = "Economía",
|
||||
bibliographic_text: Optional[str] = None,
|
||||
class_number: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate the complete prompt for LaTeX academic summary based on resumen.md template.
|
||||
|
||||
Args:
|
||||
transcription: The class transcription text
|
||||
materia: Subject name (default: "Economía")
|
||||
bibliographic_text: Optional supporting text from books/notes
|
||||
class_number: Optional class number for header
|
||||
|
||||
Returns:
|
||||
Complete prompt string ready to send to AI
|
||||
"""
|
||||
template = self._load_prompt_template()
|
||||
|
||||
# CRITICAL: Prepend explicit instructions to force direct LaTeX generation
|
||||
# (This doesn't modify resumen.md, just adds context before it)
|
||||
explicit_instructions = """CRITICAL: Tu respuesta debe ser ÚNICAMENTE código LaTeX.
|
||||
|
||||
INSTRUCCIONES OBLIGATORIAS:
|
||||
1. NO incluyas explicaciones previas
|
||||
2. NO describas lo que vas a hacer
|
||||
3. Comienza INMEDIATAMENTE con \\documentclass
|
||||
4. Tu respuesta debe ser SOLO el código LaTeX fuente
|
||||
5. Termina con \\end{document}
|
||||
|
||||
---
|
||||
|
||||
"""
|
||||
|
||||
prompt = explicit_instructions + template
|
||||
|
||||
# Replace placeholders
|
||||
prompt = prompt.replace("[MATERIA]", materia)
|
||||
|
||||
# Insert transcription
|
||||
if "[PEGAR TRANSCRIPCIÓN AQUÍ]" in prompt:
|
||||
prompt = prompt.replace("[PEGAR TRANSCRIPCIÓN AQUÍ]", transcription)
|
||||
else:
|
||||
prompt += f"\n\n## Transcripción de clase:\n{transcription}"
|
||||
|
||||
# Insert bibliographic material
|
||||
bib_text = bibliographic_text or "No se proporcionó material bibliográfico adicional."
|
||||
if "[PEGAR TEXTO DEL LIBRO/APUNTE O INDICAR QUE LO SUBISTE COMO ARCHIVO]" in prompt:
|
||||
prompt = prompt.replace(
|
||||
"[PEGAR TEXTO DEL LIBRO/APUNTE O INDICAR QUE LO SUBISTE COMO ARCHIVO]",
|
||||
bib_text
|
||||
)
|
||||
else:
|
||||
prompt += f"\n\n## Material bibliográfico:\n{bib_text}"
|
||||
|
||||
# Add class number if provided
|
||||
if class_number is not None:
|
||||
prompt = prompt.replace("[N]", str(class_number))
|
||||
|
||||
return prompt
|
||||
|
||||
def get_latex_preamble(
|
||||
self,
|
||||
materia: str = "Economía",
|
||||
class_number: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Get the LaTeX preamble with placeholders replaced.
|
||||
|
||||
Args:
|
||||
materia: Subject name
|
||||
class_number: Optional class number
|
||||
|
||||
Returns:
|
||||
Complete LaTeX preamble as string
|
||||
"""
|
||||
preamble = self._load_latex_preamble()
|
||||
|
||||
# Replace placeholders
|
||||
preamble = preamble.replace("[MATERIA]", materia)
|
||||
if class_number is not None:
|
||||
preamble = preamble.replace("[N]", str(class_number))
|
||||
|
||||
return preamble
|
||||
|
||||
def get_latex_fix_prompt(self, latex_code: str, error_log: str) -> str:
|
||||
"""Get prompt for fixing broken LaTeX code"""
|
||||
return f"""I have a LaTeX file that failed to compile. Please fix the code.
|
||||
|
||||
COMPILER ERROR LOG:
|
||||
{error_log[-3000:]}
|
||||
|
||||
BROKEN LATEX CODE:
|
||||
{latex_code}
|
||||
|
||||
INSTRUCTIONS:
|
||||
1. Analyze the error log to find the specific syntax error.
|
||||
2. Fix the LaTeX code.
|
||||
3. Return ONLY the full corrected LaTeX code.
|
||||
4. Do not include markdown blocks or explanations.
|
||||
5. Start immediately with \\documentclass.
|
||||
6. Ensure all braces {{}} are properly balanced.
|
||||
7. Ensure all environments \\begin{{...}} have matching \\end{{...}}.
|
||||
8. Ensure all packages are properly declared.
|
||||
"""
|
||||
|
||||
def extract_latex_from_response(self, response: str) -> Optional[str]:
|
||||
"""
|
||||
Extract clean LaTeX code from AI response.
|
||||
|
||||
Handles cases where AI wraps LaTeX in ```latex...``` blocks.
|
||||
"""
|
||||
if not response:
|
||||
return None
|
||||
|
||||
# Try to find content inside ```latex ... ``` blocks
|
||||
code_block_pattern = r"```(?:latex|tex)?\s*([\s\S]*?)\s*```"
|
||||
match = re.search(code_block_pattern, response, re.IGNORECASE)
|
||||
|
||||
if match:
|
||||
latex = match.group(1).strip()
|
||||
else:
|
||||
latex = response.strip()
|
||||
|
||||
# Verify it looks like LaTeX
|
||||
if "\\documentclass" not in latex:
|
||||
return None
|
||||
|
||||
# Clean up: remove anything before \documentclass
|
||||
start_idx = latex.find("\\documentclass")
|
||||
latex = latex[start_idx:]
|
||||
|
||||
# Clean up: remove anything after \end{document}
|
||||
if "\\end{document}" in latex:
|
||||
end_idx = latex.rfind("\\end{document}")
|
||||
latex = latex[:end_idx + len("\\end{document}")]
|
||||
|
||||
return latex.strip()
|
||||
|
||||
|
||||
# Singleton instance for easy import
|
||||
prompt_manager = PromptManager()
|
||||
@@ -1,80 +0,0 @@
|
||||
"""
|
||||
AI Provider Factory (Factory Pattern)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Type, Optional
|
||||
|
||||
from core import AIProcessingError
|
||||
from .base_provider import AIProvider
|
||||
from .claude_provider import ClaudeProvider
|
||||
from .gemini_provider import GeminiProvider
|
||||
from .parallel_provider import ParallelAIProvider
|
||||
|
||||
|
||||
class AIProviderFactory:
|
||||
"""Factory for creating AI providers with fallback and parallel execution"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._providers: Dict[str, AIProvider] = {
|
||||
"claude": ClaudeProvider(),
|
||||
"gemini": GeminiProvider(),
|
||||
}
|
||||
self._parallel_provider: Optional[ParallelAIProvider] = None
|
||||
|
||||
def get_provider(self, preferred: str = "gemini") -> AIProvider:
|
||||
"""Get available provider with fallback"""
|
||||
# Try preferred provider first
|
||||
if preferred in self._providers:
|
||||
provider = self._providers[preferred]
|
||||
if provider.is_available():
|
||||
self.logger.info(f"Using {preferred} provider")
|
||||
return provider
|
||||
|
||||
# Fallback to any available provider
|
||||
for name, provider in self._providers.items():
|
||||
if provider.is_available():
|
||||
self.logger.info(f"Falling back to {name} provider")
|
||||
return provider
|
||||
|
||||
raise AIProcessingError("No AI providers available")
|
||||
|
||||
def get_all_available(self) -> Dict[str, AIProvider]:
|
||||
"""Get all available providers"""
|
||||
return {
|
||||
name: provider
|
||||
for name, provider in self._providers.items()
|
||||
if provider.is_available()
|
||||
}
|
||||
|
||||
def get_best_provider(self) -> AIProvider:
|
||||
"""Get the best available provider (Claude > Gemini)"""
|
||||
return self.get_provider("claude")
|
||||
|
||||
def get_parallel_provider(self, max_workers: int = 4) -> ParallelAIProvider:
|
||||
"""Get parallel provider for racing multiple AI providers"""
|
||||
available = self.get_all_available()
|
||||
|
||||
if not available:
|
||||
raise AIProcessingError("No providers available for parallel execution")
|
||||
|
||||
if self._parallel_provider is None:
|
||||
self._parallel_provider = ParallelAIProvider(
|
||||
providers=available,
|
||||
max_workers=max_workers
|
||||
)
|
||||
self.logger.info(
|
||||
f"Created parallel provider with {len(available)} workers: "
|
||||
f"{', '.join(available.keys())}"
|
||||
)
|
||||
|
||||
return self._parallel_provider
|
||||
|
||||
def use_parallel(self) -> bool:
|
||||
"""Check if parallel execution should be used (multiple providers available)"""
|
||||
return len(self.get_all_available()) > 1
|
||||
|
||||
|
||||
# Global instance
|
||||
ai_provider_factory = AIProviderFactory()
|
||||
@@ -1,256 +0,0 @@
|
||||
"""
|
||||
AI Service - Unified interface for AI providers with caching
|
||||
"""
|
||||
import logging
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
from threading import Lock
|
||||
|
||||
from config import settings
|
||||
from core import AIProcessingError
|
||||
from .ai.provider_factory import AIProviderFactory, ai_provider_factory
|
||||
|
||||
|
||||
class LRUCache:
|
||||
"""Thread-safe LRU Cache implementation"""
|
||||
|
||||
def __init__(self, max_size: int = 100, ttl: int = 3600):
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
self._cache: Dict[str, tuple[str, float]] = {}
|
||||
self._order: list[str] = []
|
||||
self._lock = Lock()
|
||||
|
||||
def _is_expired(self, timestamp: float) -> bool:
|
||||
return (time.time() - timestamp) > self.ttl
|
||||
|
||||
def get(self, key: str) -> Optional[str]:
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
return None
|
||||
value, timestamp = self._cache[key]
|
||||
if self._is_expired(timestamp):
|
||||
del self._cache[key]
|
||||
self._order.remove(key)
|
||||
return None
|
||||
# Move to end (most recently used)
|
||||
self._order.remove(key)
|
||||
self._order.append(key)
|
||||
return value
|
||||
|
||||
def set(self, key: str, value: str) -> None:
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
self._order.remove(key)
|
||||
elif len(self._order) >= self.max_size:
|
||||
# Remove least recently used
|
||||
oldest = self._order.pop(0)
|
||||
del self._cache[oldest]
|
||||
self._cache[key] = (value, time.time())
|
||||
self._order.append(key)
|
||||
|
||||
def stats(self) -> Dict[str, int]:
|
||||
with self._lock:
|
||||
return {
|
||||
"size": len(self._cache),
|
||||
"max_size": self.max_size,
|
||||
"hits": sum(1 for _, t in self._cache.values() if not self._is_expired(t))
|
||||
}
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token bucket rate limiter"""
|
||||
|
||||
def __init__(self, rate: float = 10, capacity: int = 20):
|
||||
self.rate = rate # tokens per second
|
||||
self.capacity = capacity
|
||||
self.tokens = capacity
|
||||
self.last_update = time.time()
|
||||
self._lock = Lock()
|
||||
|
||||
def acquire(self, tokens: int = 1) -> float:
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
elapsed = now - self.last_update
|
||||
self.last_update = now
|
||||
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
|
||||
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return 0.0
|
||||
|
||||
wait_time = (tokens - self.tokens) / self.rate
|
||||
self.tokens = 0
|
||||
return wait_time
|
||||
|
||||
|
||||
class AIService:
|
||||
"""Unified service for AI operations with caching and rate limiting"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._factory: Optional[AIProviderFactory] = None
|
||||
self._prompt_cache = LRUCache(max_size=100, ttl=3600) # 1 hour TTL
|
||||
self._rate_limiter = RateLimiter(rate=15, capacity=30)
|
||||
self._stats = {
|
||||
"total_requests": 0,
|
||||
"cache_hits": 0,
|
||||
"api_calls": 0
|
||||
}
|
||||
|
||||
@property
|
||||
def factory(self) -> AIProviderFactory:
|
||||
"""Lazy initialization of provider factory"""
|
||||
if self._factory is None:
|
||||
self._factory = ai_provider_factory
|
||||
return self._factory
|
||||
|
||||
def _get_cache_key(self, prompt: str, operation: str) -> str:
|
||||
"""Generate cache key from prompt and operation"""
|
||||
content = f"{operation}:{prompt[:500]}" # Limit prompt length
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
provider: Optional[str] = None,
|
||||
max_tokens: int = 4096
|
||||
) -> str:
|
||||
"""Generate text using AI provider with caching"""
|
||||
self._stats["total_requests"] += 1
|
||||
|
||||
cache_key = self._get_cache_key(prompt, f"generate:{provider or 'default'}")
|
||||
|
||||
# Check cache
|
||||
cached_result = self._prompt_cache.get(cache_key)
|
||||
if cached_result:
|
||||
self._stats["cache_hits"] += 1
|
||||
self.logger.debug(f"Cache hit for generate_text ({len(cached_result)} chars)")
|
||||
return cached_result
|
||||
|
||||
# Apply rate limiting
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
try:
|
||||
self._stats["api_calls"] += 1
|
||||
ai_provider = self.factory.get_provider(provider or 'gemini')
|
||||
result = ai_provider.generate(prompt, max_tokens=max_tokens)
|
||||
|
||||
# Cache result
|
||||
self._prompt_cache.set(cache_key, result)
|
||||
|
||||
return result
|
||||
except AIProcessingError as e:
|
||||
self.logger.error(f"AI generation failed: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
def summarize(self, text: str, **kwargs) -> str:
|
||||
"""Generate summary of text with caching"""
|
||||
self._stats["total_requests"] += 1
|
||||
|
||||
cache_key = self._get_cache_key(text, "summarize")
|
||||
|
||||
cached_result = self._prompt_cache.get(cache_key)
|
||||
if cached_result:
|
||||
self._stats["cache_hits"] += 1
|
||||
self.logger.debug(f"Cache hit for summarize ({len(cached_result)} chars)")
|
||||
return cached_result
|
||||
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
try:
|
||||
self._stats["api_calls"] += 1
|
||||
provider = self.factory.get_best_provider()
|
||||
result = provider.summarize(text, **kwargs)
|
||||
|
||||
self._prompt_cache.set(cache_key, result)
|
||||
return result
|
||||
except AIProcessingError as e:
|
||||
self.logger.error(f"Summarization failed: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
def correct_text(self, text: str, **kwargs) -> str:
|
||||
"""Correct grammar and spelling with caching"""
|
||||
self._stats["total_requests"] += 1
|
||||
|
||||
cache_key = self._get_cache_key(text, "correct")
|
||||
|
||||
cached_result = self._prompt_cache.get(cache_key)
|
||||
if cached_result:
|
||||
self._stats["cache_hits"] += 1
|
||||
return cached_result
|
||||
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
try:
|
||||
self._stats["api_calls"] += 1
|
||||
provider = self.factory.get_best_provider()
|
||||
result = provider.correct_text(text, **kwargs)
|
||||
|
||||
self._prompt_cache.set(cache_key, result)
|
||||
return result
|
||||
except AIProcessingError as e:
|
||||
self.logger.error(f"Text correction failed: {e}")
|
||||
return text
|
||||
|
||||
def classify_content(self, text: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Classify content into categories with caching"""
|
||||
self._stats["total_requests"] += 1
|
||||
|
||||
# For classification, use a shorter text for cache key
|
||||
short_text = text[:200]
|
||||
cache_key = self._get_cache_key(short_text, "classify")
|
||||
|
||||
cached_result = self._prompt_cache.get(cache_key)
|
||||
if cached_result:
|
||||
self._stats["cache_hits"] += 1
|
||||
import json
|
||||
return json.loads(cached_result)
|
||||
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
try:
|
||||
self._stats["api_calls"] += 1
|
||||
provider = self.factory.get_best_provider()
|
||||
result = provider.classify_content(text, **kwargs)
|
||||
|
||||
import json
|
||||
self._prompt_cache.set(cache_key, json.dumps(result))
|
||||
return result
|
||||
except AIProcessingError as e:
|
||||
self.logger.error(f"Classification failed: {e}")
|
||||
return {"category": "otras_clases", "confidence": 0.0}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get service statistics"""
|
||||
cache_stats = self._prompt_cache.stats()
|
||||
hit_rate = (self._stats["cache_hits"] / self._stats["total_requests"] * 100) if self._stats["total_requests"] > 0 else 0
|
||||
|
||||
return {
|
||||
**self._stats,
|
||||
"cache_size": cache_stats["size"],
|
||||
"cache_max_size": cache_stats["max_size"],
|
||||
"cache_hit_rate": round(hit_rate, 2),
|
||||
"rate_limiter": {
|
||||
"tokens": self._rate_limiter.tokens,
|
||||
"capacity": self._rate_limiter.capacity
|
||||
}
|
||||
}
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the prompt cache"""
|
||||
self._prompt_cache = LRUCache(max_size=100, ttl=3600)
|
||||
self.logger.info("AI service cache cleared")
|
||||
|
||||
|
||||
# Global instance
|
||||
ai_service = AIService()
|
||||
158
services/ai_summary_service.py
Normal file
158
services/ai_summary_service.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""AI Summary Service using Anthropic/Z.AI API (GLM)."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AISummaryService:
|
||||
"""Service for AI-powered text summarization using Anthropic/Z.AI API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auth_token: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
timeout: int = 120,
|
||||
) -> None:
|
||||
"""Initialize the AI Summary Service.
|
||||
|
||||
Args:
|
||||
auth_token: API authentication token. Defaults to ANTHROPIC_AUTH_TOKEN env var.
|
||||
base_url: API base URL. Defaults to ANTHROPIC_BASE_URL env var.
|
||||
model: Model identifier. Defaults to ANTHROPIC_MODEL env var.
|
||||
timeout: Request timeout in seconds. Defaults to 120.
|
||||
"""
|
||||
self.auth_token = auth_token or os.getenv("ANTHROPIC_AUTH_TOKEN")
|
||||
# Normalize base_url: remove /anthropic suffix if present
|
||||
raw_base_url = base_url or os.getenv("ANTHROPIC_BASE_URL")
|
||||
if raw_base_url and raw_base_url.endswith("/anthropic"):
|
||||
raw_base_url = raw_base_url[:-len("/anthropic")]
|
||||
self.base_url = raw_base_url
|
||||
self.model = model or os.getenv("ANTHROPIC_MODEL", "glm-4")
|
||||
self.timeout = timeout
|
||||
self._available = bool(self.auth_token and self.base_url)
|
||||
|
||||
if self._available:
|
||||
logger.info(
|
||||
"AISummaryService initialized with model=%s, base_url=%s",
|
||||
self.model,
|
||||
self.base_url,
|
||||
)
|
||||
else:
|
||||
logger.debug("AISummaryService: no configuration found, running in silent mode")
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if the service is properly configured."""
|
||||
return self._available
|
||||
|
||||
def summarize(self, text: str, prompt_template: Optional[str] = None) -> str:
|
||||
"""Summarize the given text using the AI API.
|
||||
|
||||
Args:
|
||||
text: The text to summarize.
|
||||
prompt_template: Optional custom prompt template. If None, uses default.
|
||||
|
||||
Returns:
|
||||
The summarized text.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the service is not configured.
|
||||
requests.RequestException: If the API call fails.
|
||||
"""
|
||||
if not self._available:
|
||||
logger.debug("AISummaryService not configured, returning original text")
|
||||
return text
|
||||
|
||||
default_prompt = "Resume el siguiente texto de manera clara y concisa:"
|
||||
prompt = prompt_template.format(text=text) if prompt_template else f"{default_prompt}\n\n{text}"
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.auth_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
logger.debug("Calling AI API for summarization (text length: %d)", len(text))
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
summary = result.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
logger.info("Summarization completed successfully (output length: %d)", len(summary))
|
||||
return summary
|
||||
|
||||
except requests.Timeout:
|
||||
logger.error("AI API request timed out after %d seconds", self.timeout)
|
||||
raise requests.RequestException(f"Request timed out after {self.timeout}s") from None
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error("AI API request failed: %s", str(e))
|
||||
raise
|
||||
|
||||
def fix_latex(self, text: str) -> str:
|
||||
"""Fix LaTeX formatting issues in the given text.
|
||||
|
||||
Args:
|
||||
text: The text containing LaTeX to fix.
|
||||
|
||||
Returns:
|
||||
The text with corrected LaTeX formatting.
|
||||
"""
|
||||
if not self._available:
|
||||
logger.debug("AISummaryService not configured, returning original text")
|
||||
return text
|
||||
|
||||
prompt = (
|
||||
"Corrige los errores de formato LaTeX en el siguiente texto. "
|
||||
"Mantén el contenido pero corrige la sintaxis de LaTeX:\n\n"
|
||||
f"{text}"
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.3,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.auth_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
logger.debug("Calling AI API for LaTeX fixing (text length: %d)", len(text))
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
fixed = result.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
logger.info("LaTeX fixing completed successfully")
|
||||
return fixed
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error("LaTeX fixing failed: %s", str(e))
|
||||
return text
|
||||
@@ -1,247 +0,0 @@
|
||||
"""
|
||||
GPU Detection and Management Service
|
||||
|
||||
Provides unified interface for detecting and using NVIDIA (CUDA), AMD (ROCm), or CPU.
|
||||
Fallback order: NVIDIA -> AMD -> CPU
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import shutil
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import torch
|
||||
try:
|
||||
import torch
|
||||
TORCH_AVAILABLE = True
|
||||
except ImportError:
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
|
||||
class GPUType(Enum):
|
||||
"""Supported GPU types"""
|
||||
NVIDIA = "nvidia"
|
||||
AMD = "amd"
|
||||
CPU = "cpu"
|
||||
|
||||
|
||||
class GPUDetector:
|
||||
"""
|
||||
Service for detecting and managing GPU resources.
|
||||
|
||||
Detects GPU type with fallback order: NVIDIA -> AMD -> CPU
|
||||
Provides unified interface regardless of GPU vendor.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._gpu_type: Optional[GPUType] = None
|
||||
self._device: Optional[str] = None
|
||||
self._initialized: bool = False
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize GPU detection"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._gpu_type = self._detect_gpu_type()
|
||||
self._device = self._get_device_string()
|
||||
self._setup_environment()
|
||||
self._initialized = True
|
||||
|
||||
logger.info(f"GPU Detector initialized: {self._gpu_type.value} -> {self._device}")
|
||||
|
||||
def _detect_gpu_type(self) -> GPUType:
|
||||
"""
|
||||
Detect available GPU type.
|
||||
Order: NVIDIA -> AMD -> CPU
|
||||
"""
|
||||
# Check user preference first
|
||||
preference = os.getenv("GPU_PREFERENCE", "auto").lower()
|
||||
if preference == "cpu":
|
||||
logger.info("GPU preference set to CPU, skipping GPU detection")
|
||||
return GPUType.CPU
|
||||
|
||||
if not TORCH_AVAILABLE:
|
||||
logger.warning("PyTorch not available, using CPU")
|
||||
return GPUType.CPU
|
||||
|
||||
# Check NVIDIA first
|
||||
if preference in ("auto", "nvidia"):
|
||||
if self._check_nvidia():
|
||||
logger.info("NVIDIA GPU detected via nvidia-smi")
|
||||
return GPUType.NVIDIA
|
||||
|
||||
# Check AMD second
|
||||
if preference in ("auto", "amd"):
|
||||
if self._check_amd():
|
||||
logger.info("AMD GPU detected via ROCm")
|
||||
return GPUType.AMD
|
||||
|
||||
# Fallback to checking torch.cuda (works for both NVIDIA and ROCm)
|
||||
if torch.cuda.is_available():
|
||||
device_name = torch.cuda.get_device_name(0).lower()
|
||||
if "nvidia" in device_name or "geforce" in device_name or "rtx" in device_name or "gtx" in device_name:
|
||||
return GPUType.NVIDIA
|
||||
elif "amd" in device_name or "radeon" in device_name or "rx" in device_name:
|
||||
return GPUType.AMD
|
||||
else:
|
||||
# Unknown GPU vendor but CUDA works
|
||||
logger.warning(f"Unknown GPU vendor: {device_name}, treating as NVIDIA-compatible")
|
||||
return GPUType.NVIDIA
|
||||
|
||||
logger.info("No GPU detected, using CPU")
|
||||
return GPUType.CPU
|
||||
|
||||
def _check_nvidia(self) -> bool:
|
||||
"""Check if NVIDIA GPU is available using nvidia-smi"""
|
||||
nvidia_smi = shutil.which("nvidia-smi")
|
||||
if not nvidia_smi:
|
||||
return False
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[nvidia_smi, "--query-gpu=name", "--format=csv,noheader"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
return result.returncode == 0 and result.stdout.strip()
|
||||
except Exception as e:
|
||||
logger.debug(f"nvidia-smi check failed: {e}")
|
||||
return False
|
||||
|
||||
def _check_amd(self) -> bool:
|
||||
"""Check if AMD GPU is available using rocm-smi"""
|
||||
rocm_smi = shutil.which("rocm-smi")
|
||||
if not rocm_smi:
|
||||
return False
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[rocm_smi, "--showproductname"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
return result.returncode == 0 and "GPU" in result.stdout
|
||||
except Exception as e:
|
||||
logger.debug(f"rocm-smi check failed: {e}")
|
||||
return False
|
||||
|
||||
def _setup_environment(self) -> None:
|
||||
"""Set up environment variables for detected GPU"""
|
||||
if self._gpu_type == GPUType.AMD:
|
||||
# Set HSA override for AMD RX 6000 series (gfx1030)
|
||||
hsa_version = os.getenv("HSA_OVERRIDE_GFX_VERSION", "10.3.0")
|
||||
os.environ.setdefault("HSA_OVERRIDE_GFX_VERSION", hsa_version)
|
||||
logger.info(f"Set HSA_OVERRIDE_GFX_VERSION={hsa_version}")
|
||||
|
||||
def _get_device_string(self) -> str:
|
||||
"""Get PyTorch device string"""
|
||||
if self._gpu_type in (GPUType.NVIDIA, GPUType.AMD):
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
@property
|
||||
def gpu_type(self) -> GPUType:
|
||||
"""Get detected GPU type"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return self._gpu_type
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
"""Get device string for PyTorch"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return self._device
|
||||
|
||||
def get_device(self) -> "torch.device":
|
||||
"""Get PyTorch device object"""
|
||||
if not TORCH_AVAILABLE:
|
||||
raise RuntimeError("PyTorch not available")
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return torch.device(self._device)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if GPU is available"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return self._gpu_type in (GPUType.NVIDIA, GPUType.AMD)
|
||||
|
||||
def is_nvidia(self) -> bool:
|
||||
"""Check if NVIDIA GPU is being used"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return self._gpu_type == GPUType.NVIDIA
|
||||
|
||||
def is_amd(self) -> bool:
|
||||
"""Check if AMD GPU is being used"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return self._gpu_type == GPUType.AMD
|
||||
|
||||
def is_cpu(self) -> bool:
|
||||
"""Check if CPU is being used"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return self._gpu_type == GPUType.CPU
|
||||
|
||||
def get_device_name(self) -> str:
|
||||
"""Get GPU device name"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
|
||||
if self._gpu_type == GPUType.CPU:
|
||||
return "CPU"
|
||||
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
return torch.cuda.get_device_name(0)
|
||||
|
||||
return "Unknown"
|
||||
|
||||
def get_memory_info(self) -> Dict[str, Any]:
|
||||
"""Get GPU memory information"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
|
||||
if self._gpu_type == GPUType.CPU:
|
||||
return {"type": "cpu", "error": "No GPU available"}
|
||||
|
||||
if not TORCH_AVAILABLE or not torch.cuda.is_available():
|
||||
return {"type": self._gpu_type.value, "error": "CUDA not available"}
|
||||
|
||||
try:
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
total = props.total_memory / 1024**3
|
||||
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
||||
reserved = torch.cuda.memory_reserved(0) / 1024**3
|
||||
|
||||
return {
|
||||
"type": self._gpu_type.value,
|
||||
"device_name": props.name,
|
||||
"total_gb": round(total, 2),
|
||||
"allocated_gb": round(allocated, 2),
|
||||
"reserved_gb": round(reserved, 2),
|
||||
"free_gb": round(total - allocated, 2),
|
||||
"usage_percent": round((allocated / total) * 100, 1)
|
||||
}
|
||||
except Exception as e:
|
||||
return {"type": self._gpu_type.value, "error": str(e)}
|
||||
|
||||
def empty_cache(self) -> None:
|
||||
"""Clear GPU memory cache"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
logger.debug("GPU cache cleared")
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
gpu_detector = GPUDetector()
|
||||
@@ -1,137 +0,0 @@
|
||||
"""
|
||||
Performance metrics collector for CBCFacil
|
||||
"""
|
||||
import time
|
||||
import threading
|
||||
import psutil
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class MetricsCollector:
|
||||
"""Collect and aggregate performance metrics"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._start_time = time.time()
|
||||
self._request_count = 0
|
||||
self._error_count = 0
|
||||
self._total_latency = 0.0
|
||||
self._latencies = []
|
||||
self._lock = threading.Lock()
|
||||
self._process = psutil.Process()
|
||||
|
||||
def record_request(self, latency: float, success: bool = True) -> None:
|
||||
"""Record a request with latency"""
|
||||
with self._lock:
|
||||
self._request_count += 1
|
||||
self._total_latency += latency
|
||||
self._latencies.append(latency)
|
||||
|
||||
# Keep only last 1000 latencies for memory efficiency
|
||||
if len(self._latencies) > 1000:
|
||||
self._latencies = self._latencies[-1000:]
|
||||
|
||||
if not success:
|
||||
self._error_count += 1
|
||||
|
||||
def get_latency_percentiles(self) -> Dict[str, float]:
|
||||
"""Calculate latency percentiles"""
|
||||
with self._lock:
|
||||
if not self._latencies:
|
||||
return {"p50": 0, "p95": 0, "p99": 0}
|
||||
|
||||
sorted_latencies = sorted(self._latencies)
|
||||
n = len(sorted_latencies)
|
||||
|
||||
return {
|
||||
"p50": sorted_latencies[int(n * 0.50)],
|
||||
"p95": sorted_latencies[int(n * 0.95)],
|
||||
"p99": sorted_latencies[int(n * 0.99)]
|
||||
}
|
||||
|
||||
def get_system_metrics(self) -> Dict[str, Any]:
|
||||
"""Get system resource metrics"""
|
||||
try:
|
||||
memory = self._process.memory_info()
|
||||
cpu_percent = self._process.cpu_percent(interval=0.1)
|
||||
|
||||
return {
|
||||
"cpu_percent": cpu_percent,
|
||||
"memory_rss_mb": memory.rss / 1024 / 1024,
|
||||
"memory_vms_mb": memory.vms / 1024 / 1024,
|
||||
"thread_count": self._process.num_threads(),
|
||||
"open_files": self._process.open_files(),
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error getting system metrics: {e}")
|
||||
return {}
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""Get metrics summary"""
|
||||
with self._lock:
|
||||
uptime = time.time() - self._start_time
|
||||
latency_pcts = self.get_latency_percentiles()
|
||||
|
||||
return {
|
||||
"uptime_seconds": round(uptime, 2),
|
||||
"total_requests": self._request_count,
|
||||
"error_count": self._error_count,
|
||||
"error_rate": round(self._error_count / max(1, self._request_count) * 100, 2),
|
||||
"requests_per_second": round(self._request_count / max(1, uptime), 2),
|
||||
"average_latency_ms": round(self._total_latency / max(1, self._request_count) * 1000, 2),
|
||||
"latency_p50_ms": round(latency_pcts["p50"] * 1000, 2),
|
||||
"latency_p95_ms": round(latency_pcts["p95"] * 1000, 2),
|
||||
"latency_p99_ms": round(latency_pcts["p99"] * 1000, 2),
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset metrics"""
|
||||
with self._lock:
|
||||
self._request_count = 0
|
||||
self._error_count = 0
|
||||
self._total_latency = 0.0
|
||||
self._latencies = []
|
||||
self._start_time = time.time()
|
||||
|
||||
|
||||
class LatencyTracker:
|
||||
"""Context manager for tracking operation latency"""
|
||||
|
||||
def __init__(self, collector: MetricsCollector, operation: str):
|
||||
self.collector = collector
|
||||
self.operation = operation
|
||||
self.start_time: Optional[float] = None
|
||||
self.success = True
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
latency = time.time() - self.start_time
|
||||
success = exc_type is None
|
||||
self.collector.record_request(latency, success)
|
||||
return False # Don't suppress exceptions
|
||||
|
||||
|
||||
# Global metrics collector
|
||||
metrics_collector = MetricsCollector()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def track_latency(operation: str = "unknown"):
|
||||
"""Convenience function for latency tracking"""
|
||||
with LatencyTracker(metrics_collector, operation):
|
||||
yield
|
||||
|
||||
|
||||
def get_performance_report() -> Dict[str, Any]:
|
||||
"""Generate comprehensive performance report"""
|
||||
return {
|
||||
"metrics": metrics_collector.get_summary(),
|
||||
"system": metrics_collector.get_system_metrics(),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
@@ -1,353 +0,0 @@
|
||||
"""
|
||||
Notion integration service with official SDK
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
try:
|
||||
from notion_client import Client
|
||||
from notion_client.errors import APIResponseError
|
||||
|
||||
NOTION_AVAILABLE = True
|
||||
except ImportError:
|
||||
NOTION_AVAILABLE = False
|
||||
Client = None
|
||||
APIResponseError = Exception
|
||||
|
||||
from config import settings
|
||||
|
||||
|
||||
class NotionService:
|
||||
"""Enhanced Notion API integration service"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._client: Optional[Client] = None
|
||||
self._database_id: Optional[str] = None
|
||||
|
||||
def configure(self, token: str, database_id: str) -> None:
|
||||
"""Configure Notion with official SDK"""
|
||||
if not NOTION_AVAILABLE:
|
||||
self.logger.error(
|
||||
"notion-client not installed. Install with: pip install notion-client"
|
||||
)
|
||||
return
|
||||
|
||||
self._client = Client(auth=token)
|
||||
self._database_id = database_id
|
||||
self.logger.info("Notion service configured with official SDK")
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if Notion is configured"""
|
||||
return bool(self._client and self._database_id and NOTION_AVAILABLE)
|
||||
|
||||
def _rate_limited_request(self, func, *args, **kwargs):
|
||||
"""Execute request with rate limiting and retry"""
|
||||
max_retries = 3
|
||||
base_delay = 1
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except APIResponseError as e:
|
||||
if hasattr(e, "code") and e.code == "rate_limited":
|
||||
delay = base_delay * (2**attempt)
|
||||
self.logger.warning(f"Rate limited by Notion, waiting {delay}s")
|
||||
time.sleep(delay)
|
||||
else:
|
||||
raise
|
||||
|
||||
raise Exception("Max retries exceeded for Notion API")
|
||||
|
||||
def create_page_with_summary(
|
||||
self, title: str, summary: str, metadata: Dict[str, Any]
|
||||
) -> Optional[str]:
|
||||
"""Create a new page in Notion (database or parent page) with summary content"""
|
||||
if not self.is_configured:
|
||||
self.logger.warning("Notion not configured, skipping upload")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Determinar si es database o página padre
|
||||
use_as_page = metadata.get("use_as_page", False)
|
||||
|
||||
if use_as_page:
|
||||
# Crear página dentro de otra página
|
||||
page = self._rate_limited_request(
|
||||
self._client.pages.create,
|
||||
parent={"page_id": self._database_id},
|
||||
properties={"title": [{"text": {"content": title[:100]}}]},
|
||||
)
|
||||
else:
|
||||
# Crear página en database (método original)
|
||||
properties = {"Name": {"title": [{"text": {"content": title[:100]}}]}}
|
||||
|
||||
# Agregar status si la DB lo soporta
|
||||
if metadata.get("add_status", True):
|
||||
properties["Status"] = {"select": {"name": "Procesado"}}
|
||||
|
||||
# Agregar tipo de archivo si está disponible Y add_status está habilitado
|
||||
if metadata.get("add_status", False) and metadata.get("file_type"):
|
||||
properties["Tipo"] = {
|
||||
"select": {" name": metadata["file_type"].upper()}
|
||||
}
|
||||
|
||||
page = self._rate_limited_request(
|
||||
self._client.pages.create,
|
||||
parent={"database_id": self._database_id},
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
page_id = page["id"]
|
||||
self.logger.info(f"✅ Notion page created: {page_id}")
|
||||
|
||||
# Agregar contenido del resumen como bloques
|
||||
self._add_summary_content(page_id, summary, metadata.get("pdf_path"))
|
||||
|
||||
return page_id
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ Error creating Notion page: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Preparar properties de la página
|
||||
properties = {
|
||||
"Name": {
|
||||
"title": [
|
||||
{
|
||||
"text": {
|
||||
"content": title[:100] # Notion limit
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# Agregar status si la DB lo soporta
|
||||
if metadata.get("add_status", True):
|
||||
properties["Status"] = {"select": {"name": "Procesado"}}
|
||||
|
||||
# Agregar tipo de archivo si está disponible
|
||||
if metadata.get("file_type"):
|
||||
properties["Tipo"] = {"select": {"name": metadata["file_type"].upper()}}
|
||||
|
||||
# Crear página
|
||||
page = self._rate_limited_request(
|
||||
self._client.pages.create,
|
||||
parent={"database_id": self._database_id},
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
page_id = page["id"]
|
||||
self.logger.info(f"✅ Notion page created: {page_id}")
|
||||
|
||||
# Agregar contenido del resumen como bloques
|
||||
self._add_summary_content(page_id, summary, metadata.get("pdf_path"))
|
||||
|
||||
return page_id
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ Error creating Notion page: {e}")
|
||||
return None
|
||||
|
||||
def _add_summary_content(
|
||||
self, page_id: str, summary: str, pdf_path: Optional[Path] = None
|
||||
) -> bool:
|
||||
"""Add summary content as Notion blocks"""
|
||||
try:
|
||||
blocks = []
|
||||
|
||||
# Agregar nota sobre el PDF si existe
|
||||
if pdf_path and pdf_path.exists():
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "callout",
|
||||
"callout": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {
|
||||
"content": f"📄 Documento generado automáticamente: {pdf_path.name}"
|
||||
},
|
||||
}
|
||||
],
|
||||
"icon": {"emoji": "📄"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Agregar bloques del resumen
|
||||
summary_blocks = self._parse_markdown_to_blocks(summary)
|
||||
blocks.extend(summary_blocks)
|
||||
|
||||
# Agregar footer
|
||||
blocks.append({"object": "block", "type": "divider", "divider": {}})
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "paragraph",
|
||||
"paragraph": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {
|
||||
"content": f"Generado por CBCFacil el {datetime.now().strftime('%d/%m/%Y %H:%M')}"
|
||||
},
|
||||
"annotations": {"italic": True, "color": "gray"},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Notion API limita a 100 bloques por request
|
||||
if blocks:
|
||||
for i in range(0, len(blocks), 100):
|
||||
batch = blocks[i : i + 100]
|
||||
self._rate_limited_request(
|
||||
self._client.blocks.children.append,
|
||||
block_id=page_id,
|
||||
children=batch,
|
||||
)
|
||||
self.logger.info(f"✅ Added {len(blocks)} blocks to Notion page")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ Error adding content blocks: {e}")
|
||||
return False
|
||||
|
||||
def _parse_markdown_to_blocks(self, markdown: str) -> List[Dict]:
|
||||
"""Convert markdown to Notion blocks"""
|
||||
blocks = []
|
||||
lines = markdown.split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Headings
|
||||
if line.startswith("# "):
|
||||
text = line[2:].strip()[:2000]
|
||||
if text:
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "heading_1",
|
||||
"heading_1": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": text}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
elif line.startswith("## "):
|
||||
text = line[3:].strip()[:2000]
|
||||
if text:
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "heading_2",
|
||||
"heading_2": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": text}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
elif line.startswith("### "):
|
||||
text = line[4:].strip()[:2000]
|
||||
if text:
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "heading_3",
|
||||
"heading_3": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": text}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Bullet points
|
||||
elif line.startswith("- ") or line.startswith("* "):
|
||||
text = line[2:].strip()[:2000]
|
||||
if text:
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "bulleted_list_item",
|
||||
"bulleted_list_item": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": text}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Divider
|
||||
elif line.strip() == "---":
|
||||
blocks.append({"object": "block", "type": "divider", "divider": {}})
|
||||
# Paragraph (skip footer lines)
|
||||
elif not line.startswith("*Generado por"):
|
||||
text = line[:2000]
|
||||
if text:
|
||||
blocks.append(
|
||||
{
|
||||
"object": "block",
|
||||
"type": "paragraph",
|
||||
"paragraph": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": text}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return blocks
|
||||
|
||||
def upload_pdf_legacy(self, pdf_path: Path, title: str) -> bool:
|
||||
"""Legacy method - creates simple page (backward compatibility)"""
|
||||
if not self.is_configured:
|
||||
self.logger.warning("Notion not configured, skipping upload")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Crear página simple
|
||||
page_id = self.create_page_with_summary(
|
||||
title=title,
|
||||
summary=f"Documento procesado: {title}",
|
||||
metadata={"file_type": "PDF", "pdf_path": pdf_path},
|
||||
)
|
||||
|
||||
return bool(page_id)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error uploading PDF to Notion: {e}")
|
||||
return False
|
||||
|
||||
# Alias para backward compatibility
|
||||
def upload_pdf(self, pdf_path: Path, title: str) -> bool:
|
||||
"""Upload PDF info to Notion (alias for backward compatibility)"""
|
||||
return self.upload_pdf_legacy(pdf_path, title)
|
||||
|
||||
def upload_pdf_as_file(self, pdf_path: Path, title: str) -> bool:
|
||||
"""Upload PDF info as file (alias for backward compatibility)"""
|
||||
return self.upload_pdf_legacy(pdf_path, title)
|
||||
|
||||
|
||||
# Global instance
|
||||
notion_service = NotionService()
|
||||
|
||||
|
||||
def upload_to_notion(pdf_path: Path, title: str) -> bool:
|
||||
"""Legacy function for backward compatibility"""
|
||||
return notion_service.upload_pdf(pdf_path, title)
|
||||
@@ -1,203 +0,0 @@
|
||||
"""
|
||||
Notion integration service
|
||||
"""
|
||||
import logging
|
||||
import base64
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import requests
|
||||
REQUESTS_AVAILABLE = True
|
||||
except ImportError:
|
||||
REQUESTS_AVAILABLE = False
|
||||
requests = None
|
||||
|
||||
from config import settings
|
||||
|
||||
|
||||
class NotionService:
|
||||
"""Service for Notion API integration"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._token: Optional[str] = None
|
||||
self._database_id: Optional[str] = None
|
||||
self._base_url = "https://api.notion.com/v1"
|
||||
|
||||
def configure(self, token: str, database_id: str) -> None:
|
||||
"""Configure Notion credentials"""
|
||||
self._token = token
|
||||
self._database_id = database_id
|
||||
self.logger.info("Notion service configured")
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if Notion is configured"""
|
||||
return bool(self._token and self._database_id)
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
"""Get headers for Notion API requests"""
|
||||
return {
|
||||
"Authorization": f"Bearer {self._token}",
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28"
|
||||
}
|
||||
|
||||
def upload_pdf(self, pdf_path: Path, title: str) -> bool:
|
||||
"""Upload PDF to Notion database"""
|
||||
if not self.is_configured:
|
||||
self.logger.warning("Notion not configured, skipping upload")
|
||||
return False
|
||||
|
||||
if not REQUESTS_AVAILABLE:
|
||||
self.logger.error("requests library not available for Notion upload")
|
||||
return False
|
||||
|
||||
if not pdf_path.exists():
|
||||
self.logger.error(f"PDF file not found: {pdf_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Read and encode PDF
|
||||
with open(pdf_path, 'rb') as f:
|
||||
pdf_data = base64.b64encode(f.read()).decode('utf-8')
|
||||
|
||||
# Prepare the page data
|
||||
page_data = {
|
||||
"parent": {"database_id": self._database_id},
|
||||
"properties": {
|
||||
"Name": {
|
||||
"title": [
|
||||
{
|
||||
"text": {
|
||||
"content": title
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"Status": {
|
||||
"select": {
|
||||
"name": "Procesado"
|
||||
}
|
||||
}
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"object": "block",
|
||||
"type": "paragraph",
|
||||
"paragraph": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {
|
||||
"content": f"Documento generado automáticamente: {title}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"object": "block",
|
||||
"type": "file",
|
||||
"file": {
|
||||
"type": "external",
|
||||
"external": {
|
||||
"url": f"data:application/pdf;base64,{pdf_data}"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Create page in database
|
||||
response = requests.post(
|
||||
f"{self._base_url}/pages",
|
||||
headers=self._get_headers(),
|
||||
json=page_data,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
self.logger.info(f"PDF uploaded to Notion successfully: {title}")
|
||||
return True
|
||||
else:
|
||||
self.logger.error(f"Notion API error: {response.status_code} - {response.text}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error uploading PDF to Notion: {e}")
|
||||
return False
|
||||
|
||||
def upload_pdf_as_file(self, pdf_path: Path, title: str) -> bool:
|
||||
"""Upload PDF as a file block (alternative method)"""
|
||||
if not self.is_configured:
|
||||
self.logger.warning("Notion not configured, skipping upload")
|
||||
return False
|
||||
|
||||
if not REQUESTS_AVAILABLE:
|
||||
self.logger.error("requests library not available for Notion upload")
|
||||
return False
|
||||
|
||||
if not pdf_path.exists():
|
||||
self.logger.error(f"PDF file not found: {pdf_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# For simplicity, we'll create a page with just the title and a link placeholder
|
||||
# In a real implementation, you'd need to upload the file to Notion's file storage
|
||||
page_data = {
|
||||
"parent": {"database_id": self._database_id},
|
||||
"properties": {
|
||||
"Name": {
|
||||
"title": [
|
||||
{
|
||||
"text": {
|
||||
"content": title
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"Status": {
|
||||
"select": {
|
||||
"name": "Procesado"
|
||||
}
|
||||
},
|
||||
"File Path": {
|
||||
"rich_text": [
|
||||
{
|
||||
"text": {
|
||||
"content": str(pdf_path)
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self._base_url}/pages",
|
||||
headers=self._get_headers(),
|
||||
json=page_data,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
self.logger.info(f"PDF uploaded to Notion successfully: {title}")
|
||||
return True
|
||||
else:
|
||||
self.logger.error(f"Notion API error: {response.status_code} - {response.text}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error uploading PDF to Notion: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# Global instance
|
||||
notion_service = NotionService()
|
||||
|
||||
|
||||
def upload_to_notion(pdf_path: Path, title: str) -> bool:
|
||||
"""Legacy function for backward compatibility"""
|
||||
return notion_service.upload_pdf(pdf_path, title)
|
||||
270
services/pdf_generator.py
Normal file
270
services/pdf_generator.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Generador de PDFs desde texto y markdown.
|
||||
|
||||
Utiliza reportlab para la generación de PDFs con soporte UTF-8.
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from reportlab.lib import colors
|
||||
from reportlab.lib.pagesizes import A4
|
||||
from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
|
||||
from reportlab.lib.units import cm
|
||||
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PDFGenerator:
|
||||
"""Generador de PDFs desde texto plano o markdown."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Inicializa el generador de PDFs."""
|
||||
self._styles = getSampleStyleSheet()
|
||||
self._setup_styles()
|
||||
logger.info("PDFGenerator inicializado")
|
||||
|
||||
def _setup_styles(self) -> None:
|
||||
"""Configura los estilos personalizados para el documento."""
|
||||
self._styles.add(
|
||||
ParagraphStyle(
|
||||
name="CustomNormal",
|
||||
parent=self._styles["Normal"],
|
||||
fontSize=11,
|
||||
leading=14,
|
||||
spaceAfter=6,
|
||||
)
|
||||
)
|
||||
self._styles.add(
|
||||
ParagraphStyle(
|
||||
name="CustomHeading1",
|
||||
parent=self._styles["Heading1"],
|
||||
fontSize=18,
|
||||
leading=22,
|
||||
spaceAfter=12,
|
||||
)
|
||||
)
|
||||
self._styles.add(
|
||||
ParagraphStyle(
|
||||
name="CustomHeading2",
|
||||
parent=self._styles["Heading2"],
|
||||
fontSize=14,
|
||||
leading=18,
|
||||
spaceAfter=10,
|
||||
)
|
||||
)
|
||||
|
||||
def _escape_xml(self, text: str) -> str:
|
||||
"""Escapa caracteres especiales para XML/HTML."""
|
||||
return (
|
||||
text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace("\n", "<br/>")
|
||||
)
|
||||
|
||||
def _parse_markdown_basic(self, markdown: str) -> list[Paragraph]:
|
||||
"""
|
||||
Convierte markdown básico a una lista de Paragraphs de reportlab.
|
||||
|
||||
Maneja: encabezados, negritas, italicas, lineas horizontales,
|
||||
y saltos de linea.
|
||||
"""
|
||||
elements: list[Paragraph] = []
|
||||
lines = markdown.split("\n")
|
||||
in_list = False
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
if not line:
|
||||
elements.append(Spacer(1, 0.3 * cm))
|
||||
continue
|
||||
|
||||
# Encabezados
|
||||
if line.startswith("### "):
|
||||
text = self._escape_xml(line[4:])
|
||||
elements.append(
|
||||
Paragraph(f"<b>{text}</b>", self._styles["CustomHeading2"])
|
||||
)
|
||||
elif line.startswith("## "):
|
||||
text = self._escape_xml(line[3:])
|
||||
elements.append(
|
||||
Paragraph(f"<b>{text}</b>", self._styles["CustomHeading1"])
|
||||
)
|
||||
elif line.startswith("# "):
|
||||
text = self._escape_xml(line[2:])
|
||||
elements.append(
|
||||
Paragraph(f"<b><i>{text}</i></b>", self._styles["CustomHeading1"])
|
||||
)
|
||||
# Línea horizontal
|
||||
elif line == "---" or line == "***":
|
||||
elements.append(Spacer(1, 0.2 * cm))
|
||||
# Lista con guiones
|
||||
elif line.startswith("- ") or line.startswith("* "):
|
||||
text = self._escape_xml(line[2:])
|
||||
text = f"• {self._format_inline_markdown(text)}"
|
||||
elements.append(Paragraph(text, self._styles["CustomNormal"]))
|
||||
# Lista numerada
|
||||
elif line[0].isdigit() and ". " in line:
|
||||
idx = line.index(". ")
|
||||
text = self._escape_xml(line[idx + 2 :])
|
||||
text = self._format_inline_markdown(text)
|
||||
elements.append(Paragraph(text, self._styles["CustomNormal"]))
|
||||
# Párrafo normal
|
||||
else:
|
||||
text = self._escape_xml(line)
|
||||
text = self._format_inline_markdown(text)
|
||||
elements.append(Paragraph(text, self._styles["CustomNormal"]))
|
||||
|
||||
return elements
|
||||
|
||||
def _format_inline_markdown(self, text: str) -> str:
|
||||
"""Convierte formato inline de markdown a HTML."""
|
||||
# Negritas: **texto** -> <b>texto</b>
|
||||
while "**" in text:
|
||||
start = text.find("**")
|
||||
end = text.find("**", start + 2)
|
||||
if end == -1:
|
||||
break
|
||||
text = (
|
||||
text[:start]
|
||||
+ f"<b>{text[start+2:end]}</b>"
|
||||
+ text[end + 2 :]
|
||||
)
|
||||
# Italicas: *texto* -> <i>texto</i>
|
||||
while "*" in text:
|
||||
start = text.find("*")
|
||||
end = text.find("*", start + 1)
|
||||
if end == -1:
|
||||
break
|
||||
text = (
|
||||
text[:start]
|
||||
+ f"<i>{text[start+1:end]}</i>"
|
||||
+ text[end + 1 :]
|
||||
)
|
||||
return text
|
||||
|
||||
def markdown_to_pdf(self, markdown_text: str, output_path: Path) -> Path:
|
||||
"""
|
||||
Convierte markdown a PDF.
|
||||
|
||||
Args:
|
||||
markdown_text: Contenido en formato markdown.
|
||||
output_path: Ruta donde se guardará el PDF.
|
||||
|
||||
Returns:
|
||||
Path: Ruta del archivo PDF generado.
|
||||
|
||||
Raises:
|
||||
ValueError: Si el contenido está vacío.
|
||||
IOError: Si hay error al escribir el archivo.
|
||||
"""
|
||||
if not markdown_text or not markdown_text.strip():
|
||||
logger.warning("markdown_to_pdf llamado con contenido vacío")
|
||||
raise ValueError("El contenido markdown no puede estar vacío")
|
||||
|
||||
logger.info(
|
||||
"Convirtiendo markdown a PDF",
|
||||
extra={
|
||||
"content_length": len(markdown_text),
|
||||
"output_path": str(output_path),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Crear documento
|
||||
doc = SimpleDocTemplate(
|
||||
str(output_path),
|
||||
pagesize=A4,
|
||||
leftMargin=2 * cm,
|
||||
rightMargin=2 * cm,
|
||||
topMargin=2 * cm,
|
||||
bottomMargin=2 * cm,
|
||||
)
|
||||
|
||||
# Convertir markdown a elementos
|
||||
elements = self._parse_markdown_basic(markdown_text)
|
||||
|
||||
# Generar PDF
|
||||
doc.build(elements)
|
||||
|
||||
logger.info(
|
||||
"PDF generado exitosamente",
|
||||
extra={"output_path": str(output_path), "pages": "unknown"},
|
||||
)
|
||||
|
||||
return output_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error al generar PDF desde markdown: {e}")
|
||||
raise IOError(f"Error al generar PDF: {e}") from e
|
||||
|
||||
def text_to_pdf(self, text: str, output_path: Path) -> Path:
|
||||
"""
|
||||
Convierte texto plano a PDF.
|
||||
|
||||
Args:
|
||||
text: Contenido de texto plano.
|
||||
output_path: Ruta donde se guardará el PDF.
|
||||
|
||||
Returns:
|
||||
Path: Ruta del archivo PDF generado.
|
||||
|
||||
Raises:
|
||||
ValueError: Si el contenido está vacío.
|
||||
IOError: Si hay error al escribir el archivo.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
logger.warning("text_to_pdf llamado con contenido vacío")
|
||||
raise ValueError("El contenido de texto no puede estar vacío")
|
||||
|
||||
logger.info(
|
||||
"Convirtiendo texto a PDF",
|
||||
extra={
|
||||
"content_length": len(text),
|
||||
"output_path": str(output_path),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Crear documento
|
||||
doc = SimpleDocTemplate(
|
||||
str(output_path),
|
||||
pagesize=A4,
|
||||
leftMargin=2 * cm,
|
||||
rightMargin=2 * cm,
|
||||
topMargin=2 * cm,
|
||||
bottomMargin=2 * cm,
|
||||
)
|
||||
|
||||
# Convertir texto a párrafos (uno por línea)
|
||||
elements: list[Union[Paragraph, Spacer]] = []
|
||||
lines = text.split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
elements.append(Spacer(1, 0.3 * cm))
|
||||
else:
|
||||
escaped = self._escape_xml(line)
|
||||
elements.append(Paragraph(escaped, self._styles["CustomNormal"]))
|
||||
|
||||
# Generar PDF
|
||||
doc.build(elements)
|
||||
|
||||
logger.info(
|
||||
"PDF generado exitosamente",
|
||||
extra={"output_path": str(output_path), "pages": "unknown"},
|
||||
)
|
||||
|
||||
return output_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error al generar PDF desde texto: {e}")
|
||||
raise IOError(f"Error al generar PDF: {e}") from e
|
||||
|
||||
|
||||
# Instancia global del generador
|
||||
pdf_generator = PDFGenerator()
|
||||
@@ -1,91 +1,447 @@
|
||||
"""
|
||||
Telegram notification service
|
||||
Servicio de notificaciones Telegram.
|
||||
|
||||
Envía mensajes al chat configurado mediante la API de Telegram Bot.
|
||||
Silencioso si no está configurado (TELEGRAM_TOKEN y TELEGRAM_CHAT_ID).
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from config import settings
|
||||
|
||||
try:
|
||||
import requests
|
||||
REQUESTS_AVAILABLE = True
|
||||
except ImportError:
|
||||
REQUESTS_AVAILABLE = False
|
||||
import requests
|
||||
|
||||
from config.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _truncate_safely(text: str, max_length: int) -> str:
|
||||
"""
|
||||
Trunca texto sin romper entidades de formato HTML.
|
||||
|
||||
Args:
|
||||
text: Texto a truncar.
|
||||
max_length: Longitud máxima.
|
||||
|
||||
Returns:
|
||||
Texto truncado de forma segura.
|
||||
"""
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
|
||||
# Dejar margen para el sufijo "..."
|
||||
safe_length = max_length - 10
|
||||
|
||||
# Buscar el último espacio o salto de línea antes del límite
|
||||
cut_point = text.rfind("\n", 0, safe_length)
|
||||
if cut_point == -1 or cut_point < safe_length - 100:
|
||||
cut_point = text.rfind(" ", 0, safe_length)
|
||||
if cut_point == -1 or cut_point < safe_length - 50:
|
||||
cut_point = safe_length
|
||||
|
||||
return text[:cut_point] + "..."
|
||||
|
||||
|
||||
class TelegramService:
|
||||
"""Service for sending Telegram notifications"""
|
||||
"""Servicio para enviar notificaciones a Telegram."""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._token: Optional[str] = None
|
||||
self._chat_id: Optional[str] = None
|
||||
self._last_error_cache: dict = {}
|
||||
def __init__(self) -> None:
|
||||
"""Inicializa el servicio si hay configuración de Telegram."""
|
||||
self._token: Optional[str] = settings.TELEGRAM_TOKEN
|
||||
self._chat_id: Optional[str] = settings.TELEGRAM_CHAT_ID
|
||||
self._configured: bool = settings.has_telegram_config
|
||||
|
||||
def configure(self, token: str, chat_id: str) -> None:
|
||||
"""Configure Telegram credentials"""
|
||||
self._token = token
|
||||
self._chat_id = chat_id
|
||||
self.logger.info("Telegram service configured")
|
||||
# Rate limiting: mínimo tiempo entre mensajes (segundos)
|
||||
self._min_interval: float = 1.0
|
||||
self._last_send_time: float = 0.0
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if Telegram is configured"""
|
||||
return bool(self._token and self._chat_id)
|
||||
if self._configured:
|
||||
logger.info(
|
||||
"TelegramService inicializado",
|
||||
extra={"chat_id": self._mask_chat_id()},
|
||||
)
|
||||
else:
|
||||
logger.debug("TelegramService deshabilitado (sin configuración)")
|
||||
|
||||
def _send_request(self, endpoint: str, data: dict, retries: int = 3, delay: int = 2) -> bool:
|
||||
"""Make API request to Telegram"""
|
||||
if not REQUESTS_AVAILABLE:
|
||||
self.logger.warning("requests library not available")
|
||||
def _mask_chat_id(self) -> str:
|
||||
"""Oculta el chat_id para logging seguro."""
|
||||
if self._chat_id and len(self._chat_id) > 4:
|
||||
return f"***{self._chat_id[-4:]}"
|
||||
return "****"
|
||||
|
||||
def _wait_for_rate_limit(self) -> None:
|
||||
"""Espera si es necesario para cumplir el rate limiting."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_send_time
|
||||
if elapsed < self._min_interval:
|
||||
sleep_time = self._min_interval - elapsed
|
||||
logger.debug(f"Rate limiting: esperando {sleep_time:.2f}s")
|
||||
time.sleep(sleep_time)
|
||||
self._last_send_time = time.monotonic()
|
||||
|
||||
def _send_request(self, method: str, data: dict) -> bool:
|
||||
"""Envía una request a la API de Telegram."""
|
||||
if not self._configured:
|
||||
return False
|
||||
|
||||
url = f"https://api.telegram.org/bot{self._token}/{endpoint}"
|
||||
url = f"https://api.telegram.org/bot{self._token}/{method}"
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
self._wait_for_rate_limit()
|
||||
|
||||
response = requests.post(url, json=data, timeout=10)
|
||||
|
||||
# Intentar parsear JSON para obtener detalles del error
|
||||
try:
|
||||
resp = requests.post(url, data=data, timeout=10)
|
||||
if resp.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
self.logger.error(f"Telegram API error: {resp.status_code}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Telegram request failed (attempt {attempt+1}/{retries}): {e}")
|
||||
time.sleep(delay)
|
||||
return False
|
||||
result = response.json()
|
||||
except ValueError:
|
||||
result = {"raw": response.text}
|
||||
|
||||
def send_message(self, message: str) -> bool:
|
||||
"""Send a text message to Telegram"""
|
||||
if not self.is_configured:
|
||||
self.logger.warning("Telegram not configured, skipping notification")
|
||||
if response.status_code == 200 and result.get("ok"):
|
||||
logger.debug(
|
||||
"Mensaje enviado exitosamente",
|
||||
extra={"message_id": result.get("result", {}).get("message_id")},
|
||||
)
|
||||
return True
|
||||
|
||||
# Error detallado
|
||||
error_code = result.get("error_code", response.status_code)
|
||||
description = result.get("description", response.text)
|
||||
|
||||
logger.error(
|
||||
f"Error de Telegram API: HTTP {response.status_code}",
|
||||
extra={
|
||||
"method": method,
|
||||
"error_code": error_code,
|
||||
"description": description,
|
||||
"response_data": result,
|
||||
"request_data": {
|
||||
k: v if k != "text" else f"<{len(str(v))} chars>"
|
||||
for k, v in data.items()
|
||||
},
|
||||
},
|
||||
)
|
||||
return False
|
||||
data = {"chat_id": self._chat_id, "text": message}
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(
|
||||
f"Error de conexión con Telegram: {e}",
|
||||
extra={"method": method, "data_keys": list(data.keys())},
|
||||
)
|
||||
return False
|
||||
|
||||
def send_message(self, text: str, parse_mode: str = "HTML") -> bool:
|
||||
"""
|
||||
Envía un mensaje de texto al chat configurado.
|
||||
|
||||
Args:
|
||||
text: Contenido del mensaje.
|
||||
parse_mode: Modo de parseo (HTML, Markdown o MarkdownV2).
|
||||
|
||||
Returns:
|
||||
True si se envió correctamente, False en caso contrario.
|
||||
"""
|
||||
if not self._configured:
|
||||
logger.debug(f"Mensaje ignorado (sin configuración): {text[:50]}...")
|
||||
return False
|
||||
|
||||
# Validar que el texto no esté vacío
|
||||
if not text or not text.strip():
|
||||
logger.warning("Intento de enviar mensaje vacío, ignorando")
|
||||
return False
|
||||
|
||||
# Eliminar espacios en blanco al inicio y final
|
||||
text = text.strip()
|
||||
|
||||
# Telegram limita a 4096 caracteres
|
||||
MAX_LENGTH = 4096
|
||||
text = _truncate_safely(text, MAX_LENGTH)
|
||||
|
||||
data = {
|
||||
"chat_id": self._chat_id,
|
||||
"text": text,
|
||||
}
|
||||
|
||||
# Solo incluir parse_mode si hay texto y no está vacío
|
||||
if parse_mode:
|
||||
data["parse_mode"] = parse_mode
|
||||
|
||||
logger.info("Enviando mensaje a Telegram", extra={"length": len(text)})
|
||||
return self._send_request("sendMessage", data)
|
||||
|
||||
def send_start_notification(self) -> bool:
|
||||
"""Send service start notification"""
|
||||
message = "CBCFacil Service Started - AI document processing active"
|
||||
return self.send_message(message)
|
||||
def send_start_notification(self, filename: str) -> bool:
|
||||
"""
|
||||
Envía notificación de inicio de procesamiento.
|
||||
|
||||
def send_error_notification(self, error_key: str, error_message: str) -> bool:
|
||||
"""Send error notification with throttling"""
|
||||
now = datetime.utcnow()
|
||||
prev = self._last_error_cache.get(error_key)
|
||||
if prev is None:
|
||||
self._last_error_cache[error_key] = (error_message, now)
|
||||
Args:
|
||||
filename: Nombre del archivo que se está procesando.
|
||||
|
||||
Returns:
|
||||
True si se envió correctamente.
|
||||
"""
|
||||
if not filename:
|
||||
filename = "(desconocido)"
|
||||
|
||||
# Usar HTML para evitar problemas de escaping
|
||||
safe_filename = filename.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
text = f"▶️ <b>Inicio de procesamiento</b>\n\n📄 Archivo: <code>{safe_filename}</code>"
|
||||
return self.send_message(text, parse_mode="HTML")
|
||||
|
||||
def send_error_notification(self, filename: str, error: str) -> bool:
|
||||
"""
|
||||
Envía notificación de error en procesamiento.
|
||||
|
||||
Args:
|
||||
filename: Nombre del archivo que falló.
|
||||
error: Descripción del error.
|
||||
|
||||
Returns:
|
||||
True si se envió correctamente.
|
||||
"""
|
||||
if not filename:
|
||||
filename = "(desconocido)"
|
||||
if not error:
|
||||
error = "(error desconocido)"
|
||||
|
||||
# Usar HTML para evitar problemas de escaping
|
||||
safe_filename = filename.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
safe_error = error.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
text = f"❌ <b>Error de procesamiento</b>\n\n📄 Archivo: <code>{safe_filename}</code>\n⚠️ Error: {safe_error}"
|
||||
return self.send_message(text, parse_mode="HTML")
|
||||
|
||||
def send_completion_notification(
|
||||
self,
|
||||
filename: str,
|
||||
duration: Optional[float] = None,
|
||||
output_path: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Envía notificación de completado exitoso.
|
||||
|
||||
Args:
|
||||
filename: Nombre del archivo procesado.
|
||||
duration: Duración del procesamiento en segundos (opcional).
|
||||
output_path: Ruta del archivo de salida (opcional).
|
||||
|
||||
Returns:
|
||||
True si se envió correctamente.
|
||||
"""
|
||||
if not filename:
|
||||
filename = "(desconocido)"
|
||||
|
||||
# Usar HTML para evitar problemas de escaping
|
||||
safe_filename = filename.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
duration_text = ""
|
||||
if duration is not None:
|
||||
minutes = int(duration // 60)
|
||||
seconds = int(duration % 60)
|
||||
duration_text = f"\n⏱️ Duración: {minutes}m {seconds}s"
|
||||
|
||||
output_text = ""
|
||||
if output_path:
|
||||
safe_output = output_path.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
output_text = f"\n📁 Salida: <code>{safe_output}</code>"
|
||||
|
||||
text = f"✅ <b>Procesamiento completado</b>\n\n📄 Archivo: <code>{safe_filename}</code>{duration_text}{output_text}"
|
||||
return self.send_message(text, parse_mode="HTML")
|
||||
|
||||
def send_download_complete(self, filename: str) -> bool:
|
||||
"""
|
||||
Envía notificación de descarga completada.
|
||||
|
||||
Args:
|
||||
filename: Nombre del archivo descargado.
|
||||
|
||||
Returns:
|
||||
True si se envió correctamente.
|
||||
"""
|
||||
if not filename:
|
||||
filename = "(desconocido)"
|
||||
|
||||
safe_filename = filename.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = f"📥 <b>Archivo descargado</b>\n\n📄 <code>{safe_filename}</code>"
|
||||
return self.send_message(text, parse_mode="HTML")
|
||||
|
||||
def send_transcription_start(self, filename: str) -> bool:
|
||||
"""
|
||||
Envía notificación de inicio de transcripción.
|
||||
|
||||
Args:
|
||||
filename: Nombre del archivo a transcribir.
|
||||
|
||||
Returns:
|
||||
True si se envió correctamente.
|
||||
"""
|
||||
if not filename:
|
||||
filename = "(desconocido)"
|
||||
|
||||
safe_filename = filename.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = f"🎙️ <b>Iniciando transcripción...</b>\n\n📄 <code>{safe_filename}</code>"
|
||||
return self.send_message(text, parse_mode="HTML")
|
||||
|
||||
def send_transcription_progress(
|
||||
self,
|
||||
filename: str,
|
||||
progress_percent: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Envía notificación de progreso de transcripción.
|
||||
|
||||
Args:
|
||||
filename: Nombre del archivo.
|
||||
progress_percent: Porcentaje de progreso (0-100).
|
||||
|
||||
Returns:
|
||||
True si se envió correctamente.
|
||||
"""
|
||||
if not filename:
|
||||
filename = "(desconocido)"
|
||||
|
||||
safe_filename = filename.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = f"⏳ <b>Transcribiendo...</b>\n\n📄 <code>{safe_filename}</code>\n📊 Progreso: {progress_percent}%"
|
||||
return self.send_message(text, parse_mode="HTML")
|
||||
|
||||
def send_transcription_complete(
|
||||
self,
|
||||
filename: str,
|
||||
text_length: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Envía notificación de transcripción completada.
|
||||
|
||||
Args:
|
||||
filename: Nombre del archivo.
|
||||
text_length: Longitud del texto transcrito.
|
||||
|
||||
Returns:
|
||||
True si se envió correctamente.
|
||||
"""
|
||||
if not filename:
|
||||
filename = "(desconocido)"
|
||||
|
||||
safe_filename = filename.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
# Formatear longitud del texto
|
||||
if text_length >= 1000:
|
||||
length_text = f"{text_length // 1000}k caracteres"
|
||||
else:
|
||||
prev_msg, prev_time = prev
|
||||
if error_message != prev_msg or (now - prev_time).total_seconds() > settings.ERROR_THROTTLE_SECONDS:
|
||||
self._last_error_cache[error_key] = (error_message, now)
|
||||
else:
|
||||
return False
|
||||
return self.send_message(f"Error: {error_message}")
|
||||
length_text = f"{text_length} caracteres"
|
||||
|
||||
text = f"✅ <b>Transcripción completada</b>\n\n📄 <code>{safe_filename}</code>\n📝 {length_text}"
|
||||
return self.send_message(text, parse_mode="HTML")
|
||||
|
||||
def send_summary_start(self, filename: str) -> bool:
|
||||
"""
|
||||
Envía notificación de inicio de resumen con IA.
|
||||
|
||||
Args:
|
||||
filename: Nombre del archivo.
|
||||
|
||||
Returns:
|
||||
True si se envió correctamente.
|
||||
"""
|
||||
if not filename:
|
||||
filename = "(desconocido)"
|
||||
|
||||
safe_filename = filename.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = f"🤖 <b>Generando resumen con IA...</b>\n\n📄 <code>{safe_filename}</code>"
|
||||
return self.send_message(text, parse_mode="HTML")
|
||||
|
||||
def send_summary_complete(self, filename: str, has_markdown: bool = True) -> bool:
|
||||
"""
|
||||
Envía notificación de resumen completado.
|
||||
|
||||
Args:
|
||||
filename: Nombre del archivo.
|
||||
has_markdown: Si se creó el archivo markdown.
|
||||
|
||||
Returns:
|
||||
True si se envió correctamente.
|
||||
"""
|
||||
if not filename:
|
||||
filename = "(desconocido)"
|
||||
|
||||
safe_filename = filename.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
status = "✅" if has_markdown else "⚠️"
|
||||
text = f"{status} <b>Resumen completado</b>\n\n📄 <code>{safe_filename}</code>"
|
||||
return self.send_message(text, parse_mode="HTML")
|
||||
|
||||
def send_pdf_start(self, filename: str) -> bool:
|
||||
"""
|
||||
Envía notificación de inicio de generación de PDF.
|
||||
|
||||
Args:
|
||||
filename: Nombre del archivo.
|
||||
|
||||
Returns:
|
||||
True si se envió correctamente.
|
||||
"""
|
||||
if not filename:
|
||||
filename = "(desconocido)"
|
||||
|
||||
safe_filename = filename.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = f"📄 <b>Creando PDF...</b>\n\n📄 <code>{safe_filename}</code>"
|
||||
return self.send_message(text, parse_mode="HTML")
|
||||
|
||||
def send_pdf_complete(self, filename: str, pdf_path: str) -> bool:
|
||||
"""
|
||||
Envía notificación de PDF completado.
|
||||
|
||||
Args:
|
||||
filename: Nombre del archivo.
|
||||
pdf_path: Ruta del PDF generado.
|
||||
|
||||
Returns:
|
||||
True si se envió correctamente.
|
||||
"""
|
||||
if not filename:
|
||||
filename = "(desconocido)"
|
||||
|
||||
safe_filename = filename.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
safe_path = pdf_path.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = f"📄 <b>PDF creado</b>\n\n📄 <code>{safe_filename}</code>\n📁 <code>{safe_path}</code>"
|
||||
return self.send_message(text, parse_mode="HTML")
|
||||
|
||||
def send_all_complete(
|
||||
self,
|
||||
filename: str,
|
||||
txt_path: Optional[str] = None,
|
||||
md_path: Optional[str] = None,
|
||||
pdf_path: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Envía notificación final con todos los archivos generados.
|
||||
|
||||
Args:
|
||||
filename: Nombre del archivo original.
|
||||
txt_path: Ruta del archivo de texto (opcional).
|
||||
md_path: Ruta del markdown (opcional).
|
||||
pdf_path: Ruta del PDF (opcional).
|
||||
|
||||
Returns:
|
||||
True si se envió correctamente.
|
||||
"""
|
||||
if not filename:
|
||||
filename = "(desconocido)"
|
||||
|
||||
safe_filename = filename.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
files_text = ""
|
||||
if txt_path:
|
||||
safe_txt = txt_path.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
files_text += f"\n📝 <code>{safe_txt}</code>"
|
||||
if md_path:
|
||||
safe_md = md_path.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
files_text += f"\n📋 <code>{safe_md}</code>"
|
||||
if pdf_path:
|
||||
safe_pdf = pdf_path.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
files_text += f"\n📄 <code>{safe_pdf}</code>"
|
||||
|
||||
text = f"✅ <b>¡Proceso completado!</b>\n\n📄 <code>{safe_filename}</code>\n📁 Archivos:{files_text}"
|
||||
return self.send_message(text, parse_mode="HTML")
|
||||
|
||||
|
||||
# Global instance
|
||||
# Instancia global del servicio
|
||||
telegram_service = TelegramService()
|
||||
|
||||
|
||||
def send_telegram_message(message: str, retries: int = 3, delay: int = 2) -> bool:
|
||||
"""Legacy function for backward compatibility"""
|
||||
return telegram_service.send_message(message)
|
||||
|
||||
@@ -1,172 +1,307 @@
|
||||
"""
|
||||
VRAM/GPU memory management service
|
||||
Gestor de VRAM para descargar modelos de ML inactivos.
|
||||
|
||||
Proporciona limpieza automática de modelos (como Whisper) que no han sido
|
||||
usados durante un tiempo configurable para liberar memoria VRAM.
|
||||
|
||||
OPTIMIZACIONES:
|
||||
- Integración con cache global de modelos
|
||||
- Limpieza agresiva de cache CUDA
|
||||
- Monitoreo de memoria en tiempo real
|
||||
"""
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from core import BaseService
|
||||
from config import settings
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
try:
|
||||
import torch
|
||||
TORCH_AVAILABLE = True
|
||||
except ImportError:
|
||||
TORCH_AVAILABLE = False
|
||||
from config.settings import settings
|
||||
|
||||
# Import gpu_detector after torch check
|
||||
from .gpu_detector import gpu_detector, GPUType
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VRAMManager(BaseService):
|
||||
"""Service for managing GPU VRAM usage"""
|
||||
def get_gpu_memory_mb() -> Dict[str, float]:
|
||||
"""
|
||||
Obtiene uso de memoria GPU en MB.
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("VRAMManager")
|
||||
self._whisper_model = None
|
||||
self._ocr_models = None
|
||||
self._trocr_models = None
|
||||
self._models_last_used: Optional[datetime] = None
|
||||
self._cleanup_threshold = 0.7
|
||||
self._cleanup_interval = 300
|
||||
self._last_cleanup: Optional[datetime] = None
|
||||
Returns:
|
||||
Dict con 'total', 'used', 'free' en MB.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize VRAM manager"""
|
||||
# Initialize GPU detector first
|
||||
gpu_detector.initialize()
|
||||
|
||||
if not TORCH_AVAILABLE:
|
||||
self.logger.warning("PyTorch not available - VRAM management disabled")
|
||||
return
|
||||
if torch.cuda.is_available():
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
total = props.total_memory / (1024 ** 2)
|
||||
allocated = torch.cuda.memory_allocated(0) / (1024 ** 2)
|
||||
reserved = torch.cuda.memory_reserved(0) / (1024 ** 2)
|
||||
|
||||
if gpu_detector.is_available():
|
||||
gpu_type = gpu_detector.gpu_type
|
||||
device_name = gpu_detector.get_device_name()
|
||||
|
||||
if gpu_type == GPUType.AMD:
|
||||
self.logger.info(f"VRAM Manager initialized with AMD ROCm: {device_name}")
|
||||
elif gpu_type == GPUType.NVIDIA:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = settings.CUDA_VISIBLE_DEVICES
|
||||
if settings.PYTORCH_CUDA_ALLOC_CONF:
|
||||
torch.backends.cuda.max_split_size_mb = int(settings.PYTORCH_CUDA_ALLOC_CONF.split(':')[1])
|
||||
self.logger.info(f"VRAM Manager initialized with NVIDIA CUDA: {device_name}")
|
||||
else:
|
||||
self.logger.warning("No GPU available - GPU acceleration disabled")
|
||||
return {
|
||||
"total": total,
|
||||
"used": allocated,
|
||||
"free": total - reserved,
|
||||
"reserved": reserved,
|
||||
}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Error obteniendo memoria GPU: {e}")
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Cleanup all GPU models"""
|
||||
if not TORCH_AVAILABLE or not torch.cuda.is_available():
|
||||
return
|
||||
return {"total": 0, "used": 0, "free": 0, "reserved": 0}
|
||||
|
||||
models_freed = []
|
||||
|
||||
if self._whisper_model is not None:
|
||||
try:
|
||||
del self._whisper_model
|
||||
self._whisper_model = None
|
||||
models_freed.append("Whisper")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error freeing Whisper VRAM: {e}")
|
||||
def clear_cuda_cache(aggressive: bool = False) -> None:
|
||||
"""
|
||||
Limpia el cache de CUDA.
|
||||
|
||||
if self._ocr_models is not None:
|
||||
try:
|
||||
self._ocr_models = None
|
||||
models_freed.append("OCR")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error freeing OCR VRAM: {e}")
|
||||
Args:
|
||||
aggressive: Si True, ejecuta gc.collect() múltiples veces.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
if self._trocr_models is not None:
|
||||
try:
|
||||
if isinstance(self._trocr_models, dict):
|
||||
model = self._trocr_models.get('model')
|
||||
if model is not None:
|
||||
model.to('cpu')
|
||||
models_freed.append("TrOCR")
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error freeing TrOCR VRAM: {e}")
|
||||
|
||||
self._whisper_model = None
|
||||
self._ocr_models = None
|
||||
self._trocr_models = None
|
||||
self._models_last_used = None
|
||||
|
||||
if models_freed:
|
||||
self.logger.info(f"Freed VRAM for models: {', '.join(models_freed)}")
|
||||
|
||||
self._force_aggressive_cleanup()
|
||||
|
||||
def update_usage(self) -> None:
|
||||
"""Update usage timestamp"""
|
||||
self._models_last_used = datetime.utcnow()
|
||||
self.logger.debug(f"VRAM usage timestamp updated")
|
||||
|
||||
def should_cleanup(self) -> bool:
|
||||
"""Check if cleanup should be performed"""
|
||||
if not TORCH_AVAILABLE or not torch.cuda.is_available():
|
||||
return False
|
||||
if self._last_cleanup is None:
|
||||
return True
|
||||
if (datetime.utcnow() - self._last_cleanup).total_seconds() < self._cleanup_interval:
|
||||
return False
|
||||
allocated = torch.cuda.memory_allocated(0)
|
||||
total = torch.cuda.get_device_properties(0).total_memory
|
||||
return allocated / total > self._cleanup_threshold
|
||||
|
||||
def lazy_cleanup(self) -> None:
|
||||
"""Perform cleanup if needed"""
|
||||
if self.should_cleanup():
|
||||
self.cleanup()
|
||||
self._last_cleanup = datetime.utcnow()
|
||||
|
||||
def _force_aggressive_cleanup(self) -> None:
|
||||
"""Force aggressive VRAM cleanup"""
|
||||
if not TORCH_AVAILABLE or not torch.cuda.is_available():
|
||||
return
|
||||
try:
|
||||
before_allocated = torch.cuda.memory_allocated(0) / 1024**3
|
||||
before_reserved = torch.cuda.memory_reserved(0) / 1024**3
|
||||
self.logger.debug(f"Before cleanup - Allocated: {before_allocated:.2f}GB, Reserved: {before_reserved:.2f}GB")
|
||||
gc.collect(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
after_allocated = torch.cuda.memory_allocated(0) / 1024**3
|
||||
after_reserved = torch.cuda.memory_reserved(0) / 1024**3
|
||||
self.logger.debug(f"After cleanup - Allocated: {after_allocated:.2f}GB, Reserved: {after_reserved:.2f}GB")
|
||||
if after_reserved < before_reserved:
|
||||
self.logger.info(f"VRAM freed: {(before_reserved - after_reserved):.2f}GB")
|
||||
|
||||
if aggressive:
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.debug(
|
||||
"CUDA cache limpiada",
|
||||
extra={"aggressive": aggressive, "memory_mb": get_gpu_memory_mb()},
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class VRAMManager:
|
||||
"""
|
||||
Gestor singleton para administrar la descarga automática de modelos.
|
||||
|
||||
Mantiene registro del último uso de cada modelo y proporciona métodos
|
||||
para verificar y limpiar modelos inactivos.
|
||||
|
||||
NOTA: Con el nuevo cache global de modelos, este gestor ya no fuerza
|
||||
la descarga del modelo en sí, solo coordina los tiempos de cleanup.
|
||||
"""
|
||||
|
||||
_instance: Optional["VRAMManager"] = None
|
||||
|
||||
def __new__(cls) -> "VRAMManager":
|
||||
"""Implementación del patrón Singleton."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Inicializa el gestor si no ha sido inicializado."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._last_usage: Dict[str, float] = {}
|
||||
self._unload_callbacks: Dict[str, Callable[[], None]] = {}
|
||||
self._auto_unload_seconds = settings.WHISPER_AUTO_UNLOAD_SECONDS
|
||||
self._initialized = True
|
||||
|
||||
logger.info(
|
||||
"VRAMManager inicializado",
|
||||
extra={"auto_unload_seconds": self._auto_unload_seconds},
|
||||
)
|
||||
|
||||
def register_model(
|
||||
self, model_id: str, unload_callback: Callable[[], None]
|
||||
) -> None:
|
||||
"""
|
||||
Registra un modelo con su callback de descarga.
|
||||
|
||||
Args:
|
||||
model_id: Identificador único del modelo.
|
||||
unload_callback: Función a llamar para descargar el modelo.
|
||||
"""
|
||||
self._unload_callbacks[model_id] = unload_callback
|
||||
self._last_usage[model_id] = time.time()
|
||||
|
||||
logger.debug(
|
||||
"Modelo registrado en VRAMManager",
|
||||
extra={"model_id": model_id},
|
||||
)
|
||||
|
||||
def update_usage(self, model_id: str) -> None:
|
||||
"""
|
||||
Actualiza el timestamp del último uso del modelo.
|
||||
|
||||
Args:
|
||||
model_id: Identificador del modelo.
|
||||
"""
|
||||
self._last_usage[model_id] = time.time()
|
||||
|
||||
logger.debug(
|
||||
"Uso actualizado",
|
||||
extra={"model_id": model_id, "memory_mb": get_gpu_memory_mb()},
|
||||
)
|
||||
|
||||
def mark_used(self, model_id: str = "default") -> None:
|
||||
"""
|
||||
Marca el modelo como usado (alias simple para update_usage).
|
||||
|
||||
Args:
|
||||
model_id: Identificador del modelo. Default: "default".
|
||||
"""
|
||||
self.update_usage(model_id)
|
||||
|
||||
def check_and_cleanup(
|
||||
self, model_id: str, timeout_seconds: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Verifica si el modelo debe ser descargado y lo limpia si es necesario.
|
||||
|
||||
NOTA: Con el cache global, la descarga solo elimina la referencia
|
||||
local. El modelo puede permanecer en cache para otras instancias.
|
||||
|
||||
Args:
|
||||
model_id: Identificador del modelo a verificar.
|
||||
timeout_seconds: Tiempo máximo de inactividad en segundos.
|
||||
|
||||
Returns:
|
||||
True si el modelo fue descargado, False si no necesitaba descarga.
|
||||
"""
|
||||
if model_id not in self._unload_callbacks:
|
||||
logger.warning(
|
||||
"Modelo no registrado en VRAMManager",
|
||||
extra={"model_id": model_id},
|
||||
)
|
||||
return False
|
||||
|
||||
threshold = timeout_seconds or self._auto_unload_seconds
|
||||
last_used = self._last_usage.get(model_id, 0)
|
||||
elapsed = time.time() - last_used
|
||||
|
||||
logger.debug(
|
||||
"Verificando modelo",
|
||||
extra={
|
||||
"model_id": model_id,
|
||||
"elapsed_seconds": elapsed,
|
||||
"threshold_seconds": threshold,
|
||||
},
|
||||
)
|
||||
|
||||
if elapsed >= threshold:
|
||||
return self._unload_model(model_id)
|
||||
|
||||
return False
|
||||
|
||||
def _unload_model(self, model_id: str) -> bool:
|
||||
"""
|
||||
Descarga el modelo invocando su callback.
|
||||
|
||||
Args:
|
||||
model_id: Identificador del modelo a descargar.
|
||||
|
||||
Returns:
|
||||
True si la descarga fue exitosa.
|
||||
"""
|
||||
callback = self._unload_callbacks.get(model_id)
|
||||
if callback is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
callback()
|
||||
|
||||
# Limpiar cache de CUDA después de descargar
|
||||
clear_cuda_cache(aggressive=True)
|
||||
|
||||
# Limpiar registro después de descarga exitosa
|
||||
self._unload_callbacks.pop(model_id, None)
|
||||
self._last_usage.pop(model_id, None)
|
||||
|
||||
logger.info(
|
||||
"Modelo descargado por VRAMManager",
|
||||
extra={
|
||||
"model_id": model_id,
|
||||
"reason": "inactive",
|
||||
"memory_mb_after": get_gpu_memory_mb(),
|
||||
},
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in aggressive VRAM cleanup: {e}")
|
||||
logger.error(
|
||||
"Error al descargar modelo",
|
||||
extra={"model_id": model_id, "error": str(e)},
|
||||
)
|
||||
return False
|
||||
|
||||
def get_usage(self) -> Dict[str, Any]:
|
||||
"""Get VRAM usage information"""
|
||||
if not TORCH_AVAILABLE:
|
||||
return {'error': 'PyTorch not available'}
|
||||
if not torch.cuda.is_available():
|
||||
return {'error': 'CUDA not available'}
|
||||
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
||||
cached = torch.cuda.memory_reserved(0) / 1024**3
|
||||
free = total - allocated
|
||||
return {
|
||||
'total_gb': round(total, 2),
|
||||
'allocated_gb': round(allocated, 2),
|
||||
'cached_gb': round(cached, 2),
|
||||
'free_gb': round(free, 2),
|
||||
'whisper_loaded': self._whisper_model is not None,
|
||||
'ocr_models_loaded': self._ocr_models is not None,
|
||||
'trocr_models_loaded': self._trocr_models is not None,
|
||||
'last_used': self._models_last_used.isoformat() if self._models_last_used else None,
|
||||
'timeout_seconds': settings.MODEL_TIMEOUT_SECONDS
|
||||
}
|
||||
def force_unload(self, model_id: str) -> bool:
|
||||
"""
|
||||
Fuerza la descarga inmediata de un modelo.
|
||||
|
||||
def force_free(self) -> str:
|
||||
"""Force immediate VRAM free"""
|
||||
self.cleanup()
|
||||
return "VRAM freed successfully"
|
||||
Args:
|
||||
model_id: Identificador del modelo a descargar.
|
||||
|
||||
Returns:
|
||||
True si la descarga fue exitosa.
|
||||
"""
|
||||
return self._unload_model(model_id)
|
||||
|
||||
def get_memory_info(self) -> Dict[str, float]:
|
||||
"""
|
||||
Obtiene información actual de memoria GPU.
|
||||
|
||||
Returns:
|
||||
Dict con 'total', 'used', 'free', 'reserved' en MB.
|
||||
"""
|
||||
return get_gpu_memory_mb()
|
||||
|
||||
def get_last_usage(self, model_id: str) -> Optional[float]:
|
||||
"""
|
||||
Obtiene el timestamp del último uso del modelo.
|
||||
|
||||
Args:
|
||||
model_id: Identificador del modelo.
|
||||
|
||||
Returns:
|
||||
Timestamp del último uso o None si no existe.
|
||||
"""
|
||||
return self._last_usage.get(model_id)
|
||||
|
||||
def get_seconds_since_last_use(self, model_id: str) -> Optional[float]:
|
||||
"""
|
||||
Obtiene los segundos transcurridos desde el último uso.
|
||||
|
||||
Args:
|
||||
model_id: Identificador del modelo.
|
||||
|
||||
Returns:
|
||||
Segundos transcurridos o None si no existe.
|
||||
"""
|
||||
last_used = self._last_usage.get(model_id)
|
||||
if last_used is None:
|
||||
return None
|
||||
return time.time() - last_used
|
||||
|
||||
def unregister_model(self, model_id: str) -> None:
|
||||
"""
|
||||
Elimina el registro de un modelo.
|
||||
|
||||
Args:
|
||||
model_id: Identificador del modelo a eliminar.
|
||||
"""
|
||||
self._unload_callbacks.pop(model_id, None)
|
||||
self._last_usage.pop(model_id, None)
|
||||
|
||||
logger.debug(
|
||||
"Modelo eliminado de VRAMManager",
|
||||
extra={"model_id": model_id},
|
||||
)
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Limpia todos los registros del gestor."""
|
||||
self._unload_callbacks.clear()
|
||||
self._last_usage.clear()
|
||||
logger.info("VRAMManager limpiado")
|
||||
|
||||
|
||||
# Global instance
|
||||
# Instancia global singleton
|
||||
vram_manager = VRAMManager()
|
||||
|
||||
@@ -1,290 +1,102 @@
|
||||
"""
|
||||
WebDAV service for Nextcloud integration
|
||||
Cliente WebDAV para Nextcloud.
|
||||
Provee métodos para interactuar con Nextcloud via WebDAV.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import unicodedata
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
from contextlib import contextmanager
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
from requests.adapters import HTTPAdapter
|
||||
from typing import Optional
|
||||
from webdav3.client import Client
|
||||
|
||||
from config import settings
|
||||
from core import WebDAVError
|
||||
|
||||
|
||||
class WebDAVService:
|
||||
"""Service for WebDAV operations with Nextcloud"""
|
||||
"""Cliente WebDAV para Nextcloud."""
|
||||
|
||||
def __init__(self):
|
||||
self.session: Optional[requests.Session] = None
|
||||
def __init__(self) -> None:
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._retry_delay = 1
|
||||
self._max_retries = settings.WEBDAV_MAX_RETRIES
|
||||
self._client: Optional[Client] = None
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize WebDAV session"""
|
||||
if not settings.has_webdav_config:
|
||||
raise WebDAVError("WebDAV credentials not configured")
|
||||
def _get_client(self) -> Client:
|
||||
"""Obtiene o crea el cliente WebDAV."""
|
||||
if self._client is None:
|
||||
if not settings.has_webdav_config:
|
||||
raise RuntimeError("WebDAV configuration missing")
|
||||
|
||||
self.session = requests.Session()
|
||||
self.session.auth = HTTPBasicAuth(settings.NEXTCLOUD_USER, settings.NEXTCLOUD_PASSWORD)
|
||||
options = {
|
||||
"webdav_hostname": settings.NEXTCLOUD_URL,
|
||||
"webdav_login": settings.NEXTCLOUD_USER,
|
||||
"webdav_password": settings.NEXTCLOUD_PASSWORD,
|
||||
}
|
||||
self._client = Client(options)
|
||||
self._client.verify = True # Verificar SSL
|
||||
|
||||
# Configure HTTP adapter with retry strategy
|
||||
adapter = HTTPAdapter(
|
||||
max_retries=0, # We'll handle retries manually
|
||||
pool_connections=10,
|
||||
pool_maxsize=20
|
||||
)
|
||||
self.session.mount('https://', adapter)
|
||||
self.session.mount('http://', adapter)
|
||||
return self._client
|
||||
|
||||
# Test connection
|
||||
def test_connection(self) -> bool:
|
||||
"""Prueba la conexión con Nextcloud."""
|
||||
try:
|
||||
self._request('GET', '', timeout=5)
|
||||
self.logger.info("WebDAV connection established")
|
||||
client = self._get_client()
|
||||
return client.check()
|
||||
except Exception as e:
|
||||
raise WebDAVError(f"Failed to connect to WebDAV: {e}")
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Cleanup WebDAV session"""
|
||||
if self.session:
|
||||
self.session.close()
|
||||
self.session = None
|
||||
|
||||
@staticmethod
|
||||
def normalize_path(path: str) -> str:
|
||||
"""Normalize remote paths to a consistent representation"""
|
||||
if not path:
|
||||
return ""
|
||||
normalized = unicodedata.normalize("NFC", str(path)).strip()
|
||||
if not normalized:
|
||||
return ""
|
||||
normalized = normalized.replace("\\", "/")
|
||||
normalized = re.sub(r"/+", "/", normalized)
|
||||
return normalized.lstrip("/")
|
||||
|
||||
def _build_url(self, remote_path: str) -> str:
|
||||
"""Build WebDAV URL"""
|
||||
path = self.normalize_path(remote_path)
|
||||
base_url = settings.WEBDAV_ENDPOINT.rstrip('/')
|
||||
return f"{base_url}/{path}"
|
||||
|
||||
def _request(self, method: str, remote_path: str, **kwargs) -> requests.Response:
|
||||
"""Make HTTP request to WebDAV with retries"""
|
||||
if not self.session:
|
||||
raise WebDAVError("WebDAV session not initialized")
|
||||
|
||||
url = self._build_url(remote_path)
|
||||
timeout = kwargs.pop('timeout', settings.HTTP_TIMEOUT)
|
||||
|
||||
for attempt in range(self._max_retries):
|
||||
try:
|
||||
response = self.session.request(method, url, timeout=timeout, **kwargs)
|
||||
if response.status_code < 400:
|
||||
return response
|
||||
elif response.status_code == 404:
|
||||
raise WebDAVError(f"Resource not found: {remote_path}")
|
||||
else:
|
||||
raise WebDAVError(f"HTTP {response.status_code}: {response.text}")
|
||||
except (requests.RequestException, requests.Timeout) as e:
|
||||
if attempt == self._max_retries - 1:
|
||||
raise WebDAVError(f"Request failed after {self._max_retries} retries: {e}")
|
||||
delay = self._retry_delay * (2 ** attempt)
|
||||
self.logger.warning(f"Request failed (attempt {attempt + 1}/{self._max_retries}), retrying in {delay}s...")
|
||||
time.sleep(delay)
|
||||
|
||||
raise WebDAVError("Max retries exceeded")
|
||||
|
||||
def list(self, remote_path: str = "") -> List[str]:
|
||||
"""List files in remote directory"""
|
||||
self.logger.debug(f"Listing remote directory: {remote_path}")
|
||||
response = self._request('PROPFIND', remote_path, headers={'Depth': '1'})
|
||||
return self._parse_propfind_response(response.text)
|
||||
|
||||
def _parse_propfind_response(self, xml_response: str) -> List[str]:
|
||||
"""Parse PROPFIND XML response and return only files (not directories)"""
|
||||
# Simple parser for PROPFIND response
|
||||
files = []
|
||||
try:
|
||||
import xml.etree.ElementTree as ET
|
||||
from urllib.parse import urlparse, unquote
|
||||
root = ET.fromstring(xml_response)
|
||||
|
||||
# Get the WebDAV path from settings
|
||||
parsed_url = urlparse(settings.NEXTCLOUD_URL)
|
||||
webdav_path = parsed_url.path.rstrip('/') # e.g. /remote.php/webdav
|
||||
|
||||
# Find all response elements
|
||||
for response in root.findall('.//{DAV:}response'):
|
||||
href = response.find('.//{DAV:}href')
|
||||
if href is None or href.text is None:
|
||||
continue
|
||||
|
||||
href_text = unquote(href.text) # Decode URL encoding
|
||||
|
||||
# Check if this is a directory (has collection resourcetype)
|
||||
propstat = response.find('.//{DAV:}propstat')
|
||||
is_directory = False
|
||||
if propstat is not None:
|
||||
prop = propstat.find('.//{DAV:}prop')
|
||||
if prop is not None:
|
||||
resourcetype = prop.find('.//{DAV:}resourcetype')
|
||||
if resourcetype is not None and resourcetype.find('.//{DAV:}collection') is not None:
|
||||
is_directory = True
|
||||
|
||||
# Skip directories
|
||||
if is_directory:
|
||||
continue
|
||||
|
||||
# Also skip paths ending with / (another way to detect directories)
|
||||
if href_text.endswith('/'):
|
||||
continue
|
||||
|
||||
# Remove base URL from href
|
||||
base_url = settings.NEXTCLOUD_URL.rstrip('/')
|
||||
if href_text.startswith(base_url):
|
||||
href_text = href_text[len(base_url):]
|
||||
|
||||
# Also strip the webdav path if it's there
|
||||
if href_text.startswith(webdav_path):
|
||||
href_text = href_text[len(webdav_path):]
|
||||
|
||||
# Clean up the path
|
||||
href_text = href_text.lstrip('/')
|
||||
if href_text: # Skip empty paths (root directory)
|
||||
files.append(href_text)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error parsing PROPFIND response: {e}")
|
||||
|
||||
return files
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> None:
|
||||
"""Download file from WebDAV"""
|
||||
self.logger.info(f"Downloading {remote_path} to {local_path}")
|
||||
|
||||
# Ensure local directory exists
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
response = self._request('GET', remote_path, stream=True)
|
||||
|
||||
# Use larger buffer size for better performance
|
||||
with open(local_path, 'wb', buffering=65536) as f:
|
||||
for chunk in response.iter_content(chunk_size=settings.DOWNLOAD_CHUNK_SIZE):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
|
||||
self.logger.debug(f"Download completed: {local_path}")
|
||||
|
||||
def upload(self, local_path: Path, remote_path: str) -> None:
|
||||
"""Upload file to WebDAV"""
|
||||
self.logger.info(f"Uploading {local_path} to {remote_path}")
|
||||
|
||||
# Ensure remote directory exists
|
||||
remote_dir = self.normalize_path(remote_path)
|
||||
if '/' in remote_dir:
|
||||
dir_path = '/'.join(remote_dir.split('/')[:-1])
|
||||
self.makedirs(dir_path)
|
||||
|
||||
with open(local_path, 'rb') as f:
|
||||
self._request('PUT', remote_path, data=f)
|
||||
|
||||
self.logger.debug(f"Upload completed: {remote_path}")
|
||||
|
||||
def mkdir(self, remote_path: str) -> None:
|
||||
"""Create directory on WebDAV"""
|
||||
self.makedirs(remote_path)
|
||||
|
||||
def makedirs(self, remote_path: str) -> None:
|
||||
"""Create directory and parent directories on WebDAV"""
|
||||
path = self.normalize_path(remote_path)
|
||||
if not path:
|
||||
return
|
||||
|
||||
parts = path.split('/')
|
||||
current = ""
|
||||
|
||||
for part in parts:
|
||||
current = f"{current}/{part}" if current else part
|
||||
try:
|
||||
self._request('MKCOL', current)
|
||||
self.logger.debug(f"Created directory: {current}")
|
||||
except WebDAVError as e:
|
||||
# Directory might already exist (409 Conflict or 405 MethodNotAllowed is OK)
|
||||
if '409' not in str(e) and '405' not in str(e):
|
||||
raise
|
||||
|
||||
def delete(self, remote_path: str) -> None:
|
||||
"""Delete file or directory from WebDAV"""
|
||||
self.logger.info(f"Deleting remote path: {remote_path}")
|
||||
self._request('DELETE', remote_path)
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
"""Check if remote path exists"""
|
||||
try:
|
||||
self._request('HEAD', remote_path)
|
||||
return True
|
||||
except WebDAVError:
|
||||
self.logger.error(f"WebDAV connection failed: {e}")
|
||||
return False
|
||||
|
||||
def upload_batch(
|
||||
self,
|
||||
files: List[Tuple[Path, str]],
|
||||
max_workers: int = 4,
|
||||
timeout: int = 120
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
Upload multiple files concurrently.
|
||||
def list_files(self, remote_path: str = "/") -> list[str]:
|
||||
"""Lista archivos en una ruta remota."""
|
||||
try:
|
||||
client = self._get_client()
|
||||
# Asegurar que la ruta empieza con /
|
||||
if not remote_path.startswith("/"):
|
||||
remote_path = "/" + remote_path
|
||||
|
||||
Args:
|
||||
files: List of (local_path, remote_path) tuples
|
||||
max_workers: Maximum concurrent uploads
|
||||
timeout: Timeout per upload in seconds
|
||||
files = client.list(remote_path)
|
||||
return files if files else []
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to list files: {e}")
|
||||
return []
|
||||
|
||||
Returns:
|
||||
Dict mapping remote_path to success status
|
||||
"""
|
||||
if not files:
|
||||
def download_file(self, remote_path: str, local_path: Path) -> bool:
|
||||
"""Descarga un archivo desde Nextcloud."""
|
||||
try:
|
||||
client = self._get_client()
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
client.download_sync(remote_path=str(remote_path), local_path=str(local_path))
|
||||
self.logger.info(f"Downloaded: {remote_path} -> {local_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to download {remote_path}: {e}")
|
||||
return False
|
||||
|
||||
def get_file_info(self, remote_path: str) -> dict:
|
||||
"""Obtiene información de un archivo."""
|
||||
try:
|
||||
client = self._get_client()
|
||||
info = client.info(remote_path)
|
||||
return {
|
||||
"name": info.get("name", ""),
|
||||
"size": info.get("size", 0),
|
||||
"modified": info.get("modified", ""),
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get file info: {e}")
|
||||
return {}
|
||||
|
||||
results: Dict[str, bool] = {}
|
||||
def file_exists(self, remote_path: str) -> bool:
|
||||
"""Verifica si un archivo existe en remoto."""
|
||||
try:
|
||||
client = self._get_client()
|
||||
return client.check(remote_path)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all upload tasks
|
||||
future_to_path = {
|
||||
executor.submit(self.upload, local, remote): remote
|
||||
for local, remote in files
|
||||
}
|
||||
|
||||
# Collect results as they complete
|
||||
for future in as_completed(future_to_path, timeout=timeout):
|
||||
remote_path = future_to_path[future]
|
||||
try:
|
||||
future.result()
|
||||
results[remote_path] = True
|
||||
self.logger.info(f"Successfully uploaded: {remote_path}")
|
||||
except Exception as e:
|
||||
results[remote_path] = False
|
||||
self.logger.error(f"Failed to upload {remote_path}: {e}")
|
||||
|
||||
failed_count = sum(1 for success in results.values() if not success)
|
||||
if failed_count > 0:
|
||||
self.logger.warning(
|
||||
f"Batch upload completed with {failed_count} failures "
|
||||
f"({len(results) - failed_count}/{len(results)} successful)"
|
||||
)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Batch upload completed: {len(results)} files uploaded successfully"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global instance
|
||||
webdav_service = WebDAVService()
|
||||
def upload_file(self, local_path: Path, remote_path: str) -> bool:
|
||||
"""Sube un archivo a Nextcloud."""
|
||||
try:
|
||||
client = self._get_client()
|
||||
client.upload_sync(local_path=str(local_path), remote_path=str(remote_path))
|
||||
self.logger.info(f"Uploaded: {local_path} -> {remote_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to upload {local_path}: {e}")
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user